fix: remove unused delay field from RetryRoundTripper and update error handling
This commit is contained in:
parent
e5a40f8d9b
commit
bbdb85adfe
2 changed files with 254 additions and 4 deletions
|
|
@ -28,7 +28,6 @@ var (
|
|||
type RetryRoundTripper struct {
|
||||
next http.RoundTripper
|
||||
maxRetries int
|
||||
delay time.Duration
|
||||
initialDelay time.Duration
|
||||
maxDelay time.Duration
|
||||
perTryTimeout time.Duration
|
||||
|
|
@ -38,13 +37,12 @@ type RetryRoundTripper struct {
|
|||
func NewRetryRoundTripper(
|
||||
next http.RoundTripper,
|
||||
maxRetries int,
|
||||
initialDelay, delay, maxDelay, perTryTimeout time.Duration,
|
||||
initialDelay, maxDelay, perTryTimeout time.Duration,
|
||||
) *RetryRoundTripper {
|
||||
return &RetryRoundTripper{
|
||||
next: next,
|
||||
maxRetries: maxRetries,
|
||||
initialDelay: initialDelay,
|
||||
delay: delay,
|
||||
maxDelay: maxDelay,
|
||||
perTryTimeout: perTryTimeout,
|
||||
}
|
||||
|
|
@ -132,7 +130,7 @@ func (rrt *RetryRoundTripper) logRetryAttempt(
|
|||
"attempt": attempt,
|
||||
"max_attempts": rrt.maxRetries,
|
||||
"delay": delay,
|
||||
"error": err.Error(),
|
||||
"error": err,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
|
|
|||
252
stackit/internal/core/retry_round_tripper_test.go
Normal file
252
stackit/internal/core/retry_round_tripper_test.go
Normal file
|
|
@ -0,0 +1,252 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type mockRoundTripper struct {
|
||||
roundTripFunc func(req *http.Request) (*http.Response, error)
|
||||
callCount int32
|
||||
}
|
||||
|
||||
func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
atomic.AddInt32(&m.callCount, 1)
|
||||
|
||||
return m.roundTripFunc(req)
|
||||
}
|
||||
|
||||
func (m *mockRoundTripper) CallCount() int32 {
|
||||
return atomic.LoadInt32(&m.callCount)
|
||||
}
|
||||
|
||||
func TestRetryRoundTripper_RoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testRetryConfig := func(next http.RoundTripper) *RetryRoundTripper {
|
||||
return NewRetryRoundTripper(
|
||||
next,
|
||||
3,
|
||||
1*time.Millisecond,
|
||||
10*time.Millisecond,
|
||||
50*time.Millisecond,
|
||||
)
|
||||
}
|
||||
|
||||
noRetryTests := []struct {
|
||||
name string
|
||||
mockStatusCode int
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "should succeed on the first try",
|
||||
mockStatusCode: http.StatusOK,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "should not retry on a non-retryable status code like 400",
|
||||
mockStatusCode: http.StatusBadRequest,
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range noRetryTests {
|
||||
t.Run(
|
||||
testCase.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mock := &mockRoundTripper{
|
||||
roundTripFunc: func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: testCase.mockStatusCode,
|
||||
Body: io.NopCloser(nil),
|
||||
Request: req,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
tripper := testRetryConfig(mock)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
resp, err := tripper.RoundTrip(req)
|
||||
if resp != nil {
|
||||
defer func() {
|
||||
if closeErr := resp.Body.Close(); closeErr != nil {
|
||||
t.Errorf("failed to close response body: %v", closeErr)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if resp.StatusCode != testCase.expectedStatusCode {
|
||||
t.Fatalf("expected status code %d, got %d", testCase.expectedStatusCode, resp.StatusCode)
|
||||
}
|
||||
if mock.CallCount() != 1 {
|
||||
t.Fatalf("expected 1 call, got %d", mock.CallCount())
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
t.Run(
|
||||
"should retry on retryable status code (503) and eventually fail", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mock := &mockRoundTripper{
|
||||
roundTripFunc: func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Body: io.NopCloser(nil),
|
||||
Request: req,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
tripper := testRetryConfig(mock)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
resp, err := tripper.RoundTrip(req)
|
||||
if resp != nil {
|
||||
defer func() {
|
||||
if closeErr := resp.Body.Close(); closeErr != nil {
|
||||
t.Errorf("failed to close response body: %v", closeErr)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected an error, but got nil")
|
||||
}
|
||||
expectedErrorMsg := "last retry attempt failed with status code 503"
|
||||
if !strings.Contains(err.Error(), expectedErrorMsg) {
|
||||
t.Fatalf("expected error to contain %q, got %q", expectedErrorMsg, err.Error())
|
||||
}
|
||||
if mock.CallCount() != 4 { // 1 initial + 3 retries
|
||||
t.Fatalf("expected 4 calls, got %d", mock.CallCount())
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
t.Run(
|
||||
"should succeed after one retry", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mock := &mockRoundTripper{}
|
||||
mock.roundTripFunc = func(req *http.Request) (*http.Response, error) {
|
||||
if mock.CallCount() < 2 {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Body: io.NopCloser(nil),
|
||||
Request: req,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(nil),
|
||||
Request: req,
|
||||
}, nil
|
||||
}
|
||||
tripper := testRetryConfig(mock)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
resp, err := tripper.RoundTrip(req)
|
||||
if resp != nil {
|
||||
defer func() {
|
||||
if closeErr := resp.Body.Close(); closeErr != nil {
|
||||
t.Errorf("failed to close response body: %v", closeErr)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status code %d, got %d", http.StatusOK, resp.StatusCode)
|
||||
}
|
||||
if mock.CallCount() != 2 {
|
||||
t.Fatalf("expected 2 calls, got %d", mock.CallCount())
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
t.Run(
|
||||
"should retry on network error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockErr := errors.New("simulated network error")
|
||||
|
||||
mock := &mockRoundTripper{
|
||||
roundTripFunc: func(req *http.Request) (*http.Response, error) {
|
||||
return nil, mockErr
|
||||
},
|
||||
}
|
||||
tripper := testRetryConfig(mock)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
resp, err := tripper.RoundTrip(req)
|
||||
if resp != nil {
|
||||
defer func() {
|
||||
if closeErr := resp.Body.Close(); closeErr != nil {
|
||||
t.Errorf("failed to close response body: %v", closeErr)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if !errors.Is(err, mockErr) {
|
||||
t.Fatalf("expected error to be %v, got %v", mockErr, err)
|
||||
}
|
||||
if mock.CallCount() != 4 { // 1 initial + 3 retries
|
||||
t.Fatalf("expected 4 calls, got %d", mock.CallCount())
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
t.Run(
|
||||
"should abort retries if the main context is cancelled", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mock := &mockRoundTripper{
|
||||
roundTripFunc: func(req *http.Request) (*http.Response, error) {
|
||||
select {
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
return nil, errors.New("this should not be returned")
|
||||
case <-req.Context().Done():
|
||||
return nil, req.Context().Err()
|
||||
}
|
||||
},
|
||||
}
|
||||
tripper := testRetryConfig(mock)
|
||||
baseCtx := context.Background()
|
||||
|
||||
ctx, cancel := context.WithTimeout(baseCtx, 20*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx)
|
||||
|
||||
resp, err := tripper.RoundTrip(req)
|
||||
if resp != nil {
|
||||
defer func() {
|
||||
if closeErr := resp.Body.Close(); closeErr != nil {
|
||||
t.Errorf("failed to close response body: %v", closeErr)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
t.Fatalf("expected error to be context.DeadlineExceeded, got %v", err)
|
||||
}
|
||||
if mock.CallCount() != 1 {
|
||||
t.Fatalf("expected 1 call, got %d", mock.CallCount())
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue