diff --git a/stackit/internal/core/retry_round_tripper.go b/stackit/internal/core/retry_round_tripper.go new file mode 100644 index 00000000..945a3061 --- /dev/null +++ b/stackit/internal/core/retry_round_tripper.go @@ -0,0 +1,239 @@ +package core + +import ( + "context" + "crypto/rand" + "errors" + "fmt" + "math/big" + "net/http" + "time" + + "github.com/hashicorp/terraform-plugin-log/tflog" +) + +const ( + // backoffMultiplier is the factor by which the delay is multiplied for exponential backoff. + backoffMultiplier = 2 + // jitterFactor is the divisor used to calculate jitter (e.g., half of the base delay). + jitterFactor = 2 +) + +var ( + // ErrRequestFailedAfterRetries is returned when a request fails after all retry attempts. + ErrRequestFailedAfterRetries = errors.New("request failed after all retry attempts") +) + +// RetryRoundTripper implements an http.RoundTripper that adds automatic retry logic for failed requests. +type RetryRoundTripper struct { + next http.RoundTripper + maxRetries int + initialDelay time.Duration + maxDelay time.Duration + perTryTimeout time.Duration +} + +// NewRetryRoundTripper creates a new instance of the RetryRoundTripper with the specified configuration. +func NewRetryRoundTripper( + next http.RoundTripper, + maxRetries int, + initialDelay, maxDelay, perTryTimeout time.Duration, +) *RetryRoundTripper { + return &RetryRoundTripper{ + next: next, + maxRetries: maxRetries, + initialDelay: initialDelay, + maxDelay: maxDelay, + perTryTimeout: perTryTimeout, + } +} + +// RoundTrip executes the request and retries on failure. +func (rrt *RetryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + resp, err := rrt.executeRequest(req) + if !rrt.shouldRetry(resp, err) { + if err != nil { + return resp, fmt.Errorf("initial request failed, not retrying: %w", err) + } + + return resp, nil + } + + return rrt.retryLoop(req, resp, err) +} + +// executeRequest performs a single HTTP request with a per-try timeout. +func (rrt *RetryRoundTripper) executeRequest(req *http.Request) (*http.Response, error) { + ctx, cancel := context.WithTimeout(req.Context(), rrt.perTryTimeout) + defer cancel() + + resp, err := rrt.next.RoundTrip(req.WithContext(ctx)) + if err != nil { + if errors.Is(err, context.DeadlineExceeded) { + return resp, fmt.Errorf("per-try timeout of %v exceeded: %w", rrt.perTryTimeout, err) + } + + return resp, fmt.Errorf("http roundtrip failed: %w", err) + } + + return resp, nil +} + +// retryLoop handles the retry logic for a failed request. +func (rrt *RetryRoundTripper) retryLoop( + req *http.Request, + initialResp *http.Response, + initialErr error, +) (*http.Response, error) { + var ( + lastErr = initialErr + resp = initialResp + currentDelay = rrt.initialDelay + ) + + ctx := req.Context() + + for attempt := 1; attempt <= rrt.maxRetries; attempt++ { + rrt.logRetryAttempt(ctx, attempt, currentDelay, lastErr) + + waitDuration := rrt.calculateWaitDurationWithJitter(ctx, currentDelay) + if err := rrt.waitForDelay(ctx, waitDuration); err != nil { + return nil, err // Context was cancelled during wait. + } + + // Exponential backoff for the next potential retry. + currentDelay = rrt.updateCurrentDelay(currentDelay) + + // Retry attempt. + resp, lastErr = rrt.executeRequest(req) + if !rrt.shouldRetry(resp, lastErr) { + if lastErr != nil { + return resp, fmt.Errorf("request failed on retry attempt %d: %w", attempt, lastErr) + } + + return resp, nil + } + } + + return nil, rrt.handleFinalError(ctx, resp, lastErr) +} + +// logRetryAttempt logs the details of a retry attempt. +func (rrt *RetryRoundTripper) logRetryAttempt( + ctx context.Context, + attempt int, + delay time.Duration, + err error, +) { + tflog.Info( + ctx, "Request failed, retrying...", map[string]interface{}{ + "attempt": attempt, + "max_attempts": rrt.maxRetries, + "delay": delay, + "error": err, + }, + ) +} + +// updateCurrentDelay calculates the next delay for exponential backoff. +func (rrt *RetryRoundTripper) updateCurrentDelay(currentDelay time.Duration) time.Duration { + currentDelay *= backoffMultiplier + if currentDelay > rrt.maxDelay { + return rrt.maxDelay + } + + return currentDelay +} + +// handleFinalError constructs and returns the final error after all retries have been exhausted. +func (rrt *RetryRoundTripper) handleFinalError( + ctx context.Context, + resp *http.Response, + lastErr error, +) error { + if resp != nil { + if err := resp.Body.Close(); err != nil { + + tflog.Warn( + ctx, "Failed to close response body", map[string]interface{}{ + "error": err.Error(), + }, + ) + } + } + + if lastErr != nil { + return fmt.Errorf("%w: %w", ErrRequestFailedAfterRetries, lastErr) + } + + // This case occurs if shouldRetry was true due to a retryable status code, + // but all retries failed with similar status codes. + if resp != nil { + return fmt.Errorf( + "%w: last retry attempt failed with status code %d", + ErrRequestFailedAfterRetries, + resp.StatusCode, + ) + } + + return fmt.Errorf("%w: no response received", ErrRequestFailedAfterRetries) +} + +// shouldRetry determines if a request should be retried based on the response or an error. +func (rrt *RetryRoundTripper) shouldRetry(resp *http.Response, err error) bool { + if err != nil { + return true + } + + if resp != nil { + if resp.StatusCode == http.StatusBadGateway || + resp.StatusCode == http.StatusServiceUnavailable || + resp.StatusCode == http.StatusGatewayTimeout { + return true + } + } + + return false + +} + +// calculateWaitDurationWithJitter calculates the backoff duration for the next retry, +// adding a random jitter to prevent thundering herd issues. +func (rrt *RetryRoundTripper) calculateWaitDurationWithJitter( + ctx context.Context, + baseDelay time.Duration, +) time.Duration { + if baseDelay <= 0 { + return 0 + } + + maxJitter := int64(baseDelay / jitterFactor) + if maxJitter <= 0 { + return baseDelay + } + + random, err := rand.Int(rand.Reader, big.NewInt(maxJitter)) + if err != nil { + tflog.Warn( + ctx, "Failed to generate random jitter, proceeding without it.", map[string]interface{}{ + "error": err.Error(), + }, + ) + + return baseDelay + } + + jitter := time.Duration(random.Int64()) + + return baseDelay + jitter +} + +// waitForDelay pauses execution for a given duration or until the context is canceled. +func (rrt *RetryRoundTripper) waitForDelay(ctx context.Context, delay time.Duration) error { + select { + case <-ctx.Done(): + return fmt.Errorf("context cancelled during backoff wait: %w", ctx.Err()) + case <-time.After(delay): + return nil + } +} diff --git a/stackit/internal/core/retry_round_tripper_test.go b/stackit/internal/core/retry_round_tripper_test.go new file mode 100644 index 00000000..e19c553c --- /dev/null +++ b/stackit/internal/core/retry_round_tripper_test.go @@ -0,0 +1,252 @@ +package core + +import ( + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" +) + +type mockRoundTripper struct { + roundTripFunc func(req *http.Request) (*http.Response, error) + callCount int32 +} + +func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + atomic.AddInt32(&m.callCount, 1) + + return m.roundTripFunc(req) +} + +func (m *mockRoundTripper) CallCount() int32 { + return atomic.LoadInt32(&m.callCount) +} + +func TestRetryRoundTripper_RoundTrip(t *testing.T) { + t.Parallel() + + testRetryConfig := func(next http.RoundTripper) *RetryRoundTripper { + return NewRetryRoundTripper( + next, + 3, + 1*time.Millisecond, + 10*time.Millisecond, + 50*time.Millisecond, + ) + } + + noRetryTests := []struct { + name string + mockStatusCode int + expectedStatusCode int + }{ + { + name: "should succeed on the first try", + mockStatusCode: http.StatusOK, + expectedStatusCode: http.StatusOK, + }, + { + name: "should not retry on a non-retryable status code like 400", + mockStatusCode: http.StatusBadRequest, + expectedStatusCode: http.StatusBadRequest, + }, + } + + for _, testCase := range noRetryTests { + t.Run( + testCase.name, func(t *testing.T) { + t.Parallel() + + mock := &mockRoundTripper{ + roundTripFunc: func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: testCase.mockStatusCode, + Body: io.NopCloser(nil), + Request: req, + }, nil + }, + } + tripper := testRetryConfig(mock) + req := httptest.NewRequest(http.MethodGet, "/", nil) + + resp, err := tripper.RoundTrip(req) + if resp != nil { + defer func() { + if closeErr := resp.Body.Close(); closeErr != nil { + t.Errorf("failed to close response body: %v", closeErr) + } + }() + } + + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if resp.StatusCode != testCase.expectedStatusCode { + t.Fatalf("expected status code %d, got %d", testCase.expectedStatusCode, resp.StatusCode) + } + if mock.CallCount() != 1 { + t.Fatalf("expected 1 call, got %d", mock.CallCount()) + } + }, + ) + } + + t.Run( + "should retry on retryable status code (503) and eventually fail", func(t *testing.T) { + t.Parallel() + + mock := &mockRoundTripper{ + roundTripFunc: func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Body: io.NopCloser(nil), + Request: req, + }, nil + }, + } + tripper := testRetryConfig(mock) + req := httptest.NewRequest(http.MethodGet, "/", nil) + + resp, err := tripper.RoundTrip(req) + if resp != nil { + defer func() { + if closeErr := resp.Body.Close(); closeErr != nil { + t.Errorf("failed to close response body: %v", closeErr) + } + }() + } + + if err == nil { + t.Fatal("expected an error, but got nil") + } + expectedErrorMsg := "last retry attempt failed with status code 503" + if !strings.Contains(err.Error(), expectedErrorMsg) { + t.Fatalf("expected error to contain %q, got %q", expectedErrorMsg, err.Error()) + } + if mock.CallCount() != 4 { // 1 initial + 3 retries + t.Fatalf("expected 4 calls, got %d", mock.CallCount()) + } + }, + ) + + t.Run( + "should succeed after one retry", func(t *testing.T) { + t.Parallel() + + mock := &mockRoundTripper{} + mock.roundTripFunc = func(req *http.Request) (*http.Response, error) { + if mock.CallCount() < 2 { + return &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Body: io.NopCloser(nil), + Request: req, + }, nil + } + + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(nil), + Request: req, + }, nil + } + tripper := testRetryConfig(mock) + req := httptest.NewRequest(http.MethodGet, "/", nil) + + resp, err := tripper.RoundTrip(req) + if resp != nil { + defer func() { + if closeErr := resp.Body.Close(); closeErr != nil { + t.Errorf("failed to close response body: %v", closeErr) + } + }() + } + + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status code %d, got %d", http.StatusOK, resp.StatusCode) + } + if mock.CallCount() != 2 { + t.Fatalf("expected 2 calls, got %d", mock.CallCount()) + } + }, + ) + + t.Run( + "should retry on network error", func(t *testing.T) { + t.Parallel() + + mockErr := errors.New("simulated network error") + + mock := &mockRoundTripper{ + roundTripFunc: func(req *http.Request) (*http.Response, error) { + return nil, mockErr + }, + } + tripper := testRetryConfig(mock) + req := httptest.NewRequest(http.MethodGet, "/", nil) + + resp, err := tripper.RoundTrip(req) + if resp != nil { + defer func() { + if closeErr := resp.Body.Close(); closeErr != nil { + t.Errorf("failed to close response body: %v", closeErr) + } + }() + } + + if !errors.Is(err, mockErr) { + t.Fatalf("expected error to be %v, got %v", mockErr, err) + } + if mock.CallCount() != 4 { // 1 initial + 3 retries + t.Fatalf("expected 4 calls, got %d", mock.CallCount()) + } + }, + ) + + t.Run( + "should abort retries if the main context is cancelled", func(t *testing.T) { + t.Parallel() + + mock := &mockRoundTripper{ + roundTripFunc: func(req *http.Request) (*http.Response, error) { + select { + case <-time.After(100 * time.Millisecond): + return nil, errors.New("this should not be returned") + case <-req.Context().Done(): + return nil, req.Context().Err() + } + }, + } + tripper := testRetryConfig(mock) + baseCtx := context.Background() + + ctx, cancel := context.WithTimeout(baseCtx, 20*time.Millisecond) + defer cancel() + + req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx) + + resp, err := tripper.RoundTrip(req) + if resp != nil { + defer func() { + if closeErr := resp.Body.Close(); closeErr != nil { + t.Errorf("failed to close response body: %v", closeErr) + } + }() + } + + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected error to be context.DeadlineExceeded, got %v", err) + } + if mock.CallCount() != 1 { + t.Fatalf("expected 1 call, got %d", mock.CallCount()) + } + }, + ) +} diff --git a/stackit/provider.go b/stackit/provider.go index b90d7375..ab3dd060 100644 --- a/stackit/provider.go +++ b/stackit/provider.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "strings" + "time" "github.com/hashicorp/terraform-plugin-framework-validators/stringvalidator" "github.com/hashicorp/terraform-plugin-framework/datasource" @@ -45,6 +46,17 @@ var ( _ provider.Provider = &Provider{} ) +const ( + // maxRetries is the maximum number of retries for a failed HTTP request. + maxRetries = 3 + // initialDelay is the initial delay before the first retry attempt. + initialDelay = 2 * time.Second + // maxDelay is the maximum delay between retry attempts. + maxDelay = 90 * time.Second + // perTryTimeout is the timeout for each individual HTTP request attempt. + perTryTimeout = 30 * time.Second +) + // Provider is the provider implementation. type Provider struct { version string @@ -466,7 +478,7 @@ func (p *Provider) Configure(ctx context.Context, req provider.ConfigureRequest, providerData.Experiments = experimentValues } - roundTripper, err := sdkauth.SetupAuth(sdkConfig) + baseRoundTripper, err := sdkauth.SetupAuth(sdkConfig) if err != nil { core.LogAndAddError( ctx, @@ -477,6 +489,14 @@ func (p *Provider) Configure(ctx context.Context, req provider.ConfigureRequest, return } + roundTripper := core.NewRetryRoundTripper( + baseRoundTripper, + maxRetries, + initialDelay, + maxDelay, + perTryTimeout, + ) + // Make round tripper and custom endpoints available during DataSource and Resource // type Configure methods. providerData.RoundTripper = roundTripper