persona-community-2/workers/media-worker/cmd/worker/main.go
2026-02-23 10:54:06 +00:00

270 lines
9.0 KiB
Go

// Package main is the entry point for the media-worker worker.
package main
import (
"context"
"os"
"os/signal"
"syscall"
"time"
"github.com/redis/go-redis/v9"
"git.threesix.ai/jordan/persona-community-2/pkg/database"
"git.threesix.ai/jordan/persona-community-2/pkg/gemini"
"git.threesix.ai/jordan/persona-community-2/pkg/laozhang"
"git.threesix.ai/jordan/persona-community-2/pkg/logging"
"git.threesix.ai/jordan/persona-community-2/pkg/mediagen"
mediagenAdapters "git.threesix.ai/jordan/persona-community-2/pkg/mediagen/adapters"
"git.threesix.ai/jordan/persona-community-2/pkg/personagen"
"git.threesix.ai/jordan/persona-community-2/pkg/queue"
"git.threesix.ai/jordan/persona-community-2/pkg/realtime"
"git.threesix.ai/jordan/persona-community-2/pkg/storage"
"git.threesix.ai/jordan/persona-community-2/pkg/textgen"
textgenAdapters "git.threesix.ai/jordan/persona-community-2/pkg/textgen/adapters"
"git.threesix.ai/jordan/persona-community-2/workers/media-worker/internal/config"
"git.threesix.ai/jordan/persona-community-2/workers/media-worker/internal/handlers"
)
func main() {
// Initialize logger first (with defaults) so we can log config errors
logger := logging.New(logging.Config{
Level: logging.LevelInfo,
Format: logging.FormatJSON,
}).WithService("media-worker")
// Initialize configuration
cfg, err := config.Load()
if err != nil {
logger.Error("failed to load config", "error", err)
os.Exit(1)
}
// Reconfigure logger with loaded config
logger = logging.New(logging.Config{
Level: logging.ParseLevel(cfg.Logging.Level),
Format: logging.ParseFormat(cfg.Logging.Format),
Environment: cfg.AppConfig.Environment,
AddSource: cfg.AppConfig.IsDevelopment(),
}).WithService("media-worker")
logger.Info("starting media-worker worker")
// Setup graceful shutdown
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Connect to database
pool, err := database.Connect(ctx, cfg.Database.URL, database.Options{
MaxOpenConns: cfg.Database.MaxOpenConns,
MaxIdleConns: cfg.Database.MaxIdleConns,
ConnMaxLifetime: cfg.Database.ConnMaxLifetime,
})
if err != nil {
logger.Error("failed to connect to database", "error", err)
os.Exit(1)
}
defer pool.Close()
logger.Info("connected to database", "url", pool.URL)
// Run queue migrations (idempotent — safe for both service and worker)
if err := queue.RunMigrations(ctx, pool); err != nil {
logger.Error("failed to run queue migrations", "error", err)
os.Exit(1)
}
logger.Info("queue migrations complete")
// Initialize queue
jobQueue := queue.NewQueue(pool.DB, logger)
// Initialize Redis for SSE event publishing
if cfg.RedisURL == "" {
logger.Error("REDIS_URL is required for worker to publish SSE events")
os.Exit(1)
}
redisOpts, err := redis.ParseURL(cfg.RedisURL)
if err != nil {
logger.Error("failed to parse REDIS_URL", "error", err)
os.Exit(1)
}
redisClient := redis.NewClient(redisOpts)
if err := redisClient.Ping(ctx).Err(); err != nil {
logger.Error("failed to connect to Redis", "error", err)
os.Exit(1)
}
logger.Info("connected to Redis")
ssePub := realtime.NewSSEPublisher(redisClient, logger.Logger)
// Initialize AI providers
// LaoZhang client (primary provider — pay-per-use, OpenAI-compatible)
var laozhangClient *laozhang.Client
if apiKey := os.Getenv("LAOZHANG_API_KEY"); apiKey != "" {
laozhangClient, err = laozhang.NewClient(laozhang.Config{
APIKey: apiKey,
VideoTimeout: 5 * time.Minute,
Logger: logger.Logger,
})
if err != nil {
logger.Warn("failed to create LaoZhang client", "error", err)
} else {
logger.Info("LaoZhang client initialized")
}
}
// Gemini client for media generation
var geminiClient *gemini.Client
if apiKey := os.Getenv("GEMINI_API_KEY"); apiKey != "" {
geminiClient, err = gemini.NewClient(ctx, gemini.Config{
APIKey: apiKey,
Logger: logger.Logger,
})
if err != nil {
logger.Warn("failed to create Gemini client", "error", err)
} else {
logger.Info("Gemini client initialized")
}
}
// Create mediagen manager (image + video)
var mediagenManager *mediagen.Manager
{
var laozhangMediaProvider *mediagenAdapters.LaoZhangProvider
var geminiMediaProvider *mediagenAdapters.GeminiProvider
if laozhangClient != nil {
laozhangMediaProvider = mediagenAdapters.NewLaoZhangProvider(laozhangClient)
}
if geminiClient != nil {
geminiMediaProvider = mediagenAdapters.NewGeminiProvider(geminiClient)
}
if geminiMediaProvider != nil || laozhangMediaProvider != nil {
mgCfg := mediagen.ProductionConfig(mediagen.ProviderSet{
LaoZhang: laozhangMediaProvider,
Gemini: geminiMediaProvider,
}, mediagen.WithLogger(logger.Logger))
if laozhangMediaProvider != nil {
mgCfg.VideoProviders = append(mgCfg.VideoProviders, laozhangMediaProvider)
}
if geminiMediaProvider != nil {
mgCfg.VideoProviders = append(mgCfg.VideoProviders, geminiMediaProvider)
}
mediagenManager, err = mediagen.NewManager(mgCfg)
if err != nil {
logger.Warn("failed to create mediagen manager", "error", err)
} else {
logger.Info("mediagen manager initialized (image + video)")
}
}
}
// Create textgen manager (text + streaming)
var textgenManager *textgen.Manager
{
var textProviders []textgen.TextGenerator
if laozhangClient != nil {
textProviders = append(textProviders, textgenAdapters.NewLaoZhangTextProvider(laozhangClient, ""))
}
if apiKey := os.Getenv("GEMINI_API_KEY"); apiKey != "" {
geminiTextProvider, err := textgenAdapters.NewGeminiTextProvider(ctx, textgenAdapters.GeminiTextConfig{
APIKey: apiKey,
})
if err != nil {
logger.Warn("failed to create gemini text provider", "error", err)
} else {
textProviders = append(textProviders, geminiTextProvider)
}
}
if len(textProviders) > 0 {
textgenCfg := textgen.ProductionConfig(textgen.ProviderSet{}, textgen.WithLogger(logger.Logger))
textgenCfg.Providers = textProviders
textgenManager, err = textgen.NewManager(textgenCfg)
if err != nil {
logger.Warn("failed to create textgen manager", "error", err)
} else {
logger.Info("textgen manager initialized")
}
}
}
// Initialize and start handler
handler := handlers.New(logger, jobQueue, handlers.Config{
PollInterval: cfg.Worker.PollInterval,
StaleJobTimeout: cfg.Worker.StaleJobTimeout,
JobTimeout: cfg.Worker.JobTimeout,
})
// Initialize storage backend for persisting generated media.
// GCS_BUCKET is injected by the platform; if absent, store is nil (media not persisted).
var mediaStore storage.Store
if bucket := os.Getenv("GCS_BUCKET"); bucket != "" {
gcsStore, err := storage.NewGCSStore(ctx, bucket, os.Getenv("GCS_SERVICE_ACCOUNT_JSON"), logger.Logger)
if err != nil {
logger.Warn("failed to create GCS store, generated media will not be persisted", "error", err)
} else {
defer func() { _ = gcsStore.Close() }()
mediaStore = gcsStore
logger.Info("storage initialized (GCS)", "bucket", bucket)
}
}
// Register job handlers
if mediagenManager != nil {
handler.RegisterHandler("generate_image", handlers.ImageHandler(mediagenManager, mediaStore, ssePub, logger))
handler.RegisterHandler("generate_video", handlers.VideoHandler(mediagenManager, mediaStore, ssePub, logger))
}
if textgenManager != nil {
handler.RegisterHandler("generate_text", handlers.TextHandler(textgenManager, ssePub, logger))
handler.RegisterHandler("ai_chat_response", handlers.ChatResponseHandler(textgenManager, ssePub, logger))
}
// Persona generation requires both textgen (5-stage LLM pipeline) and mediagen (20 images + 4 videos).
if textgenManager != nil && mediagenManager != nil {
handler.RegisterHandler("persona_generate", personagen.QueueHandler(textgenManager, mediagenManager, mediaStore, ssePub, logger.Logger))
}
// Setup signal handling
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
// Start worker in goroutine
go handler.Run(ctx)
// Start stale job recovery in goroutine
go runStaleJobRecovery(ctx, jobQueue, cfg.Worker.StaleJobTimeout, logger)
// Wait for shutdown signal
sig := <-sigCh
logger.Info("received shutdown signal", "signal", sig.String())
// Trigger graceful shutdown with grace period
logger.Info("initiating graceful shutdown")
cancel()
// Give in-flight jobs time to complete (grace period)
const shutdownGracePeriod = 5 * time.Second
time.Sleep(shutdownGracePeriod)
logger.Info("media-worker worker stopped")
}
// runStaleJobRecovery periodically requeues jobs that have been running too long.
func runStaleJobRecovery(ctx context.Context, q *queue.DBQueue, timeout time.Duration, logger *logging.Logger) {
const staleCheckInterval = time.Minute
ticker := time.NewTicker(staleCheckInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
count, err := q.RequeueStale(ctx, timeout)
if err != nil {
logger.Error("failed to requeue stale jobs", "error", err)
} else if count > 0 {
logger.Info("requeued stale jobs", "count", count)
}
}
}
}