feat: implement RetryRoundTripper for automatic request retries
This commit is contained in:
parent
20e9b3ca4c
commit
e5a40f8d9b
1 changed files with 241 additions and 0 deletions
241
stackit/internal/core/retry_round_tripper.go
Normal file
241
stackit/internal/core/retry_round_tripper.go
Normal file
|
|
@ -0,0 +1,241 @@
|
||||||
|
package core
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/hashicorp/terraform-plugin-log/tflog"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// backoffMultiplier is the factor by which the delay is multiplied for exponential backoff.
|
||||||
|
backoffMultiplier = 2
|
||||||
|
// jitterFactor is the divisor used to calculate jitter (e.g., half of the base delay).
|
||||||
|
jitterFactor = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrRequestFailedAfterRetries is returned when a request fails after all retry attempts.
|
||||||
|
ErrRequestFailedAfterRetries = errors.New("request failed after all retry attempts")
|
||||||
|
)
|
||||||
|
|
||||||
|
// RetryRoundTripper implements an http.RoundTripper that adds automatic retry logic for failed requests.
|
||||||
|
type RetryRoundTripper struct {
|
||||||
|
next http.RoundTripper
|
||||||
|
maxRetries int
|
||||||
|
delay time.Duration
|
||||||
|
initialDelay time.Duration
|
||||||
|
maxDelay time.Duration
|
||||||
|
perTryTimeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRetryRoundTripper creates a new instance of the RetryRoundTripper with the specified configuration.
|
||||||
|
func NewRetryRoundTripper(
|
||||||
|
next http.RoundTripper,
|
||||||
|
maxRetries int,
|
||||||
|
initialDelay, delay, maxDelay, perTryTimeout time.Duration,
|
||||||
|
) *RetryRoundTripper {
|
||||||
|
return &RetryRoundTripper{
|
||||||
|
next: next,
|
||||||
|
maxRetries: maxRetries,
|
||||||
|
initialDelay: initialDelay,
|
||||||
|
delay: delay,
|
||||||
|
maxDelay: maxDelay,
|
||||||
|
perTryTimeout: perTryTimeout,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoundTrip executes the request and retries on failure.
|
||||||
|
func (rrt *RetryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
resp, err := rrt.executeRequest(req)
|
||||||
|
if !rrt.shouldRetry(resp, err) {
|
||||||
|
if err != nil {
|
||||||
|
return resp, fmt.Errorf("initial request failed, not retrying: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return rrt.retryLoop(req, resp, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// executeRequest performs a single HTTP request with a per-try timeout.
|
||||||
|
func (rrt *RetryRoundTripper) executeRequest(req *http.Request) (*http.Response, error) {
|
||||||
|
ctx, cancel := context.WithTimeout(req.Context(), rrt.perTryTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
resp, err := rrt.next.RoundTrip(req.WithContext(ctx))
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
return resp, fmt.Errorf("per-try timeout of %v exceeded: %w", rrt.perTryTimeout, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, fmt.Errorf("http roundtrip failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// retryLoop handles the retry logic for a failed request.
|
||||||
|
func (rrt *RetryRoundTripper) retryLoop(
|
||||||
|
req *http.Request,
|
||||||
|
initialResp *http.Response,
|
||||||
|
initialErr error,
|
||||||
|
) (*http.Response, error) {
|
||||||
|
var (
|
||||||
|
lastErr = initialErr
|
||||||
|
resp = initialResp
|
||||||
|
currentDelay = rrt.initialDelay
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx := req.Context()
|
||||||
|
|
||||||
|
for attempt := 1; attempt <= rrt.maxRetries; attempt++ {
|
||||||
|
rrt.logRetryAttempt(ctx, attempt, currentDelay, lastErr)
|
||||||
|
|
||||||
|
waitDuration := rrt.calculateWaitDurationWithJitter(ctx, currentDelay)
|
||||||
|
if err := rrt.waitForDelay(ctx, waitDuration); err != nil {
|
||||||
|
return nil, err // Context was cancelled during wait.
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exponential backoff for the next potential retry.
|
||||||
|
currentDelay = rrt.updateCurrentDelay(currentDelay)
|
||||||
|
|
||||||
|
// Retry attempt.
|
||||||
|
resp, lastErr = rrt.executeRequest(req)
|
||||||
|
if !rrt.shouldRetry(resp, lastErr) {
|
||||||
|
if lastErr != nil {
|
||||||
|
return resp, fmt.Errorf("request failed on retry attempt %d: %w", attempt, lastErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, rrt.handleFinalError(ctx, resp, lastErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// logRetryAttempt logs the details of a retry attempt.
|
||||||
|
func (rrt *RetryRoundTripper) logRetryAttempt(
|
||||||
|
ctx context.Context,
|
||||||
|
attempt int,
|
||||||
|
delay time.Duration,
|
||||||
|
err error,
|
||||||
|
) {
|
||||||
|
tflog.Info(
|
||||||
|
ctx, "Request failed, retrying...", map[string]interface{}{
|
||||||
|
"attempt": attempt,
|
||||||
|
"max_attempts": rrt.maxRetries,
|
||||||
|
"delay": delay,
|
||||||
|
"error": err.Error(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateCurrentDelay calculates the next delay for exponential backoff.
|
||||||
|
func (rrt *RetryRoundTripper) updateCurrentDelay(currentDelay time.Duration) time.Duration {
|
||||||
|
currentDelay *= backoffMultiplier
|
||||||
|
if currentDelay > rrt.maxDelay {
|
||||||
|
return rrt.maxDelay
|
||||||
|
}
|
||||||
|
|
||||||
|
return currentDelay
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleFinalError constructs and returns the final error after all retries have been exhausted.
|
||||||
|
func (rrt *RetryRoundTripper) handleFinalError(
|
||||||
|
ctx context.Context,
|
||||||
|
resp *http.Response,
|
||||||
|
lastErr error,
|
||||||
|
) error {
|
||||||
|
if resp != nil {
|
||||||
|
if err := resp.Body.Close(); err != nil {
|
||||||
|
|
||||||
|
tflog.Warn(
|
||||||
|
ctx, "Failed to close response body", map[string]interface{}{
|
||||||
|
"error": err.Error(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if lastErr != nil {
|
||||||
|
return fmt.Errorf("%w: %w", ErrRequestFailedAfterRetries, lastErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// This case occurs if shouldRetry was true due to a retryable status code,
|
||||||
|
// but all retries failed with similar status codes.
|
||||||
|
if resp != nil {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"%w: last retry attempt failed with status code %d",
|
||||||
|
ErrRequestFailedAfterRetries,
|
||||||
|
resp.StatusCode,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("%w: no response received", ErrRequestFailedAfterRetries)
|
||||||
|
}
|
||||||
|
|
||||||
|
// shouldRetry determines if a request should be retried based on the response or an error.
|
||||||
|
func (rrt *RetryRoundTripper) shouldRetry(resp *http.Response, err error) bool {
|
||||||
|
if err != nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp != nil {
|
||||||
|
if resp.StatusCode == http.StatusBadGateway ||
|
||||||
|
resp.StatusCode == http.StatusServiceUnavailable ||
|
||||||
|
resp.StatusCode == http.StatusGatewayTimeout {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculateWaitDurationWithJitter calculates the backoff duration for the next retry,
|
||||||
|
// adding a random jitter to prevent thundering herd issues.
|
||||||
|
func (rrt *RetryRoundTripper) calculateWaitDurationWithJitter(
|
||||||
|
ctx context.Context,
|
||||||
|
baseDelay time.Duration,
|
||||||
|
) time.Duration {
|
||||||
|
if baseDelay <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
maxJitter := int64(baseDelay / jitterFactor)
|
||||||
|
if maxJitter <= 0 {
|
||||||
|
return baseDelay
|
||||||
|
}
|
||||||
|
|
||||||
|
random, err := rand.Int(rand.Reader, big.NewInt(maxJitter))
|
||||||
|
if err != nil {
|
||||||
|
tflog.Warn(
|
||||||
|
ctx, "Failed to generate random jitter, proceeding without it.", map[string]interface{}{
|
||||||
|
"error": err.Error(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return baseDelay
|
||||||
|
}
|
||||||
|
|
||||||
|
jitter := time.Duration(random.Int64())
|
||||||
|
|
||||||
|
return baseDelay + jitter
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitForDelay pauses execution for a given duration or until the context is canceled.
|
||||||
|
func (rrt *RetryRoundTripper) waitForDelay(ctx context.Context, delay time.Duration) error {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return fmt.Errorf("context cancelled during backoff wait: %w", ctx.Err())
|
||||||
|
case <-time.After(delay):
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue