package routing import ( "context" "fmt" "log/slog" "sync/atomic" "time" ) // Provider is the minimal interface that all providers must implement. // This is intentionally minimal to work with any provider type. type Provider interface { Name() string } // ExemptProviders lists providers that never enter cooldown. // These are typically pay-per-use providers with no rate limits. // // The executor checks this map and skips cooldown recording for exempt providers. // This is a package-level decision (business logic) rather than per-request config. var ExemptProviders = map[string]bool{ "laozhang": true, } // ExecuteConfig configures the fallback executor. type ExecuteConfig struct { // Strategy for routing requests to providers. // Required. Use StrategyFallback for production workloads. Strategy Strategy // Cooldown tracker for managing provider cooldowns. // Optional. If nil, cooldowns are not tracked. Cooldown CooldownTracker // Logger for debug and operational logging. // Optional. If nil, uses slog.Default(). Logger *slog.Logger // RoundRobinIndex is an atomic counter for round-robin distribution. // Required only for StrategyRoundRobin. Pass a shared pointer. RoundRobinIndex *atomic.Uint64 } // ExecuteResult contains the result of executing across providers. type ExecuteResult[T any] struct { // Response is the successful response from the provider. Response T // Provider is the name of the provider that succeeded. Provider string // Latency is the time taken for the successful call. Latency time.Duration // AttemptNum is which attempt succeeded (1-based). // 1 = first provider succeeded, 2 = first failed + second succeeded, etc. AttemptNum int // WasTerminus is true if success came from the terminus (last) provider. // This can indicate that all earlier providers are having issues. WasTerminus bool } // ExecuteFunc is the function signature for provider execution. // T is the response type (e.g., *ImageResponse, *TextResponse). type ExecuteFunc[T any] func(ctx context.Context, provider Provider) (T, error) // Execute runs the given function across providers using the configured strategy. // // MANDATORY: All provider routing in the codebase MUST use this function. // Do NOT implement custom fallback loops elsewhere. // // # Terminus Semantics (StrategyFallback) // // The LAST provider in the list is the "terminus" and will ALWAYS be attempted // regardless of cooldown state. This ensures there's always a fallback of last resort. // // Example: With providers [Gemini, Grok, LaoZhang]: // - Gemini: Checked against cooldown, may be skipped // - Grok: Checked against cooldown, may be skipped // - LaoZhang (terminus): ALWAYS tried, even if in cooldown // // # Exempt Providers // // Providers listed in ExemptProviders (like "laozhang") never enter cooldown, // even if they fail. This is appropriate for pay-per-use providers. // // # Type Safety // // The generic type parameter T allows compile-time type checking for the response. // Usage: // // result, err := routing.Execute(ctx, providers, config, // func(ctx context.Context, p routing.Provider) (*MyResponse, error) { // return p.(MyProvider).DoSomething(ctx, req) // }) func Execute[T any]( ctx context.Context, providers []Provider, config ExecuteConfig, fn ExecuteFunc[T], ) (*ExecuteResult[T], error) { if len(providers) == 0 { var zero T return &ExecuteResult[T]{Response: zero}, ErrNoProviders } logger := config.Logger if logger == nil { logger = slog.Default() } switch config.Strategy { case StrategyPrimaryOnly, "": return executePrimaryOnly(ctx, providers, fn, logger) case StrategyFallback: return executeFallback(ctx, providers, config.Cooldown, fn, logger) case StrategyRoundRobin: return executeRoundRobin(ctx, providers, config.RoundRobinIndex, fn, logger) default: var zero T return &ExecuteResult[T]{Response: zero}, fmt.Errorf("%w: unknown strategy %s", ErrInvalidConfig, config.Strategy) } } func executePrimaryOnly[T any]( ctx context.Context, providers []Provider, fn ExecuteFunc[T], logger *slog.Logger, ) (*ExecuteResult[T], error) { provider := providers[0] start := time.Now() resp, err := fn(ctx, provider) latency := time.Since(start) if err != nil { var zero T logger.Error("primary provider failed", "provider", provider.Name(), "error", err, "latency", latency, ) return &ExecuteResult[T]{Response: zero}, fmt.Errorf("primary provider %s failed: %w", provider.Name(), err) } return &ExecuteResult[T]{ Response: resp, Provider: provider.Name(), Latency: latency, AttemptNum: 1, }, nil } func executeFallback[T any]( ctx context.Context, providers []Provider, cooldown CooldownTracker, fn ExecuteFunc[T], logger *slog.Logger, ) (*ExecuteResult[T], error) { var lastErr error attemptNum := 0 terminusIdx := len(providers) - 1 for i, provider := range providers { isTerminus := i == terminusIdx providerName := provider.Name() // Check cooldown UNLESS this is the terminus provider. // TERMINUS IS ALWAYS ATTEMPTED as the fallback of last resort. if !isTerminus && cooldown != nil && !cooldown.IsAvailable(providerName) { remaining := cooldown.CooldownRemaining(providerName) logger.Debug("skipping provider in cooldown", "provider", providerName, "cooldown_remaining", remaining.Round(time.Second), ) continue } // Log when terminus is being attempted despite cooldown if isTerminus && cooldown != nil && !cooldown.IsAvailable(providerName) { remaining := cooldown.CooldownRemaining(providerName) logger.Info("attempting terminus provider despite cooldown", "provider", providerName, "cooldown_remaining", remaining.Round(time.Second), "reason", "terminus_always_tried", ) } attemptNum++ logger.Debug("attempting provider", "provider", providerName, "attempt", attemptNum, "is_terminus", isTerminus, "provider_index", i+1, "total_providers", len(providers), ) start := time.Now() resp, err := fn(ctx, provider) latency := time.Since(start) if err == nil { if attemptNum > 1 { logger.Info("succeeded on fallback provider", "provider", providerName, "attempt", attemptNum, "latency", latency, "is_terminus", isTerminus, ) } return &ExecuteResult[T]{ Response: resp, Provider: providerName, Latency: latency, AttemptNum: attemptNum, WasTerminus: isTerminus, }, nil } // Record failure for cooldown tracking. // Exempt providers (like laozhang) never enter cooldown. if cooldown != nil && !ExemptProviders[providerName] { if cooldown.RecordFailure(providerName, err) { remaining := cooldown.CooldownRemaining(providerName) logger.Warn("provider entering cooldown", "provider", providerName, "cooldown", remaining.Round(time.Second), "error", err, ) } } logger.Warn("provider failed", "provider", providerName, "error", err, "attempt", attemptNum, "is_terminus", isTerminus, ) lastErr = err // Check context before trying next provider if ctx.Err() != nil { var zero T return &ExecuteResult[T]{Response: zero}, ctx.Err() } } if attemptNum == 0 { var zero T return &ExecuteResult[T]{Response: zero}, fmt.Errorf("%w: all %d providers are in cooldown", ErrAllProvidersFailed, len(providers)) } var zero T return &ExecuteResult[T]{Response: zero}, fmt.Errorf("%w: all %d providers failed (last error: %v)", ErrAllProvidersFailed, len(providers), lastErr) } func executeRoundRobin[T any]( ctx context.Context, providers []Provider, index *atomic.Uint64, fn ExecuteFunc[T], logger *slog.Logger, ) (*ExecuteResult[T], error) { n := uint64(len(providers)) // Get next provider index atomically var idx uint64 if index != nil { idx = index.Add(1) - 1 } provider := providers[idx%n] start := time.Now() resp, err := fn(ctx, provider) latency := time.Since(start) if err != nil { var zero T logger.Error("round-robin provider failed", "provider", provider.Name(), "index", idx%n, "error", err, "latency", latency, ) return &ExecuteResult[T]{Response: zero}, fmt.Errorf("provider %s failed: %w", provider.Name(), err) } return &ExecuteResult[T]{ Response: resp, Provider: provider.Name(), Latency: latency, AttemptNum: 1, }, nil }