Onboard Load Balancer (part 2: add update payload and API response map helpers) (#111)

* Add mapFields and tests

* Extend Create function

* Adjust function signature

* Fix toCreatePayload

* Add toTargetPoolUpdatePayload and tests

* Wait for creation

* Use waiter response for mapFields

* Adjust after dependency update
This commit is contained in:
João Palet 2023-10-27 18:48:31 +02:00 committed by GitHub
parent fc8d2663cd
commit 19a679e0bc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 586 additions and 32 deletions

View file

@ -3,10 +3,12 @@ package loadbalancer
import (
"context"
"fmt"
"strings"
"github.com/hashicorp/terraform-plugin-framework-validators/setvalidator"
"github.com/hashicorp/terraform-plugin-framework-validators/stringvalidator"
"github.com/hashicorp/terraform-plugin-framework/attr"
"github.com/hashicorp/terraform-plugin-framework/diag"
"github.com/hashicorp/terraform-plugin-framework/resource"
"github.com/hashicorp/terraform-plugin-framework/resource/schema"
"github.com/hashicorp/terraform-plugin-framework/resource/schema/planmodifier"
@ -16,6 +18,7 @@ import (
"github.com/hashicorp/terraform-plugin-framework/types/basetypes"
"github.com/hashicorp/terraform-plugin-log/tflog"
"github.com/stackitcloud/stackit-sdk-go/core/config"
"github.com/stackitcloud/stackit-sdk-go/core/utils"
"github.com/stackitcloud/stackit-sdk-go/services/loadbalancer"
"github.com/stackitcloud/stackit-sdk-go/services/loadbalancer/wait"
"github.com/stackitcloud/terraform-provider-stackit/stackit/internal/core"
@ -44,7 +47,6 @@ type Model struct {
// Struct corresponding to each Model.Listener
type Listener struct {
DisplayName types.String `tfsdk:"display_name"`
Name types.String `tfsdk:"name"`
Port types.Int64 `tfsdk:"port"`
Protocol types.String `tfsdk:"protocol"`
TargetPool types.String `tfsdk:"target_pool"`
@ -160,7 +162,6 @@ func (r *projectResource) Schema(_ context.Context, _ resource.SchemaRequest, re
"project_id": "STACKIT project ID to which the Load Balancer is associated.",
"external_address": "External Load Balancer IP address where this Load Balancer is exposed.",
"listeners": "List of all listeners which will accept traffic. Limited to 20.",
"listeners.name": "Will be used to reference a listener and will replace display name in the future.",
"port": "Port number where we listen for traffic.",
"protocol": "Protocol is the highest network protocol we understand to load balance.",
"target_pool": "Reference target pool by target pool name.",
@ -221,10 +222,6 @@ func (r *projectResource) Schema(_ context.Context, _ resource.SchemaRequest, re
Optional: true,
Computed: true,
},
"name": schema.StringAttribute{
Description: descriptions["listeners.display_name"],
Computed: true,
},
"port": schema.Int64Attribute{
Description: descriptions["port"],
Optional: true,
@ -398,6 +395,8 @@ func (r *projectResource) Create(ctx context.Context, req resource.CreateRequest
core.LogAndAddError(ctx, &resp.Diagnostics, "Error getting status of load balancer functionality", fmt.Sprintf("Calling API: %v", err))
return
}
// If load balancer functionality is not enabled, enable it
if *statusResp.Status != wait.FunctionalityStatusReady {
_, err = r.client.EnableLoadBalancing(ctx, projectId).Execute()
if err != nil {
@ -413,21 +412,129 @@ func (r *projectResource) Create(ctx context.Context, req resource.CreateRequest
}
// Generate API request body from model
_, err = toCreatePayload(ctx, &model)
payload, err := toCreatePayload(ctx, &model)
if err != nil {
core.LogAndAddError(ctx, &resp.Diagnostics, "Error creating instance", fmt.Sprintf("Creating API payload: %v", err))
core.LogAndAddError(ctx, &resp.Diagnostics, "Error creating load balancer", fmt.Sprintf("Creating API payload: %v", err))
return
}
// Create a new load balancer
createResp, err := r.client.CreateLoadBalancer(ctx, projectId).CreateLoadBalancerPayload(*payload).Execute()
if err != nil {
core.LogAndAddError(ctx, &resp.Diagnostics, "Error creating load balancer", fmt.Sprintf("Calling API: %v", err))
return
}
waitResp, err := wait.CreateLoadBalancerWaitHandler(ctx, r.client, projectId, *createResp.Name).WaitWithContext(ctx)
if err != nil {
core.LogAndAddError(ctx, &resp.Diagnostics, "Error creating load balancer", fmt.Sprintf("Load balancer creation waiting: %v", err))
return
}
// Map response body to schema
err = mapFields(ctx, waitResp, &model)
if err != nil {
core.LogAndAddError(ctx, &resp.Diagnostics, "Error creating load balancer", fmt.Sprintf("Processing API payload: %v", err))
return
}
// Set state to fully populated data
diags = resp.State.Set(ctx, model)
resp.Diagnostics.Append(diags...)
if resp.Diagnostics.HasError() {
return
}
tflog.Info(ctx, "Load balancer created")
}
// Read refreshes the Terraform state with the latest data.
func (r *projectResource) Read(_ context.Context, _ resource.ReadRequest, _ *resource.ReadResponse) { // nolint:gocritic // function signature required by Terraform
func (r *projectResource) Read(ctx context.Context, req resource.ReadRequest, resp *resource.ReadResponse) { // nolint:gocritic // function signature required by Terraform
var model Model
diags := req.State.Get(ctx, &model)
resp.Diagnostics.Append(diags...)
if resp.Diagnostics.HasError() {
return
}
projectId := model.ProjectId.ValueString()
name := model.Name.ValueString()
ctx = tflog.SetField(ctx, "project_id", projectId)
ctx = tflog.SetField(ctx, "name", name)
lbResp, err := r.client.GetLoadBalancer(ctx, projectId, name).Execute()
if err != nil {
core.LogAndAddError(ctx, &resp.Diagnostics, "Error reading load balancer", err.Error())
return
}
// Map response body to schema
err = mapFields(ctx, lbResp, &model)
if err != nil {
core.LogAndAddError(ctx, &resp.Diagnostics, "Error reading load balancer", fmt.Sprintf("Processing API payload: %v", err))
return
}
// Set refreshed state
diags = resp.State.Set(ctx, model)
resp.Diagnostics.Append(diags...)
if resp.Diagnostics.HasError() {
return
}
tflog.Info(ctx, "Load balancer read")
}
// Update updates the resource and sets the updated Terraform state on success.
func (r *projectResource) Update(_ context.Context, _ resource.UpdateRequest, _ *resource.UpdateResponse) { // nolint:gocritic // function signature required by Terraform
func (r *projectResource) Update(ctx context.Context, req resource.UpdateRequest, resp *resource.UpdateResponse) { // nolint:gocritic // function signature required by Terraform
// Retrieve values from plan
var model Model
diags := req.Plan.Get(ctx, &model)
resp.Diagnostics.Append(diags...)
if resp.Diagnostics.HasError() {
return
}
projectId := model.ProjectId.ValueString()
name := model.Name.ValueString()
ctx = tflog.SetField(ctx, "project_id", projectId)
ctx = tflog.SetField(ctx, "name", name)
for _, targetPool := range model.TargetPools {
// Generate API request body from model
payload, err := toTargetPoolUpdatePayload(ctx, utils.Ptr(targetPool))
if err != nil {
core.LogAndAddError(ctx, &resp.Diagnostics, "Error updating load balancer", fmt.Sprintf("Creating API payload: %v", err))
return
}
// Update target pool
_, err = r.client.UpdateTargetPool(ctx, projectId, name, targetPool.Name.ValueString()).UpdateTargetPoolPayload(*payload).Execute()
if err != nil {
core.LogAndAddError(ctx, &resp.Diagnostics, "Error updating load balancer", fmt.Sprintf("Calling API: %v", err))
return
}
}
// Get updated load balancer
getResp, err := r.client.GetLoadBalancer(ctx, projectId, name).Execute()
if err != nil {
core.LogAndAddError(ctx, &resp.Diagnostics, "Error updating load balancer", fmt.Sprintf("Calling API: %v", err))
return
}
// Map response body to schema
err = mapFields(ctx, getResp, &model)
if err != nil {
core.LogAndAddError(ctx, &resp.Diagnostics, "Error creating load balancer", fmt.Sprintf("Processing API payload: %v", err))
return
}
// Set state to fully populated data
diags = resp.State.Set(ctx, model)
resp.Diagnostics.Append(diags...)
if resp.Diagnostics.HasError() {
return
}
tflog.Info(ctx, "Load balancer updated")
}
// Delete deletes the resource and removes the Terraform state on success.
@ -535,38 +642,232 @@ func toTargetPoolsPayload(ctx context.Context, model *Model) (*[]loadbalancer.Ta
var targetPools []loadbalancer.TargetPool
for _, targetPool := range model.TargetPools {
var activeHealthCheck *loadbalancer.ActiveHealthCheck
if !(targetPool.ActiveHealthCheck.IsNull() || targetPool.ActiveHealthCheck.IsUnknown()) {
var activeHealthCheckModel ActiveHealthCheck
diags := targetPool.ActiveHealthCheck.As(ctx, &activeHealthCheckModel, basetypes.ObjectAsOptions{})
if diags.HasError() {
return nil, fmt.Errorf("converting active health check: %w", core.DiagsToError(diags))
}
activeHealthCheck = &loadbalancer.ActiveHealthCheck{
HealthyThreshold: activeHealthCheckModel.HealthyThreshold.ValueInt64Pointer(),
Interval: activeHealthCheckModel.Interval.ValueStringPointer(),
IntervalJitter: activeHealthCheckModel.IntervalJitter.ValueStringPointer(),
Timeout: activeHealthCheckModel.Timeout.ValueStringPointer(),
UnhealthyThreshold: activeHealthCheckModel.UnhealthyThreshold.ValueInt64Pointer(),
}
activeHealthCheck, err := toActiveHealthCheckPayload(ctx, utils.Ptr(targetPool))
if err != nil {
return nil, fmt.Errorf("converting target pool: %w", err)
}
var targets []loadbalancer.Target
for _, target := range targetPool.Targets {
targets = append(targets, loadbalancer.Target{
DisplayName: target.DisplayName.ValueStringPointer(),
Ip: target.Ip.ValueStringPointer(),
})
targets := toTargetsPayload(utils.Ptr(targetPool))
if err != nil {
return nil, fmt.Errorf("converting target pool: %w", err)
}
targetPools = append(targetPools, loadbalancer.TargetPool{
ActiveHealthCheck: activeHealthCheck,
Name: targetPool.Name.ValueStringPointer(),
TargetPort: targetPool.TargetPort.ValueInt64Pointer(),
Targets: &targets,
Targets: targets,
})
}
return &targetPools, nil
}
func toTargetPoolUpdatePayload(ctx context.Context, targetPool *TargetPool) (*loadbalancer.UpdateTargetPoolPayload, error) {
if targetPool == nil {
return nil, fmt.Errorf("nil target pool")
}
activeHealthCheck, err := toActiveHealthCheckPayload(ctx, targetPool)
if err != nil {
return nil, fmt.Errorf("converting target pool: %w", err)
}
targets := toTargetsPayload(targetPool)
return &loadbalancer.UpdateTargetPoolPayload{
ActiveHealthCheck: activeHealthCheck,
Name: targetPool.Name.ValueStringPointer(),
TargetPort: targetPool.TargetPort.ValueInt64Pointer(),
Targets: targets,
}, nil
}
func toActiveHealthCheckPayload(ctx context.Context, targetPool *TargetPool) (*loadbalancer.ActiveHealthCheck, error) {
if targetPool.ActiveHealthCheck.IsNull() {
return nil, nil
}
var activeHealthCheckModel ActiveHealthCheck
diags := targetPool.ActiveHealthCheck.As(ctx, &activeHealthCheckModel, basetypes.ObjectAsOptions{})
if diags.HasError() {
return nil, fmt.Errorf("converting active health check: %w", core.DiagsToError(diags))
}
return &loadbalancer.ActiveHealthCheck{
HealthyThreshold: activeHealthCheckModel.HealthyThreshold.ValueInt64Pointer(),
Interval: activeHealthCheckModel.Interval.ValueStringPointer(),
IntervalJitter: activeHealthCheckModel.IntervalJitter.ValueStringPointer(),
Timeout: activeHealthCheckModel.Timeout.ValueStringPointer(),
UnhealthyThreshold: activeHealthCheckModel.UnhealthyThreshold.ValueInt64Pointer(),
}, nil
}
func toTargetsPayload(targetPool *TargetPool) *[]loadbalancer.Target {
if targetPool.Targets == nil {
return nil
}
var targets []loadbalancer.Target
for _, target := range targetPool.Targets {
targets = append(targets, loadbalancer.Target{
DisplayName: target.DisplayName.ValueStringPointer(),
Ip: target.Ip.ValueStringPointer(),
})
}
return &targets
}
func mapFields(ctx context.Context, lb *loadbalancer.LoadBalancer, m *Model) error {
if lb == nil {
return fmt.Errorf("response input is nil")
}
if m == nil {
return fmt.Errorf("model input is nil")
}
var name string
if m.Name.ValueString() != "" {
name = m.Name.ValueString()
} else if lb.Name != nil {
name = *lb.Name
} else {
return fmt.Errorf("name not present")
}
m.Name = types.StringValue(name)
idParts := []string{
m.ProjectId.ValueString(),
name,
}
m.Id = types.StringValue(
strings.Join(idParts, core.Separator),
)
m.ExternalAddress = types.StringPointerValue(lb.ExternalAddress)
m.PrivateAddress = types.StringPointerValue(lb.PrivateAddress)
mapListeners(lb, m)
mapNetworks(lb, m)
err := mapOptions(ctx, lb, m)
if err != nil {
return fmt.Errorf("mapping options: %w", err)
}
err = mapTargetPools(lb, m)
if err != nil {
return fmt.Errorf("mapping target pools: %w", err)
}
return nil
}
func mapListeners(lb *loadbalancer.LoadBalancer, m *Model) {
if lb.Listeners == nil {
return
}
var listeners []Listener
for _, listener := range *lb.Listeners {
listeners = append(listeners, Listener{
DisplayName: types.StringPointerValue(listener.DisplayName),
Port: types.Int64PointerValue(listener.Port),
Protocol: types.StringPointerValue(listener.Protocol),
TargetPool: types.StringPointerValue(listener.TargetPool),
})
}
m.Listeners = listeners
}
func mapNetworks(lb *loadbalancer.LoadBalancer, m *Model) {
if lb.Networks == nil {
return
}
var networks []Network
for _, network := range *lb.Networks {
networks = append(networks, Network{
NetworkId: types.StringPointerValue(network.NetworkId),
Role: types.StringPointerValue(network.Role),
})
}
m.Networks = networks
}
func mapOptions(ctx context.Context, lb *loadbalancer.LoadBalancer, m *Model) error {
if lb.Options == nil {
return nil
}
var diags diag.Diagnostics
acl := types.ListNull(types.StringType)
if lb.Options.AccessControl != nil && lb.Options.AccessControl.AllowedSourceRanges != nil {
acl, diags = types.ListValueFrom(ctx, types.StringType, *lb.Options.AccessControl.AllowedSourceRanges)
if diags != nil {
return fmt.Errorf("converting acl: %w", core.DiagsToError(diags))
}
}
privateNetworkOnly := types.BoolNull()
if lb.Options.PrivateNetworkOnly != nil {
privateNetworkOnly = types.BoolValue(*lb.Options.PrivateNetworkOnly)
}
if acl.IsNull() && privateNetworkOnly.IsNull() {
return nil
}
optionsValues := map[string]attr.Value{
"acl": acl,
"private_network_only": privateNetworkOnly,
}
options, diags := types.ObjectValue(optionsTypes, optionsValues)
if diags != nil {
return fmt.Errorf("converting options: %w", core.DiagsToError(diags))
}
m.Options = options
return nil
}
func mapTargetPools(lb *loadbalancer.LoadBalancer, m *Model) error {
if lb.TargetPools == nil {
return nil
}
var diags diag.Diagnostics
var targetPools []TargetPool
for _, targetPool := range *lb.TargetPools {
var activeHealthCheck basetypes.ObjectValue
if targetPool.ActiveHealthCheck != nil {
activeHealthCheckValues := map[string]attr.Value{
"healthy_threshold": types.Int64Value(*targetPool.ActiveHealthCheck.HealthyThreshold),
"interval": types.StringValue(*targetPool.ActiveHealthCheck.Interval),
"interval_jitter": types.StringValue(*targetPool.ActiveHealthCheck.IntervalJitter),
"timeout": types.StringValue(*targetPool.ActiveHealthCheck.Timeout),
"unhealthy_threshold": types.Int64Value(*targetPool.ActiveHealthCheck.UnhealthyThreshold),
}
activeHealthCheck, diags = types.ObjectValue(activeHealthCheckTypes, activeHealthCheckValues)
if diags != nil {
return fmt.Errorf("converting active health check: %w", core.DiagsToError(diags))
}
}
var targets []Target
if targetPool.Targets != nil {
for _, target := range *targetPool.Targets {
targets = append(targets, Target{
DisplayName: types.StringPointerValue(target.DisplayName),
Ip: types.StringPointerValue(target.Ip),
})
}
}
targetPools = append(targetPools, TargetPool{
ActiveHealthCheck: activeHealthCheck,
Name: types.StringPointerValue(targetPool.Name),
TargetPort: types.Int64Value(*targetPool.TargetPort),
Targets: targets,
})
}
m.TargetPools = targetPools
return nil
}

View file

@ -165,3 +165,256 @@ func TestToCreatePayload(t *testing.T) {
})
}
}
func TestToTargetPoolUpdatePayload(t *testing.T) {
tests := []struct {
description string
input *TargetPool
expected *loadbalancer.UpdateTargetPoolPayload
isValid bool
}{
{
"default_values_ok",
&TargetPool{},
&loadbalancer.UpdateTargetPoolPayload{},
true,
},
{
"simple_values_ok",
&TargetPool{
ActiveHealthCheck: types.ObjectValueMust(
activeHealthCheckTypes,
map[string]attr.Value{
"healthy_threshold": types.Int64Value(1),
"interval": types.StringValue("2s"),
"interval_jitter": types.StringValue("3s"),
"timeout": types.StringValue("4s"),
"unhealthy_threshold": types.Int64Value(5),
},
),
Name: types.StringValue("name"),
TargetPort: types.Int64Value(80),
Targets: []Target{
{
DisplayName: types.StringValue("display_name"),
Ip: types.StringValue("ip"),
},
},
},
&loadbalancer.UpdateTargetPoolPayload{
ActiveHealthCheck: utils.Ptr(loadbalancer.ActiveHealthCheck{
HealthyThreshold: utils.Ptr(int64(1)),
Interval: utils.Ptr("2s"),
IntervalJitter: utils.Ptr("3s"),
Timeout: utils.Ptr("4s"),
UnhealthyThreshold: utils.Ptr(int64(5)),
}),
Name: utils.Ptr("name"),
TargetPort: utils.Ptr(int64(80)),
Targets: utils.Ptr([]loadbalancer.Target{
{
DisplayName: utils.Ptr("display_name"),
Ip: utils.Ptr("ip"),
},
}),
},
true,
},
{
"nil_target_pool",
nil,
nil,
false,
},
}
for _, tt := range tests {
t.Run(tt.description, func(t *testing.T) {
output, err := toTargetPoolUpdatePayload(context.Background(), tt.input)
if !tt.isValid && err == nil {
t.Fatalf("Should have failed")
}
if tt.isValid && err != nil {
t.Fatalf("Should not have failed: %v", err)
}
if tt.isValid {
diff := cmp.Diff(output, tt.expected)
if diff != "" {
t.Fatalf("Data does not match: %s", diff)
}
}
})
}
}
func TestMapFields(t *testing.T) {
tests := []struct {
description string
input *loadbalancer.LoadBalancer
expected *Model
isValid bool
}{
{
"default_values_ok",
&loadbalancer.LoadBalancer{
ExternalAddress: nil,
Listeners: nil,
Name: utils.Ptr("name"),
Networks: nil,
Options: &loadbalancer.LoadBalancerOptions{
AccessControl: &loadbalancer.LoadbalancerOptionAccessControl{
AllowedSourceRanges: nil,
},
PrivateNetworkOnly: nil,
},
TargetPools: nil,
},
&Model{
Id: types.StringValue("pid,name"),
ProjectId: types.StringValue("pid"),
Name: types.StringValue("name"),
},
true,
},
{
"simple_values_ok",
&loadbalancer.LoadBalancer{
ExternalAddress: utils.Ptr("external_address"),
Listeners: utils.Ptr([]loadbalancer.Listener{
{
DisplayName: utils.Ptr("display_name"),
Port: utils.Ptr(int64(80)),
Protocol: utils.Ptr("protocol"),
TargetPool: utils.Ptr("target_pool"),
},
}),
Name: utils.Ptr("name"),
Networks: utils.Ptr([]loadbalancer.Network{
{
NetworkId: utils.Ptr("network_id"),
Role: utils.Ptr("role"),
},
{
NetworkId: utils.Ptr("network_id_2"),
Role: utils.Ptr("role_2"),
},
}),
Options: utils.Ptr(loadbalancer.LoadBalancerOptions{
AccessControl: &loadbalancer.LoadbalancerOptionAccessControl{
AllowedSourceRanges: utils.Ptr([]string{"cidr"}),
},
PrivateNetworkOnly: utils.Ptr(true),
}),
TargetPools: utils.Ptr([]loadbalancer.TargetPool{
{
ActiveHealthCheck: utils.Ptr(loadbalancer.ActiveHealthCheck{
HealthyThreshold: utils.Ptr(int64(1)),
Interval: utils.Ptr("2s"),
IntervalJitter: utils.Ptr("3s"),
Timeout: utils.Ptr("4s"),
UnhealthyThreshold: utils.Ptr(int64(5)),
}),
Name: utils.Ptr("name"),
TargetPort: utils.Ptr(int64(80)),
Targets: utils.Ptr([]loadbalancer.Target{
{
DisplayName: utils.Ptr("display_name"),
Ip: utils.Ptr("ip"),
},
}),
},
}),
},
&Model{
Id: types.StringValue("pid,name"),
ProjectId: types.StringValue("pid"),
Name: types.StringValue("name"),
ExternalAddress: types.StringValue("external_address"),
Listeners: []Listener{
{
DisplayName: types.StringValue("display_name"),
Port: types.Int64Value(80),
Protocol: types.StringValue("protocol"),
TargetPool: types.StringValue("target_pool"),
},
},
Networks: []Network{
{
NetworkId: types.StringValue("network_id"),
Role: types.StringValue("role"),
},
{
NetworkId: types.StringValue("network_id_2"),
Role: types.StringValue("role_2"),
},
},
Options: types.ObjectValueMust(
optionsTypes,
map[string]attr.Value{
"acl": types.ListValueMust(
types.StringType,
[]attr.Value{types.StringValue("cidr")}),
"private_network_only": types.BoolValue(true),
},
),
TargetPools: []TargetPool{
{
ActiveHealthCheck: types.ObjectValueMust(
activeHealthCheckTypes,
map[string]attr.Value{
"healthy_threshold": types.Int64Value(1),
"interval": types.StringValue("2s"),
"interval_jitter": types.StringValue("3s"),
"timeout": types.StringValue("4s"),
"unhealthy_threshold": types.Int64Value(5),
},
),
Name: types.StringValue("name"),
TargetPort: types.Int64Value(80),
Targets: []Target{
{
DisplayName: types.StringValue("display_name"),
Ip: types.StringValue("ip"),
},
},
},
},
},
true,
},
{
"nil_response",
nil,
&Model{},
false,
},
{
"no_name",
&loadbalancer.LoadBalancer{},
&Model{},
false,
},
}
for _, tt := range tests {
t.Run(tt.description, func(t *testing.T) {
model := &Model{
ProjectId: tt.expected.ProjectId,
}
err := mapFields(context.Background(), tt.input, model)
if !tt.isValid && err == nil {
t.Fatalf("Should have failed")
}
if tt.isValid && err != nil {
t.Fatalf("Should not have failed: %v", err)
}
if tt.isValid {
diff := cmp.Diff(model, tt.expected)
if diff != "" {
t.Fatalf("Data does not match: %s", diff)
}
}
})
}
}