235 lines
7.5 KiB
Go
235 lines
7.5 KiB
Go
package handlers
|
|
|
|
import (
|
|
"net/http"
|
|
|
|
"git.threesix.ai/jordan/persona-community-2/pkg/app"
|
|
"git.threesix.ai/jordan/persona-community-2/pkg/auth"
|
|
"git.threesix.ai/jordan/persona-community-2/pkg/httperror"
|
|
"git.threesix.ai/jordan/persona-community-2/pkg/httpresponse"
|
|
"git.threesix.ai/jordan/persona-community-2/pkg/logging"
|
|
"git.threesix.ai/jordan/persona-community-2/pkg/queue"
|
|
"git.threesix.ai/jordan/persona-community-2/pkg/realtime"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
)
|
|
|
|
// Generate handles HTTP requests for AI generation endpoints.
|
|
// All generation is async: validate request, enqueue job, return 202 with job ID.
|
|
// The worker processes jobs and sends results via Redis → SSE.
|
|
// Job status can be polled via GET /generate/jobs/{id} as a fallback to SSE.
|
|
type Generate struct {
|
|
queue queue.Producer
|
|
jobReader queue.JobReader
|
|
sseHub *realtime.SSEHub
|
|
logger *logging.Logger
|
|
}
|
|
|
|
// NewGenerate creates a new Generate handler with injected dependencies.
|
|
func NewGenerate(q queue.Producer, jr queue.JobReader, hub *realtime.SSEHub, logger *logging.Logger) *Generate {
|
|
return &Generate{
|
|
queue: q,
|
|
jobReader: jr,
|
|
sseHub: hub,
|
|
logger: logger.WithComponent("GenerateHandler"),
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Image generation (async - returns job ID, results via SSE)
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// GenerateImageRequest is the request body for image generation.
|
|
type GenerateImageRequest struct {
|
|
Prompt string `json:"prompt" validate:"required,min=1,max=2000"`
|
|
Count int `json:"count"`
|
|
AspectRatio string `json:"aspectRatio"`
|
|
}
|
|
|
|
// GenerateAccepted is the immediate HTTP response with the job ID.
|
|
type GenerateAccepted struct {
|
|
JobID string `json:"jobId"`
|
|
}
|
|
|
|
// GenerateImage queues an image generation job.
|
|
// Returns immediately with job ID. Results come via SSE events:
|
|
// - generation_started: Job accepted
|
|
// - generation_progress: Progress updates
|
|
// - generation_complete: Images available
|
|
// - generation_failed: Error occurred
|
|
//
|
|
// Client should subscribe to SSE channel `user:<userId>` before calling.
|
|
func (h *Generate) GenerateImage(w http.ResponseWriter, r *http.Request) error {
|
|
var req GenerateImageRequest
|
|
if err := app.BindAndValidate(r, &req); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Set defaults
|
|
if req.Count == 0 {
|
|
req.Count = 1
|
|
}
|
|
if req.Count > 4 {
|
|
req.Count = 4
|
|
}
|
|
|
|
user := auth.GetUser(r.Context())
|
|
if user == nil {
|
|
return httperror.Unauthorized("authentication required")
|
|
}
|
|
|
|
jobID, err := h.queue.Enqueue(r.Context(), "generate_image", map[string]any{
|
|
"prompt": req.Prompt,
|
|
"count": req.Count,
|
|
"aspectRatio": req.AspectRatio,
|
|
"userID": user.ID,
|
|
})
|
|
if err != nil {
|
|
h.logger.Error("failed to enqueue image job", "error", err)
|
|
return httperror.Internal("failed to queue image generation")
|
|
}
|
|
|
|
h.logger.Info("image generation queued", "jobId", jobID, "userID", user.ID)
|
|
|
|
httpresponse.Accepted(w, r, GenerateAccepted{JobID: jobID})
|
|
return nil
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Video generation (async - takes 2-5 minutes)
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// GenerateVideoRequest is the request body for video generation.
|
|
type GenerateVideoRequest struct {
|
|
Prompt string `json:"prompt" validate:"required,min=1,max=2000"`
|
|
AspectRatio string `json:"aspectRatio"`
|
|
Duration string `json:"duration"`
|
|
}
|
|
|
|
// GenerateVideo queues a video generation job.
|
|
// Returns immediately with job ID. Results come via SSE events.
|
|
func (h *Generate) GenerateVideo(w http.ResponseWriter, r *http.Request) error {
|
|
var req GenerateVideoRequest
|
|
if err := app.BindAndValidate(r, &req); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Validate video aspect ratio (Veo only supports 16:9 and 9:16)
|
|
if req.AspectRatio != "" && req.AspectRatio != "16:9" && req.AspectRatio != "9:16" {
|
|
return httperror.BadRequest("video only supports 16:9 and 9:16 aspect ratios")
|
|
}
|
|
|
|
user := auth.GetUser(r.Context())
|
|
if user == nil {
|
|
return httperror.Unauthorized("authentication required")
|
|
}
|
|
|
|
jobID, err := h.queue.Enqueue(r.Context(), "generate_video", map[string]any{
|
|
"prompt": req.Prompt,
|
|
"aspectRatio": req.AspectRatio,
|
|
"duration": req.Duration,
|
|
"userID": user.ID,
|
|
})
|
|
if err != nil {
|
|
h.logger.Error("failed to enqueue video job", "error", err)
|
|
return httperror.Internal("failed to queue video generation")
|
|
}
|
|
|
|
h.logger.Info("video generation queued", "jobId", jobID, "userID", user.ID)
|
|
|
|
httpresponse.Accepted(w, r, GenerateAccepted{JobID: jobID})
|
|
return nil
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Text generation (async - returns job ID, results via SSE with streaming chunks)
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// GenerateTextRequest is the request body for text generation.
|
|
type GenerateTextRequest struct {
|
|
Prompt string `json:"prompt" validate:"required,min=1,max=5000"`
|
|
SystemPrompt string `json:"systemPrompt"`
|
|
MaxTokens int `json:"maxTokens"`
|
|
}
|
|
|
|
// GenerateText queues a text generation job.
|
|
// Returns immediately with job ID. Chunks come via SSE as ai_chat_chunk events.
|
|
func (h *Generate) GenerateText(w http.ResponseWriter, r *http.Request) error {
|
|
var req GenerateTextRequest
|
|
if err := app.BindAndValidate(r, &req); err != nil {
|
|
return err
|
|
}
|
|
|
|
user := auth.GetUser(r.Context())
|
|
if user == nil {
|
|
return httperror.Unauthorized("authentication required")
|
|
}
|
|
|
|
jobID, err := h.queue.Enqueue(r.Context(), "generate_text", map[string]any{
|
|
"prompt": req.Prompt,
|
|
"systemPrompt": req.SystemPrompt,
|
|
"maxTokens": req.MaxTokens,
|
|
"userID": user.ID,
|
|
})
|
|
if err != nil {
|
|
h.logger.Error("failed to enqueue text job", "error", err)
|
|
return httperror.Internal("failed to queue text generation")
|
|
}
|
|
|
|
h.logger.Info("text generation queued", "jobId", jobID, "userID", user.ID)
|
|
|
|
httpresponse.Accepted(w, r, GenerateAccepted{JobID: jobID})
|
|
return nil
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Job status (poll fallback for SSE)
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// GetJobStatus returns the current status of a generation job.
|
|
// This is a poll-based fallback for clients that can't use SSE.
|
|
func (h *Generate) GetJobStatus(w http.ResponseWriter, r *http.Request) error {
|
|
jobID := chi.URLParam(r, "id")
|
|
if jobID == "" {
|
|
return httperror.BadRequest("job ID is required")
|
|
}
|
|
|
|
job, err := h.jobReader.GetJob(r.Context(), jobID)
|
|
if err != nil {
|
|
if err == queue.ErrJobNotFound {
|
|
return httperror.NotFound("job not found")
|
|
}
|
|
h.logger.Error("failed to get job status", "error", err, "job_id", jobID)
|
|
return httperror.Internal("failed to get job status")
|
|
}
|
|
|
|
resp := map[string]any{
|
|
"id": job.ID,
|
|
"type": job.Type,
|
|
"status": string(job.Status),
|
|
"createdAt": job.CreatedAt,
|
|
}
|
|
if job.StartedAt != nil {
|
|
resp["startedAt"] = job.StartedAt
|
|
}
|
|
if job.CompletedAt != nil {
|
|
resp["completedAt"] = job.CompletedAt
|
|
}
|
|
if job.Error != "" {
|
|
resp["error"] = job.Error
|
|
}
|
|
|
|
httpresponse.OK(w, r, resp)
|
|
return nil
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// SSE Events endpoint
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// Events returns the SSE handler for event subscriptions.
|
|
// Mount at /api/events.
|
|
func (h *Generate) Events() http.Handler {
|
|
return realtime.NewSSEHandler(h.sseHub, h.logger.Logger)
|
|
}
|