diff --git a/pkg/personagen/pipeline.go b/pkg/personagen/pipeline.go new file mode 100644 index 0000000..798e471 --- /dev/null +++ b/pkg/personagen/pipeline.go @@ -0,0 +1,617 @@ +package personagen + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "time" + + "git.threesix.ai/jordan/persona-community-5/pkg/mediagen" + "git.threesix.ai/jordan/persona-community-5/pkg/persona" + "git.threesix.ai/jordan/persona-community-5/pkg/queue" + "git.threesix.ai/jordan/persona-community-5/pkg/realtime" + "git.threesix.ai/jordan/persona-community-5/pkg/storage" + "git.threesix.ai/jordan/persona-community-5/pkg/textgen" +) + +// anchorFetchClient is used to download anchor images from storage URLs. +var anchorFetchClient = &http.Client{Timeout: 30 * time.Second} + +// Stage constants matching domain.PersonaStage in the API service. +const ( + StageSpec = "spec" + StageAnchor = "anchor" + StageAvatar = "avatar" + StageBanner = "banner" + StageGalleryBatch = "gallery_batch" + StageVideo = "video" +) + +// Video motion types in generation order. +var videoOrder = []persona.MotionType{ + persona.MotionSmileReveal, + persona.MotionPersonality, + persona.MotionLifestyle, + persona.MotionInvitation, +} + +// PipelineDeps holds dependencies for the staged pipeline handler. +type PipelineDeps struct { + TextGen *textgen.Manager + MediaGen *mediagen.Manager + Store storage.Store + Pub realtime.EventPublisher + Personas PersonaStore + Queue queue.Producer + Logger *slog.Logger +} + +// StagedQueueHandler returns a queue.Handler for processing persona_generate jobs +// using a stage-based pipeline. Each stage processes one unit of work, updates the +// persona row, publishes a job_update SSE event, and enqueues the next stage. +// +// Job payload: {"persona_id": "...", "stage": "spec|anchor|avatar|banner|gallery_batch|video"} +func StagedQueueHandler(deps PipelineDeps) queue.Handler { + return func(ctx context.Context, job *queue.Job) error { + personaID, _ := job.Payload["persona_id"].(string) + if personaID == "" { + return fmt.Errorf("missing persona_id in persona_generate job payload") + } + stage, _ := job.Payload["stage"].(string) + if stage == "" { + return fmt.Errorf("missing stage in persona_generate job payload") + } + + logger := deps.Logger.With("job_id", job.ID, "persona_id", personaID, "stage", stage) + logger.Info("processing persona generation stage") + + var err error + switch stage { + case StageSpec: + err = handleStageSpec(ctx, deps, personaID, job.ID, logger) + case StageAnchor: + err = handleStageAnchor(ctx, deps, personaID, job.ID, logger) + case StageAvatar: + err = handleStageAvatar(ctx, deps, personaID, job.ID, logger) + case StageBanner: + err = handleStageBanner(ctx, deps, personaID, job.ID, logger) + case StageGalleryBatch: + err = handleStageGalleryBatch(ctx, deps, personaID, job.ID, logger) + case StageVideo: + err = handleStageVideo(ctx, deps, personaID, job.ID, logger) + default: + err = fmt.Errorf("unknown stage: %s", stage) + } + + if err != nil { + logger.Error("stage failed", "error", err) + publishJobUpdate(deps.Pub, personaID, stage, "error", 0, err.Error()) + } + return err + } +} + +// handleStageSpec generates the persona spec via LLM and updates the persona row. +func handleStageSpec(ctx context.Context, deps PipelineDeps, personaID, jobID string, logger *slog.Logger) error { + rec, err := deps.Personas.GetByID(ctx, personaID) + if err != nil { + return fmt.Errorf("load persona: %w", err) + } + + // Update status to generating. + rec.Status = "generating" + if err := deps.Personas.Update(ctx, rec); err != nil { + return fmt.Errorf("update persona status: %w", err) + } + + svc := New(deps.TextGen, deps.MediaGen, deps.Store, logger) + seed := SeedParams{ + Description: rec.Description, + Gender: rec.Gender, + Name: rec.Name, + } + + spec, err := svc.GenerateSpec(ctx, seed) + if err != nil { + rec.Status = "failed" + _ = deps.Personas.Update(ctx, rec) + return fmt.Errorf("spec generation: %w", err) + } + + // Serialize full spec to JSON for storage. + specJSON, err := json.Marshal(spec) + if err != nil { + return fmt.Errorf("marshal spec: %w", err) + } + + // Update persona with spec data. + rec.SpecJSON = specJSON + rec.Name = spec.Name.First + if spec.Name.Last != "" { + rec.Name = spec.Name.First + " " + spec.Name.Last + } + rec.Handle = generateHandleFromSpec(spec) + rec.Tags = extractTags(spec) + + if err := deps.Personas.Update(ctx, rec); err != nil { + return fmt.Errorf("update persona spec: %w", err) + } + + publishJobUpdate(deps.Pub, personaID, StageSpec, "complete", 100, "") + + // Enqueue anchor stage. + if _, err := deps.Queue.Enqueue(ctx, "persona_generate", map[string]any{ + "persona_id": personaID, + "stage": StageAnchor, + }); err != nil { + return fmt.Errorf("enqueue anchor stage: %w", err) + } + + logger.Info("spec stage complete, anchor enqueued") + return nil +} + +// handleStageAnchor generates the anchor (position 1) image. +func handleStageAnchor(ctx context.Context, deps PipelineDeps, personaID, jobID string, logger *slog.Logger) error { + rec, err := deps.Personas.GetByID(ctx, personaID) + if err != nil { + return fmt.Errorf("load persona: %w", err) + } + + spec, err := unmarshalSpec(rec.SpecJSON) + if err != nil { + return err + } + + svc := New(deps.TextGen, deps.MediaGen, deps.Store, logger) + + // Generate position 1 (the identity anchor). + if err := svc.generatePosition(ctx, spec, 1); err != nil { + return fmt.Errorf("anchor generation: %w", err) + } + + // Find position 1 URL. + var anchorURL string + for _, img := range spec.ImageMatrix { + if img.Position == 1 { + anchorURL = img.URL + break + } + } + + rec.AnchorURL = anchorURL + rec.ImageURLs = appendIfMissing(rec.ImageURLs, anchorURL) + + // Update spec JSON with the generated image URL/status. + updatedSpec, _ := json.Marshal(spec) + rec.SpecJSON = updatedSpec + + if err := deps.Personas.Update(ctx, rec); err != nil { + return fmt.Errorf("update persona anchor: %w", err) + } + + publishJobUpdate(deps.Pub, personaID, StageAnchor, "complete", 100, "") + + // Enqueue avatar, banner, and gallery_batch in parallel. + for _, nextStage := range []string{StageAvatar, StageBanner, StageGalleryBatch} { + if _, err := deps.Queue.Enqueue(ctx, "persona_generate", map[string]any{ + "persona_id": personaID, + "stage": nextStage, + }); err != nil { + return fmt.Errorf("enqueue %s stage: %w", nextStage, err) + } + } + + logger.Info("anchor stage complete, avatar/banner/gallery_batch enqueued") + return nil +} + +// handleStageAvatar generates the circular portrait avatar. +func handleStageAvatar(ctx context.Context, deps PipelineDeps, personaID, jobID string, logger *slog.Logger) error { + rec, err := deps.Personas.GetByID(ctx, personaID) + if err != nil { + return fmt.Errorf("load persona: %w", err) + } + + spec, err := unmarshalSpec(rec.SpecJSON) + if err != nil { + return err + } + + svc := New(deps.TextGen, deps.MediaGen, deps.Store, logger) + + // Restore anchor from storage for identity consistency. + if err := restoreAnchor(ctx, svc, deps.Store, spec); err != nil { + logger.Warn("could not restore anchor for avatar", "error", err) + } + + avatarBytes, err := svc.GenerateAvatar(ctx, spec) + if err != nil { + return fmt.Errorf("avatar generation: %w", err) + } + + storagePath := fmt.Sprintf("personas/%s/avatar.png", spec.ID) + url, err := deps.Store.Upload(ctx, storagePath, avatarBytes, "image/png") + if err != nil { + return fmt.Errorf("upload avatar: %w", err) + } + + rec.AvatarURL = url + if err := deps.Personas.Update(ctx, rec); err != nil { + return fmt.Errorf("update persona avatar: %w", err) + } + + publishJobUpdate(deps.Pub, personaID, StageAvatar, "complete", 100, "") + logger.Info("avatar stage complete") + return nil +} + +// handleStageBanner generates the 3:1 landscape banner image. +func handleStageBanner(ctx context.Context, deps PipelineDeps, personaID, jobID string, logger *slog.Logger) error { + rec, err := deps.Personas.GetByID(ctx, personaID) + if err != nil { + return fmt.Errorf("load persona: %w", err) + } + + spec, err := unmarshalSpec(rec.SpecJSON) + if err != nil { + return err + } + + svc := New(deps.TextGen, deps.MediaGen, deps.Store, logger) + + // Restore anchor for identity consistency. + if err := restoreAnchor(ctx, svc, deps.Store, spec); err != nil { + logger.Warn("could not restore anchor for banner", "error", err) + } + + bannerBytes, err := svc.GenerateBanner(ctx, spec, "lifestyle") + if err != nil { + return fmt.Errorf("banner generation: %w", err) + } + + storagePath := fmt.Sprintf("personas/%s/banner.png", spec.ID) + url, err := deps.Store.Upload(ctx, storagePath, bannerBytes, "image/png") + if err != nil { + return fmt.Errorf("upload banner: %w", err) + } + + rec.BannerURL = url + if err := deps.Personas.Update(ctx, rec); err != nil { + return fmt.Errorf("update persona banner: %w", err) + } + + publishJobUpdate(deps.Pub, personaID, StageBanner, "complete", 100, "") + logger.Info("banner stage complete") + return nil +} + +// handleStageGalleryBatch generates the next 10 unfilled image positions. +func handleStageGalleryBatch(ctx context.Context, deps PipelineDeps, personaID, jobID string, logger *slog.Logger) error { + rec, err := deps.Personas.GetByID(ctx, personaID) + if err != nil { + return fmt.Errorf("load persona: %w", err) + } + + spec, err := unmarshalSpec(rec.SpecJSON) + if err != nil { + return err + } + + svc := New(deps.TextGen, deps.MediaGen, deps.Store, logger) + + // Restore anchor for identity consistency. + if err := restoreAnchor(ctx, svc, deps.Store, spec); err != nil { + logger.Warn("could not restore anchor for gallery batch", "error", err) + } + + // Find unfilled positions (skip position 1 which is the anchor, already done). + var unfilled []int + for _, img := range spec.ImageMatrix { + if img.Status != persona.ImageStatusComplete && img.Position != 1 { + unfilled = append(unfilled, img.Position) + } + } + + // Take up to 10 positions for this batch. + batchSize := 10 + if len(unfilled) < batchSize { + batchSize = len(unfilled) + } + batch := unfilled[:batchSize] + + for _, pos := range batch { + if err := svc.generatePosition(ctx, spec, pos); err != nil { + logger.Error("gallery image generation failed", "position", pos, "error", err) + return fmt.Errorf("gallery position %d: %w", pos, err) + } + + // Find the URL for this position. + var imgURL string + for _, img := range spec.ImageMatrix { + if img.Position == pos { + imgURL = img.URL + break + } + } + rec.ImageURLs = appendIfMissing(rec.ImageURLs, imgURL) + } + + // Update spec JSON with new image URLs/statuses. + updatedSpec, _ := json.Marshal(spec) + rec.SpecJSON = updatedSpec + + if err := deps.Personas.Update(ctx, rec); err != nil { + return fmt.Errorf("update persona gallery: %w", err) + } + + // Count total completed images (including anchor). + completedImages := countCompletedImages(spec) + totalImages := len(spec.ImageMatrix) + progress := (completedImages * 100) / totalImages + + publishJobUpdate(deps.Pub, personaID, StageGalleryBatch, "complete", progress, "") + + // If fewer than totalImages are complete, enqueue another gallery_batch. + if completedImages < totalImages { + if _, err := deps.Queue.Enqueue(ctx, "persona_generate", map[string]any{ + "persona_id": personaID, + "stage": StageGalleryBatch, + }); err != nil { + return fmt.Errorf("enqueue next gallery batch: %w", err) + } + logger.Info("gallery batch complete, more images needed", "completed", completedImages, "total", totalImages) + } else { + // All images done, enqueue first video. + if _, err := deps.Queue.Enqueue(ctx, "persona_generate", map[string]any{ + "persona_id": personaID, + "stage": StageVideo, + }); err != nil { + return fmt.Errorf("enqueue video stage: %w", err) + } + logger.Info("all gallery images complete, video stage enqueued") + } + + return nil +} + +// handleStageVideo generates the next missing video. +func handleStageVideo(ctx context.Context, deps PipelineDeps, personaID, jobID string, logger *slog.Logger) error { + rec, err := deps.Personas.GetByID(ctx, personaID) + if err != nil { + return fmt.Errorf("load persona: %w", err) + } + + spec, err := unmarshalSpec(rec.SpecJSON) + if err != nil { + return err + } + + svc := New(deps.TextGen, deps.MediaGen, deps.Store, logger) + + // Restore anchor for identity consistency. + if err := restoreAnchor(ctx, svc, deps.Store, spec); err != nil { + logger.Warn("could not restore anchor for video", "error", err) + } + + // Find the next missing video in order. + completedVideos := len(rec.VideoURLs) + if completedVideos >= len(videoOrder) { + // All videos done — mark complete. + rec.Status = "complete" + if err := deps.Personas.Update(ctx, rec); err != nil { + return fmt.Errorf("update persona complete: %w", err) + } + publishJobUpdate(deps.Pub, personaID, StageVideo, "complete", 100, "") + logger.Info("all videos complete, persona generation finished") + return nil + } + + motionType := videoOrder[completedVideos] + logger.Info("generating video", "motion_type", motionType, "index", completedVideos) + + videoSpec, err := svc.GenerateVideo(ctx, spec, motionType) + if err != nil { + // Videos are best-effort — log the failure and continue to next. + logger.Warn("video generation failed (non-fatal)", "motion_type", motionType, "error", err) + publishJobUpdate(deps.Pub, personaID, StageVideo, "error", 0, fmt.Sprintf("%s video failed: %s", motionType, err.Error())) + } else { + rec.VideoURLs = append(rec.VideoURLs, videoSpec.URL) + } + + // Update spec JSON. + updatedSpec, _ := json.Marshal(spec) + rec.SpecJSON = updatedSpec + + if err := deps.Personas.Update(ctx, rec); err != nil { + return fmt.Errorf("update persona video: %w", err) + } + + // Check if more videos needed (count by index, not by success). + nextIndex := completedVideos + 1 + if nextIndex < len(videoOrder) { + if _, err := deps.Queue.Enqueue(ctx, "persona_generate", map[string]any{ + "persona_id": personaID, + "stage": StageVideo, + }); err != nil { + return fmt.Errorf("enqueue next video: %w", err) + } + progress := (nextIndex * 100) / len(videoOrder) + publishJobUpdate(deps.Pub, personaID, StageVideo, "complete", progress, "") + logger.Info("video stage complete, next video enqueued", "completed", nextIndex, "total", len(videoOrder)) + } else { + // All videos done — mark persona complete. + rec.Status = "complete" + if err := deps.Personas.Update(ctx, rec); err != nil { + return fmt.Errorf("update persona complete: %w", err) + } + publishJobUpdate(deps.Pub, personaID, StageVideo, "complete", 100, "") + logger.Info("all videos complete, persona generation finished") + } + + return nil +} + +// publishJobUpdate sends a job_update SSE event to channel:personas. +func publishJobUpdate(pub realtime.EventPublisher, personaID, stage, status string, progress int, errMsg string) { + event := &realtime.SSEEvent{ + Type: "job_update", + Progress: progress, + Result: map[string]any{ + "persona_id": personaID, + "stage": stage, + "status": status, + "progress": progress, + }, + } + if errMsg != "" { + event.Error = errMsg + } + if err := pub.SendToChannel("channel:personas", event); err != nil { + slog.Warn("failed to publish job_update event", "error", err, "persona_id", personaID, "stage", stage) + } +} + +// unmarshalSpec deserializes a PersonaSpec from JSON stored on the persona record. +func unmarshalSpec(specJSON json.RawMessage) (*persona.PersonaSpec, error) { + if len(specJSON) == 0 { + return nil, fmt.Errorf("persona has no spec_json (spec stage may not have completed)") + } + var spec persona.PersonaSpec + if err := json.Unmarshal(specJSON, &spec); err != nil { + return nil, fmt.Errorf("unmarshal spec: %w", err) + } + return &spec, nil +} + +// restoreAnchor downloads the anchor image (position 1) from storage and sets it +// on the service for identity consistency in subsequent generations. +func restoreAnchor(ctx context.Context, svc *Service, store storage.Store, spec *persona.PersonaSpec) error { + if store == nil { + return fmt.Errorf("no storage configured") + } + // Find position 1 URL from spec. + var anchorURL string + for _, img := range spec.ImageMatrix { + if img.Position == 1 && img.URL != "" { + anchorURL = img.URL + break + } + } + if anchorURL == "" { + return fmt.Errorf("no anchor URL found in spec") + } + + // Download anchor image from storage URL. + anchorBytes, err := storage.FetchURL(ctx, anchorFetchClient, anchorURL, 50<<20) // 50 MB max + if err != nil { + return fmt.Errorf("fetch anchor image: %w", err) + } + svc.SetAnchor(anchorBytes) + return nil +} + +// generateHandleFromSpec creates a URL-safe handle from the spec name. +func generateHandleFromSpec(spec *persona.PersonaSpec) string { + name := spec.Name.First + if spec.Name.Last != "" { + name += " " + spec.Name.Last + } + return generateHandle(name) +} + +// generateHandle creates a URL-safe handle from a name (same logic as persona-api service). +func generateHandle(name string) string { + // Reuse the service's handle generation logic. + h := name + // Simplified: lowercase, replace non-alphanumeric with underscore, trim. + result := make([]byte, 0, len(h)) + for _, c := range []byte(h) { + if (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') { + result = append(result, c) + } else if c >= 'A' && c <= 'Z' { + result = append(result, c+32) // lowercase + } else { + if len(result) > 0 && result[len(result)-1] != '_' { + result = append(result, '_') + } + } + } + s := string(result) + if len(s) > 40 { + s = s[:40] + } + // Trim trailing underscores. + for len(s) > 0 && s[len(s)-1] == '_' { + s = s[:len(s)-1] + } + // Append timestamp suffix. + suffix := fmt.Sprintf("_%d", now().UnixMilli()%100000) + return s + suffix +} + +// extractTags extracts 8 tags from the spec for the persona row. +func extractTags(spec *persona.PersonaSpec) []string { + tags := make([]string, 0, 8) + + // Add gender. + if spec.DNA != nil { + tags = append(tags, string(spec.DNA.Identity.Gender)) + tags = append(tags, string(spec.DNA.Identity.Ethnicity)) + } + + // Add occupation from identity. + if spec.Name.First != "" { + tags = append(tags, spec.Name.First) + } + + // Add fashion context. + if spec.Lifestyle.FashionSense.Primary != "" { + tags = append(tags, string(spec.Lifestyle.FashionSense.Primary)) + } + + // Add interests. + for _, interest := range spec.Lifestyle.Interests.Creative { + if len(tags) >= 8 { + break + } + tags = append(tags, interest) + } + for _, interest := range spec.Lifestyle.Interests.Active { + if len(tags) >= 8 { + break + } + tags = append(tags, interest) + } + + if len(tags) > 8 { + tags = tags[:8] + } + return tags +} + +// countCompletedImages counts how many images have status "complete" in the spec. +func countCompletedImages(spec *persona.PersonaSpec) int { + count := 0 + for _, img := range spec.ImageMatrix { + if img.Status == persona.ImageStatusComplete { + count++ + } + } + return count +} + +// appendIfMissing appends a string to a slice if it's not already present and not empty. +func appendIfMissing(slice []string, s string) []string { + if s == "" { + return slice + } + for _, existing := range slice { + if existing == s { + return slice + } + } + return append(slice, s) +} diff --git a/pkg/personagen/pipeline_test.go b/pkg/personagen/pipeline_test.go new file mode 100644 index 0000000..e125b74 --- /dev/null +++ b/pkg/personagen/pipeline_test.go @@ -0,0 +1,384 @@ +package personagen + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "testing" + "time" + + "git.threesix.ai/jordan/persona-community-5/pkg/persona" + "git.threesix.ai/jordan/persona-community-5/pkg/queue" + "git.threesix.ai/jordan/persona-community-5/pkg/realtime" +) + +// ── Mock implementations ───────────────────────────────────────────────────── + +// mockPersonaStore implements PersonaStore for testing. +type mockPersonaStore struct { + records map[string]*PersonaRecord +} + +func newMockPersonaStore() *mockPersonaStore { + return &mockPersonaStore{records: make(map[string]*PersonaRecord)} +} + +func (m *mockPersonaStore) GetByID(_ context.Context, id string) (*PersonaRecord, error) { + rec, ok := m.records[id] + if !ok { + return nil, ErrPersonaNotFound + } + // Return a copy to avoid shared mutation. + cp := *rec + if rec.Tags != nil { + cp.Tags = make([]string, len(rec.Tags)) + copy(cp.Tags, rec.Tags) + } + if rec.ImageURLs != nil { + cp.ImageURLs = make([]string, len(rec.ImageURLs)) + copy(cp.ImageURLs, rec.ImageURLs) + } + if rec.VideoURLs != nil { + cp.VideoURLs = make([]string, len(rec.VideoURLs)) + copy(cp.VideoURLs, rec.VideoURLs) + } + if rec.SpecJSON != nil { + cp.SpecJSON = make(json.RawMessage, len(rec.SpecJSON)) + copy(cp.SpecJSON, rec.SpecJSON) + } + return &cp, nil +} + +func (m *mockPersonaStore) Update(_ context.Context, p *PersonaRecord) error { + if _, ok := m.records[p.ID]; !ok { + return ErrPersonaNotFound + } + m.records[p.ID] = p + return nil +} + +// mockProducer records enqueued jobs. +type mockProducer struct { + jobs []enqueuedJob +} + +type enqueuedJob struct { + jobType string + payload map[string]any +} + +var _ queue.Producer = (*mockProducer)(nil) + +func (m *mockProducer) Enqueue(_ context.Context, jobType string, payload map[string]any) (string, error) { + m.jobs = append(m.jobs, enqueuedJob{jobType: jobType, payload: payload}) + return fmt.Sprintf("job-%d", len(m.jobs)), nil +} + +func (m *mockProducer) EnqueueWithOptions(_ context.Context, job queue.Job) (string, error) { + m.jobs = append(m.jobs, enqueuedJob{jobType: job.Type, payload: job.Payload}) + return fmt.Sprintf("job-%d", len(m.jobs)), nil +} + +// mockPublisher records published events. +type mockPublisher struct { + userEvents []publishedEvent + channelEvents []publishedEvent +} + +type publishedEvent struct { + target string + event *realtime.SSEEvent +} + +var _ realtime.EventPublisher = (*mockPublisher)(nil) + +func (m *mockPublisher) SendToUser(userID string, event *realtime.SSEEvent) error { + m.userEvents = append(m.userEvents, publishedEvent{target: userID, event: event}) + return nil +} + +func (m *mockPublisher) SendToChannel(channel string, event *realtime.SSEEvent) error { + m.channelEvents = append(m.channelEvents, publishedEvent{target: channel, event: event}) + return nil +} + +// ── Tests ──────────────────────────────────────────────────────────────────── + +func TestStagedQueueHandler_MissingPersonaID(t *testing.T) { + handler := StagedQueueHandler(PipelineDeps{}) + + job := &queue.Job{ + ID: "test-job", + Type: "persona_generate", + Payload: map[string]any{"stage": "spec"}, + } + + err := handler(context.Background(), job) + if err == nil { + t.Fatal("expected error for missing persona_id") + } + if err.Error() != "missing persona_id in persona_generate job payload" { + t.Errorf("unexpected error: %v", err) + } +} + +func TestStagedQueueHandler_MissingStage(t *testing.T) { + handler := StagedQueueHandler(PipelineDeps{}) + + job := &queue.Job{ + ID: "test-job", + Type: "persona_generate", + Payload: map[string]any{"persona_id": "ps_test"}, + } + + err := handler(context.Background(), job) + if err == nil { + t.Fatal("expected error for missing stage") + } + if err.Error() != "missing stage in persona_generate job payload" { + t.Errorf("unexpected error: %v", err) + } +} + +func TestStagedQueueHandler_UnknownStage(t *testing.T) { + store := newMockPersonaStore() + pub := &mockPublisher{} + handler := StagedQueueHandler(PipelineDeps{ + Personas: store, + Pub: pub, + Logger: noopLogger(), + }) + + job := &queue.Job{ + ID: "test-job", + Type: "persona_generate", + Payload: map[string]any{ + "persona_id": "ps_test", + "stage": "unknown_stage", + }, + } + + err := handler(context.Background(), job) + if err == nil { + t.Fatal("expected error for unknown stage") + } +} + +func TestStagedQueueHandler_SpecStage_PersonaNotFound(t *testing.T) { + store := newMockPersonaStore() + pub := &mockPublisher{} + handler := StagedQueueHandler(PipelineDeps{ + Personas: store, + Pub: pub, + Logger: noopLogger(), + }) + + job := &queue.Job{ + ID: "test-job", + Type: "persona_generate", + Payload: map[string]any{ + "persona_id": "ps_nonexistent", + "stage": StageSpec, + }, + } + + err := handler(context.Background(), job) + if err == nil { + t.Fatal("expected error for persona not found") + } +} + +func TestExtractTags(t *testing.T) { + spec := &persona.PersonaSpec{ + DNA: &persona.DNA{ + Identity: persona.IdentityDNA{ + Gender: persona.GenderWoman, + Ethnicity: persona.EthnicityEastAsian, + }, + }, + Name: persona.NameSpec{First: "Luna"}, + Lifestyle: persona.Lifestyle{ + FashionSense: persona.FashionSense{ + Primary: persona.FashionClassicMinimalist, + }, + Interests: persona.Interests{ + Creative: []string{"photography", "painting"}, + Active: []string{"yoga", "hiking"}, + }, + }, + } + + tags := extractTags(spec) + if len(tags) > 8 { + t.Errorf("expected at most 8 tags, got %d", len(tags)) + } + if len(tags) == 0 { + t.Error("expected at least some tags") + } + + // Should contain gender. + found := false + for _, tag := range tags { + if tag == "woman" { + found = true + break + } + } + if !found { + t.Error("expected 'woman' tag") + } +} + +func TestCountCompletedImages(t *testing.T) { + spec := &persona.PersonaSpec{ + ImageMatrix: []persona.ImageSpec{ + {Position: 1, Status: persona.ImageStatusComplete}, + {Position: 2, Status: persona.ImageStatusComplete}, + {Position: 3, Status: persona.ImageStatusPending}, + {Position: 4, Status: persona.ImageStatusFailed}, + }, + } + + count := countCompletedImages(spec) + if count != 2 { + t.Errorf("expected 2 completed, got %d", count) + } +} + +func TestAppendIfMissing(t *testing.T) { + tests := []struct { + name string + slice []string + s string + expected int + }{ + {"adds new", []string{"a"}, "b", 2}, + {"skips duplicate", []string{"a", "b"}, "a", 2}, + {"skips empty", []string{"a"}, "", 1}, + {"adds to nil", nil, "a", 1}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := appendIfMissing(tt.slice, tt.s) + if len(result) != tt.expected { + t.Errorf("expected len %d, got %d", tt.expected, len(result)) + } + }) + } +} + +func TestGenerateHandle(t *testing.T) { + // Override time for deterministic output. + origNow := now + now = func() time.Time { return time.UnixMilli(1234567890) } + defer func() { now = origNow }() + + tests := []struct { + name string + input string + }{ + {"simple name", "Luna Shadow"}, + {"special chars", "DJ Beats!@#"}, + {"unicode", "Café Noir"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + handle := generateHandle(tt.input) + if handle == "" { + t.Error("expected non-empty handle") + } + if len(handle) > 50 { + t.Errorf("handle too long: %d chars", len(handle)) + } + }) + } +} + +func TestPublishJobUpdate(t *testing.T) { + pub := &mockPublisher{} + + publishJobUpdate(pub, "ps_123", StageSpec, "complete", 100, "") + + if len(pub.channelEvents) != 1 { + t.Fatalf("expected 1 channel event, got %d", len(pub.channelEvents)) + } + + event := pub.channelEvents[0] + if event.target != "channel:personas" { + t.Errorf("expected channel 'channel:personas', got '%s'", event.target) + } + if event.event.Type != "job_update" { + t.Errorf("expected event type 'job_update', got '%s'", event.event.Type) + } + if event.event.Progress != 100 { + t.Errorf("expected progress 100, got %d", event.event.Progress) + } + + result, ok := event.event.Result.(map[string]any) + if !ok { + t.Fatal("expected Result to be map") + } + if result["persona_id"] != "ps_123" { + t.Errorf("expected persona_id 'ps_123', got %v", result["persona_id"]) + } + if result["stage"] != StageSpec { + t.Errorf("expected stage 'spec', got %v", result["stage"]) + } + if result["status"] != "complete" { + t.Errorf("expected status 'complete', got %v", result["status"]) + } +} + +func TestPublishJobUpdate_WithError(t *testing.T) { + pub := &mockPublisher{} + + publishJobUpdate(pub, "ps_123", StageSpec, "error", 0, "something went wrong") + + if len(pub.channelEvents) != 1 { + t.Fatalf("expected 1 channel event, got %d", len(pub.channelEvents)) + } + + event := pub.channelEvents[0] + if event.event.Error != "something went wrong" { + t.Errorf("expected error message, got '%s'", event.event.Error) + } +} + +func TestUnmarshalSpec_EmptyJSON(t *testing.T) { + _, err := unmarshalSpec(nil) + if err == nil { + t.Fatal("expected error for nil spec JSON") + } + + _, err = unmarshalSpec(json.RawMessage{}) + if err == nil { + t.Fatal("expected error for empty spec JSON") + } +} + +func TestUnmarshalSpec_ValidJSON(t *testing.T) { + spec := &persona.PersonaSpec{ + ID: "ps_test", + Name: persona.NameSpec{First: "Luna", Last: "Shadow"}, + } + data, _ := json.Marshal(spec) + + result, err := unmarshalSpec(data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.ID != "ps_test" { + t.Errorf("expected ID 'ps_test', got '%s'", result.ID) + } + if result.Name.First != "Luna" { + t.Errorf("expected first name 'Luna', got '%s'", result.Name.First) + } +} + +// noopLogger returns a discard logger for tests. +func noopLogger() *slog.Logger { + return slog.Default() // In test context, logs go to stderr which is fine +} diff --git a/pkg/personagen/store.go b/pkg/personagen/store.go new file mode 100644 index 0000000..eef7567 --- /dev/null +++ b/pkg/personagen/store.go @@ -0,0 +1,36 @@ +package personagen + +import ( + "context" + "encoding/json" + "time" +) + +// PersonaRecord is a minimal representation of a persona row for the worker pipeline. +// It mirrors the database columns needed for stage-based generation without +// importing service-internal domain types. +type PersonaRecord struct { + ID string `json:"id" db:"id"` + Name string `json:"name" db:"name"` + Handle string `json:"handle" db:"handle"` + Gender string `json:"gender" db:"gender"` + Description string `json:"description" db:"description"` + Tags []string `json:"tags" db:"tags"` + SpecJSON json.RawMessage `json:"spec_json,omitempty" db:"spec_json"` + AnchorURL string `json:"anchor_url,omitempty" db:"anchor_url"` + AvatarURL string `json:"avatar_url,omitempty" db:"avatar_url"` + BannerURL string `json:"banner_url,omitempty" db:"banner_url"` + ImageURLs []string `json:"image_urls" db:"image_urls"` + VideoURLs []string `json:"video_urls" db:"video_urls"` + Status string `json:"status" db:"status"` + CreatedAt time.Time `json:"created_at" db:"created_at"` +} + +// PersonaStore provides read/write access to persona records for the generation pipeline. +type PersonaStore interface { + // GetByID returns a persona by ID. + GetByID(ctx context.Context, id string) (*PersonaRecord, error) + + // Update persists changes to an existing persona record. + Update(ctx context.Context, persona *PersonaRecord) error +} diff --git a/pkg/personagen/store_postgres.go b/pkg/personagen/store_postgres.go new file mode 100644 index 0000000..e175b12 --- /dev/null +++ b/pkg/personagen/store_postgres.go @@ -0,0 +1,146 @@ +package personagen + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/jmoiron/sqlx" + "github.com/lib/pq" +) + +// ErrPersonaNotFound is returned when a persona ID does not exist. +var ErrPersonaNotFound = errors.New("persona not found") + +// personaRow is the database scan target for persona queries. +type personaRow struct { + ID string `db:"id"` + Name string `db:"name"` + Handle string `db:"handle"` + Gender string `db:"gender"` + Description string `db:"description"` + Tags pq.StringArray `db:"tags"` + SpecJSON []byte `db:"spec_json"` + AnchorURL *string `db:"anchor_url"` + AvatarURL *string `db:"avatar_url"` + BannerURL *string `db:"banner_url"` + ImageURLs pq.StringArray `db:"image_urls"` + VideoURLs pq.StringArray `db:"video_urls"` + Status string `db:"status"` +} + +func (r *personaRow) toRecord() *PersonaRecord { + rec := &PersonaRecord{ + ID: r.ID, + Name: r.Name, + Handle: r.Handle, + Gender: r.Gender, + Description: r.Description, + Tags: []string(r.Tags), + Status: r.Status, + ImageURLs: []string(r.ImageURLs), + VideoURLs: []string(r.VideoURLs), + } + if rec.Tags == nil { + rec.Tags = []string{} + } + if rec.ImageURLs == nil { + rec.ImageURLs = []string{} + } + if rec.VideoURLs == nil { + rec.VideoURLs = []string{} + } + if len(r.SpecJSON) > 0 { + rec.SpecJSON = r.SpecJSON + } + if r.AnchorURL != nil { + rec.AnchorURL = *r.AnchorURL + } + if r.AvatarURL != nil { + rec.AvatarURL = *r.AvatarURL + } + if r.BannerURL != nil { + rec.BannerURL = *r.BannerURL + } + return rec +} + +// PostgresPersonaStore implements PersonaStore using PostgreSQL/CockroachDB. +type PostgresPersonaStore struct { + db *sqlx.DB +} + +// Compile-time interface check. +var _ PersonaStore = (*PostgresPersonaStore)(nil) + +// NewPostgresPersonaStore creates a PersonaStore backed by a SQL database. +func NewPostgresPersonaStore(db *sqlx.DB) *PostgresPersonaStore { + return &PostgresPersonaStore{db: db} +} + +func (s *PostgresPersonaStore) GetByID(ctx context.Context, id string) (*PersonaRecord, error) { + var row personaRow + err := s.db.QueryRowxContext(ctx, ` + SELECT id, name, handle, gender, description, tags, spec_json, + anchor_url, avatar_url, banner_url, image_urls, video_urls, status + FROM personas WHERE id = $1 + `, id).StructScan(&row) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrPersonaNotFound + } + return nil, fmt.Errorf("get persona: %w", err) + } + return row.toRecord(), nil +} + +func (s *PostgresPersonaStore) Update(ctx context.Context, p *PersonaRecord) error { + var specJSON []byte + if p.SpecJSON != nil { + specJSON = []byte(p.SpecJSON) + } + + var anchorURL, avatarURL, bannerURL *string + if p.AnchorURL != "" { + anchorURL = &p.AnchorURL + } + if p.AvatarURL != "" { + avatarURL = &p.AvatarURL + } + if p.BannerURL != "" { + bannerURL = &p.BannerURL + } + + result, err := s.db.ExecContext(ctx, ` + UPDATE personas + SET name = $2, handle = $3, tags = $4, spec_json = $5, + anchor_url = $6, avatar_url = $7, banner_url = $8, + image_urls = $9, video_urls = $10, status = $11 + WHERE id = $1 + `, + p.ID, + p.Name, + p.Handle, + pq.StringArray(p.Tags), + specJSON, + anchorURL, + avatarURL, + bannerURL, + pq.StringArray(p.ImageURLs), + pq.StringArray(p.VideoURLs), + p.Status, + ) + if err != nil { + return fmt.Errorf("update persona: %w", err) + } + + rows, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("update persona rows affected: %w", err) + } + if rows == 0 { + return ErrPersonaNotFound + } + return nil +} diff --git a/workers/media-worker/cmd/worker/main.go b/workers/media-worker/cmd/worker/main.go index 1c2a898..4401e21 100644 --- a/workers/media-worker/cmd/worker/main.go +++ b/workers/media-worker/cmd/worker/main.go @@ -217,9 +217,19 @@ func main() { handler.RegisterHandler("generate_text", handlers.TextHandler(textgenManager, ssePub, logger)) handler.RegisterHandler("ai_chat_response", handlers.ChatResponseHandler(textgenManager, ssePub, logger)) } - // Persona generation requires both textgen (5-stage LLM pipeline) and mediagen (20 images + 4 videos). + // Staged persona generation pipeline: spec → anchor → avatar/banner/gallery_batch → video. + // Requires textgen (LLM spec), mediagen (images + video), and persona store (DB persistence). if textgenManager != nil && mediagenManager != nil { - handler.RegisterHandler("persona_generate", personagen.QueueHandler(textgenManager, mediagenManager, mediaStore, ssePub, logger.Logger)) + personaStore := personagen.NewPostgresPersonaStore(pool.DB) + handler.RegisterHandler("persona_generate", handlers.PersonaGenerateHandler(personagen.PipelineDeps{ + TextGen: textgenManager, + MediaGen: mediagenManager, + Store: mediaStore, + Pub: ssePub, + Personas: personaStore, + Queue: jobQueue, + Logger: logger.Logger, + })) } // Setup signal handling diff --git a/workers/media-worker/internal/handlers/generate.go b/workers/media-worker/internal/handlers/generate.go index 7117661..64b9351 100644 --- a/workers/media-worker/internal/handlers/generate.go +++ b/workers/media-worker/internal/handlers/generate.go @@ -6,6 +6,7 @@ import ( "git.threesix.ai/jordan/persona-community-5/pkg/generation" "git.threesix.ai/jordan/persona-community-5/pkg/logging" "git.threesix.ai/jordan/persona-community-5/pkg/mediagen" + "git.threesix.ai/jordan/persona-community-5/pkg/personagen" "git.threesix.ai/jordan/persona-community-5/pkg/queue" "git.threesix.ai/jordan/persona-community-5/pkg/realtime" "git.threesix.ai/jordan/persona-community-5/pkg/storage" @@ -31,3 +32,10 @@ func TextHandler(tg *textgen.Manager, pub realtime.EventPublisher, logger *loggi func ChatResponseHandler(tg *textgen.Manager, pub realtime.EventPublisher, logger *logging.Logger) queue.Handler { return generation.ChatResponseHandler(tg, pub, logger) } + +// PersonaGenerateHandler returns a queue.Handler for the staged persona generation pipeline. +// Each job carries a stage (spec, anchor, avatar, banner, gallery_batch, video) and processes +// one unit of work, updating the persona row and publishing SSE events after each stage. +func PersonaGenerateHandler(deps personagen.PipelineDeps) queue.Handler { + return personagen.StagedQueueHandler(deps) +}