fix: remove unused delay field from RetryRoundTripper and update error handling
This commit is contained in:
parent
e5a40f8d9b
commit
bbdb85adfe
2 changed files with 254 additions and 4 deletions
|
|
@ -28,7 +28,6 @@ var (
|
||||||
type RetryRoundTripper struct {
|
type RetryRoundTripper struct {
|
||||||
next http.RoundTripper
|
next http.RoundTripper
|
||||||
maxRetries int
|
maxRetries int
|
||||||
delay time.Duration
|
|
||||||
initialDelay time.Duration
|
initialDelay time.Duration
|
||||||
maxDelay time.Duration
|
maxDelay time.Duration
|
||||||
perTryTimeout time.Duration
|
perTryTimeout time.Duration
|
||||||
|
|
@ -38,13 +37,12 @@ type RetryRoundTripper struct {
|
||||||
func NewRetryRoundTripper(
|
func NewRetryRoundTripper(
|
||||||
next http.RoundTripper,
|
next http.RoundTripper,
|
||||||
maxRetries int,
|
maxRetries int,
|
||||||
initialDelay, delay, maxDelay, perTryTimeout time.Duration,
|
initialDelay, maxDelay, perTryTimeout time.Duration,
|
||||||
) *RetryRoundTripper {
|
) *RetryRoundTripper {
|
||||||
return &RetryRoundTripper{
|
return &RetryRoundTripper{
|
||||||
next: next,
|
next: next,
|
||||||
maxRetries: maxRetries,
|
maxRetries: maxRetries,
|
||||||
initialDelay: initialDelay,
|
initialDelay: initialDelay,
|
||||||
delay: delay,
|
|
||||||
maxDelay: maxDelay,
|
maxDelay: maxDelay,
|
||||||
perTryTimeout: perTryTimeout,
|
perTryTimeout: perTryTimeout,
|
||||||
}
|
}
|
||||||
|
|
@ -132,7 +130,7 @@ func (rrt *RetryRoundTripper) logRetryAttempt(
|
||||||
"attempt": attempt,
|
"attempt": attempt,
|
||||||
"max_attempts": rrt.maxRetries,
|
"max_attempts": rrt.maxRetries,
|
||||||
"delay": delay,
|
"delay": delay,
|
||||||
"error": err.Error(),
|
"error": err,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
252
stackit/internal/core/retry_round_tripper_test.go
Normal file
252
stackit/internal/core/retry_round_tripper_test.go
Normal file
|
|
@ -0,0 +1,252 @@
|
||||||
|
package core
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockRoundTripper struct {
|
||||||
|
roundTripFunc func(req *http.Request) (*http.Response, error)
|
||||||
|
callCount int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
atomic.AddInt32(&m.callCount, 1)
|
||||||
|
|
||||||
|
return m.roundTripFunc(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRoundTripper) CallCount() int32 {
|
||||||
|
return atomic.LoadInt32(&m.callCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetryRoundTripper_RoundTrip(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testRetryConfig := func(next http.RoundTripper) *RetryRoundTripper {
|
||||||
|
return NewRetryRoundTripper(
|
||||||
|
next,
|
||||||
|
3,
|
||||||
|
1*time.Millisecond,
|
||||||
|
10*time.Millisecond,
|
||||||
|
50*time.Millisecond,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
noRetryTests := []struct {
|
||||||
|
name string
|
||||||
|
mockStatusCode int
|
||||||
|
expectedStatusCode int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "should succeed on the first try",
|
||||||
|
mockStatusCode: http.StatusOK,
|
||||||
|
expectedStatusCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should not retry on a non-retryable status code like 400",
|
||||||
|
mockStatusCode: http.StatusBadRequest,
|
||||||
|
expectedStatusCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range noRetryTests {
|
||||||
|
t.Run(
|
||||||
|
testCase.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mock := &mockRoundTripper{
|
||||||
|
roundTripFunc: func(req *http.Request) (*http.Response, error) {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: testCase.mockStatusCode,
|
||||||
|
Body: io.NopCloser(nil),
|
||||||
|
Request: req,
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tripper := testRetryConfig(mock)
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
|
||||||
|
resp, err := tripper.RoundTrip(req)
|
||||||
|
if resp != nil {
|
||||||
|
defer func() {
|
||||||
|
if closeErr := resp.Body.Close(); closeErr != nil {
|
||||||
|
t.Errorf("failed to close response body: %v", closeErr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != testCase.expectedStatusCode {
|
||||||
|
t.Fatalf("expected status code %d, got %d", testCase.expectedStatusCode, resp.StatusCode)
|
||||||
|
}
|
||||||
|
if mock.CallCount() != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %d", mock.CallCount())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run(
|
||||||
|
"should retry on retryable status code (503) and eventually fail", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mock := &mockRoundTripper{
|
||||||
|
roundTripFunc: func(req *http.Request) (*http.Response, error) {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusServiceUnavailable,
|
||||||
|
Body: io.NopCloser(nil),
|
||||||
|
Request: req,
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tripper := testRetryConfig(mock)
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
|
||||||
|
resp, err := tripper.RoundTrip(req)
|
||||||
|
if resp != nil {
|
||||||
|
defer func() {
|
||||||
|
if closeErr := resp.Body.Close(); closeErr != nil {
|
||||||
|
t.Errorf("failed to close response body: %v", closeErr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected an error, but got nil")
|
||||||
|
}
|
||||||
|
expectedErrorMsg := "last retry attempt failed with status code 503"
|
||||||
|
if !strings.Contains(err.Error(), expectedErrorMsg) {
|
||||||
|
t.Fatalf("expected error to contain %q, got %q", expectedErrorMsg, err.Error())
|
||||||
|
}
|
||||||
|
if mock.CallCount() != 4 { // 1 initial + 3 retries
|
||||||
|
t.Fatalf("expected 4 calls, got %d", mock.CallCount())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
t.Run(
|
||||||
|
"should succeed after one retry", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mock := &mockRoundTripper{}
|
||||||
|
mock.roundTripFunc = func(req *http.Request) (*http.Response, error) {
|
||||||
|
if mock.CallCount() < 2 {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusServiceUnavailable,
|
||||||
|
Body: io.NopCloser(nil),
|
||||||
|
Request: req,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(nil),
|
||||||
|
Request: req,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
tripper := testRetryConfig(mock)
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
|
||||||
|
resp, err := tripper.RoundTrip(req)
|
||||||
|
if resp != nil {
|
||||||
|
defer func() {
|
||||||
|
if closeErr := resp.Body.Close(); closeErr != nil {
|
||||||
|
t.Errorf("failed to close response body: %v", closeErr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code %d, got %d", http.StatusOK, resp.StatusCode)
|
||||||
|
}
|
||||||
|
if mock.CallCount() != 2 {
|
||||||
|
t.Fatalf("expected 2 calls, got %d", mock.CallCount())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
t.Run(
|
||||||
|
"should retry on network error", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mockErr := errors.New("simulated network error")
|
||||||
|
|
||||||
|
mock := &mockRoundTripper{
|
||||||
|
roundTripFunc: func(req *http.Request) (*http.Response, error) {
|
||||||
|
return nil, mockErr
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tripper := testRetryConfig(mock)
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
|
||||||
|
resp, err := tripper.RoundTrip(req)
|
||||||
|
if resp != nil {
|
||||||
|
defer func() {
|
||||||
|
if closeErr := resp.Body.Close(); closeErr != nil {
|
||||||
|
t.Errorf("failed to close response body: %v", closeErr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
if !errors.Is(err, mockErr) {
|
||||||
|
t.Fatalf("expected error to be %v, got %v", mockErr, err)
|
||||||
|
}
|
||||||
|
if mock.CallCount() != 4 { // 1 initial + 3 retries
|
||||||
|
t.Fatalf("expected 4 calls, got %d", mock.CallCount())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
t.Run(
|
||||||
|
"should abort retries if the main context is cancelled", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mock := &mockRoundTripper{
|
||||||
|
roundTripFunc: func(req *http.Request) (*http.Response, error) {
|
||||||
|
select {
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
return nil, errors.New("this should not be returned")
|
||||||
|
case <-req.Context().Done():
|
||||||
|
return nil, req.Context().Err()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tripper := testRetryConfig(mock)
|
||||||
|
baseCtx := context.Background()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(baseCtx, 20*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx)
|
||||||
|
|
||||||
|
resp, err := tripper.RoundTrip(req)
|
||||||
|
if resp != nil {
|
||||||
|
defer func() {
|
||||||
|
if closeErr := resp.Body.Close(); closeErr != nil {
|
||||||
|
t.Errorf("failed to close response body: %v", closeErr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
if !errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
t.Fatalf("expected error to be context.DeadlineExceeded, got %v", err)
|
||||||
|
}
|
||||||
|
if mock.CallCount() != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %d", mock.CallCount())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue