persona-community-1/pkg/generation/handlers.go
jordan 4004f88f4a
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed
ci/woodpecker/manual/woodpecker Pipeline was successful
Initialize project from skeleton template
2026-02-23 10:20:59 +00:00

361 lines
11 KiB
Go

// Package generation provides queue job handlers for AI generation tasks.
// Used by both the worker (production) and service standalone mode (development).
package generation
import (
"context"
"encoding/base64"
"fmt"
"io"
"log/slog"
"net/http"
"time"
"git.threesix.ai/jordan/persona-community-1/pkg/logging"
"git.threesix.ai/jordan/persona-community-1/pkg/mediagen"
"git.threesix.ai/jordan/persona-community-1/pkg/queue"
"git.threesix.ai/jordan/persona-community-1/pkg/realtime"
"git.threesix.ai/jordan/persona-community-1/pkg/storage"
"git.threesix.ai/jordan/persona-community-1/pkg/textgen"
)
// httpClient is used for downloading video content from provider URLs before persisting to storage.
var httpClient = &http.Client{Timeout: 2 * time.Minute}
// sendUserEvent sends an SSE event and logs delivery failures at warn level.
// SSE delivery can fail if the user disconnected; this is non-fatal for the job.
func sendUserEvent(pub realtime.EventPublisher, userID string, event *realtime.SSEEvent) {
if err := pub.SendToUser(userID, event); err != nil {
slog.Warn("failed to send SSE event", "error", err, "type", event.Type, "job_id", event.JobID)
}
}
// GeneratedImage represents a single generated image in SSE response payloads.
type GeneratedImage struct {
Data string `json:"data"`
IsURL bool `json:"isUrl"`
Seed *int32 `json:"seed,omitempty"`
}
// GenerateImageResponse is the SSE result payload for completed image generation.
type GenerateImageResponse struct {
Images []GeneratedImage `json:"images"`
Provider string `json:"provider"`
LatencyMs int64 `json:"latencyMs"`
}
// ImageHandler returns a queue.Handler that processes image generation jobs.
// If store is non-nil, generated images are persisted and URLs are returned instead of base64.
func ImageHandler(mg *mediagen.Manager, store storage.Store, pub realtime.EventPublisher, logger *logging.Logger) queue.Handler {
return func(ctx context.Context, job *queue.Job) error {
userID, _ := job.Payload["userID"].(string)
if userID == "" {
return fmt.Errorf("missing userID in job payload")
}
prompt, _ := job.Payload["prompt"].(string)
count := 1
if c, ok := job.Payload["count"].(float64); ok && c > 0 {
count = int(c)
}
aspectRatio, _ := job.Payload["aspectRatio"].(string)
sendUserEvent(pub, userID, &realtime.SSEEvent{
Type: realtime.EventGenerationStarted,
JobID: job.ID,
Message: "Starting image generation...",
})
sendUserEvent(pub, userID, &realtime.SSEEvent{
Type: realtime.EventGenerationProgress,
JobID: job.ID,
Progress: 30,
Message: "Generating image...",
})
start := time.Now()
resp, err := mg.GenerateImage(ctx, mediagen.ImageRequest{
Prompt: prompt,
Count: count,
AspectRatio: aspectRatio,
})
elapsed := time.Since(start)
if err != nil {
logger.Error("image generation failed", "error", err, "job_id", job.ID)
sendUserEvent(pub, userID, &realtime.SSEEvent{
Type: realtime.EventGenerationFailed,
JobID: job.ID,
Error: "Image generation failed: " + err.Error(),
})
return err
}
images := make([]GeneratedImage, len(resp.Images))
for i, img := range resp.Images {
// Try to persist to storage if available
if store != nil && img.Data != nil {
storagePath := fmt.Sprintf("media/%s/images/%s_%d.png", userID, job.ID, i)
url, uploadErr := store.Upload(ctx, storagePath, img.Data, "image/png")
if uploadErr != nil {
logger.Warn("failed to persist image to storage", "error", uploadErr, "job_id", job.ID)
} else {
images[i] = GeneratedImage{Data: url, IsURL: true, Seed: resp.Seed}
continue
}
}
// Fallback: return URL or base64
if img.URL != "" {
images[i] = GeneratedImage{Data: img.URL, IsURL: true, Seed: resp.Seed}
} else {
images[i] = GeneratedImage{Data: base64.StdEncoding.EncodeToString(img.Data), IsURL: false, Seed: resp.Seed}
}
}
sendUserEvent(pub, userID, &realtime.SSEEvent{
Type: realtime.EventGenerationComplete,
JobID: job.ID,
Progress: 100,
Message: "Complete",
Result: GenerateImageResponse{
Images: images,
Provider: resp.Provider,
LatencyMs: elapsed.Milliseconds(),
},
})
logger.Info("image generation complete",
"job_id", job.ID, "provider", resp.Provider,
"images", len(resp.Images), "elapsed", elapsed)
return nil
}
}
// VideoHandler returns a queue.Handler that processes video generation jobs.
// If store is non-nil, generated videos are persisted and URLs are returned.
func VideoHandler(mg *mediagen.Manager, store storage.Store, pub realtime.EventPublisher, logger *logging.Logger) queue.Handler {
return func(ctx context.Context, job *queue.Job) error {
userID, _ := job.Payload["userID"].(string)
if userID == "" {
return fmt.Errorf("missing userID in job payload")
}
prompt, _ := job.Payload["prompt"].(string)
aspectRatio, _ := job.Payload["aspectRatio"].(string)
sendUserEvent(pub, userID, &realtime.SSEEvent{
Type: realtime.EventGenerationStarted,
JobID: job.ID,
Message: "Starting video generation...",
})
sendUserEvent(pub, userID, &realtime.SSEEvent{
Type: realtime.EventGenerationProgress,
JobID: job.ID,
Progress: 10,
Message: "Initializing video generation (this takes 2-5 minutes)...",
})
start := time.Now()
resp, err := mg.GenerateVideo(ctx, mediagen.VideoRequest{
Prompt: prompt,
AspectRatio: aspectRatio,
})
elapsed := time.Since(start)
if err != nil {
logger.Error("video generation failed", "error", err, "job_id", job.ID)
sendUserEvent(pub, userID, &realtime.SSEEvent{
Type: realtime.EventGenerationFailed,
JobID: job.ID,
Error: "Video generation failed: " + err.Error(),
})
return err
}
// Build videos array matching frontend VideoResult shape:
// { videos: [{ data, isUrl, mimeType }], provider, latencyMs }
videos := make([]map[string]any, 0, len(resp.Videos))
for i, vid := range resp.Videos {
videoURL := vid.URL
// Persist to storage if available.
// Prefer vid.Data (already downloaded by provider adapter) over re-downloading from URL.
// Provider URLs (e.g., Gemini API) often require authentication and fail with plain GET.
if store != nil {
storagePath := fmt.Sprintf("media/%s/videos/%s_%d.mp4", userID, job.ID, i)
var videoData []byte
if len(vid.Data) > 0 {
videoData = vid.Data
} else if vid.URL != "" {
downloaded, downloadErr := downloadURL(ctx, vid.URL)
if downloadErr != nil {
logger.Warn("failed to download video from provider", "error", downloadErr, "job_id", job.ID)
} else {
videoData = downloaded
}
}
if len(videoData) > 0 {
persistedURL, uploadErr := store.Upload(ctx, storagePath, videoData, "video/mp4")
if uploadErr != nil {
logger.Warn("failed to persist video to storage", "error", uploadErr, "job_id", job.ID)
} else {
videoURL = persistedURL
}
}
}
videos = append(videos, map[string]any{
"data": videoURL,
"isUrl": true,
"mimeType": "video/mp4",
})
}
sendUserEvent(pub, userID, &realtime.SSEEvent{
Type: realtime.EventGenerationComplete,
JobID: job.ID,
Progress: 100,
Message: "Complete",
Result: map[string]any{
"videos": videos,
"provider": resp.Provider,
"latencyMs": elapsed.Milliseconds(),
},
})
logger.Info("video generation complete",
"job_id", job.ID, "provider", resp.Provider, "elapsed", elapsed)
return nil
}
}
// TextHandler returns a queue.Handler that processes text generation jobs with streaming.
func TextHandler(tg *textgen.Manager, pub realtime.EventPublisher, logger *logging.Logger) queue.Handler {
return func(ctx context.Context, job *queue.Job) error {
userID, _ := job.Payload["userID"].(string)
if userID == "" {
return fmt.Errorf("missing userID in job payload")
}
prompt, _ := job.Payload["prompt"].(string)
systemPrompt, _ := job.Payload["systemPrompt"].(string)
maxTokens := 0
if mt, ok := job.Payload["maxTokens"].(float64); ok {
maxTokens = int(mt)
}
sendUserEvent(pub, userID, &realtime.SSEEvent{
Type: realtime.EventGenerationStarted,
JobID: job.ID,
Message: "Starting text generation...",
})
streamID := job.ID
err := tg.GenerateStream(ctx, textgen.TextRequest{
Prompt: prompt,
SystemPrompt: systemPrompt,
MaxTokens: maxTokens,
Temperature: 0.7,
}, func(chunk textgen.StreamChunk) {
sendUserEvent(pub, userID, &realtime.SSEEvent{
Type: realtime.MessageTypeAIChatChunk,
JobID: job.ID,
Result: realtime.AIChunkData{
StreamID: streamID,
Text: chunk.Text,
Done: chunk.Done,
Provider: chunk.Provider,
},
})
})
if err != nil {
logger.Error("text generation failed", "error", err, "job_id", job.ID)
sendUserEvent(pub, userID, &realtime.SSEEvent{
Type: realtime.EventGenerationFailed,
JobID: job.ID,
Error: "Text generation failed: " + err.Error(),
})
return err
}
sendUserEvent(pub, userID, &realtime.SSEEvent{
Type: realtime.EventGenerationComplete,
JobID: job.ID,
Progress: 100,
Message: "Complete",
})
logger.Info("text generation complete", "job_id", job.ID)
return nil
}
}
// ChatResponseHandler returns a queue.Handler that generates AI chat responses
// and streams them to a channel (e.g., channel:general) for all participants.
func ChatResponseHandler(tg *textgen.Manager, pub realtime.EventPublisher, logger *logging.Logger) queue.Handler {
return func(ctx context.Context, job *queue.Job) error {
content, _ := job.Payload["content"].(string)
channel, _ := job.Payload["channel"].(string)
if channel == "" {
channel = "channel:general"
}
streamID := job.ID
err := tg.GenerateStream(ctx, textgen.TextRequest{
Prompt: content,
SystemPrompt: "You are a helpful AI assistant in a chat room. Keep responses concise, friendly, and under 200 words.",
MaxTokens: 300,
Temperature: 0.7,
}, func(chunk textgen.StreamChunk) {
pub.SendToChannel(channel, &realtime.SSEEvent{
Type: realtime.MessageTypeAIChatChunk,
JobID: job.ID,
Result: realtime.AIChunkData{
StreamID: streamID,
Text: chunk.Text,
Done: chunk.Done,
Provider: chunk.Provider,
},
})
})
if err != nil {
logger.Error("AI chat response failed", "error", err, "job_id", job.ID)
return err
}
logger.Info("AI chat response complete", "job_id", job.ID)
return nil
}
}
// downloadURL fetches content from a URL and returns the bytes.
// Used to download provider-hosted videos before persisting to storage.
func downloadURL(ctx context.Context, url string) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
resp, err := httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("download: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("download: status %d", resp.StatusCode)
}
// Limit body to 500MB to prevent OOM from unexpected large responses.
const maxBodySize = 500 << 20
data, err := io.ReadAll(io.LimitReader(resp.Body, maxBodySize))
if err != nil {
return nil, fmt.Errorf("read body: %w", err)
}
return data, nil
}