350 lines
11 KiB
Go
350 lines
11 KiB
Go
// Package generation provides queue job handlers for AI generation tasks.
|
|
// Used by both the worker (production) and service standalone mode (development).
|
|
package generation
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"log/slog"
|
|
"net/http"
|
|
"time"
|
|
|
|
"git.threesix.ai/jordan/persona-community-5/pkg/logging"
|
|
"git.threesix.ai/jordan/persona-community-5/pkg/mediagen"
|
|
"git.threesix.ai/jordan/persona-community-5/pkg/queue"
|
|
"git.threesix.ai/jordan/persona-community-5/pkg/realtime"
|
|
"git.threesix.ai/jordan/persona-community-5/pkg/storage"
|
|
"git.threesix.ai/jordan/persona-community-5/pkg/textgen"
|
|
)
|
|
|
|
// httpClient is used for downloading video content from provider URLs before persisting to storage.
|
|
var httpClient = &http.Client{Timeout: 2 * time.Minute}
|
|
|
|
// sendUserEvent sends an SSE event and logs delivery failures at warn level.
|
|
// SSE delivery can fail if the user disconnected; this is non-fatal for the job.
|
|
func sendUserEvent(pub realtime.EventPublisher, userID string, event *realtime.SSEEvent) {
|
|
if err := pub.SendToUser(userID, event); err != nil {
|
|
slog.Warn("failed to send SSE event", "error", err, "type", event.Type, "job_id", event.JobID)
|
|
}
|
|
}
|
|
|
|
// GeneratedImage represents a single generated image in SSE response payloads.
|
|
type GeneratedImage struct {
|
|
Data string `json:"data"`
|
|
IsURL bool `json:"isUrl"`
|
|
Seed *int32 `json:"seed,omitempty"`
|
|
}
|
|
|
|
// GenerateImageResponse is the SSE result payload for completed image generation.
|
|
type GenerateImageResponse struct {
|
|
Images []GeneratedImage `json:"images"`
|
|
Provider string `json:"provider"`
|
|
LatencyMs int64 `json:"latencyMs"`
|
|
}
|
|
|
|
// ImageHandler returns a queue.Handler that processes image generation jobs.
|
|
// If store is non-nil, generated images are persisted and URLs are returned instead of base64.
|
|
func ImageHandler(mg *mediagen.Manager, store storage.Store, pub realtime.EventPublisher, logger *logging.Logger) queue.Handler {
|
|
return func(ctx context.Context, job *queue.Job) error {
|
|
userID, _ := job.Payload["userID"].(string)
|
|
if userID == "" {
|
|
return fmt.Errorf("missing userID in job payload")
|
|
}
|
|
|
|
prompt, _ := job.Payload["prompt"].(string)
|
|
count := 1
|
|
if c, ok := job.Payload["count"].(float64); ok && c > 0 {
|
|
count = int(c)
|
|
}
|
|
aspectRatio, _ := job.Payload["aspectRatio"].(string)
|
|
|
|
sendUserEvent(pub, userID, &realtime.SSEEvent{
|
|
Type: realtime.EventGenerationStarted,
|
|
JobID: job.ID,
|
|
Message: "Starting image generation...",
|
|
})
|
|
|
|
sendUserEvent(pub, userID, &realtime.SSEEvent{
|
|
Type: realtime.EventGenerationProgress,
|
|
JobID: job.ID,
|
|
Progress: 30,
|
|
Message: "Generating image...",
|
|
})
|
|
|
|
start := time.Now()
|
|
resp, err := mg.GenerateImage(ctx, mediagen.ImageRequest{
|
|
Prompt: prompt,
|
|
Count: count,
|
|
AspectRatio: aspectRatio,
|
|
})
|
|
elapsed := time.Since(start)
|
|
|
|
if err != nil {
|
|
logger.Error("image generation failed", "error", err, "job_id", job.ID)
|
|
sendUserEvent(pub, userID, &realtime.SSEEvent{
|
|
Type: realtime.EventGenerationFailed,
|
|
JobID: job.ID,
|
|
Error: "Image generation failed: " + err.Error(),
|
|
})
|
|
return err
|
|
}
|
|
|
|
images := make([]GeneratedImage, len(resp.Images))
|
|
for i, img := range resp.Images {
|
|
// Try to persist to storage if available
|
|
if store != nil && img.Data != nil {
|
|
storagePath := fmt.Sprintf("media/%s/images/%s_%d.png", userID, job.ID, i)
|
|
url, uploadErr := store.Upload(ctx, storagePath, img.Data, "image/png")
|
|
if uploadErr != nil {
|
|
logger.Warn("failed to persist image to storage", "error", uploadErr, "job_id", job.ID)
|
|
} else {
|
|
captionPath := fmt.Sprintf("media/%s/images/%s_%d.caption", userID, job.ID, i)
|
|
if _, captionErr := store.Upload(ctx, captionPath, []byte(prompt), "text/plain"); captionErr != nil {
|
|
logger.Warn("failed to persist image caption", "error", captionErr, "job_id", job.ID)
|
|
}
|
|
images[i] = GeneratedImage{Data: url, IsURL: true, Seed: resp.Seed}
|
|
continue
|
|
}
|
|
}
|
|
// Fallback: return URL or base64
|
|
if img.URL != "" {
|
|
images[i] = GeneratedImage{Data: img.URL, IsURL: true, Seed: resp.Seed}
|
|
} else {
|
|
images[i] = GeneratedImage{Data: base64.StdEncoding.EncodeToString(img.Data), IsURL: false, Seed: resp.Seed}
|
|
}
|
|
}
|
|
|
|
sendUserEvent(pub, userID, &realtime.SSEEvent{
|
|
Type: realtime.EventGenerationComplete,
|
|
JobID: job.ID,
|
|
Progress: 100,
|
|
Message: "Complete",
|
|
Result: GenerateImageResponse{
|
|
Images: images,
|
|
Provider: resp.Provider,
|
|
LatencyMs: elapsed.Milliseconds(),
|
|
},
|
|
})
|
|
|
|
logger.Info("image generation complete",
|
|
"job_id", job.ID, "provider", resp.Provider,
|
|
"images", len(resp.Images), "elapsed", elapsed)
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// VideoHandler returns a queue.Handler that processes video generation jobs.
|
|
// If store is non-nil, generated videos are persisted and URLs are returned.
|
|
func VideoHandler(mg *mediagen.Manager, store storage.Store, pub realtime.EventPublisher, logger *logging.Logger) queue.Handler {
|
|
return func(ctx context.Context, job *queue.Job) error {
|
|
userID, _ := job.Payload["userID"].(string)
|
|
if userID == "" {
|
|
return fmt.Errorf("missing userID in job payload")
|
|
}
|
|
|
|
prompt, _ := job.Payload["prompt"].(string)
|
|
aspectRatio, _ := job.Payload["aspectRatio"].(string)
|
|
|
|
sendUserEvent(pub, userID, &realtime.SSEEvent{
|
|
Type: realtime.EventGenerationStarted,
|
|
JobID: job.ID,
|
|
Message: "Starting video generation...",
|
|
})
|
|
|
|
sendUserEvent(pub, userID, &realtime.SSEEvent{
|
|
Type: realtime.EventGenerationProgress,
|
|
JobID: job.ID,
|
|
Progress: 10,
|
|
Message: "Initializing video generation (this takes 2-5 minutes)...",
|
|
})
|
|
|
|
start := time.Now()
|
|
resp, err := mg.GenerateVideo(ctx, mediagen.VideoRequest{
|
|
Prompt: prompt,
|
|
AspectRatio: aspectRatio,
|
|
})
|
|
elapsed := time.Since(start)
|
|
|
|
if err != nil {
|
|
logger.Error("video generation failed", "error", err, "job_id", job.ID)
|
|
sendUserEvent(pub, userID, &realtime.SSEEvent{
|
|
Type: realtime.EventGenerationFailed,
|
|
JobID: job.ID,
|
|
Error: "Video generation failed: " + err.Error(),
|
|
})
|
|
return err
|
|
}
|
|
|
|
// Build videos array matching frontend VideoResult shape:
|
|
// { videos: [{ data, isUrl, mimeType }], provider, latencyMs }
|
|
const videoMaxBytes = 500 << 20 // 500 MB — videos can be large
|
|
videos := make([]map[string]any, 0, len(resp.Videos))
|
|
for i, vid := range resp.Videos {
|
|
videoURL := vid.URL
|
|
|
|
// Persist to storage if available.
|
|
// Prefer vid.Data (already downloaded by provider adapter) over re-downloading from URL.
|
|
// Provider URLs (e.g., Gemini API) often require authentication and fail with plain GET.
|
|
if store != nil {
|
|
storagePath := fmt.Sprintf("media/%s/videos/%s_%d.mp4", userID, job.ID, i)
|
|
|
|
var videoData []byte
|
|
if len(vid.Data) > 0 {
|
|
videoData = vid.Data
|
|
} else if vid.URL != "" {
|
|
downloaded, downloadErr := storage.FetchURL(ctx, httpClient, vid.URL, videoMaxBytes)
|
|
if downloadErr != nil {
|
|
logger.Warn("failed to download video from provider", "error", downloadErr, "job_id", job.ID)
|
|
} else {
|
|
videoData = downloaded
|
|
}
|
|
}
|
|
|
|
if len(videoData) > 0 {
|
|
persistedURL, uploadErr := store.Upload(ctx, storagePath, videoData, "video/mp4")
|
|
if uploadErr != nil {
|
|
logger.Warn("failed to persist video to storage", "error", uploadErr, "job_id", job.ID)
|
|
} else {
|
|
videoURL = persistedURL
|
|
}
|
|
}
|
|
|
|
// Save caption alongside the video regardless of where it's stored.
|
|
if videoURL != "" && prompt != "" {
|
|
captionPath := fmt.Sprintf("media/%s/videos/%s_%d.caption", userID, job.ID, i)
|
|
if _, captionErr := store.Upload(ctx, captionPath, []byte(prompt), "text/plain"); captionErr != nil {
|
|
logger.Warn("failed to persist video caption", "error", captionErr, "job_id", job.ID)
|
|
}
|
|
}
|
|
}
|
|
|
|
videos = append(videos, map[string]any{
|
|
"data": videoURL,
|
|
"isUrl": true,
|
|
"mimeType": "video/mp4",
|
|
})
|
|
}
|
|
|
|
sendUserEvent(pub, userID, &realtime.SSEEvent{
|
|
Type: realtime.EventGenerationComplete,
|
|
JobID: job.ID,
|
|
Progress: 100,
|
|
Message: "Complete",
|
|
Result: map[string]any{
|
|
"videos": videos,
|
|
"provider": resp.Provider,
|
|
"latencyMs": elapsed.Milliseconds(),
|
|
},
|
|
})
|
|
|
|
logger.Info("video generation complete",
|
|
"job_id", job.ID, "provider", resp.Provider, "elapsed", elapsed)
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// TextHandler returns a queue.Handler that processes text generation jobs with streaming.
|
|
func TextHandler(tg *textgen.Manager, pub realtime.EventPublisher, logger *logging.Logger) queue.Handler {
|
|
return func(ctx context.Context, job *queue.Job) error {
|
|
userID, _ := job.Payload["userID"].(string)
|
|
if userID == "" {
|
|
return fmt.Errorf("missing userID in job payload")
|
|
}
|
|
|
|
prompt, _ := job.Payload["prompt"].(string)
|
|
systemPrompt, _ := job.Payload["systemPrompt"].(string)
|
|
maxTokens := 0
|
|
if mt, ok := job.Payload["maxTokens"].(float64); ok {
|
|
maxTokens = int(mt)
|
|
}
|
|
|
|
sendUserEvent(pub, userID, &realtime.SSEEvent{
|
|
Type: realtime.EventGenerationStarted,
|
|
JobID: job.ID,
|
|
Message: "Starting text generation...",
|
|
})
|
|
|
|
streamID := job.ID
|
|
|
|
err := tg.GenerateStream(ctx, textgen.TextRequest{
|
|
Prompt: prompt,
|
|
SystemPrompt: systemPrompt,
|
|
MaxTokens: maxTokens,
|
|
Temperature: 0.7,
|
|
}, func(chunk textgen.StreamChunk) {
|
|
sendUserEvent(pub, userID, &realtime.SSEEvent{
|
|
Type: realtime.MessageTypeAIChatChunk,
|
|
JobID: job.ID,
|
|
Result: realtime.AIChunkData{
|
|
StreamID: streamID,
|
|
Text: chunk.Text,
|
|
Done: chunk.Done,
|
|
Provider: chunk.Provider,
|
|
},
|
|
})
|
|
})
|
|
|
|
if err != nil {
|
|
logger.Error("text generation failed", "error", err, "job_id", job.ID)
|
|
sendUserEvent(pub, userID, &realtime.SSEEvent{
|
|
Type: realtime.EventGenerationFailed,
|
|
JobID: job.ID,
|
|
Error: "Text generation failed: " + err.Error(),
|
|
})
|
|
return err
|
|
}
|
|
|
|
sendUserEvent(pub, userID, &realtime.SSEEvent{
|
|
Type: realtime.EventGenerationComplete,
|
|
JobID: job.ID,
|
|
Progress: 100,
|
|
Message: "Complete",
|
|
})
|
|
|
|
logger.Info("text generation complete", "job_id", job.ID)
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// ChatResponseHandler returns a queue.Handler that generates AI chat responses
|
|
// and streams them to a channel (e.g., channel:general) for all participants.
|
|
func ChatResponseHandler(tg *textgen.Manager, pub realtime.EventPublisher, logger *logging.Logger) queue.Handler {
|
|
return func(ctx context.Context, job *queue.Job) error {
|
|
content, _ := job.Payload["content"].(string)
|
|
channel, _ := job.Payload["channel"].(string)
|
|
if channel == "" {
|
|
channel = "channel:general"
|
|
}
|
|
|
|
streamID := job.ID
|
|
|
|
err := tg.GenerateStream(ctx, textgen.TextRequest{
|
|
Prompt: content,
|
|
SystemPrompt: "You are a helpful AI assistant in a chat room. Keep responses concise, friendly, and under 200 words.",
|
|
MaxTokens: 300,
|
|
Temperature: 0.7,
|
|
}, func(chunk textgen.StreamChunk) {
|
|
pub.SendToChannel(channel, &realtime.SSEEvent{
|
|
Type: realtime.MessageTypeAIChatChunk,
|
|
JobID: job.ID,
|
|
Result: realtime.AIChunkData{
|
|
StreamID: streamID,
|
|
Text: chunk.Text,
|
|
Done: chunk.Done,
|
|
Provider: chunk.Provider,
|
|
},
|
|
})
|
|
})
|
|
|
|
if err != nil {
|
|
logger.Error("AI chat response failed", "error", err, "job_id", job.ID)
|
|
return err
|
|
}
|
|
|
|
logger.Info("AI chat response complete", "job_id", job.ID)
|
|
return nil
|
|
}
|
|
}
|
|
|