build: /implement-feature persona-generation --requirements 'Implement the g...
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed
This commit is contained in:
parent
9c009926d1
commit
66ceb7e55f
617
pkg/personagen/pipeline.go
Normal file
617
pkg/personagen/pipeline.go
Normal file
@ -0,0 +1,617 @@
|
||||
package personagen
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.threesix.ai/jordan/persona-community-5/pkg/mediagen"
|
||||
"git.threesix.ai/jordan/persona-community-5/pkg/persona"
|
||||
"git.threesix.ai/jordan/persona-community-5/pkg/queue"
|
||||
"git.threesix.ai/jordan/persona-community-5/pkg/realtime"
|
||||
"git.threesix.ai/jordan/persona-community-5/pkg/storage"
|
||||
"git.threesix.ai/jordan/persona-community-5/pkg/textgen"
|
||||
)
|
||||
|
||||
// anchorFetchClient is used to download anchor images from storage URLs.
|
||||
var anchorFetchClient = &http.Client{Timeout: 30 * time.Second}
|
||||
|
||||
// Stage constants matching domain.PersonaStage in the API service.
|
||||
const (
|
||||
StageSpec = "spec"
|
||||
StageAnchor = "anchor"
|
||||
StageAvatar = "avatar"
|
||||
StageBanner = "banner"
|
||||
StageGalleryBatch = "gallery_batch"
|
||||
StageVideo = "video"
|
||||
)
|
||||
|
||||
// Video motion types in generation order.
|
||||
var videoOrder = []persona.MotionType{
|
||||
persona.MotionSmileReveal,
|
||||
persona.MotionPersonality,
|
||||
persona.MotionLifestyle,
|
||||
persona.MotionInvitation,
|
||||
}
|
||||
|
||||
// PipelineDeps holds dependencies for the staged pipeline handler.
|
||||
type PipelineDeps struct {
|
||||
TextGen *textgen.Manager
|
||||
MediaGen *mediagen.Manager
|
||||
Store storage.Store
|
||||
Pub realtime.EventPublisher
|
||||
Personas PersonaStore
|
||||
Queue queue.Producer
|
||||
Logger *slog.Logger
|
||||
}
|
||||
|
||||
// StagedQueueHandler returns a queue.Handler for processing persona_generate jobs
|
||||
// using a stage-based pipeline. Each stage processes one unit of work, updates the
|
||||
// persona row, publishes a job_update SSE event, and enqueues the next stage.
|
||||
//
|
||||
// Job payload: {"persona_id": "...", "stage": "spec|anchor|avatar|banner|gallery_batch|video"}
|
||||
func StagedQueueHandler(deps PipelineDeps) queue.Handler {
|
||||
return func(ctx context.Context, job *queue.Job) error {
|
||||
personaID, _ := job.Payload["persona_id"].(string)
|
||||
if personaID == "" {
|
||||
return fmt.Errorf("missing persona_id in persona_generate job payload")
|
||||
}
|
||||
stage, _ := job.Payload["stage"].(string)
|
||||
if stage == "" {
|
||||
return fmt.Errorf("missing stage in persona_generate job payload")
|
||||
}
|
||||
|
||||
logger := deps.Logger.With("job_id", job.ID, "persona_id", personaID, "stage", stage)
|
||||
logger.Info("processing persona generation stage")
|
||||
|
||||
var err error
|
||||
switch stage {
|
||||
case StageSpec:
|
||||
err = handleStageSpec(ctx, deps, personaID, job.ID, logger)
|
||||
case StageAnchor:
|
||||
err = handleStageAnchor(ctx, deps, personaID, job.ID, logger)
|
||||
case StageAvatar:
|
||||
err = handleStageAvatar(ctx, deps, personaID, job.ID, logger)
|
||||
case StageBanner:
|
||||
err = handleStageBanner(ctx, deps, personaID, job.ID, logger)
|
||||
case StageGalleryBatch:
|
||||
err = handleStageGalleryBatch(ctx, deps, personaID, job.ID, logger)
|
||||
case StageVideo:
|
||||
err = handleStageVideo(ctx, deps, personaID, job.ID, logger)
|
||||
default:
|
||||
err = fmt.Errorf("unknown stage: %s", stage)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.Error("stage failed", "error", err)
|
||||
publishJobUpdate(deps.Pub, personaID, stage, "error", 0, err.Error())
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// handleStageSpec generates the persona spec via LLM and updates the persona row.
|
||||
func handleStageSpec(ctx context.Context, deps PipelineDeps, personaID, jobID string, logger *slog.Logger) error {
|
||||
rec, err := deps.Personas.GetByID(ctx, personaID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load persona: %w", err)
|
||||
}
|
||||
|
||||
// Update status to generating.
|
||||
rec.Status = "generating"
|
||||
if err := deps.Personas.Update(ctx, rec); err != nil {
|
||||
return fmt.Errorf("update persona status: %w", err)
|
||||
}
|
||||
|
||||
svc := New(deps.TextGen, deps.MediaGen, deps.Store, logger)
|
||||
seed := SeedParams{
|
||||
Description: rec.Description,
|
||||
Gender: rec.Gender,
|
||||
Name: rec.Name,
|
||||
}
|
||||
|
||||
spec, err := svc.GenerateSpec(ctx, seed)
|
||||
if err != nil {
|
||||
rec.Status = "failed"
|
||||
_ = deps.Personas.Update(ctx, rec)
|
||||
return fmt.Errorf("spec generation: %w", err)
|
||||
}
|
||||
|
||||
// Serialize full spec to JSON for storage.
|
||||
specJSON, err := json.Marshal(spec)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal spec: %w", err)
|
||||
}
|
||||
|
||||
// Update persona with spec data.
|
||||
rec.SpecJSON = specJSON
|
||||
rec.Name = spec.Name.First
|
||||
if spec.Name.Last != "" {
|
||||
rec.Name = spec.Name.First + " " + spec.Name.Last
|
||||
}
|
||||
rec.Handle = generateHandleFromSpec(spec)
|
||||
rec.Tags = extractTags(spec)
|
||||
|
||||
if err := deps.Personas.Update(ctx, rec); err != nil {
|
||||
return fmt.Errorf("update persona spec: %w", err)
|
||||
}
|
||||
|
||||
publishJobUpdate(deps.Pub, personaID, StageSpec, "complete", 100, "")
|
||||
|
||||
// Enqueue anchor stage.
|
||||
if _, err := deps.Queue.Enqueue(ctx, "persona_generate", map[string]any{
|
||||
"persona_id": personaID,
|
||||
"stage": StageAnchor,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("enqueue anchor stage: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("spec stage complete, anchor enqueued")
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleStageAnchor generates the anchor (position 1) image.
|
||||
func handleStageAnchor(ctx context.Context, deps PipelineDeps, personaID, jobID string, logger *slog.Logger) error {
|
||||
rec, err := deps.Personas.GetByID(ctx, personaID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load persona: %w", err)
|
||||
}
|
||||
|
||||
spec, err := unmarshalSpec(rec.SpecJSON)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
svc := New(deps.TextGen, deps.MediaGen, deps.Store, logger)
|
||||
|
||||
// Generate position 1 (the identity anchor).
|
||||
if err := svc.generatePosition(ctx, spec, 1); err != nil {
|
||||
return fmt.Errorf("anchor generation: %w", err)
|
||||
}
|
||||
|
||||
// Find position 1 URL.
|
||||
var anchorURL string
|
||||
for _, img := range spec.ImageMatrix {
|
||||
if img.Position == 1 {
|
||||
anchorURL = img.URL
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
rec.AnchorURL = anchorURL
|
||||
rec.ImageURLs = appendIfMissing(rec.ImageURLs, anchorURL)
|
||||
|
||||
// Update spec JSON with the generated image URL/status.
|
||||
updatedSpec, _ := json.Marshal(spec)
|
||||
rec.SpecJSON = updatedSpec
|
||||
|
||||
if err := deps.Personas.Update(ctx, rec); err != nil {
|
||||
return fmt.Errorf("update persona anchor: %w", err)
|
||||
}
|
||||
|
||||
publishJobUpdate(deps.Pub, personaID, StageAnchor, "complete", 100, "")
|
||||
|
||||
// Enqueue avatar, banner, and gallery_batch in parallel.
|
||||
for _, nextStage := range []string{StageAvatar, StageBanner, StageGalleryBatch} {
|
||||
if _, err := deps.Queue.Enqueue(ctx, "persona_generate", map[string]any{
|
||||
"persona_id": personaID,
|
||||
"stage": nextStage,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("enqueue %s stage: %w", nextStage, err)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("anchor stage complete, avatar/banner/gallery_batch enqueued")
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleStageAvatar generates the circular portrait avatar.
|
||||
func handleStageAvatar(ctx context.Context, deps PipelineDeps, personaID, jobID string, logger *slog.Logger) error {
|
||||
rec, err := deps.Personas.GetByID(ctx, personaID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load persona: %w", err)
|
||||
}
|
||||
|
||||
spec, err := unmarshalSpec(rec.SpecJSON)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
svc := New(deps.TextGen, deps.MediaGen, deps.Store, logger)
|
||||
|
||||
// Restore anchor from storage for identity consistency.
|
||||
if err := restoreAnchor(ctx, svc, deps.Store, spec); err != nil {
|
||||
logger.Warn("could not restore anchor for avatar", "error", err)
|
||||
}
|
||||
|
||||
avatarBytes, err := svc.GenerateAvatar(ctx, spec)
|
||||
if err != nil {
|
||||
return fmt.Errorf("avatar generation: %w", err)
|
||||
}
|
||||
|
||||
storagePath := fmt.Sprintf("personas/%s/avatar.png", spec.ID)
|
||||
url, err := deps.Store.Upload(ctx, storagePath, avatarBytes, "image/png")
|
||||
if err != nil {
|
||||
return fmt.Errorf("upload avatar: %w", err)
|
||||
}
|
||||
|
||||
rec.AvatarURL = url
|
||||
if err := deps.Personas.Update(ctx, rec); err != nil {
|
||||
return fmt.Errorf("update persona avatar: %w", err)
|
||||
}
|
||||
|
||||
publishJobUpdate(deps.Pub, personaID, StageAvatar, "complete", 100, "")
|
||||
logger.Info("avatar stage complete")
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleStageBanner generates the 3:1 landscape banner image.
|
||||
func handleStageBanner(ctx context.Context, deps PipelineDeps, personaID, jobID string, logger *slog.Logger) error {
|
||||
rec, err := deps.Personas.GetByID(ctx, personaID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load persona: %w", err)
|
||||
}
|
||||
|
||||
spec, err := unmarshalSpec(rec.SpecJSON)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
svc := New(deps.TextGen, deps.MediaGen, deps.Store, logger)
|
||||
|
||||
// Restore anchor for identity consistency.
|
||||
if err := restoreAnchor(ctx, svc, deps.Store, spec); err != nil {
|
||||
logger.Warn("could not restore anchor for banner", "error", err)
|
||||
}
|
||||
|
||||
bannerBytes, err := svc.GenerateBanner(ctx, spec, "lifestyle")
|
||||
if err != nil {
|
||||
return fmt.Errorf("banner generation: %w", err)
|
||||
}
|
||||
|
||||
storagePath := fmt.Sprintf("personas/%s/banner.png", spec.ID)
|
||||
url, err := deps.Store.Upload(ctx, storagePath, bannerBytes, "image/png")
|
||||
if err != nil {
|
||||
return fmt.Errorf("upload banner: %w", err)
|
||||
}
|
||||
|
||||
rec.BannerURL = url
|
||||
if err := deps.Personas.Update(ctx, rec); err != nil {
|
||||
return fmt.Errorf("update persona banner: %w", err)
|
||||
}
|
||||
|
||||
publishJobUpdate(deps.Pub, personaID, StageBanner, "complete", 100, "")
|
||||
logger.Info("banner stage complete")
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleStageGalleryBatch generates the next 10 unfilled image positions.
|
||||
func handleStageGalleryBatch(ctx context.Context, deps PipelineDeps, personaID, jobID string, logger *slog.Logger) error {
|
||||
rec, err := deps.Personas.GetByID(ctx, personaID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load persona: %w", err)
|
||||
}
|
||||
|
||||
spec, err := unmarshalSpec(rec.SpecJSON)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
svc := New(deps.TextGen, deps.MediaGen, deps.Store, logger)
|
||||
|
||||
// Restore anchor for identity consistency.
|
||||
if err := restoreAnchor(ctx, svc, deps.Store, spec); err != nil {
|
||||
logger.Warn("could not restore anchor for gallery batch", "error", err)
|
||||
}
|
||||
|
||||
// Find unfilled positions (skip position 1 which is the anchor, already done).
|
||||
var unfilled []int
|
||||
for _, img := range spec.ImageMatrix {
|
||||
if img.Status != persona.ImageStatusComplete && img.Position != 1 {
|
||||
unfilled = append(unfilled, img.Position)
|
||||
}
|
||||
}
|
||||
|
||||
// Take up to 10 positions for this batch.
|
||||
batchSize := 10
|
||||
if len(unfilled) < batchSize {
|
||||
batchSize = len(unfilled)
|
||||
}
|
||||
batch := unfilled[:batchSize]
|
||||
|
||||
for _, pos := range batch {
|
||||
if err := svc.generatePosition(ctx, spec, pos); err != nil {
|
||||
logger.Error("gallery image generation failed", "position", pos, "error", err)
|
||||
return fmt.Errorf("gallery position %d: %w", pos, err)
|
||||
}
|
||||
|
||||
// Find the URL for this position.
|
||||
var imgURL string
|
||||
for _, img := range spec.ImageMatrix {
|
||||
if img.Position == pos {
|
||||
imgURL = img.URL
|
||||
break
|
||||
}
|
||||
}
|
||||
rec.ImageURLs = appendIfMissing(rec.ImageURLs, imgURL)
|
||||
}
|
||||
|
||||
// Update spec JSON with new image URLs/statuses.
|
||||
updatedSpec, _ := json.Marshal(spec)
|
||||
rec.SpecJSON = updatedSpec
|
||||
|
||||
if err := deps.Personas.Update(ctx, rec); err != nil {
|
||||
return fmt.Errorf("update persona gallery: %w", err)
|
||||
}
|
||||
|
||||
// Count total completed images (including anchor).
|
||||
completedImages := countCompletedImages(spec)
|
||||
totalImages := len(spec.ImageMatrix)
|
||||
progress := (completedImages * 100) / totalImages
|
||||
|
||||
publishJobUpdate(deps.Pub, personaID, StageGalleryBatch, "complete", progress, "")
|
||||
|
||||
// If fewer than totalImages are complete, enqueue another gallery_batch.
|
||||
if completedImages < totalImages {
|
||||
if _, err := deps.Queue.Enqueue(ctx, "persona_generate", map[string]any{
|
||||
"persona_id": personaID,
|
||||
"stage": StageGalleryBatch,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("enqueue next gallery batch: %w", err)
|
||||
}
|
||||
logger.Info("gallery batch complete, more images needed", "completed", completedImages, "total", totalImages)
|
||||
} else {
|
||||
// All images done, enqueue first video.
|
||||
if _, err := deps.Queue.Enqueue(ctx, "persona_generate", map[string]any{
|
||||
"persona_id": personaID,
|
||||
"stage": StageVideo,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("enqueue video stage: %w", err)
|
||||
}
|
||||
logger.Info("all gallery images complete, video stage enqueued")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleStageVideo generates the next missing video.
|
||||
func handleStageVideo(ctx context.Context, deps PipelineDeps, personaID, jobID string, logger *slog.Logger) error {
|
||||
rec, err := deps.Personas.GetByID(ctx, personaID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load persona: %w", err)
|
||||
}
|
||||
|
||||
spec, err := unmarshalSpec(rec.SpecJSON)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
svc := New(deps.TextGen, deps.MediaGen, deps.Store, logger)
|
||||
|
||||
// Restore anchor for identity consistency.
|
||||
if err := restoreAnchor(ctx, svc, deps.Store, spec); err != nil {
|
||||
logger.Warn("could not restore anchor for video", "error", err)
|
||||
}
|
||||
|
||||
// Find the next missing video in order.
|
||||
completedVideos := len(rec.VideoURLs)
|
||||
if completedVideos >= len(videoOrder) {
|
||||
// All videos done — mark complete.
|
||||
rec.Status = "complete"
|
||||
if err := deps.Personas.Update(ctx, rec); err != nil {
|
||||
return fmt.Errorf("update persona complete: %w", err)
|
||||
}
|
||||
publishJobUpdate(deps.Pub, personaID, StageVideo, "complete", 100, "")
|
||||
logger.Info("all videos complete, persona generation finished")
|
||||
return nil
|
||||
}
|
||||
|
||||
motionType := videoOrder[completedVideos]
|
||||
logger.Info("generating video", "motion_type", motionType, "index", completedVideos)
|
||||
|
||||
videoSpec, err := svc.GenerateVideo(ctx, spec, motionType)
|
||||
if err != nil {
|
||||
// Videos are best-effort — log the failure and continue to next.
|
||||
logger.Warn("video generation failed (non-fatal)", "motion_type", motionType, "error", err)
|
||||
publishJobUpdate(deps.Pub, personaID, StageVideo, "error", 0, fmt.Sprintf("%s video failed: %s", motionType, err.Error()))
|
||||
} else {
|
||||
rec.VideoURLs = append(rec.VideoURLs, videoSpec.URL)
|
||||
}
|
||||
|
||||
// Update spec JSON.
|
||||
updatedSpec, _ := json.Marshal(spec)
|
||||
rec.SpecJSON = updatedSpec
|
||||
|
||||
if err := deps.Personas.Update(ctx, rec); err != nil {
|
||||
return fmt.Errorf("update persona video: %w", err)
|
||||
}
|
||||
|
||||
// Check if more videos needed (count by index, not by success).
|
||||
nextIndex := completedVideos + 1
|
||||
if nextIndex < len(videoOrder) {
|
||||
if _, err := deps.Queue.Enqueue(ctx, "persona_generate", map[string]any{
|
||||
"persona_id": personaID,
|
||||
"stage": StageVideo,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("enqueue next video: %w", err)
|
||||
}
|
||||
progress := (nextIndex * 100) / len(videoOrder)
|
||||
publishJobUpdate(deps.Pub, personaID, StageVideo, "complete", progress, "")
|
||||
logger.Info("video stage complete, next video enqueued", "completed", nextIndex, "total", len(videoOrder))
|
||||
} else {
|
||||
// All videos done — mark persona complete.
|
||||
rec.Status = "complete"
|
||||
if err := deps.Personas.Update(ctx, rec); err != nil {
|
||||
return fmt.Errorf("update persona complete: %w", err)
|
||||
}
|
||||
publishJobUpdate(deps.Pub, personaID, StageVideo, "complete", 100, "")
|
||||
logger.Info("all videos complete, persona generation finished")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// publishJobUpdate sends a job_update SSE event to channel:personas.
|
||||
func publishJobUpdate(pub realtime.EventPublisher, personaID, stage, status string, progress int, errMsg string) {
|
||||
event := &realtime.SSEEvent{
|
||||
Type: "job_update",
|
||||
Progress: progress,
|
||||
Result: map[string]any{
|
||||
"persona_id": personaID,
|
||||
"stage": stage,
|
||||
"status": status,
|
||||
"progress": progress,
|
||||
},
|
||||
}
|
||||
if errMsg != "" {
|
||||
event.Error = errMsg
|
||||
}
|
||||
if err := pub.SendToChannel("channel:personas", event); err != nil {
|
||||
slog.Warn("failed to publish job_update event", "error", err, "persona_id", personaID, "stage", stage)
|
||||
}
|
||||
}
|
||||
|
||||
// unmarshalSpec deserializes a PersonaSpec from JSON stored on the persona record.
|
||||
func unmarshalSpec(specJSON json.RawMessage) (*persona.PersonaSpec, error) {
|
||||
if len(specJSON) == 0 {
|
||||
return nil, fmt.Errorf("persona has no spec_json (spec stage may not have completed)")
|
||||
}
|
||||
var spec persona.PersonaSpec
|
||||
if err := json.Unmarshal(specJSON, &spec); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal spec: %w", err)
|
||||
}
|
||||
return &spec, nil
|
||||
}
|
||||
|
||||
// restoreAnchor downloads the anchor image (position 1) from storage and sets it
|
||||
// on the service for identity consistency in subsequent generations.
|
||||
func restoreAnchor(ctx context.Context, svc *Service, store storage.Store, spec *persona.PersonaSpec) error {
|
||||
if store == nil {
|
||||
return fmt.Errorf("no storage configured")
|
||||
}
|
||||
// Find position 1 URL from spec.
|
||||
var anchorURL string
|
||||
for _, img := range spec.ImageMatrix {
|
||||
if img.Position == 1 && img.URL != "" {
|
||||
anchorURL = img.URL
|
||||
break
|
||||
}
|
||||
}
|
||||
if anchorURL == "" {
|
||||
return fmt.Errorf("no anchor URL found in spec")
|
||||
}
|
||||
|
||||
// Download anchor image from storage URL.
|
||||
anchorBytes, err := storage.FetchURL(ctx, anchorFetchClient, anchorURL, 50<<20) // 50 MB max
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetch anchor image: %w", err)
|
||||
}
|
||||
svc.SetAnchor(anchorBytes)
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateHandleFromSpec creates a URL-safe handle from the spec name.
|
||||
func generateHandleFromSpec(spec *persona.PersonaSpec) string {
|
||||
name := spec.Name.First
|
||||
if spec.Name.Last != "" {
|
||||
name += " " + spec.Name.Last
|
||||
}
|
||||
return generateHandle(name)
|
||||
}
|
||||
|
||||
// generateHandle creates a URL-safe handle from a name (same logic as persona-api service).
|
||||
func generateHandle(name string) string {
|
||||
// Reuse the service's handle generation logic.
|
||||
h := name
|
||||
// Simplified: lowercase, replace non-alphanumeric with underscore, trim.
|
||||
result := make([]byte, 0, len(h))
|
||||
for _, c := range []byte(h) {
|
||||
if (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') {
|
||||
result = append(result, c)
|
||||
} else if c >= 'A' && c <= 'Z' {
|
||||
result = append(result, c+32) // lowercase
|
||||
} else {
|
||||
if len(result) > 0 && result[len(result)-1] != '_' {
|
||||
result = append(result, '_')
|
||||
}
|
||||
}
|
||||
}
|
||||
s := string(result)
|
||||
if len(s) > 40 {
|
||||
s = s[:40]
|
||||
}
|
||||
// Trim trailing underscores.
|
||||
for len(s) > 0 && s[len(s)-1] == '_' {
|
||||
s = s[:len(s)-1]
|
||||
}
|
||||
// Append timestamp suffix.
|
||||
suffix := fmt.Sprintf("_%d", now().UnixMilli()%100000)
|
||||
return s + suffix
|
||||
}
|
||||
|
||||
// extractTags extracts 8 tags from the spec for the persona row.
|
||||
func extractTags(spec *persona.PersonaSpec) []string {
|
||||
tags := make([]string, 0, 8)
|
||||
|
||||
// Add gender.
|
||||
if spec.DNA != nil {
|
||||
tags = append(tags, string(spec.DNA.Identity.Gender))
|
||||
tags = append(tags, string(spec.DNA.Identity.Ethnicity))
|
||||
}
|
||||
|
||||
// Add occupation from identity.
|
||||
if spec.Name.First != "" {
|
||||
tags = append(tags, spec.Name.First)
|
||||
}
|
||||
|
||||
// Add fashion context.
|
||||
if spec.Lifestyle.FashionSense.Primary != "" {
|
||||
tags = append(tags, string(spec.Lifestyle.FashionSense.Primary))
|
||||
}
|
||||
|
||||
// Add interests.
|
||||
for _, interest := range spec.Lifestyle.Interests.Creative {
|
||||
if len(tags) >= 8 {
|
||||
break
|
||||
}
|
||||
tags = append(tags, interest)
|
||||
}
|
||||
for _, interest := range spec.Lifestyle.Interests.Active {
|
||||
if len(tags) >= 8 {
|
||||
break
|
||||
}
|
||||
tags = append(tags, interest)
|
||||
}
|
||||
|
||||
if len(tags) > 8 {
|
||||
tags = tags[:8]
|
||||
}
|
||||
return tags
|
||||
}
|
||||
|
||||
// countCompletedImages counts how many images have status "complete" in the spec.
|
||||
func countCompletedImages(spec *persona.PersonaSpec) int {
|
||||
count := 0
|
||||
for _, img := range spec.ImageMatrix {
|
||||
if img.Status == persona.ImageStatusComplete {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// appendIfMissing appends a string to a slice if it's not already present and not empty.
|
||||
func appendIfMissing(slice []string, s string) []string {
|
||||
if s == "" {
|
||||
return slice
|
||||
}
|
||||
for _, existing := range slice {
|
||||
if existing == s {
|
||||
return slice
|
||||
}
|
||||
}
|
||||
return append(slice, s)
|
||||
}
|
||||
384
pkg/personagen/pipeline_test.go
Normal file
384
pkg/personagen/pipeline_test.go
Normal file
@ -0,0 +1,384 @@
|
||||
package personagen
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.threesix.ai/jordan/persona-community-5/pkg/persona"
|
||||
"git.threesix.ai/jordan/persona-community-5/pkg/queue"
|
||||
"git.threesix.ai/jordan/persona-community-5/pkg/realtime"
|
||||
)
|
||||
|
||||
// ── Mock implementations ─────────────────────────────────────────────────────
|
||||
|
||||
// mockPersonaStore implements PersonaStore for testing.
|
||||
type mockPersonaStore struct {
|
||||
records map[string]*PersonaRecord
|
||||
}
|
||||
|
||||
func newMockPersonaStore() *mockPersonaStore {
|
||||
return &mockPersonaStore{records: make(map[string]*PersonaRecord)}
|
||||
}
|
||||
|
||||
func (m *mockPersonaStore) GetByID(_ context.Context, id string) (*PersonaRecord, error) {
|
||||
rec, ok := m.records[id]
|
||||
if !ok {
|
||||
return nil, ErrPersonaNotFound
|
||||
}
|
||||
// Return a copy to avoid shared mutation.
|
||||
cp := *rec
|
||||
if rec.Tags != nil {
|
||||
cp.Tags = make([]string, len(rec.Tags))
|
||||
copy(cp.Tags, rec.Tags)
|
||||
}
|
||||
if rec.ImageURLs != nil {
|
||||
cp.ImageURLs = make([]string, len(rec.ImageURLs))
|
||||
copy(cp.ImageURLs, rec.ImageURLs)
|
||||
}
|
||||
if rec.VideoURLs != nil {
|
||||
cp.VideoURLs = make([]string, len(rec.VideoURLs))
|
||||
copy(cp.VideoURLs, rec.VideoURLs)
|
||||
}
|
||||
if rec.SpecJSON != nil {
|
||||
cp.SpecJSON = make(json.RawMessage, len(rec.SpecJSON))
|
||||
copy(cp.SpecJSON, rec.SpecJSON)
|
||||
}
|
||||
return &cp, nil
|
||||
}
|
||||
|
||||
func (m *mockPersonaStore) Update(_ context.Context, p *PersonaRecord) error {
|
||||
if _, ok := m.records[p.ID]; !ok {
|
||||
return ErrPersonaNotFound
|
||||
}
|
||||
m.records[p.ID] = p
|
||||
return nil
|
||||
}
|
||||
|
||||
// mockProducer records enqueued jobs.
|
||||
type mockProducer struct {
|
||||
jobs []enqueuedJob
|
||||
}
|
||||
|
||||
type enqueuedJob struct {
|
||||
jobType string
|
||||
payload map[string]any
|
||||
}
|
||||
|
||||
var _ queue.Producer = (*mockProducer)(nil)
|
||||
|
||||
func (m *mockProducer) Enqueue(_ context.Context, jobType string, payload map[string]any) (string, error) {
|
||||
m.jobs = append(m.jobs, enqueuedJob{jobType: jobType, payload: payload})
|
||||
return fmt.Sprintf("job-%d", len(m.jobs)), nil
|
||||
}
|
||||
|
||||
func (m *mockProducer) EnqueueWithOptions(_ context.Context, job queue.Job) (string, error) {
|
||||
m.jobs = append(m.jobs, enqueuedJob{jobType: job.Type, payload: job.Payload})
|
||||
return fmt.Sprintf("job-%d", len(m.jobs)), nil
|
||||
}
|
||||
|
||||
// mockPublisher records published events.
|
||||
type mockPublisher struct {
|
||||
userEvents []publishedEvent
|
||||
channelEvents []publishedEvent
|
||||
}
|
||||
|
||||
type publishedEvent struct {
|
||||
target string
|
||||
event *realtime.SSEEvent
|
||||
}
|
||||
|
||||
var _ realtime.EventPublisher = (*mockPublisher)(nil)
|
||||
|
||||
func (m *mockPublisher) SendToUser(userID string, event *realtime.SSEEvent) error {
|
||||
m.userEvents = append(m.userEvents, publishedEvent{target: userID, event: event})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockPublisher) SendToChannel(channel string, event *realtime.SSEEvent) error {
|
||||
m.channelEvents = append(m.channelEvents, publishedEvent{target: channel, event: event})
|
||||
return nil
|
||||
}
|
||||
|
||||
// ── Tests ────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestStagedQueueHandler_MissingPersonaID(t *testing.T) {
|
||||
handler := StagedQueueHandler(PipelineDeps{})
|
||||
|
||||
job := &queue.Job{
|
||||
ID: "test-job",
|
||||
Type: "persona_generate",
|
||||
Payload: map[string]any{"stage": "spec"},
|
||||
}
|
||||
|
||||
err := handler(context.Background(), job)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing persona_id")
|
||||
}
|
||||
if err.Error() != "missing persona_id in persona_generate job payload" {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStagedQueueHandler_MissingStage(t *testing.T) {
|
||||
handler := StagedQueueHandler(PipelineDeps{})
|
||||
|
||||
job := &queue.Job{
|
||||
ID: "test-job",
|
||||
Type: "persona_generate",
|
||||
Payload: map[string]any{"persona_id": "ps_test"},
|
||||
}
|
||||
|
||||
err := handler(context.Background(), job)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing stage")
|
||||
}
|
||||
if err.Error() != "missing stage in persona_generate job payload" {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStagedQueueHandler_UnknownStage(t *testing.T) {
|
||||
store := newMockPersonaStore()
|
||||
pub := &mockPublisher{}
|
||||
handler := StagedQueueHandler(PipelineDeps{
|
||||
Personas: store,
|
||||
Pub: pub,
|
||||
Logger: noopLogger(),
|
||||
})
|
||||
|
||||
job := &queue.Job{
|
||||
ID: "test-job",
|
||||
Type: "persona_generate",
|
||||
Payload: map[string]any{
|
||||
"persona_id": "ps_test",
|
||||
"stage": "unknown_stage",
|
||||
},
|
||||
}
|
||||
|
||||
err := handler(context.Background(), job)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown stage")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStagedQueueHandler_SpecStage_PersonaNotFound(t *testing.T) {
|
||||
store := newMockPersonaStore()
|
||||
pub := &mockPublisher{}
|
||||
handler := StagedQueueHandler(PipelineDeps{
|
||||
Personas: store,
|
||||
Pub: pub,
|
||||
Logger: noopLogger(),
|
||||
})
|
||||
|
||||
job := &queue.Job{
|
||||
ID: "test-job",
|
||||
Type: "persona_generate",
|
||||
Payload: map[string]any{
|
||||
"persona_id": "ps_nonexistent",
|
||||
"stage": StageSpec,
|
||||
},
|
||||
}
|
||||
|
||||
err := handler(context.Background(), job)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for persona not found")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractTags(t *testing.T) {
|
||||
spec := &persona.PersonaSpec{
|
||||
DNA: &persona.DNA{
|
||||
Identity: persona.IdentityDNA{
|
||||
Gender: persona.GenderWoman,
|
||||
Ethnicity: persona.EthnicityEastAsian,
|
||||
},
|
||||
},
|
||||
Name: persona.NameSpec{First: "Luna"},
|
||||
Lifestyle: persona.Lifestyle{
|
||||
FashionSense: persona.FashionSense{
|
||||
Primary: persona.FashionClassicMinimalist,
|
||||
},
|
||||
Interests: persona.Interests{
|
||||
Creative: []string{"photography", "painting"},
|
||||
Active: []string{"yoga", "hiking"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tags := extractTags(spec)
|
||||
if len(tags) > 8 {
|
||||
t.Errorf("expected at most 8 tags, got %d", len(tags))
|
||||
}
|
||||
if len(tags) == 0 {
|
||||
t.Error("expected at least some tags")
|
||||
}
|
||||
|
||||
// Should contain gender.
|
||||
found := false
|
||||
for _, tag := range tags {
|
||||
if tag == "woman" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("expected 'woman' tag")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCountCompletedImages(t *testing.T) {
|
||||
spec := &persona.PersonaSpec{
|
||||
ImageMatrix: []persona.ImageSpec{
|
||||
{Position: 1, Status: persona.ImageStatusComplete},
|
||||
{Position: 2, Status: persona.ImageStatusComplete},
|
||||
{Position: 3, Status: persona.ImageStatusPending},
|
||||
{Position: 4, Status: persona.ImageStatusFailed},
|
||||
},
|
||||
}
|
||||
|
||||
count := countCompletedImages(spec)
|
||||
if count != 2 {
|
||||
t.Errorf("expected 2 completed, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendIfMissing(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
slice []string
|
||||
s string
|
||||
expected int
|
||||
}{
|
||||
{"adds new", []string{"a"}, "b", 2},
|
||||
{"skips duplicate", []string{"a", "b"}, "a", 2},
|
||||
{"skips empty", []string{"a"}, "", 1},
|
||||
{"adds to nil", nil, "a", 1},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := appendIfMissing(tt.slice, tt.s)
|
||||
if len(result) != tt.expected {
|
||||
t.Errorf("expected len %d, got %d", tt.expected, len(result))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateHandle(t *testing.T) {
|
||||
// Override time for deterministic output.
|
||||
origNow := now
|
||||
now = func() time.Time { return time.UnixMilli(1234567890) }
|
||||
defer func() { now = origNow }()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
}{
|
||||
{"simple name", "Luna Shadow"},
|
||||
{"special chars", "DJ Beats!@#"},
|
||||
{"unicode", "Café Noir"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
handle := generateHandle(tt.input)
|
||||
if handle == "" {
|
||||
t.Error("expected non-empty handle")
|
||||
}
|
||||
if len(handle) > 50 {
|
||||
t.Errorf("handle too long: %d chars", len(handle))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishJobUpdate(t *testing.T) {
|
||||
pub := &mockPublisher{}
|
||||
|
||||
publishJobUpdate(pub, "ps_123", StageSpec, "complete", 100, "")
|
||||
|
||||
if len(pub.channelEvents) != 1 {
|
||||
t.Fatalf("expected 1 channel event, got %d", len(pub.channelEvents))
|
||||
}
|
||||
|
||||
event := pub.channelEvents[0]
|
||||
if event.target != "channel:personas" {
|
||||
t.Errorf("expected channel 'channel:personas', got '%s'", event.target)
|
||||
}
|
||||
if event.event.Type != "job_update" {
|
||||
t.Errorf("expected event type 'job_update', got '%s'", event.event.Type)
|
||||
}
|
||||
if event.event.Progress != 100 {
|
||||
t.Errorf("expected progress 100, got %d", event.event.Progress)
|
||||
}
|
||||
|
||||
result, ok := event.event.Result.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("expected Result to be map")
|
||||
}
|
||||
if result["persona_id"] != "ps_123" {
|
||||
t.Errorf("expected persona_id 'ps_123', got %v", result["persona_id"])
|
||||
}
|
||||
if result["stage"] != StageSpec {
|
||||
t.Errorf("expected stage 'spec', got %v", result["stage"])
|
||||
}
|
||||
if result["status"] != "complete" {
|
||||
t.Errorf("expected status 'complete', got %v", result["status"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishJobUpdate_WithError(t *testing.T) {
|
||||
pub := &mockPublisher{}
|
||||
|
||||
publishJobUpdate(pub, "ps_123", StageSpec, "error", 0, "something went wrong")
|
||||
|
||||
if len(pub.channelEvents) != 1 {
|
||||
t.Fatalf("expected 1 channel event, got %d", len(pub.channelEvents))
|
||||
}
|
||||
|
||||
event := pub.channelEvents[0]
|
||||
if event.event.Error != "something went wrong" {
|
||||
t.Errorf("expected error message, got '%s'", event.event.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalSpec_EmptyJSON(t *testing.T) {
|
||||
_, err := unmarshalSpec(nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for nil spec JSON")
|
||||
}
|
||||
|
||||
_, err = unmarshalSpec(json.RawMessage{})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty spec JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalSpec_ValidJSON(t *testing.T) {
|
||||
spec := &persona.PersonaSpec{
|
||||
ID: "ps_test",
|
||||
Name: persona.NameSpec{First: "Luna", Last: "Shadow"},
|
||||
}
|
||||
data, _ := json.Marshal(spec)
|
||||
|
||||
result, err := unmarshalSpec(data)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result.ID != "ps_test" {
|
||||
t.Errorf("expected ID 'ps_test', got '%s'", result.ID)
|
||||
}
|
||||
if result.Name.First != "Luna" {
|
||||
t.Errorf("expected first name 'Luna', got '%s'", result.Name.First)
|
||||
}
|
||||
}
|
||||
|
||||
// noopLogger returns a discard logger for tests.
|
||||
func noopLogger() *slog.Logger {
|
||||
return slog.Default() // In test context, logs go to stderr which is fine
|
||||
}
|
||||
36
pkg/personagen/store.go
Normal file
36
pkg/personagen/store.go
Normal file
@ -0,0 +1,36 @@
|
||||
package personagen
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
// PersonaRecord is a minimal representation of a persona row for the worker pipeline.
|
||||
// It mirrors the database columns needed for stage-based generation without
|
||||
// importing service-internal domain types.
|
||||
type PersonaRecord struct {
|
||||
ID string `json:"id" db:"id"`
|
||||
Name string `json:"name" db:"name"`
|
||||
Handle string `json:"handle" db:"handle"`
|
||||
Gender string `json:"gender" db:"gender"`
|
||||
Description string `json:"description" db:"description"`
|
||||
Tags []string `json:"tags" db:"tags"`
|
||||
SpecJSON json.RawMessage `json:"spec_json,omitempty" db:"spec_json"`
|
||||
AnchorURL string `json:"anchor_url,omitempty" db:"anchor_url"`
|
||||
AvatarURL string `json:"avatar_url,omitempty" db:"avatar_url"`
|
||||
BannerURL string `json:"banner_url,omitempty" db:"banner_url"`
|
||||
ImageURLs []string `json:"image_urls" db:"image_urls"`
|
||||
VideoURLs []string `json:"video_urls" db:"video_urls"`
|
||||
Status string `json:"status" db:"status"`
|
||||
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
||||
}
|
||||
|
||||
// PersonaStore provides read/write access to persona records for the generation pipeline.
|
||||
type PersonaStore interface {
|
||||
// GetByID returns a persona by ID.
|
||||
GetByID(ctx context.Context, id string) (*PersonaRecord, error)
|
||||
|
||||
// Update persists changes to an existing persona record.
|
||||
Update(ctx context.Context, persona *PersonaRecord) error
|
||||
}
|
||||
146
pkg/personagen/store_postgres.go
Normal file
146
pkg/personagen/store_postgres.go
Normal file
@ -0,0 +1,146 @@
|
||||
package personagen
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
// ErrPersonaNotFound is returned when a persona ID does not exist.
|
||||
var ErrPersonaNotFound = errors.New("persona not found")
|
||||
|
||||
// personaRow is the database scan target for persona queries.
|
||||
type personaRow struct {
|
||||
ID string `db:"id"`
|
||||
Name string `db:"name"`
|
||||
Handle string `db:"handle"`
|
||||
Gender string `db:"gender"`
|
||||
Description string `db:"description"`
|
||||
Tags pq.StringArray `db:"tags"`
|
||||
SpecJSON []byte `db:"spec_json"`
|
||||
AnchorURL *string `db:"anchor_url"`
|
||||
AvatarURL *string `db:"avatar_url"`
|
||||
BannerURL *string `db:"banner_url"`
|
||||
ImageURLs pq.StringArray `db:"image_urls"`
|
||||
VideoURLs pq.StringArray `db:"video_urls"`
|
||||
Status string `db:"status"`
|
||||
}
|
||||
|
||||
func (r *personaRow) toRecord() *PersonaRecord {
|
||||
rec := &PersonaRecord{
|
||||
ID: r.ID,
|
||||
Name: r.Name,
|
||||
Handle: r.Handle,
|
||||
Gender: r.Gender,
|
||||
Description: r.Description,
|
||||
Tags: []string(r.Tags),
|
||||
Status: r.Status,
|
||||
ImageURLs: []string(r.ImageURLs),
|
||||
VideoURLs: []string(r.VideoURLs),
|
||||
}
|
||||
if rec.Tags == nil {
|
||||
rec.Tags = []string{}
|
||||
}
|
||||
if rec.ImageURLs == nil {
|
||||
rec.ImageURLs = []string{}
|
||||
}
|
||||
if rec.VideoURLs == nil {
|
||||
rec.VideoURLs = []string{}
|
||||
}
|
||||
if len(r.SpecJSON) > 0 {
|
||||
rec.SpecJSON = r.SpecJSON
|
||||
}
|
||||
if r.AnchorURL != nil {
|
||||
rec.AnchorURL = *r.AnchorURL
|
||||
}
|
||||
if r.AvatarURL != nil {
|
||||
rec.AvatarURL = *r.AvatarURL
|
||||
}
|
||||
if r.BannerURL != nil {
|
||||
rec.BannerURL = *r.BannerURL
|
||||
}
|
||||
return rec
|
||||
}
|
||||
|
||||
// PostgresPersonaStore implements PersonaStore using PostgreSQL/CockroachDB.
|
||||
type PostgresPersonaStore struct {
|
||||
db *sqlx.DB
|
||||
}
|
||||
|
||||
// Compile-time interface check.
|
||||
var _ PersonaStore = (*PostgresPersonaStore)(nil)
|
||||
|
||||
// NewPostgresPersonaStore creates a PersonaStore backed by a SQL database.
|
||||
func NewPostgresPersonaStore(db *sqlx.DB) *PostgresPersonaStore {
|
||||
return &PostgresPersonaStore{db: db}
|
||||
}
|
||||
|
||||
func (s *PostgresPersonaStore) GetByID(ctx context.Context, id string) (*PersonaRecord, error) {
|
||||
var row personaRow
|
||||
err := s.db.QueryRowxContext(ctx, `
|
||||
SELECT id, name, handle, gender, description, tags, spec_json,
|
||||
anchor_url, avatar_url, banner_url, image_urls, video_urls, status
|
||||
FROM personas WHERE id = $1
|
||||
`, id).StructScan(&row)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrPersonaNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get persona: %w", err)
|
||||
}
|
||||
return row.toRecord(), nil
|
||||
}
|
||||
|
||||
func (s *PostgresPersonaStore) Update(ctx context.Context, p *PersonaRecord) error {
|
||||
var specJSON []byte
|
||||
if p.SpecJSON != nil {
|
||||
specJSON = []byte(p.SpecJSON)
|
||||
}
|
||||
|
||||
var anchorURL, avatarURL, bannerURL *string
|
||||
if p.AnchorURL != "" {
|
||||
anchorURL = &p.AnchorURL
|
||||
}
|
||||
if p.AvatarURL != "" {
|
||||
avatarURL = &p.AvatarURL
|
||||
}
|
||||
if p.BannerURL != "" {
|
||||
bannerURL = &p.BannerURL
|
||||
}
|
||||
|
||||
result, err := s.db.ExecContext(ctx, `
|
||||
UPDATE personas
|
||||
SET name = $2, handle = $3, tags = $4, spec_json = $5,
|
||||
anchor_url = $6, avatar_url = $7, banner_url = $8,
|
||||
image_urls = $9, video_urls = $10, status = $11
|
||||
WHERE id = $1
|
||||
`,
|
||||
p.ID,
|
||||
p.Name,
|
||||
p.Handle,
|
||||
pq.StringArray(p.Tags),
|
||||
specJSON,
|
||||
anchorURL,
|
||||
avatarURL,
|
||||
bannerURL,
|
||||
pq.StringArray(p.ImageURLs),
|
||||
pq.StringArray(p.VideoURLs),
|
||||
p.Status,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update persona: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("update persona rows affected: %w", err)
|
||||
}
|
||||
if rows == 0 {
|
||||
return ErrPersonaNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -217,9 +217,19 @@ func main() {
|
||||
handler.RegisterHandler("generate_text", handlers.TextHandler(textgenManager, ssePub, logger))
|
||||
handler.RegisterHandler("ai_chat_response", handlers.ChatResponseHandler(textgenManager, ssePub, logger))
|
||||
}
|
||||
// Persona generation requires both textgen (5-stage LLM pipeline) and mediagen (20 images + 4 videos).
|
||||
// Staged persona generation pipeline: spec → anchor → avatar/banner/gallery_batch → video.
|
||||
// Requires textgen (LLM spec), mediagen (images + video), and persona store (DB persistence).
|
||||
if textgenManager != nil && mediagenManager != nil {
|
||||
handler.RegisterHandler("persona_generate", personagen.QueueHandler(textgenManager, mediagenManager, mediaStore, ssePub, logger.Logger))
|
||||
personaStore := personagen.NewPostgresPersonaStore(pool.DB)
|
||||
handler.RegisterHandler("persona_generate", handlers.PersonaGenerateHandler(personagen.PipelineDeps{
|
||||
TextGen: textgenManager,
|
||||
MediaGen: mediagenManager,
|
||||
Store: mediaStore,
|
||||
Pub: ssePub,
|
||||
Personas: personaStore,
|
||||
Queue: jobQueue,
|
||||
Logger: logger.Logger,
|
||||
}))
|
||||
}
|
||||
|
||||
// Setup signal handling
|
||||
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"git.threesix.ai/jordan/persona-community-5/pkg/generation"
|
||||
"git.threesix.ai/jordan/persona-community-5/pkg/logging"
|
||||
"git.threesix.ai/jordan/persona-community-5/pkg/mediagen"
|
||||
"git.threesix.ai/jordan/persona-community-5/pkg/personagen"
|
||||
"git.threesix.ai/jordan/persona-community-5/pkg/queue"
|
||||
"git.threesix.ai/jordan/persona-community-5/pkg/realtime"
|
||||
"git.threesix.ai/jordan/persona-community-5/pkg/storage"
|
||||
@ -31,3 +32,10 @@ func TextHandler(tg *textgen.Manager, pub realtime.EventPublisher, logger *loggi
|
||||
func ChatResponseHandler(tg *textgen.Manager, pub realtime.EventPublisher, logger *logging.Logger) queue.Handler {
|
||||
return generation.ChatResponseHandler(tg, pub, logger)
|
||||
}
|
||||
|
||||
// PersonaGenerateHandler returns a queue.Handler for the staged persona generation pipeline.
|
||||
// Each job carries a stage (spec, anchor, avatar, banner, gallery_batch, video) and processes
|
||||
// one unit of work, updating the persona row and publishing SSE events after each stage.
|
||||
func PersonaGenerateHandler(deps personagen.PipelineDeps) queue.Handler {
|
||||
return personagen.StagedQueueHandler(deps)
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user