diff --git a/pkg/sqlserverflexalpha/wait/wait.go b/pkg/sqlserverflexalpha/wait/wait.go new file mode 100644 index 00000000..5f1563be --- /dev/null +++ b/pkg/sqlserverflexalpha/wait/wait.go @@ -0,0 +1,101 @@ +package wait + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + "time" + + "github.com/stackitcloud/stackit-sdk-go/core/oapierror" + "github.com/stackitcloud/stackit-sdk-go/core/wait" + "github.com/stackitcloud/stackit-sdk-go/services/sqlserverflex" +) + +const ( + InstanceStateEmpty = "" + InstanceStateProcessing = "Progressing" + InstanceStateUnknown = "Unknown" + InstanceStateSuccess = "Ready" + InstanceStateFailed = "Failed" +) + +// Interface needed for tests +type APIClientInstanceInterface interface { + GetInstanceExecute(ctx context.Context, projectId, instanceId, region string) (*sqlserverflex.GetInstanceResponse, error) +} + +// CreateInstanceWaitHandler will wait for instance creation +func CreateInstanceWaitHandler(ctx context.Context, a APIClientInstanceInterface, projectId, instanceId, region string) *wait.AsyncActionHandler[sqlserverflex.GetInstanceResponse] { + handler := wait.New(func() (waitFinished bool, response *sqlserverflex.GetInstanceResponse, err error) { + s, err := a.GetInstanceExecute(ctx, projectId, instanceId, region) + if err != nil { + return false, nil, err + } + if s == nil || s.Item == nil || s.Item.Id == nil || *s.Item.Id != instanceId || s.Item.Status == nil { + return false, nil, nil + } + switch strings.ToLower(*s.Item.Status) { + case strings.ToLower(InstanceStateSuccess): + return true, s, nil + case strings.ToLower(InstanceStateUnknown), strings.ToLower(InstanceStateFailed): + return true, s, fmt.Errorf("create failed for instance with id %s", instanceId) + default: + return false, s, nil + } + }) + handler.SetTimeout(45 * time.Minute) + handler.SetSleepBeforeWait(5 * time.Second) + return handler +} + +// UpdateInstanceWaitHandler will wait for instance update +func UpdateInstanceWaitHandler(ctx context.Context, a APIClientInstanceInterface, projectId, instanceId, region string) *wait.AsyncActionHandler[sqlserverflex.GetInstanceResponse] { + handler := wait.New(func() (waitFinished bool, response *sqlserverflex.GetInstanceResponse, err error) { + s, err := a.GetInstanceExecute(ctx, projectId, instanceId, region) + if err != nil { + return false, nil, err + } + if s == nil || s.Item == nil || s.Item.Id == nil || *s.Item.Id != instanceId || s.Item.Status == nil { + return false, nil, nil + } + switch strings.ToLower(*s.Item.Status) { + case strings.ToLower(InstanceStateSuccess): + return true, s, nil + case strings.ToLower(InstanceStateUnknown), strings.ToLower(InstanceStateFailed): + return true, s, fmt.Errorf("update failed for instance with id %s", instanceId) + default: + return false, s, nil + } + }) + handler.SetSleepBeforeWait(2 * time.Second) + handler.SetTimeout(45 * time.Minute) + return handler +} + +// PartialUpdateInstanceWaitHandler will wait for instance update +func PartialUpdateInstanceWaitHandler(ctx context.Context, a APIClientInstanceInterface, projectId, instanceId, region string) *wait.AsyncActionHandler[sqlserverflex.GetInstanceResponse] { + return UpdateInstanceWaitHandler(ctx, a, projectId, instanceId, region) +} + +// DeleteInstanceWaitHandler will wait for instance deletion +func DeleteInstanceWaitHandler(ctx context.Context, a APIClientInstanceInterface, projectId, instanceId, region string) *wait.AsyncActionHandler[struct{}] { + handler := wait.New(func() (waitFinished bool, response *struct{}, err error) { + _, err = a.GetInstanceExecute(ctx, projectId, instanceId, region) + if err == nil { + return false, nil, nil + } + var oapiErr *oapierror.GenericOpenAPIError + ok := errors.As(err, &oapiErr) + if !ok { + return false, nil, fmt.Errorf("could not convert error to oapierror.GenericOpenAPIError") + } + if oapiErr.StatusCode != http.StatusNotFound { + return false, nil, err + } + return true, nil, nil + }) + handler.SetTimeout(15 * time.Minute) + return handler +} diff --git a/pkg/sqlserverflexalpha/wait/wait_test.go b/pkg/sqlserverflexalpha/wait/wait_test.go new file mode 100644 index 00000000..737f4e89 --- /dev/null +++ b/pkg/sqlserverflexalpha/wait/wait_test.go @@ -0,0 +1,242 @@ +package wait + +import ( + "context" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/stackitcloud/stackit-sdk-go/core/oapierror" + "github.com/stackitcloud/stackit-sdk-go/core/utils" + "github.com/stackitcloud/stackit-sdk-go/services/sqlserverflex" +) + +// Used for testing instance operations +type apiClientInstanceMocked struct { + instanceId string + instanceState string + instanceIsDeleted bool + instanceGetFails bool +} + +func (a *apiClientInstanceMocked) GetInstanceExecute(_ context.Context, _, _, _ string) (*sqlserverflex.GetInstanceResponse, error) { + if a.instanceGetFails { + return nil, &oapierror.GenericOpenAPIError{ + StatusCode: 500, + } + } + + if a.instanceIsDeleted { + return nil, &oapierror.GenericOpenAPIError{ + StatusCode: 404, + } + } + + return &sqlserverflex.GetInstanceResponse{ + Item: &sqlserverflex.Instance{ + Id: &a.instanceId, + Status: &a.instanceState, + }, + }, nil +} +func TestCreateInstanceWaitHandler(t *testing.T) { + tests := []struct { + desc string + instanceGetFails bool + instanceState string + usersGetErrorStatus int + wantErr bool + wantResp bool + }{ + { + desc: "create_succeeded", + instanceGetFails: false, + instanceState: InstanceStateSuccess, + wantErr: false, + wantResp: true, + }, + { + desc: "create_failed", + instanceGetFails: false, + instanceState: InstanceStateFailed, + wantErr: true, + wantResp: true, + }, + { + desc: "create_failed_2", + instanceGetFails: false, + instanceState: InstanceStateEmpty, + wantErr: true, + wantResp: true, + }, + { + desc: "instance_get_fails", + instanceGetFails: true, + wantErr: true, + wantResp: false, + }, + { + desc: "timeout", + instanceGetFails: false, + instanceState: InstanceStateProcessing, + wantErr: true, + wantResp: true, + }, + } + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + instanceId := "foo-bar" + + apiClient := &apiClientInstanceMocked{ + instanceId: instanceId, + instanceState: tt.instanceState, + instanceGetFails: tt.instanceGetFails, + } + + var wantRes *sqlserverflex.GetInstanceResponse + if tt.wantResp { + wantRes = &sqlserverflex.GetInstanceResponse{ + Item: &sqlserverflex.Instance{ + Id: &instanceId, + Status: utils.Ptr(tt.instanceState), + }, + } + } + + handler := CreateInstanceWaitHandler(context.Background(), apiClient, "", instanceId, "") + + gotRes, err := handler.SetTimeout(10 * time.Millisecond).SetSleepBeforeWait(1 * time.Millisecond).WaitWithContext(context.Background()) + + if (err != nil) != tt.wantErr { + t.Fatalf("handler error = %v, wantErr %v", err, tt.wantErr) + } + if !cmp.Equal(gotRes, wantRes) { + t.Fatalf("handler gotRes = %v, want %v", gotRes, wantRes) + } + }) + } +} + +func TestUpdateInstanceWaitHandler(t *testing.T) { + tests := []struct { + desc string + instanceGetFails bool + instanceState string + wantErr bool + wantResp bool + }{ + { + desc: "update_succeeded", + instanceGetFails: false, + instanceState: InstanceStateSuccess, + wantErr: false, + wantResp: true, + }, + { + desc: "update_failed", + instanceGetFails: false, + instanceState: InstanceStateFailed, + wantErr: true, + wantResp: true, + }, + { + desc: "update_failed_2", + instanceGetFails: false, + instanceState: InstanceStateEmpty, + wantErr: true, + wantResp: true, + }, + { + desc: "get_fails", + instanceGetFails: true, + wantErr: true, + wantResp: false, + }, + { + desc: "timeout", + instanceGetFails: false, + instanceState: InstanceStateProcessing, + wantErr: true, + wantResp: true, + }, + } + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + instanceId := "foo-bar" + + apiClient := &apiClientInstanceMocked{ + instanceId: instanceId, + instanceState: tt.instanceState, + instanceGetFails: tt.instanceGetFails, + } + + var wantRes *sqlserverflex.GetInstanceResponse + if tt.wantResp { + wantRes = &sqlserverflex.GetInstanceResponse{ + Item: &sqlserverflex.Instance{ + Id: &instanceId, + Status: utils.Ptr(tt.instanceState), + }, + } + } + + handler := UpdateInstanceWaitHandler(context.Background(), apiClient, "", instanceId, "") + + gotRes, err := handler.SetTimeout(10 * time.Millisecond).SetSleepBeforeWait(1 * time.Millisecond).WaitWithContext(context.Background()) + + if (err != nil) != tt.wantErr { + t.Fatalf("handler error = %v, wantErr %v", err, tt.wantErr) + } + if !cmp.Equal(gotRes, wantRes) { + t.Fatalf("handler gotRes = %v, want %v", gotRes, wantRes) + } + }) + } +} + +func TestDeleteInstanceWaitHandler(t *testing.T) { + tests := []struct { + desc string + instanceGetFails bool + instanceState string + wantErr bool + }{ + { + desc: "delete_succeeded", + instanceGetFails: false, + instanceState: InstanceStateSuccess, + wantErr: false, + }, + { + desc: "delete_failed", + instanceGetFails: false, + instanceState: InstanceStateFailed, + wantErr: true, + }, + { + desc: "get_fails", + instanceGetFails: true, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + instanceId := "foo-bar" + + apiClient := &apiClientInstanceMocked{ + instanceGetFails: tt.instanceGetFails, + instanceIsDeleted: tt.instanceState == InstanceStateSuccess, + instanceId: instanceId, + instanceState: tt.instanceState, + } + + handler := DeleteInstanceWaitHandler(context.Background(), apiClient, "", instanceId, "") + + _, err := handler.SetTimeout(10 * time.Millisecond).WaitWithContext(context.Background()) + + if (err != nil) != tt.wantErr { + t.Fatalf("handler error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +}