diff --git a/stackit/internal/services/loadbalancer/loadbalancer/resource.go b/stackit/internal/services/loadbalancer/loadbalancer/resource.go index 26eb8781..9fd83907 100644 --- a/stackit/internal/services/loadbalancer/loadbalancer/resource.go +++ b/stackit/internal/services/loadbalancer/loadbalancer/resource.go @@ -3,10 +3,12 @@ package loadbalancer import ( "context" "fmt" + "strings" "github.com/hashicorp/terraform-plugin-framework-validators/setvalidator" "github.com/hashicorp/terraform-plugin-framework-validators/stringvalidator" "github.com/hashicorp/terraform-plugin-framework/attr" + "github.com/hashicorp/terraform-plugin-framework/diag" "github.com/hashicorp/terraform-plugin-framework/resource" "github.com/hashicorp/terraform-plugin-framework/resource/schema" "github.com/hashicorp/terraform-plugin-framework/resource/schema/planmodifier" @@ -16,6 +18,7 @@ import ( "github.com/hashicorp/terraform-plugin-framework/types/basetypes" "github.com/hashicorp/terraform-plugin-log/tflog" "github.com/stackitcloud/stackit-sdk-go/core/config" + "github.com/stackitcloud/stackit-sdk-go/core/utils" "github.com/stackitcloud/stackit-sdk-go/services/loadbalancer" "github.com/stackitcloud/stackit-sdk-go/services/loadbalancer/wait" "github.com/stackitcloud/terraform-provider-stackit/stackit/internal/core" @@ -44,7 +47,6 @@ type Model struct { // Struct corresponding to each Model.Listener type Listener struct { DisplayName types.String `tfsdk:"display_name"` - Name types.String `tfsdk:"name"` Port types.Int64 `tfsdk:"port"` Protocol types.String `tfsdk:"protocol"` TargetPool types.String `tfsdk:"target_pool"` @@ -160,7 +162,6 @@ func (r *projectResource) Schema(_ context.Context, _ resource.SchemaRequest, re "project_id": "STACKIT project ID to which the Load Balancer is associated.", "external_address": "External Load Balancer IP address where this Load Balancer is exposed.", "listeners": "List of all listeners which will accept traffic. Limited to 20.", - "listeners.name": "Will be used to reference a listener and will replace display name in the future.", "port": "Port number where we listen for traffic.", "protocol": "Protocol is the highest network protocol we understand to load balance.", "target_pool": "Reference target pool by target pool name.", @@ -221,10 +222,6 @@ func (r *projectResource) Schema(_ context.Context, _ resource.SchemaRequest, re Optional: true, Computed: true, }, - "name": schema.StringAttribute{ - Description: descriptions["listeners.display_name"], - Computed: true, - }, "port": schema.Int64Attribute{ Description: descriptions["port"], Optional: true, @@ -398,6 +395,8 @@ func (r *projectResource) Create(ctx context.Context, req resource.CreateRequest core.LogAndAddError(ctx, &resp.Diagnostics, "Error getting status of load balancer functionality", fmt.Sprintf("Calling API: %v", err)) return } + + // If load balancer functionality is not enabled, enable it if *statusResp.Status != wait.FunctionalityStatusReady { _, err = r.client.EnableLoadBalancing(ctx, projectId).Execute() if err != nil { @@ -413,21 +412,129 @@ func (r *projectResource) Create(ctx context.Context, req resource.CreateRequest } // Generate API request body from model - _, err = toCreatePayload(ctx, &model) + payload, err := toCreatePayload(ctx, &model) if err != nil { - core.LogAndAddError(ctx, &resp.Diagnostics, "Error creating instance", fmt.Sprintf("Creating API payload: %v", err)) + core.LogAndAddError(ctx, &resp.Diagnostics, "Error creating load balancer", fmt.Sprintf("Creating API payload: %v", err)) return } + + // Create a new load balancer + createResp, err := r.client.CreateLoadBalancer(ctx, projectId).CreateLoadBalancerPayload(*payload).Execute() + if err != nil { + core.LogAndAddError(ctx, &resp.Diagnostics, "Error creating load balancer", fmt.Sprintf("Calling API: %v", err)) + return + } + + waitResp, err := wait.CreateLoadBalancerWaitHandler(ctx, r.client, projectId, *createResp.Name).WaitWithContext(ctx) + if err != nil { + core.LogAndAddError(ctx, &resp.Diagnostics, "Error creating load balancer", fmt.Sprintf("Load balancer creation waiting: %v", err)) + return + } + + // Map response body to schema + err = mapFields(ctx, waitResp, &model) + if err != nil { + core.LogAndAddError(ctx, &resp.Diagnostics, "Error creating load balancer", fmt.Sprintf("Processing API payload: %v", err)) + return + } + + // Set state to fully populated data + diags = resp.State.Set(ctx, model) + resp.Diagnostics.Append(diags...) + if resp.Diagnostics.HasError() { + return + } + + tflog.Info(ctx, "Load balancer created") } // Read refreshes the Terraform state with the latest data. -func (r *projectResource) Read(_ context.Context, _ resource.ReadRequest, _ *resource.ReadResponse) { // nolint:gocritic // function signature required by Terraform +func (r *projectResource) Read(ctx context.Context, req resource.ReadRequest, resp *resource.ReadResponse) { // nolint:gocritic // function signature required by Terraform + var model Model + diags := req.State.Get(ctx, &model) + resp.Diagnostics.Append(diags...) + if resp.Diagnostics.HasError() { + return + } + projectId := model.ProjectId.ValueString() + name := model.Name.ValueString() + ctx = tflog.SetField(ctx, "project_id", projectId) + ctx = tflog.SetField(ctx, "name", name) + lbResp, err := r.client.GetLoadBalancer(ctx, projectId, name).Execute() + if err != nil { + core.LogAndAddError(ctx, &resp.Diagnostics, "Error reading load balancer", err.Error()) + return + } + + // Map response body to schema + err = mapFields(ctx, lbResp, &model) + if err != nil { + core.LogAndAddError(ctx, &resp.Diagnostics, "Error reading load balancer", fmt.Sprintf("Processing API payload: %v", err)) + return + } + + // Set refreshed state + diags = resp.State.Set(ctx, model) + resp.Diagnostics.Append(diags...) + if resp.Diagnostics.HasError() { + return + } + tflog.Info(ctx, "Load balancer read") } // Update updates the resource and sets the updated Terraform state on success. -func (r *projectResource) Update(_ context.Context, _ resource.UpdateRequest, _ *resource.UpdateResponse) { // nolint:gocritic // function signature required by Terraform +func (r *projectResource) Update(ctx context.Context, req resource.UpdateRequest, resp *resource.UpdateResponse) { // nolint:gocritic // function signature required by Terraform + // Retrieve values from plan + var model Model + diags := req.Plan.Get(ctx, &model) + resp.Diagnostics.Append(diags...) + if resp.Diagnostics.HasError() { + return + } + projectId := model.ProjectId.ValueString() + name := model.Name.ValueString() + ctx = tflog.SetField(ctx, "project_id", projectId) + ctx = tflog.SetField(ctx, "name", name) + for _, targetPool := range model.TargetPools { + // Generate API request body from model + payload, err := toTargetPoolUpdatePayload(ctx, utils.Ptr(targetPool)) + if err != nil { + core.LogAndAddError(ctx, &resp.Diagnostics, "Error updating load balancer", fmt.Sprintf("Creating API payload: %v", err)) + return + } + + // Update target pool + _, err = r.client.UpdateTargetPool(ctx, projectId, name, targetPool.Name.ValueString()).UpdateTargetPoolPayload(*payload).Execute() + if err != nil { + core.LogAndAddError(ctx, &resp.Diagnostics, "Error updating load balancer", fmt.Sprintf("Calling API: %v", err)) + return + } + } + + // Get updated load balancer + getResp, err := r.client.GetLoadBalancer(ctx, projectId, name).Execute() + if err != nil { + core.LogAndAddError(ctx, &resp.Diagnostics, "Error updating load balancer", fmt.Sprintf("Calling API: %v", err)) + return + } + + // Map response body to schema + err = mapFields(ctx, getResp, &model) + if err != nil { + core.LogAndAddError(ctx, &resp.Diagnostics, "Error creating load balancer", fmt.Sprintf("Processing API payload: %v", err)) + return + } + + // Set state to fully populated data + diags = resp.State.Set(ctx, model) + resp.Diagnostics.Append(diags...) + if resp.Diagnostics.HasError() { + return + } + + tflog.Info(ctx, "Load balancer updated") } // Delete deletes the resource and removes the Terraform state on success. @@ -535,38 +642,232 @@ func toTargetPoolsPayload(ctx context.Context, model *Model) (*[]loadbalancer.Ta var targetPools []loadbalancer.TargetPool for _, targetPool := range model.TargetPools { - var activeHealthCheck *loadbalancer.ActiveHealthCheck - if !(targetPool.ActiveHealthCheck.IsNull() || targetPool.ActiveHealthCheck.IsUnknown()) { - var activeHealthCheckModel ActiveHealthCheck - diags := targetPool.ActiveHealthCheck.As(ctx, &activeHealthCheckModel, basetypes.ObjectAsOptions{}) - if diags.HasError() { - return nil, fmt.Errorf("converting active health check: %w", core.DiagsToError(diags)) - } - - activeHealthCheck = &loadbalancer.ActiveHealthCheck{ - HealthyThreshold: activeHealthCheckModel.HealthyThreshold.ValueInt64Pointer(), - Interval: activeHealthCheckModel.Interval.ValueStringPointer(), - IntervalJitter: activeHealthCheckModel.IntervalJitter.ValueStringPointer(), - Timeout: activeHealthCheckModel.Timeout.ValueStringPointer(), - UnhealthyThreshold: activeHealthCheckModel.UnhealthyThreshold.ValueInt64Pointer(), - } + activeHealthCheck, err := toActiveHealthCheckPayload(ctx, utils.Ptr(targetPool)) + if err != nil { + return nil, fmt.Errorf("converting target pool: %w", err) } - var targets []loadbalancer.Target - for _, target := range targetPool.Targets { - targets = append(targets, loadbalancer.Target{ - DisplayName: target.DisplayName.ValueStringPointer(), - Ip: target.Ip.ValueStringPointer(), - }) + targets := toTargetsPayload(utils.Ptr(targetPool)) + if err != nil { + return nil, fmt.Errorf("converting target pool: %w", err) } targetPools = append(targetPools, loadbalancer.TargetPool{ ActiveHealthCheck: activeHealthCheck, Name: targetPool.Name.ValueStringPointer(), TargetPort: targetPool.TargetPort.ValueInt64Pointer(), - Targets: &targets, + Targets: targets, }) } return &targetPools, nil } + +func toTargetPoolUpdatePayload(ctx context.Context, targetPool *TargetPool) (*loadbalancer.UpdateTargetPoolPayload, error) { + if targetPool == nil { + return nil, fmt.Errorf("nil target pool") + } + + activeHealthCheck, err := toActiveHealthCheckPayload(ctx, targetPool) + if err != nil { + return nil, fmt.Errorf("converting target pool: %w", err) + } + + targets := toTargetsPayload(targetPool) + + return &loadbalancer.UpdateTargetPoolPayload{ + ActiveHealthCheck: activeHealthCheck, + Name: targetPool.Name.ValueStringPointer(), + TargetPort: targetPool.TargetPort.ValueInt64Pointer(), + Targets: targets, + }, nil +} + +func toActiveHealthCheckPayload(ctx context.Context, targetPool *TargetPool) (*loadbalancer.ActiveHealthCheck, error) { + if targetPool.ActiveHealthCheck.IsNull() { + return nil, nil + } + + var activeHealthCheckModel ActiveHealthCheck + diags := targetPool.ActiveHealthCheck.As(ctx, &activeHealthCheckModel, basetypes.ObjectAsOptions{}) + if diags.HasError() { + return nil, fmt.Errorf("converting active health check: %w", core.DiagsToError(diags)) + } + + return &loadbalancer.ActiveHealthCheck{ + HealthyThreshold: activeHealthCheckModel.HealthyThreshold.ValueInt64Pointer(), + Interval: activeHealthCheckModel.Interval.ValueStringPointer(), + IntervalJitter: activeHealthCheckModel.IntervalJitter.ValueStringPointer(), + Timeout: activeHealthCheckModel.Timeout.ValueStringPointer(), + UnhealthyThreshold: activeHealthCheckModel.UnhealthyThreshold.ValueInt64Pointer(), + }, nil +} + +func toTargetsPayload(targetPool *TargetPool) *[]loadbalancer.Target { + if targetPool.Targets == nil { + return nil + } + + var targets []loadbalancer.Target + for _, target := range targetPool.Targets { + targets = append(targets, loadbalancer.Target{ + DisplayName: target.DisplayName.ValueStringPointer(), + Ip: target.Ip.ValueStringPointer(), + }) + } + + return &targets +} + +func mapFields(ctx context.Context, lb *loadbalancer.LoadBalancer, m *Model) error { + if lb == nil { + return fmt.Errorf("response input is nil") + } + if m == nil { + return fmt.Errorf("model input is nil") + } + + var name string + if m.Name.ValueString() != "" { + name = m.Name.ValueString() + } else if lb.Name != nil { + name = *lb.Name + } else { + return fmt.Errorf("name not present") + } + m.Name = types.StringValue(name) + idParts := []string{ + m.ProjectId.ValueString(), + name, + } + m.Id = types.StringValue( + strings.Join(idParts, core.Separator), + ) + + m.ExternalAddress = types.StringPointerValue(lb.ExternalAddress) + m.PrivateAddress = types.StringPointerValue(lb.PrivateAddress) + + mapListeners(lb, m) + mapNetworks(lb, m) + err := mapOptions(ctx, lb, m) + if err != nil { + return fmt.Errorf("mapping options: %w", err) + } + err = mapTargetPools(lb, m) + if err != nil { + return fmt.Errorf("mapping target pools: %w", err) + } + + return nil +} + +func mapListeners(lb *loadbalancer.LoadBalancer, m *Model) { + if lb.Listeners == nil { + return + } + + var listeners []Listener + for _, listener := range *lb.Listeners { + listeners = append(listeners, Listener{ + DisplayName: types.StringPointerValue(listener.DisplayName), + Port: types.Int64PointerValue(listener.Port), + Protocol: types.StringPointerValue(listener.Protocol), + TargetPool: types.StringPointerValue(listener.TargetPool), + }) + } + m.Listeners = listeners +} + +func mapNetworks(lb *loadbalancer.LoadBalancer, m *Model) { + if lb.Networks == nil { + return + } + + var networks []Network + for _, network := range *lb.Networks { + networks = append(networks, Network{ + NetworkId: types.StringPointerValue(network.NetworkId), + Role: types.StringPointerValue(network.Role), + }) + } + m.Networks = networks +} + +func mapOptions(ctx context.Context, lb *loadbalancer.LoadBalancer, m *Model) error { + if lb.Options == nil { + return nil + } + + var diags diag.Diagnostics + acl := types.ListNull(types.StringType) + if lb.Options.AccessControl != nil && lb.Options.AccessControl.AllowedSourceRanges != nil { + acl, diags = types.ListValueFrom(ctx, types.StringType, *lb.Options.AccessControl.AllowedSourceRanges) + if diags != nil { + return fmt.Errorf("converting acl: %w", core.DiagsToError(diags)) + } + } + privateNetworkOnly := types.BoolNull() + if lb.Options.PrivateNetworkOnly != nil { + privateNetworkOnly = types.BoolValue(*lb.Options.PrivateNetworkOnly) + } + if acl.IsNull() && privateNetworkOnly.IsNull() { + return nil + } + + optionsValues := map[string]attr.Value{ + "acl": acl, + "private_network_only": privateNetworkOnly, + } + options, diags := types.ObjectValue(optionsTypes, optionsValues) + if diags != nil { + return fmt.Errorf("converting options: %w", core.DiagsToError(diags)) + } + m.Options = options + + return nil +} + +func mapTargetPools(lb *loadbalancer.LoadBalancer, m *Model) error { + if lb.TargetPools == nil { + return nil + } + + var diags diag.Diagnostics + var targetPools []TargetPool + for _, targetPool := range *lb.TargetPools { + var activeHealthCheck basetypes.ObjectValue + if targetPool.ActiveHealthCheck != nil { + activeHealthCheckValues := map[string]attr.Value{ + "healthy_threshold": types.Int64Value(*targetPool.ActiveHealthCheck.HealthyThreshold), + "interval": types.StringValue(*targetPool.ActiveHealthCheck.Interval), + "interval_jitter": types.StringValue(*targetPool.ActiveHealthCheck.IntervalJitter), + "timeout": types.StringValue(*targetPool.ActiveHealthCheck.Timeout), + "unhealthy_threshold": types.Int64Value(*targetPool.ActiveHealthCheck.UnhealthyThreshold), + } + activeHealthCheck, diags = types.ObjectValue(activeHealthCheckTypes, activeHealthCheckValues) + if diags != nil { + return fmt.Errorf("converting active health check: %w", core.DiagsToError(diags)) + } + } + + var targets []Target + if targetPool.Targets != nil { + for _, target := range *targetPool.Targets { + targets = append(targets, Target{ + DisplayName: types.StringPointerValue(target.DisplayName), + Ip: types.StringPointerValue(target.Ip), + }) + } + } + + targetPools = append(targetPools, TargetPool{ + ActiveHealthCheck: activeHealthCheck, + Name: types.StringPointerValue(targetPool.Name), + TargetPort: types.Int64Value(*targetPool.TargetPort), + Targets: targets, + }) + } + m.TargetPools = targetPools + + return nil +} diff --git a/stackit/internal/services/loadbalancer/loadbalancer/resource_test.go b/stackit/internal/services/loadbalancer/loadbalancer/resource_test.go index 832fff8c..2be4f975 100644 --- a/stackit/internal/services/loadbalancer/loadbalancer/resource_test.go +++ b/stackit/internal/services/loadbalancer/loadbalancer/resource_test.go @@ -165,3 +165,256 @@ func TestToCreatePayload(t *testing.T) { }) } } + +func TestToTargetPoolUpdatePayload(t *testing.T) { + tests := []struct { + description string + input *TargetPool + expected *loadbalancer.UpdateTargetPoolPayload + isValid bool + }{ + { + "default_values_ok", + &TargetPool{}, + &loadbalancer.UpdateTargetPoolPayload{}, + true, + }, + { + "simple_values_ok", + &TargetPool{ + ActiveHealthCheck: types.ObjectValueMust( + activeHealthCheckTypes, + map[string]attr.Value{ + "healthy_threshold": types.Int64Value(1), + "interval": types.StringValue("2s"), + "interval_jitter": types.StringValue("3s"), + "timeout": types.StringValue("4s"), + "unhealthy_threshold": types.Int64Value(5), + }, + ), + Name: types.StringValue("name"), + TargetPort: types.Int64Value(80), + Targets: []Target{ + { + DisplayName: types.StringValue("display_name"), + Ip: types.StringValue("ip"), + }, + }, + }, + &loadbalancer.UpdateTargetPoolPayload{ + ActiveHealthCheck: utils.Ptr(loadbalancer.ActiveHealthCheck{ + HealthyThreshold: utils.Ptr(int64(1)), + Interval: utils.Ptr("2s"), + IntervalJitter: utils.Ptr("3s"), + Timeout: utils.Ptr("4s"), + UnhealthyThreshold: utils.Ptr(int64(5)), + }), + Name: utils.Ptr("name"), + TargetPort: utils.Ptr(int64(80)), + Targets: utils.Ptr([]loadbalancer.Target{ + { + DisplayName: utils.Ptr("display_name"), + Ip: utils.Ptr("ip"), + }, + }), + }, + true, + }, + { + "nil_target_pool", + nil, + nil, + false, + }, + } + for _, tt := range tests { + t.Run(tt.description, func(t *testing.T) { + output, err := toTargetPoolUpdatePayload(context.Background(), tt.input) + if !tt.isValid && err == nil { + t.Fatalf("Should have failed") + } + if tt.isValid && err != nil { + t.Fatalf("Should not have failed: %v", err) + } + if tt.isValid { + diff := cmp.Diff(output, tt.expected) + if diff != "" { + t.Fatalf("Data does not match: %s", diff) + } + } + }) + } +} + +func TestMapFields(t *testing.T) { + tests := []struct { + description string + input *loadbalancer.LoadBalancer + expected *Model + isValid bool + }{ + { + "default_values_ok", + &loadbalancer.LoadBalancer{ + ExternalAddress: nil, + Listeners: nil, + Name: utils.Ptr("name"), + Networks: nil, + Options: &loadbalancer.LoadBalancerOptions{ + AccessControl: &loadbalancer.LoadbalancerOptionAccessControl{ + AllowedSourceRanges: nil, + }, + PrivateNetworkOnly: nil, + }, + TargetPools: nil, + }, + &Model{ + Id: types.StringValue("pid,name"), + ProjectId: types.StringValue("pid"), + Name: types.StringValue("name"), + }, + true, + }, + + { + "simple_values_ok", + &loadbalancer.LoadBalancer{ + ExternalAddress: utils.Ptr("external_address"), + Listeners: utils.Ptr([]loadbalancer.Listener{ + { + DisplayName: utils.Ptr("display_name"), + Port: utils.Ptr(int64(80)), + Protocol: utils.Ptr("protocol"), + TargetPool: utils.Ptr("target_pool"), + }, + }), + Name: utils.Ptr("name"), + Networks: utils.Ptr([]loadbalancer.Network{ + { + NetworkId: utils.Ptr("network_id"), + Role: utils.Ptr("role"), + }, + { + NetworkId: utils.Ptr("network_id_2"), + Role: utils.Ptr("role_2"), + }, + }), + Options: utils.Ptr(loadbalancer.LoadBalancerOptions{ + AccessControl: &loadbalancer.LoadbalancerOptionAccessControl{ + AllowedSourceRanges: utils.Ptr([]string{"cidr"}), + }, + PrivateNetworkOnly: utils.Ptr(true), + }), + TargetPools: utils.Ptr([]loadbalancer.TargetPool{ + { + ActiveHealthCheck: utils.Ptr(loadbalancer.ActiveHealthCheck{ + HealthyThreshold: utils.Ptr(int64(1)), + Interval: utils.Ptr("2s"), + IntervalJitter: utils.Ptr("3s"), + Timeout: utils.Ptr("4s"), + UnhealthyThreshold: utils.Ptr(int64(5)), + }), + Name: utils.Ptr("name"), + TargetPort: utils.Ptr(int64(80)), + Targets: utils.Ptr([]loadbalancer.Target{ + { + DisplayName: utils.Ptr("display_name"), + Ip: utils.Ptr("ip"), + }, + }), + }, + }), + }, + &Model{ + Id: types.StringValue("pid,name"), + ProjectId: types.StringValue("pid"), + Name: types.StringValue("name"), + ExternalAddress: types.StringValue("external_address"), + Listeners: []Listener{ + { + DisplayName: types.StringValue("display_name"), + Port: types.Int64Value(80), + Protocol: types.StringValue("protocol"), + TargetPool: types.StringValue("target_pool"), + }, + }, + Networks: []Network{ + { + NetworkId: types.StringValue("network_id"), + Role: types.StringValue("role"), + }, + { + NetworkId: types.StringValue("network_id_2"), + Role: types.StringValue("role_2"), + }, + }, + Options: types.ObjectValueMust( + optionsTypes, + map[string]attr.Value{ + "acl": types.ListValueMust( + types.StringType, + + []attr.Value{types.StringValue("cidr")}), + "private_network_only": types.BoolValue(true), + }, + ), + TargetPools: []TargetPool{ + { + ActiveHealthCheck: types.ObjectValueMust( + activeHealthCheckTypes, + map[string]attr.Value{ + "healthy_threshold": types.Int64Value(1), + "interval": types.StringValue("2s"), + "interval_jitter": types.StringValue("3s"), + "timeout": types.StringValue("4s"), + + "unhealthy_threshold": types.Int64Value(5), + }, + ), + Name: types.StringValue("name"), + TargetPort: types.Int64Value(80), + Targets: []Target{ + { + DisplayName: types.StringValue("display_name"), + Ip: types.StringValue("ip"), + }, + }, + }, + }, + }, + true, + }, + { + "nil_response", + nil, + &Model{}, + false, + }, + { + "no_name", + &loadbalancer.LoadBalancer{}, + &Model{}, + false, + }, + } + for _, tt := range tests { + t.Run(tt.description, func(t *testing.T) { + model := &Model{ + ProjectId: tt.expected.ProjectId, + } + err := mapFields(context.Background(), tt.input, model) + if !tt.isValid && err == nil { + t.Fatalf("Should have failed") + } + if tt.isValid && err != nil { + t.Fatalf("Should not have failed: %v", err) + } + if tt.isValid { + diff := cmp.Diff(model, tt.expected) + if diff != "" { + t.Fatalf("Data does not match: %s", diff) + } + } + }) + } +}