From bbdb85adfe45b1b736b377b4992f54763401c50a Mon Sep 17 00:00:00 2001 From: Andre Harms Date: Mon, 16 Feb 2026 09:54:44 +0100 Subject: [PATCH] fix: remove unused delay field from RetryRoundTripper and update error handling --- stackit/internal/core/retry_round_tripper.go | 6 +- .../internal/core/retry_round_tripper_test.go | 252 ++++++++++++++++++ 2 files changed, 254 insertions(+), 4 deletions(-) create mode 100644 stackit/internal/core/retry_round_tripper_test.go diff --git a/stackit/internal/core/retry_round_tripper.go b/stackit/internal/core/retry_round_tripper.go index 8495b803..945a3061 100644 --- a/stackit/internal/core/retry_round_tripper.go +++ b/stackit/internal/core/retry_round_tripper.go @@ -28,7 +28,6 @@ var ( type RetryRoundTripper struct { next http.RoundTripper maxRetries int - delay time.Duration initialDelay time.Duration maxDelay time.Duration perTryTimeout time.Duration @@ -38,13 +37,12 @@ type RetryRoundTripper struct { func NewRetryRoundTripper( next http.RoundTripper, maxRetries int, - initialDelay, delay, maxDelay, perTryTimeout time.Duration, + initialDelay, maxDelay, perTryTimeout time.Duration, ) *RetryRoundTripper { return &RetryRoundTripper{ next: next, maxRetries: maxRetries, initialDelay: initialDelay, - delay: delay, maxDelay: maxDelay, perTryTimeout: perTryTimeout, } @@ -132,7 +130,7 @@ func (rrt *RetryRoundTripper) logRetryAttempt( "attempt": attempt, "max_attempts": rrt.maxRetries, "delay": delay, - "error": err.Error(), + "error": err, }, ) } diff --git a/stackit/internal/core/retry_round_tripper_test.go b/stackit/internal/core/retry_round_tripper_test.go new file mode 100644 index 00000000..e19c553c --- /dev/null +++ b/stackit/internal/core/retry_round_tripper_test.go @@ -0,0 +1,252 @@ +package core + +import ( + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" +) + +type mockRoundTripper struct { + roundTripFunc func(req *http.Request) (*http.Response, error) + callCount int32 +} + +func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + atomic.AddInt32(&m.callCount, 1) + + return m.roundTripFunc(req) +} + +func (m *mockRoundTripper) CallCount() int32 { + return atomic.LoadInt32(&m.callCount) +} + +func TestRetryRoundTripper_RoundTrip(t *testing.T) { + t.Parallel() + + testRetryConfig := func(next http.RoundTripper) *RetryRoundTripper { + return NewRetryRoundTripper( + next, + 3, + 1*time.Millisecond, + 10*time.Millisecond, + 50*time.Millisecond, + ) + } + + noRetryTests := []struct { + name string + mockStatusCode int + expectedStatusCode int + }{ + { + name: "should succeed on the first try", + mockStatusCode: http.StatusOK, + expectedStatusCode: http.StatusOK, + }, + { + name: "should not retry on a non-retryable status code like 400", + mockStatusCode: http.StatusBadRequest, + expectedStatusCode: http.StatusBadRequest, + }, + } + + for _, testCase := range noRetryTests { + t.Run( + testCase.name, func(t *testing.T) { + t.Parallel() + + mock := &mockRoundTripper{ + roundTripFunc: func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: testCase.mockStatusCode, + Body: io.NopCloser(nil), + Request: req, + }, nil + }, + } + tripper := testRetryConfig(mock) + req := httptest.NewRequest(http.MethodGet, "/", nil) + + resp, err := tripper.RoundTrip(req) + if resp != nil { + defer func() { + if closeErr := resp.Body.Close(); closeErr != nil { + t.Errorf("failed to close response body: %v", closeErr) + } + }() + } + + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if resp.StatusCode != testCase.expectedStatusCode { + t.Fatalf("expected status code %d, got %d", testCase.expectedStatusCode, resp.StatusCode) + } + if mock.CallCount() != 1 { + t.Fatalf("expected 1 call, got %d", mock.CallCount()) + } + }, + ) + } + + t.Run( + "should retry on retryable status code (503) and eventually fail", func(t *testing.T) { + t.Parallel() + + mock := &mockRoundTripper{ + roundTripFunc: func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Body: io.NopCloser(nil), + Request: req, + }, nil + }, + } + tripper := testRetryConfig(mock) + req := httptest.NewRequest(http.MethodGet, "/", nil) + + resp, err := tripper.RoundTrip(req) + if resp != nil { + defer func() { + if closeErr := resp.Body.Close(); closeErr != nil { + t.Errorf("failed to close response body: %v", closeErr) + } + }() + } + + if err == nil { + t.Fatal("expected an error, but got nil") + } + expectedErrorMsg := "last retry attempt failed with status code 503" + if !strings.Contains(err.Error(), expectedErrorMsg) { + t.Fatalf("expected error to contain %q, got %q", expectedErrorMsg, err.Error()) + } + if mock.CallCount() != 4 { // 1 initial + 3 retries + t.Fatalf("expected 4 calls, got %d", mock.CallCount()) + } + }, + ) + + t.Run( + "should succeed after one retry", func(t *testing.T) { + t.Parallel() + + mock := &mockRoundTripper{} + mock.roundTripFunc = func(req *http.Request) (*http.Response, error) { + if mock.CallCount() < 2 { + return &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Body: io.NopCloser(nil), + Request: req, + }, nil + } + + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(nil), + Request: req, + }, nil + } + tripper := testRetryConfig(mock) + req := httptest.NewRequest(http.MethodGet, "/", nil) + + resp, err := tripper.RoundTrip(req) + if resp != nil { + defer func() { + if closeErr := resp.Body.Close(); closeErr != nil { + t.Errorf("failed to close response body: %v", closeErr) + } + }() + } + + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status code %d, got %d", http.StatusOK, resp.StatusCode) + } + if mock.CallCount() != 2 { + t.Fatalf("expected 2 calls, got %d", mock.CallCount()) + } + }, + ) + + t.Run( + "should retry on network error", func(t *testing.T) { + t.Parallel() + + mockErr := errors.New("simulated network error") + + mock := &mockRoundTripper{ + roundTripFunc: func(req *http.Request) (*http.Response, error) { + return nil, mockErr + }, + } + tripper := testRetryConfig(mock) + req := httptest.NewRequest(http.MethodGet, "/", nil) + + resp, err := tripper.RoundTrip(req) + if resp != nil { + defer func() { + if closeErr := resp.Body.Close(); closeErr != nil { + t.Errorf("failed to close response body: %v", closeErr) + } + }() + } + + if !errors.Is(err, mockErr) { + t.Fatalf("expected error to be %v, got %v", mockErr, err) + } + if mock.CallCount() != 4 { // 1 initial + 3 retries + t.Fatalf("expected 4 calls, got %d", mock.CallCount()) + } + }, + ) + + t.Run( + "should abort retries if the main context is cancelled", func(t *testing.T) { + t.Parallel() + + mock := &mockRoundTripper{ + roundTripFunc: func(req *http.Request) (*http.Response, error) { + select { + case <-time.After(100 * time.Millisecond): + return nil, errors.New("this should not be returned") + case <-req.Context().Done(): + return nil, req.Context().Err() + } + }, + } + tripper := testRetryConfig(mock) + baseCtx := context.Background() + + ctx, cancel := context.WithTimeout(baseCtx, 20*time.Millisecond) + defer cancel() + + req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx) + + resp, err := tripper.RoundTrip(req) + if resp != nil { + defer func() { + if closeErr := resp.Body.Close(); closeErr != nil { + t.Errorf("failed to close response body: %v", closeErr) + } + }() + } + + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected error to be context.DeadlineExceeded, got %v", err) + } + if mock.CallCount() != 1 { + t.Fatalf("expected 1 call, got %d", mock.CallCount()) + } + }, + ) +}