203 lines
6.0 KiB
Go
203 lines
6.0 KiB
Go
package textgen
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log/slog"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"git.threesix.ai/jordan/persona-community-2/pkg/routing"
|
|
)
|
|
|
|
// Manager coordinates multiple providers with configurable routing strategies.
|
|
// Safe for concurrent use.
|
|
//
|
|
// IMPORTANT: All text generation MUST go through this Manager.
|
|
// Do NOT implement custom provider routing elsewhere. The Manager delegates to
|
|
// pkg/routing for consistent terminus semantics and cooldown handling.
|
|
type Manager struct {
|
|
providers []TextGenerator
|
|
strategy Strategy
|
|
logger *slog.Logger
|
|
onMetrics MetricsHook
|
|
cooldown routing.CooldownTracker
|
|
|
|
// Round-robin state
|
|
index atomic.Uint64
|
|
}
|
|
|
|
// MetricsHook is called after each generation attempt for observability.
|
|
// provider: name of the provider used
|
|
// operation: "GenerateText"
|
|
// latency: time taken for the operation
|
|
// err: error if the operation failed, nil on success
|
|
type MetricsHook func(provider, operation string, latency time.Duration, err error)
|
|
|
|
// ManagerConfig configures the provider manager.
|
|
type ManagerConfig struct {
|
|
Providers []TextGenerator // Text generation providers (order matters for fallback)
|
|
Strategy Strategy // Routing strategy (default: StrategyPrimaryOnly)
|
|
Logger *slog.Logger // Optional: defaults to slog.Default()
|
|
OnMetrics MetricsHook // Optional: callback for metrics collection
|
|
CircuitBreaker *CircuitBreaker // Optional: in-memory cooldown tracker
|
|
CooldownPeriod time.Duration // Optional: cooldown for rate-limited providers (default: 1 hour)
|
|
}
|
|
|
|
// NewManager creates a new provider manager.
|
|
//
|
|
// IMPORTANT: The Manager uses pkg/routing for all provider routing.
|
|
// The LAST provider in the list is the "terminus" and will ALWAYS be attempted
|
|
// regardless of cooldown state when using StrategyFallback.
|
|
func NewManager(config ManagerConfig) (*Manager, error) {
|
|
if len(config.Providers) == 0 {
|
|
return nil, fmt.Errorf("%w: at least one provider required", ErrInvalidConfig)
|
|
}
|
|
|
|
if config.Strategy == "" {
|
|
config.Strategy = StrategyPrimaryOnly
|
|
}
|
|
|
|
if !config.Strategy.Valid() {
|
|
return nil, fmt.Errorf("%w: unknown strategy %s", ErrInvalidConfig, config.Strategy)
|
|
}
|
|
|
|
if config.Logger == nil {
|
|
config.Logger = slog.Default()
|
|
}
|
|
|
|
// Build cooldown tracker using routing.BuildCooldownTracker
|
|
cooldown := routing.BuildCooldownTracker(routing.CooldownConfig{
|
|
CircuitBreaker: config.CircuitBreaker,
|
|
CooldownPeriod: config.CooldownPeriod,
|
|
})
|
|
|
|
return &Manager{
|
|
providers: config.Providers,
|
|
strategy: config.Strategy,
|
|
logger: config.Logger,
|
|
onMetrics: config.OnMetrics,
|
|
cooldown: cooldown,
|
|
}, nil
|
|
}
|
|
|
|
// recordMetrics calls the metrics hook if configured.
|
|
func (m *Manager) recordMetrics(provider, operation string, latency time.Duration, err error) {
|
|
if m.onMetrics != nil {
|
|
m.onMetrics(provider, operation, latency, err)
|
|
}
|
|
}
|
|
|
|
// GenerateText generates text using the configured strategy.
|
|
//
|
|
// IMPORTANT: This method delegates to pkg/routing for consistent terminus semantics.
|
|
// The LAST provider is ALWAYS attempted regardless of cooldown when using StrategyFallback.
|
|
func (m *Manager) GenerateText(ctx context.Context, req TextRequest) (*TextResponse, error) {
|
|
if len(m.providers) == 0 {
|
|
return nil, ErrNoProvidersConfigured
|
|
}
|
|
|
|
if req.Prompt == "" && len(req.Messages) == 0 {
|
|
return nil, fmt.Errorf("%w: prompt or messages required", ErrInvalidRequest)
|
|
}
|
|
|
|
// Apply request timeout if specified
|
|
if req.Timeout > 0 {
|
|
var cancel context.CancelFunc
|
|
ctx, cancel = context.WithTimeout(ctx, req.Timeout)
|
|
defer cancel()
|
|
}
|
|
|
|
// Convert to routing.Provider slice
|
|
providers := make([]routing.Provider, len(m.providers))
|
|
for i, p := range m.providers {
|
|
providers[i] = p
|
|
}
|
|
|
|
// Use routing.Execute for consistent terminus semantics
|
|
result, err := routing.Execute(ctx, providers, routing.ExecuteConfig{
|
|
Strategy: m.strategy,
|
|
Cooldown: m.cooldown,
|
|
Logger: m.logger,
|
|
RoundRobinIndex: &m.index,
|
|
}, func(ctx context.Context, p routing.Provider) (*TextResponse, error) {
|
|
provider := p.(TextGenerator)
|
|
start := time.Now()
|
|
|
|
resp, err := provider.GenerateText(ctx, req)
|
|
latency := time.Since(start)
|
|
|
|
// Record metrics for observability
|
|
m.recordMetrics(provider.Name(), "GenerateText", latency, err)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Attach provider metadata to response
|
|
resp.Provider = provider.Name()
|
|
resp.Latency = latency
|
|
return resp, nil
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return result.Response, nil
|
|
}
|
|
|
|
// GenerateStream generates text with streaming delivery via onChunk callback.
|
|
// Routes to the first provider implementing TextStreamer.
|
|
// Falls back to non-streaming GenerateText + single chunk if no streaming provider.
|
|
func (m *Manager) GenerateStream(ctx context.Context, req TextRequest, onChunk func(StreamChunk)) error {
|
|
if len(m.providers) == 0 {
|
|
return ErrNoProvidersConfigured
|
|
}
|
|
|
|
if req.Prompt == "" && len(req.Messages) == 0 {
|
|
return fmt.Errorf("%w: prompt or messages required", ErrInvalidRequest)
|
|
}
|
|
|
|
if req.Timeout > 0 {
|
|
var cancel context.CancelFunc
|
|
ctx, cancel = context.WithTimeout(ctx, req.Timeout)
|
|
defer cancel()
|
|
}
|
|
|
|
// Try to find a streaming provider
|
|
for _, p := range m.providers {
|
|
if streamer, ok := p.(TextStreamer); ok {
|
|
start := time.Now()
|
|
err := streamer.GenerateStream(ctx, req, onChunk)
|
|
latency := time.Since(start)
|
|
m.recordMetrics(p.Name(), "GenerateStream", latency, err)
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Fallback: use non-streaming GenerateText and deliver as single chunk
|
|
resp, err := m.GenerateText(ctx, req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
onChunk(StreamChunk{Text: resp.Text, Done: true, Provider: resp.Provider})
|
|
return nil
|
|
}
|
|
|
|
// Health checks all providers and returns the first error encountered.
|
|
func (m *Manager) Health(ctx context.Context) error {
|
|
for _, provider := range m.providers {
|
|
if err := provider.Health(ctx); err != nil {
|
|
return fmt.Errorf("provider %s unhealthy: %w", provider.Name(), err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Name returns "textgen-manager" for logging purposes.
|
|
func (m *Manager) Name() string {
|
|
return "textgen-manager"
|
|
}
|