feat: implement RetryRoundTripper for automatic request retries
This commit is contained in:
parent
20e9b3ca4c
commit
e5a40f8d9b
1 changed files with 241 additions and 0 deletions
241
stackit/internal/core/retry_round_tripper.go
Normal file
241
stackit/internal/core/retry_round_tripper.go
Normal file
|
|
@ -0,0 +1,241 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/terraform-plugin-log/tflog"
|
||||
)
|
||||
|
||||
const (
|
||||
// backoffMultiplier is the factor by which the delay is multiplied for exponential backoff.
|
||||
backoffMultiplier = 2
|
||||
// jitterFactor is the divisor used to calculate jitter (e.g., half of the base delay).
|
||||
jitterFactor = 2
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrRequestFailedAfterRetries is returned when a request fails after all retry attempts.
|
||||
ErrRequestFailedAfterRetries = errors.New("request failed after all retry attempts")
|
||||
)
|
||||
|
||||
// RetryRoundTripper implements an http.RoundTripper that adds automatic retry logic for failed requests.
|
||||
type RetryRoundTripper struct {
|
||||
next http.RoundTripper
|
||||
maxRetries int
|
||||
delay time.Duration
|
||||
initialDelay time.Duration
|
||||
maxDelay time.Duration
|
||||
perTryTimeout time.Duration
|
||||
}
|
||||
|
||||
// NewRetryRoundTripper creates a new instance of the RetryRoundTripper with the specified configuration.
|
||||
func NewRetryRoundTripper(
|
||||
next http.RoundTripper,
|
||||
maxRetries int,
|
||||
initialDelay, delay, maxDelay, perTryTimeout time.Duration,
|
||||
) *RetryRoundTripper {
|
||||
return &RetryRoundTripper{
|
||||
next: next,
|
||||
maxRetries: maxRetries,
|
||||
initialDelay: initialDelay,
|
||||
delay: delay,
|
||||
maxDelay: maxDelay,
|
||||
perTryTimeout: perTryTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// RoundTrip executes the request and retries on failure.
|
||||
func (rrt *RetryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
resp, err := rrt.executeRequest(req)
|
||||
if !rrt.shouldRetry(resp, err) {
|
||||
if err != nil {
|
||||
return resp, fmt.Errorf("initial request failed, not retrying: %w", err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
return rrt.retryLoop(req, resp, err)
|
||||
}
|
||||
|
||||
// executeRequest performs a single HTTP request with a per-try timeout.
|
||||
func (rrt *RetryRoundTripper) executeRequest(req *http.Request) (*http.Response, error) {
|
||||
ctx, cancel := context.WithTimeout(req.Context(), rrt.perTryTimeout)
|
||||
defer cancel()
|
||||
|
||||
resp, err := rrt.next.RoundTrip(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return resp, fmt.Errorf("per-try timeout of %v exceeded: %w", rrt.perTryTimeout, err)
|
||||
}
|
||||
|
||||
return resp, fmt.Errorf("http roundtrip failed: %w", err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// retryLoop handles the retry logic for a failed request.
|
||||
func (rrt *RetryRoundTripper) retryLoop(
|
||||
req *http.Request,
|
||||
initialResp *http.Response,
|
||||
initialErr error,
|
||||
) (*http.Response, error) {
|
||||
var (
|
||||
lastErr = initialErr
|
||||
resp = initialResp
|
||||
currentDelay = rrt.initialDelay
|
||||
)
|
||||
|
||||
ctx := req.Context()
|
||||
|
||||
for attempt := 1; attempt <= rrt.maxRetries; attempt++ {
|
||||
rrt.logRetryAttempt(ctx, attempt, currentDelay, lastErr)
|
||||
|
||||
waitDuration := rrt.calculateWaitDurationWithJitter(ctx, currentDelay)
|
||||
if err := rrt.waitForDelay(ctx, waitDuration); err != nil {
|
||||
return nil, err // Context was cancelled during wait.
|
||||
}
|
||||
|
||||
// Exponential backoff for the next potential retry.
|
||||
currentDelay = rrt.updateCurrentDelay(currentDelay)
|
||||
|
||||
// Retry attempt.
|
||||
resp, lastErr = rrt.executeRequest(req)
|
||||
if !rrt.shouldRetry(resp, lastErr) {
|
||||
if lastErr != nil {
|
||||
return resp, fmt.Errorf("request failed on retry attempt %d: %w", attempt, lastErr)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, rrt.handleFinalError(ctx, resp, lastErr)
|
||||
}
|
||||
|
||||
// logRetryAttempt logs the details of a retry attempt.
|
||||
func (rrt *RetryRoundTripper) logRetryAttempt(
|
||||
ctx context.Context,
|
||||
attempt int,
|
||||
delay time.Duration,
|
||||
err error,
|
||||
) {
|
||||
tflog.Info(
|
||||
ctx, "Request failed, retrying...", map[string]interface{}{
|
||||
"attempt": attempt,
|
||||
"max_attempts": rrt.maxRetries,
|
||||
"delay": delay,
|
||||
"error": err.Error(),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// updateCurrentDelay calculates the next delay for exponential backoff.
|
||||
func (rrt *RetryRoundTripper) updateCurrentDelay(currentDelay time.Duration) time.Duration {
|
||||
currentDelay *= backoffMultiplier
|
||||
if currentDelay > rrt.maxDelay {
|
||||
return rrt.maxDelay
|
||||
}
|
||||
|
||||
return currentDelay
|
||||
}
|
||||
|
||||
// handleFinalError constructs and returns the final error after all retries have been exhausted.
|
||||
func (rrt *RetryRoundTripper) handleFinalError(
|
||||
ctx context.Context,
|
||||
resp *http.Response,
|
||||
lastErr error,
|
||||
) error {
|
||||
if resp != nil {
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
|
||||
tflog.Warn(
|
||||
ctx, "Failed to close response body", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if lastErr != nil {
|
||||
return fmt.Errorf("%w: %w", ErrRequestFailedAfterRetries, lastErr)
|
||||
}
|
||||
|
||||
// This case occurs if shouldRetry was true due to a retryable status code,
|
||||
// but all retries failed with similar status codes.
|
||||
if resp != nil {
|
||||
return fmt.Errorf(
|
||||
"%w: last retry attempt failed with status code %d",
|
||||
ErrRequestFailedAfterRetries,
|
||||
resp.StatusCode,
|
||||
)
|
||||
}
|
||||
|
||||
return fmt.Errorf("%w: no response received", ErrRequestFailedAfterRetries)
|
||||
}
|
||||
|
||||
// shouldRetry determines if a request should be retried based on the response or an error.
|
||||
func (rrt *RetryRoundTripper) shouldRetry(resp *http.Response, err error) bool {
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if resp != nil {
|
||||
if resp.StatusCode == http.StatusBadGateway ||
|
||||
resp.StatusCode == http.StatusServiceUnavailable ||
|
||||
resp.StatusCode == http.StatusGatewayTimeout {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
|
||||
}
|
||||
|
||||
// calculateWaitDurationWithJitter calculates the backoff duration for the next retry,
|
||||
// adding a random jitter to prevent thundering herd issues.
|
||||
func (rrt *RetryRoundTripper) calculateWaitDurationWithJitter(
|
||||
ctx context.Context,
|
||||
baseDelay time.Duration,
|
||||
) time.Duration {
|
||||
if baseDelay <= 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
maxJitter := int64(baseDelay / jitterFactor)
|
||||
if maxJitter <= 0 {
|
||||
return baseDelay
|
||||
}
|
||||
|
||||
random, err := rand.Int(rand.Reader, big.NewInt(maxJitter))
|
||||
if err != nil {
|
||||
tflog.Warn(
|
||||
ctx, "Failed to generate random jitter, proceeding without it.", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
},
|
||||
)
|
||||
|
||||
return baseDelay
|
||||
}
|
||||
|
||||
jitter := time.Duration(random.Int64())
|
||||
|
||||
return baseDelay + jitter
|
||||
}
|
||||
|
||||
// waitForDelay pauses execution for a given duration or until the context is canceled.
|
||||
func (rrt *RetryRoundTripper) waitForDelay(ctx context.Context, delay time.Duration) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("context cancelled during backoff wait: %w", ctx.Err())
|
||||
case <-time.After(delay):
|
||||
return nil
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue