241 lines
7.1 KiB
Go
241 lines
7.1 KiB
Go
// Package mediagen provides image and video generation with provider routing.
|
|
//
|
|
// IMPORTANT: This package delegates to pkg/routing for all fallback execution.
|
|
// Do NOT implement custom fallback loops or cooldown logic here.
|
|
// Use pkg/routing directly for new code paths.
|
|
package mediagen
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log/slog"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"git.threesix.ai/jordan/persona-community-2/pkg/routing"
|
|
)
|
|
|
|
// Manager coordinates multiple providers with configurable routing strategies.
|
|
// Safe for concurrent use.
|
|
//
|
|
// IMPORTANT: Internally delegates to pkg/routing for fallback execution.
|
|
// The last provider in the chain (terminus) is ALWAYS tried regardless of
|
|
// cooldown state. See pkg/routing.Execute for details.
|
|
type Manager struct {
|
|
imageProviders []ImageGenerator
|
|
videoProviders []VideoGenerator
|
|
strategy routing.Strategy
|
|
logger *slog.Logger
|
|
onMetrics MetricsHook
|
|
cooldown routing.CooldownTracker
|
|
|
|
// Round-robin state
|
|
imageIndex atomic.Uint64
|
|
videoIndex atomic.Uint64
|
|
}
|
|
|
|
// MetricsHook is called after each generation attempt for observability.
|
|
// provider: name of the provider used
|
|
// operation: "GenerateImage" or "GenerateVideo"
|
|
// latency: time taken for the operation
|
|
// err: error if the operation failed, nil on success
|
|
type MetricsHook func(provider, operation string, latency time.Duration, err error)
|
|
|
|
// ManagerConfig configures the provider manager.
|
|
type ManagerConfig struct {
|
|
ImageProviders []ImageGenerator // Image generation providers (order matters for fallback, last is terminus)
|
|
VideoProviders []VideoGenerator // Video generation providers (order matters for fallback, last is terminus)
|
|
Strategy Strategy // Routing strategy (default: StrategyPrimaryOnly)
|
|
Logger *slog.Logger // Optional: defaults to slog.Default()
|
|
OnMetrics MetricsHook // Optional: callback for metrics collection
|
|
|
|
// Cooldown configuration - use ONE of these:
|
|
// - CircuitBreaker: in-memory only (for long-running services)
|
|
// - Both: combined cooldown tracking
|
|
// If neither is provided, a default CircuitBreaker is created.
|
|
CircuitBreaker *CircuitBreaker // Optional: shared circuit breaker (in-memory)
|
|
CooldownPeriod time.Duration // Optional: cooldown for rate-limited providers (default: 1 hour)
|
|
}
|
|
|
|
// NewManager creates a new provider manager.
|
|
//
|
|
// IMPORTANT: The last provider in ImageProviders/VideoProviders is the "terminus"
|
|
// and will ALWAYS be tried regardless of cooldown. See pkg/routing for details.
|
|
func NewManager(config ManagerConfig) (*Manager, error) {
|
|
if len(config.ImageProviders) == 0 && len(config.VideoProviders) == 0 {
|
|
return nil, fmt.Errorf("%w: at least one provider required", ErrInvalidConfig)
|
|
}
|
|
|
|
strategy := config.Strategy
|
|
if strategy == "" {
|
|
strategy = StrategyPrimaryOnly
|
|
}
|
|
|
|
if !strategy.Valid() {
|
|
return nil, fmt.Errorf("%w: unknown strategy %s", ErrInvalidConfig, strategy)
|
|
}
|
|
|
|
logger := config.Logger
|
|
if logger == nil {
|
|
logger = slog.Default()
|
|
}
|
|
|
|
// Build cooldown tracker using routing.BuildCooldownTracker
|
|
cooldown := routing.BuildCooldownTracker(routing.CooldownConfig{
|
|
CircuitBreaker: config.CircuitBreaker,
|
|
CooldownPeriod: config.CooldownPeriod,
|
|
})
|
|
|
|
return &Manager{
|
|
imageProviders: config.ImageProviders,
|
|
videoProviders: config.VideoProviders,
|
|
strategy: strategy,
|
|
logger: logger,
|
|
onMetrics: config.OnMetrics,
|
|
cooldown: cooldown,
|
|
}, nil
|
|
}
|
|
|
|
// GenerateImage generates images using the configured strategy.
|
|
//
|
|
// When using StrategyFallback, the last provider is the terminus and will
|
|
// ALWAYS be tried regardless of cooldown state.
|
|
func (m *Manager) GenerateImage(ctx context.Context, req ImageRequest) (*ImageResponse, error) {
|
|
if len(m.imageProviders) == 0 {
|
|
return nil, ErrNoProvidersConfigured
|
|
}
|
|
|
|
if req.Prompt == "" {
|
|
return nil, fmt.Errorf("%w: prompt is required", ErrInvalidRequest)
|
|
}
|
|
|
|
// Apply request timeout if specified
|
|
if req.Timeout > 0 {
|
|
var cancel context.CancelFunc
|
|
ctx, cancel = context.WithTimeout(ctx, req.Timeout)
|
|
defer cancel()
|
|
}
|
|
|
|
// Convert to routing.Provider slice
|
|
providers := make([]routing.Provider, len(m.imageProviders))
|
|
for i, p := range m.imageProviders {
|
|
providers[i] = p
|
|
}
|
|
|
|
// Delegate to pkg/routing for execution
|
|
result, err := routing.Execute(ctx, providers, routing.ExecuteConfig{
|
|
Strategy: m.strategy,
|
|
Cooldown: m.cooldown,
|
|
Logger: m.logger,
|
|
RoundRobinIndex: &m.imageIndex,
|
|
}, func(ctx context.Context, p routing.Provider) (*ImageResponse, error) {
|
|
provider := p.(ImageGenerator)
|
|
start := time.Now()
|
|
|
|
resp, genErr := provider.GenerateImage(ctx, req)
|
|
latency := time.Since(start)
|
|
|
|
// Call metrics hook
|
|
if m.onMetrics != nil {
|
|
m.onMetrics(provider.Name(), "GenerateImage", latency, genErr)
|
|
}
|
|
|
|
if genErr != nil {
|
|
return nil, genErr
|
|
}
|
|
|
|
resp.Latency = latency
|
|
return resp, nil
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
result.Response.Provider = result.Provider
|
|
return result.Response, nil
|
|
}
|
|
|
|
// GenerateVideo generates videos using the configured strategy.
|
|
//
|
|
// When using StrategyFallback, the last provider is the terminus and will
|
|
// ALWAYS be tried regardless of cooldown state.
|
|
func (m *Manager) GenerateVideo(ctx context.Context, req VideoRequest) (*VideoResponse, error) {
|
|
if len(m.videoProviders) == 0 {
|
|
return nil, ErrNoProvidersConfigured
|
|
}
|
|
|
|
if req.Prompt == "" {
|
|
return nil, fmt.Errorf("%w: prompt is required", ErrInvalidRequest)
|
|
}
|
|
|
|
// Apply request timeout if specified
|
|
if req.Timeout > 0 {
|
|
var cancel context.CancelFunc
|
|
ctx, cancel = context.WithTimeout(ctx, req.Timeout)
|
|
defer cancel()
|
|
}
|
|
|
|
// Convert to routing.Provider slice
|
|
providers := make([]routing.Provider, len(m.videoProviders))
|
|
for i, p := range m.videoProviders {
|
|
providers[i] = p
|
|
}
|
|
|
|
// Delegate to pkg/routing for execution
|
|
result, err := routing.Execute(ctx, providers, routing.ExecuteConfig{
|
|
Strategy: m.strategy,
|
|
Cooldown: m.cooldown,
|
|
Logger: m.logger,
|
|
RoundRobinIndex: &m.videoIndex,
|
|
}, func(ctx context.Context, p routing.Provider) (*VideoResponse, error) {
|
|
provider := p.(VideoGenerator)
|
|
start := time.Now()
|
|
|
|
resp, genErr := provider.GenerateVideo(ctx, req)
|
|
latency := time.Since(start)
|
|
|
|
// Call metrics hook
|
|
if m.onMetrics != nil {
|
|
m.onMetrics(provider.Name(), "GenerateVideo", latency, genErr)
|
|
}
|
|
|
|
if genErr != nil {
|
|
return nil, genErr
|
|
}
|
|
|
|
resp.Latency = latency
|
|
return resp, nil
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
result.Response.Provider = result.Provider
|
|
return result.Response, nil
|
|
}
|
|
|
|
// Cooldown returns the manager's cooldown tracker.
|
|
// Useful for inspection or manual cooldown management.
|
|
func (m *Manager) Cooldown() routing.CooldownTracker {
|
|
return m.cooldown
|
|
}
|
|
|
|
// Health checks health of all configured providers.
|
|
func (m *Manager) Health(ctx context.Context) error {
|
|
for _, provider := range m.imageProviders {
|
|
if err := provider.Health(ctx); err != nil {
|
|
return fmt.Errorf("image provider %s unhealthy: %w", provider.Name(), err)
|
|
}
|
|
}
|
|
|
|
for _, provider := range m.videoProviders {
|
|
if err := provider.Health(ctx); err != nil {
|
|
return fmt.Errorf("video provider %s unhealthy: %w", provider.Name(), err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|