302 lines
8.2 KiB
Go
302 lines
8.2 KiB
Go
package routing
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log/slog"
|
|
"sync/atomic"
|
|
"time"
|
|
)
|
|
|
|
// Provider is the minimal interface that all providers must implement.
|
|
// This is intentionally minimal to work with any provider type.
|
|
type Provider interface {
|
|
Name() string
|
|
}
|
|
|
|
// ExemptProviders lists providers that never enter cooldown.
|
|
// These are typically pay-per-use providers with no rate limits.
|
|
//
|
|
// The executor checks this map and skips cooldown recording for exempt providers.
|
|
// This is a package-level decision (business logic) rather than per-request config.
|
|
var ExemptProviders = map[string]bool{
|
|
"laozhang": true,
|
|
}
|
|
|
|
// ExecuteConfig configures the fallback executor.
|
|
type ExecuteConfig struct {
|
|
// Strategy for routing requests to providers.
|
|
// Required. Use StrategyFallback for production workloads.
|
|
Strategy Strategy
|
|
|
|
// Cooldown tracker for managing provider cooldowns.
|
|
// Optional. If nil, cooldowns are not tracked.
|
|
Cooldown CooldownTracker
|
|
|
|
// Logger for debug and operational logging.
|
|
// Optional. If nil, uses slog.Default().
|
|
Logger *slog.Logger
|
|
|
|
// RoundRobinIndex is an atomic counter for round-robin distribution.
|
|
// Required only for StrategyRoundRobin. Pass a shared pointer.
|
|
RoundRobinIndex *atomic.Uint64
|
|
}
|
|
|
|
// ExecuteResult contains the result of executing across providers.
|
|
type ExecuteResult[T any] struct {
|
|
// Response is the successful response from the provider.
|
|
Response T
|
|
|
|
// Provider is the name of the provider that succeeded.
|
|
Provider string
|
|
|
|
// Latency is the time taken for the successful call.
|
|
Latency time.Duration
|
|
|
|
// AttemptNum is which attempt succeeded (1-based).
|
|
// 1 = first provider succeeded, 2 = first failed + second succeeded, etc.
|
|
AttemptNum int
|
|
|
|
// WasTerminus is true if success came from the terminus (last) provider.
|
|
// This can indicate that all earlier providers are having issues.
|
|
WasTerminus bool
|
|
}
|
|
|
|
// ExecuteFunc is the function signature for provider execution.
|
|
// T is the response type (e.g., *ImageResponse, *TextResponse).
|
|
type ExecuteFunc[T any] func(ctx context.Context, provider Provider) (T, error)
|
|
|
|
// Execute runs the given function across providers using the configured strategy.
|
|
//
|
|
// MANDATORY: All provider routing in the codebase MUST use this function.
|
|
// Do NOT implement custom fallback loops elsewhere.
|
|
//
|
|
// # Terminus Semantics (StrategyFallback)
|
|
//
|
|
// The LAST provider in the list is the "terminus" and will ALWAYS be attempted
|
|
// regardless of cooldown state. This ensures there's always a fallback of last resort.
|
|
//
|
|
// Example: With providers [Gemini, Grok, LaoZhang]:
|
|
// - Gemini: Checked against cooldown, may be skipped
|
|
// - Grok: Checked against cooldown, may be skipped
|
|
// - LaoZhang (terminus): ALWAYS tried, even if in cooldown
|
|
//
|
|
// # Exempt Providers
|
|
//
|
|
// Providers listed in ExemptProviders (like "laozhang") never enter cooldown,
|
|
// even if they fail. This is appropriate for pay-per-use providers.
|
|
//
|
|
// # Type Safety
|
|
//
|
|
// The generic type parameter T allows compile-time type checking for the response.
|
|
// Usage:
|
|
//
|
|
// result, err := routing.Execute(ctx, providers, config,
|
|
// func(ctx context.Context, p routing.Provider) (*MyResponse, error) {
|
|
// return p.(MyProvider).DoSomething(ctx, req)
|
|
// })
|
|
func Execute[T any](
|
|
ctx context.Context,
|
|
providers []Provider,
|
|
config ExecuteConfig,
|
|
fn ExecuteFunc[T],
|
|
) (*ExecuteResult[T], error) {
|
|
if len(providers) == 0 {
|
|
var zero T
|
|
return &ExecuteResult[T]{Response: zero}, ErrNoProviders
|
|
}
|
|
|
|
logger := config.Logger
|
|
if logger == nil {
|
|
logger = slog.Default()
|
|
}
|
|
|
|
switch config.Strategy {
|
|
case StrategyPrimaryOnly, "":
|
|
return executePrimaryOnly(ctx, providers, fn, logger)
|
|
case StrategyFallback:
|
|
return executeFallback(ctx, providers, config.Cooldown, fn, logger)
|
|
case StrategyRoundRobin:
|
|
return executeRoundRobin(ctx, providers, config.RoundRobinIndex, fn, logger)
|
|
default:
|
|
var zero T
|
|
return &ExecuteResult[T]{Response: zero}, fmt.Errorf("%w: unknown strategy %s", ErrInvalidConfig, config.Strategy)
|
|
}
|
|
}
|
|
|
|
func executePrimaryOnly[T any](
|
|
ctx context.Context,
|
|
providers []Provider,
|
|
fn ExecuteFunc[T],
|
|
logger *slog.Logger,
|
|
) (*ExecuteResult[T], error) {
|
|
provider := providers[0]
|
|
start := time.Now()
|
|
|
|
resp, err := fn(ctx, provider)
|
|
latency := time.Since(start)
|
|
|
|
if err != nil {
|
|
var zero T
|
|
logger.Error("primary provider failed",
|
|
"provider", provider.Name(),
|
|
"error", err,
|
|
"latency", latency,
|
|
)
|
|
return &ExecuteResult[T]{Response: zero}, fmt.Errorf("primary provider %s failed: %w", provider.Name(), err)
|
|
}
|
|
|
|
return &ExecuteResult[T]{
|
|
Response: resp,
|
|
Provider: provider.Name(),
|
|
Latency: latency,
|
|
AttemptNum: 1,
|
|
}, nil
|
|
}
|
|
|
|
func executeFallback[T any](
|
|
ctx context.Context,
|
|
providers []Provider,
|
|
cooldown CooldownTracker,
|
|
fn ExecuteFunc[T],
|
|
logger *slog.Logger,
|
|
) (*ExecuteResult[T], error) {
|
|
var lastErr error
|
|
attemptNum := 0
|
|
terminusIdx := len(providers) - 1
|
|
|
|
for i, provider := range providers {
|
|
isTerminus := i == terminusIdx
|
|
providerName := provider.Name()
|
|
|
|
// Check cooldown UNLESS this is the terminus provider.
|
|
// TERMINUS IS ALWAYS ATTEMPTED as the fallback of last resort.
|
|
if !isTerminus && cooldown != nil && !cooldown.IsAvailable(providerName) {
|
|
remaining := cooldown.CooldownRemaining(providerName)
|
|
logger.Debug("skipping provider in cooldown",
|
|
"provider", providerName,
|
|
"cooldown_remaining", remaining.Round(time.Second),
|
|
)
|
|
continue
|
|
}
|
|
|
|
// Log when terminus is being attempted despite cooldown
|
|
if isTerminus && cooldown != nil && !cooldown.IsAvailable(providerName) {
|
|
remaining := cooldown.CooldownRemaining(providerName)
|
|
logger.Info("attempting terminus provider despite cooldown",
|
|
"provider", providerName,
|
|
"cooldown_remaining", remaining.Round(time.Second),
|
|
"reason", "terminus_always_tried",
|
|
)
|
|
}
|
|
|
|
attemptNum++
|
|
logger.Debug("attempting provider",
|
|
"provider", providerName,
|
|
"attempt", attemptNum,
|
|
"is_terminus", isTerminus,
|
|
"provider_index", i+1,
|
|
"total_providers", len(providers),
|
|
)
|
|
|
|
start := time.Now()
|
|
resp, err := fn(ctx, provider)
|
|
latency := time.Since(start)
|
|
|
|
if err == nil {
|
|
if attemptNum > 1 {
|
|
logger.Info("succeeded on fallback provider",
|
|
"provider", providerName,
|
|
"attempt", attemptNum,
|
|
"latency", latency,
|
|
"is_terminus", isTerminus,
|
|
)
|
|
}
|
|
return &ExecuteResult[T]{
|
|
Response: resp,
|
|
Provider: providerName,
|
|
Latency: latency,
|
|
AttemptNum: attemptNum,
|
|
WasTerminus: isTerminus,
|
|
}, nil
|
|
}
|
|
|
|
// Record failure for cooldown tracking.
|
|
// Exempt providers (like laozhang) never enter cooldown.
|
|
if cooldown != nil && !ExemptProviders[providerName] {
|
|
if cooldown.RecordFailure(providerName, err) {
|
|
remaining := cooldown.CooldownRemaining(providerName)
|
|
logger.Warn("provider entering cooldown",
|
|
"provider", providerName,
|
|
"cooldown", remaining.Round(time.Second),
|
|
"error", err,
|
|
)
|
|
}
|
|
}
|
|
|
|
logger.Warn("provider failed",
|
|
"provider", providerName,
|
|
"error", err,
|
|
"attempt", attemptNum,
|
|
"is_terminus", isTerminus,
|
|
)
|
|
lastErr = err
|
|
|
|
// Check context before trying next provider
|
|
if ctx.Err() != nil {
|
|
var zero T
|
|
return &ExecuteResult[T]{Response: zero}, ctx.Err()
|
|
}
|
|
}
|
|
|
|
if attemptNum == 0 {
|
|
var zero T
|
|
return &ExecuteResult[T]{Response: zero}, fmt.Errorf("%w: all %d providers are in cooldown",
|
|
ErrAllProvidersFailed, len(providers))
|
|
}
|
|
|
|
var zero T
|
|
return &ExecuteResult[T]{Response: zero}, fmt.Errorf("%w: all %d providers failed (last error: %v)",
|
|
ErrAllProvidersFailed, len(providers), lastErr)
|
|
}
|
|
|
|
func executeRoundRobin[T any](
|
|
ctx context.Context,
|
|
providers []Provider,
|
|
index *atomic.Uint64,
|
|
fn ExecuteFunc[T],
|
|
logger *slog.Logger,
|
|
) (*ExecuteResult[T], error) {
|
|
n := uint64(len(providers))
|
|
|
|
// Get next provider index atomically
|
|
var idx uint64
|
|
if index != nil {
|
|
idx = index.Add(1) - 1
|
|
}
|
|
|
|
provider := providers[idx%n]
|
|
start := time.Now()
|
|
|
|
resp, err := fn(ctx, provider)
|
|
latency := time.Since(start)
|
|
|
|
if err != nil {
|
|
var zero T
|
|
logger.Error("round-robin provider failed",
|
|
"provider", provider.Name(),
|
|
"index", idx%n,
|
|
"error", err,
|
|
"latency", latency,
|
|
)
|
|
return &ExecuteResult[T]{Response: zero}, fmt.Errorf("provider %s failed: %w", provider.Name(), err)
|
|
}
|
|
|
|
return &ExecuteResult[T]{
|
|
Response: resp,
|
|
Provider: provider.Name(),
|
|
Latency: latency,
|
|
AttemptNum: 1,
|
|
}, nil
|
|
}
|