chore: #64 add system hardening with retry logic for client (#68)
All checks were successful
Publish / Check GoReleaser config (push) Successful in 5s
Publish / Publish provider (push) Successful in 12m49s

- implement RetryRoundTripper

Refs: #64

Reviewed-on: #68
Reviewed-by: Marcel_Henselin <marcel.henselin@stackit.cloud>
Co-authored-by: Andre Harms <andre.harms@stackit.cloud>
Co-committed-by: Andre Harms <andre.harms@stackit.cloud>
This commit is contained in:
Andre_Harms 2026-02-16 09:35:21 +00:00 committed by Marcel_Henselin
parent 20e9b3ca4c
commit d5644ec27f
Signed by: tf-provider.git.onstackit.cloud
GPG key ID: 6D7E8A1ED8955A9C
3 changed files with 512 additions and 1 deletions

View file

@ -0,0 +1,239 @@
package core
import (
"context"
"crypto/rand"
"errors"
"fmt"
"math/big"
"net/http"
"time"
"github.com/hashicorp/terraform-plugin-log/tflog"
)
const (
// backoffMultiplier is the factor by which the delay is multiplied for exponential backoff.
backoffMultiplier = 2
// jitterFactor is the divisor used to calculate jitter (e.g., half of the base delay).
jitterFactor = 2
)
var (
// ErrRequestFailedAfterRetries is returned when a request fails after all retry attempts.
ErrRequestFailedAfterRetries = errors.New("request failed after all retry attempts")
)
// RetryRoundTripper implements an http.RoundTripper that adds automatic retry logic for failed requests.
type RetryRoundTripper struct {
next http.RoundTripper
maxRetries int
initialDelay time.Duration
maxDelay time.Duration
perTryTimeout time.Duration
}
// NewRetryRoundTripper creates a new instance of the RetryRoundTripper with the specified configuration.
func NewRetryRoundTripper(
next http.RoundTripper,
maxRetries int,
initialDelay, maxDelay, perTryTimeout time.Duration,
) *RetryRoundTripper {
return &RetryRoundTripper{
next: next,
maxRetries: maxRetries,
initialDelay: initialDelay,
maxDelay: maxDelay,
perTryTimeout: perTryTimeout,
}
}
// RoundTrip executes the request and retries on failure.
func (rrt *RetryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
resp, err := rrt.executeRequest(req)
if !rrt.shouldRetry(resp, err) {
if err != nil {
return resp, fmt.Errorf("initial request failed, not retrying: %w", err)
}
return resp, nil
}
return rrt.retryLoop(req, resp, err)
}
// executeRequest performs a single HTTP request with a per-try timeout.
func (rrt *RetryRoundTripper) executeRequest(req *http.Request) (*http.Response, error) {
ctx, cancel := context.WithTimeout(req.Context(), rrt.perTryTimeout)
defer cancel()
resp, err := rrt.next.RoundTrip(req.WithContext(ctx))
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
return resp, fmt.Errorf("per-try timeout of %v exceeded: %w", rrt.perTryTimeout, err)
}
return resp, fmt.Errorf("http roundtrip failed: %w", err)
}
return resp, nil
}
// retryLoop handles the retry logic for a failed request.
func (rrt *RetryRoundTripper) retryLoop(
req *http.Request,
initialResp *http.Response,
initialErr error,
) (*http.Response, error) {
var (
lastErr = initialErr
resp = initialResp
currentDelay = rrt.initialDelay
)
ctx := req.Context()
for attempt := 1; attempt <= rrt.maxRetries; attempt++ {
rrt.logRetryAttempt(ctx, attempt, currentDelay, lastErr)
waitDuration := rrt.calculateWaitDurationWithJitter(ctx, currentDelay)
if err := rrt.waitForDelay(ctx, waitDuration); err != nil {
return nil, err // Context was cancelled during wait.
}
// Exponential backoff for the next potential retry.
currentDelay = rrt.updateCurrentDelay(currentDelay)
// Retry attempt.
resp, lastErr = rrt.executeRequest(req)
if !rrt.shouldRetry(resp, lastErr) {
if lastErr != nil {
return resp, fmt.Errorf("request failed on retry attempt %d: %w", attempt, lastErr)
}
return resp, nil
}
}
return nil, rrt.handleFinalError(ctx, resp, lastErr)
}
// logRetryAttempt logs the details of a retry attempt.
func (rrt *RetryRoundTripper) logRetryAttempt(
ctx context.Context,
attempt int,
delay time.Duration,
err error,
) {
tflog.Info(
ctx, "Request failed, retrying...", map[string]interface{}{
"attempt": attempt,
"max_attempts": rrt.maxRetries,
"delay": delay,
"error": err,
},
)
}
// updateCurrentDelay calculates the next delay for exponential backoff.
func (rrt *RetryRoundTripper) updateCurrentDelay(currentDelay time.Duration) time.Duration {
currentDelay *= backoffMultiplier
if currentDelay > rrt.maxDelay {
return rrt.maxDelay
}
return currentDelay
}
// handleFinalError constructs and returns the final error after all retries have been exhausted.
func (rrt *RetryRoundTripper) handleFinalError(
ctx context.Context,
resp *http.Response,
lastErr error,
) error {
if resp != nil {
if err := resp.Body.Close(); err != nil {
tflog.Warn(
ctx, "Failed to close response body", map[string]interface{}{
"error": err.Error(),
},
)
}
}
if lastErr != nil {
return fmt.Errorf("%w: %w", ErrRequestFailedAfterRetries, lastErr)
}
// This case occurs if shouldRetry was true due to a retryable status code,
// but all retries failed with similar status codes.
if resp != nil {
return fmt.Errorf(
"%w: last retry attempt failed with status code %d",
ErrRequestFailedAfterRetries,
resp.StatusCode,
)
}
return fmt.Errorf("%w: no response received", ErrRequestFailedAfterRetries)
}
// shouldRetry determines if a request should be retried based on the response or an error.
func (rrt *RetryRoundTripper) shouldRetry(resp *http.Response, err error) bool {
if err != nil {
return true
}
if resp != nil {
if resp.StatusCode == http.StatusBadGateway ||
resp.StatusCode == http.StatusServiceUnavailable ||
resp.StatusCode == http.StatusGatewayTimeout {
return true
}
}
return false
}
// calculateWaitDurationWithJitter calculates the backoff duration for the next retry,
// adding a random jitter to prevent thundering herd issues.
func (rrt *RetryRoundTripper) calculateWaitDurationWithJitter(
ctx context.Context,
baseDelay time.Duration,
) time.Duration {
if baseDelay <= 0 {
return 0
}
maxJitter := int64(baseDelay / jitterFactor)
if maxJitter <= 0 {
return baseDelay
}
random, err := rand.Int(rand.Reader, big.NewInt(maxJitter))
if err != nil {
tflog.Warn(
ctx, "Failed to generate random jitter, proceeding without it.", map[string]interface{}{
"error": err.Error(),
},
)
return baseDelay
}
jitter := time.Duration(random.Int64())
return baseDelay + jitter
}
// waitForDelay pauses execution for a given duration or until the context is canceled.
func (rrt *RetryRoundTripper) waitForDelay(ctx context.Context, delay time.Duration) error {
select {
case <-ctx.Done():
return fmt.Errorf("context cancelled during backoff wait: %w", ctx.Err())
case <-time.After(delay):
return nil
}
}

View file

@ -0,0 +1,252 @@
package core
import (
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"
)
type mockRoundTripper struct {
roundTripFunc func(req *http.Request) (*http.Response, error)
callCount int32
}
func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
atomic.AddInt32(&m.callCount, 1)
return m.roundTripFunc(req)
}
func (m *mockRoundTripper) CallCount() int32 {
return atomic.LoadInt32(&m.callCount)
}
func TestRetryRoundTripper_RoundTrip(t *testing.T) {
t.Parallel()
testRetryConfig := func(next http.RoundTripper) *RetryRoundTripper {
return NewRetryRoundTripper(
next,
3,
1*time.Millisecond,
10*time.Millisecond,
50*time.Millisecond,
)
}
noRetryTests := []struct {
name string
mockStatusCode int
expectedStatusCode int
}{
{
name: "should succeed on the first try",
mockStatusCode: http.StatusOK,
expectedStatusCode: http.StatusOK,
},
{
name: "should not retry on a non-retryable status code like 400",
mockStatusCode: http.StatusBadRequest,
expectedStatusCode: http.StatusBadRequest,
},
}
for _, testCase := range noRetryTests {
t.Run(
testCase.name, func(t *testing.T) {
t.Parallel()
mock := &mockRoundTripper{
roundTripFunc: func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: testCase.mockStatusCode,
Body: io.NopCloser(nil),
Request: req,
}, nil
},
}
tripper := testRetryConfig(mock)
req := httptest.NewRequest(http.MethodGet, "/", nil)
resp, err := tripper.RoundTrip(req)
if resp != nil {
defer func() {
if closeErr := resp.Body.Close(); closeErr != nil {
t.Errorf("failed to close response body: %v", closeErr)
}
}()
}
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if resp.StatusCode != testCase.expectedStatusCode {
t.Fatalf("expected status code %d, got %d", testCase.expectedStatusCode, resp.StatusCode)
}
if mock.CallCount() != 1 {
t.Fatalf("expected 1 call, got %d", mock.CallCount())
}
},
)
}
t.Run(
"should retry on retryable status code (503) and eventually fail", func(t *testing.T) {
t.Parallel()
mock := &mockRoundTripper{
roundTripFunc: func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusServiceUnavailable,
Body: io.NopCloser(nil),
Request: req,
}, nil
},
}
tripper := testRetryConfig(mock)
req := httptest.NewRequest(http.MethodGet, "/", nil)
resp, err := tripper.RoundTrip(req)
if resp != nil {
defer func() {
if closeErr := resp.Body.Close(); closeErr != nil {
t.Errorf("failed to close response body: %v", closeErr)
}
}()
}
if err == nil {
t.Fatal("expected an error, but got nil")
}
expectedErrorMsg := "last retry attempt failed with status code 503"
if !strings.Contains(err.Error(), expectedErrorMsg) {
t.Fatalf("expected error to contain %q, got %q", expectedErrorMsg, err.Error())
}
if mock.CallCount() != 4 { // 1 initial + 3 retries
t.Fatalf("expected 4 calls, got %d", mock.CallCount())
}
},
)
t.Run(
"should succeed after one retry", func(t *testing.T) {
t.Parallel()
mock := &mockRoundTripper{}
mock.roundTripFunc = func(req *http.Request) (*http.Response, error) {
if mock.CallCount() < 2 {
return &http.Response{
StatusCode: http.StatusServiceUnavailable,
Body: io.NopCloser(nil),
Request: req,
}, nil
}
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(nil),
Request: req,
}, nil
}
tripper := testRetryConfig(mock)
req := httptest.NewRequest(http.MethodGet, "/", nil)
resp, err := tripper.RoundTrip(req)
if resp != nil {
defer func() {
if closeErr := resp.Body.Close(); closeErr != nil {
t.Errorf("failed to close response body: %v", closeErr)
}
}()
}
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status code %d, got %d", http.StatusOK, resp.StatusCode)
}
if mock.CallCount() != 2 {
t.Fatalf("expected 2 calls, got %d", mock.CallCount())
}
},
)
t.Run(
"should retry on network error", func(t *testing.T) {
t.Parallel()
mockErr := errors.New("simulated network error")
mock := &mockRoundTripper{
roundTripFunc: func(req *http.Request) (*http.Response, error) {
return nil, mockErr
},
}
tripper := testRetryConfig(mock)
req := httptest.NewRequest(http.MethodGet, "/", nil)
resp, err := tripper.RoundTrip(req)
if resp != nil {
defer func() {
if closeErr := resp.Body.Close(); closeErr != nil {
t.Errorf("failed to close response body: %v", closeErr)
}
}()
}
if !errors.Is(err, mockErr) {
t.Fatalf("expected error to be %v, got %v", mockErr, err)
}
if mock.CallCount() != 4 { // 1 initial + 3 retries
t.Fatalf("expected 4 calls, got %d", mock.CallCount())
}
},
)
t.Run(
"should abort retries if the main context is cancelled", func(t *testing.T) {
t.Parallel()
mock := &mockRoundTripper{
roundTripFunc: func(req *http.Request) (*http.Response, error) {
select {
case <-time.After(100 * time.Millisecond):
return nil, errors.New("this should not be returned")
case <-req.Context().Done():
return nil, req.Context().Err()
}
},
}
tripper := testRetryConfig(mock)
baseCtx := context.Background()
ctx, cancel := context.WithTimeout(baseCtx, 20*time.Millisecond)
defer cancel()
req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx)
resp, err := tripper.RoundTrip(req)
if resp != nil {
defer func() {
if closeErr := resp.Body.Close(); closeErr != nil {
t.Errorf("failed to close response body: %v", closeErr)
}
}()
}
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("expected error to be context.DeadlineExceeded, got %v", err)
}
if mock.CallCount() != 1 {
t.Fatalf("expected 1 call, got %d", mock.CallCount())
}
},
)
}

View file

@ -6,6 +6,7 @@ import (
"context"
"fmt"
"strings"
"time"
"github.com/hashicorp/terraform-plugin-framework-validators/stringvalidator"
"github.com/hashicorp/terraform-plugin-framework/datasource"
@ -45,6 +46,17 @@ var (
_ provider.Provider = &Provider{}
)
const (
// maxRetries is the maximum number of retries for a failed HTTP request.
maxRetries = 3
// initialDelay is the initial delay before the first retry attempt.
initialDelay = 2 * time.Second
// maxDelay is the maximum delay between retry attempts.
maxDelay = 90 * time.Second
// perTryTimeout is the timeout for each individual HTTP request attempt.
perTryTimeout = 30 * time.Second
)
// Provider is the provider implementation.
type Provider struct {
version string
@ -466,7 +478,7 @@ func (p *Provider) Configure(ctx context.Context, req provider.ConfigureRequest,
providerData.Experiments = experimentValues
}
roundTripper, err := sdkauth.SetupAuth(sdkConfig)
baseRoundTripper, err := sdkauth.SetupAuth(sdkConfig)
if err != nil {
core.LogAndAddError(
ctx,
@ -477,6 +489,14 @@ func (p *Provider) Configure(ctx context.Context, req provider.ConfigureRequest,
return
}
roundTripper := core.NewRetryRoundTripper(
baseRoundTripper,
maxRetries,
initialDelay,
maxDelay,
perTryTimeout,
)
// Make round tripper and custom endpoints available during DataSource and Resource
// type Configure methods.
providerData.RoundTripper = roundTripper