441 lines
15 KiB
Go
441 lines
15 KiB
Go
// Package personagen provides persona generation services using LLM and media generation pipelines.
|
|
// It orchestrates a 5-stage spec generation pipeline (text) and image/video generation (media).
|
|
//
|
|
// Usage:
|
|
//
|
|
// svc := personagen.New(textgenManager, mediagenManager, store, logger)
|
|
// spec, err := svc.GenerateSpec(ctx, personagen.SeedParams{
|
|
// Description: "mysterious woman with dark hair who loves poetry",
|
|
// Gender: "woman",
|
|
// })
|
|
// err = svc.GenerateImages(ctx, spec, nil) // all 20 positions
|
|
// video, err := svc.GenerateVideo(ctx, spec, persona.MotionSmileReveal)
|
|
package personagen
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"strings"
|
|
"time"
|
|
|
|
"git.threesix.ai/jordan/persona-community-2/pkg/mediagen"
|
|
"git.threesix.ai/jordan/persona-community-2/pkg/persona"
|
|
"git.threesix.ai/jordan/persona-community-2/pkg/queue"
|
|
"git.threesix.ai/jordan/persona-community-2/pkg/realtime"
|
|
"git.threesix.ai/jordan/persona-community-2/pkg/storage"
|
|
"git.threesix.ai/jordan/persona-community-2/pkg/textgen"
|
|
)
|
|
|
|
// ErrAnchorNotSet is returned when GenerateVideo is called without a set anchor image.
|
|
// Call SetAnchor() or run GenerateImages() (which generates position 1) first.
|
|
var ErrAnchorNotSet = errors.New("anchor image not set: call SetAnchor() or GenerateImages() first")
|
|
|
|
// SeedParams contains the initial inputs for the persona generation pipeline.
|
|
type SeedParams struct {
|
|
// Description is a natural-language persona concept (required).
|
|
// Example: "mysterious woman with dark hair who loves poetry"
|
|
Description string
|
|
|
|
// Gender is the gender identity: "woman", "man", or "non_binary" (required).
|
|
Gender string
|
|
|
|
// Name is an optional name override. If empty, the LLM generates one.
|
|
Name string
|
|
}
|
|
|
|
// Service generates complete persona specs, images, and videos.
|
|
// Create with New(). Safe for concurrent use — each job should create its own instance
|
|
// to avoid shared anchor state between concurrent generations.
|
|
type Service struct {
|
|
textgen *textgen.Manager
|
|
mediagen *mediagen.Manager
|
|
store storage.Store
|
|
logger *slog.Logger
|
|
anchor []byte // position 1 PNG bytes — identity anchor for subsequent generations
|
|
}
|
|
|
|
// New creates a new personagen Service.
|
|
func New(tg *textgen.Manager, mg *mediagen.Manager, store storage.Store, logger *slog.Logger) *Service {
|
|
return &Service{
|
|
textgen: tg,
|
|
mediagen: mg,
|
|
store: store,
|
|
logger: logger.With("pkg", "personagen"),
|
|
}
|
|
}
|
|
|
|
// SetAnchor updates the anchor image used for identity consistency in subsequent generations.
|
|
// The anchor is the position 1 image — always generated first to establish identity.
|
|
func (s *Service) SetAnchor(imageBytes []byte) {
|
|
s.anchor = imageBytes
|
|
}
|
|
|
|
// GenerateSpec runs the 5-stage LLM pipeline to produce a complete PersonaSpec.
|
|
// Stages: (1) identity, (2) psychology, (3) lifestyle, (4) visual DNA, (5) image matrix.
|
|
func (s *Service) GenerateSpec(ctx context.Context, seed SeedParams) (*persona.PersonaSpec, error) {
|
|
return generatePersonaSpec(ctx, s.textgen, seed, s.logger)
|
|
}
|
|
|
|
// GenerateImages generates images for the specified positions in the image matrix.
|
|
// If positions is nil or empty, all 20 positions are generated sequentially.
|
|
// Position 1 (the anchor) is always generated first when included.
|
|
// Automatically calls SetAnchor() after position 1 is generated.
|
|
func (s *Service) GenerateImages(ctx context.Context, spec *persona.PersonaSpec, positions []int) error {
|
|
if len(positions) == 0 {
|
|
positions = make([]int, 20)
|
|
for i := range positions {
|
|
positions[i] = i + 1
|
|
}
|
|
}
|
|
|
|
// Check if position 1 is in the list — it must be generated first.
|
|
hasPos1 := false
|
|
for _, p := range positions {
|
|
if p == 1 {
|
|
hasPos1 = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if hasPos1 {
|
|
if err := s.generatePosition(ctx, spec, 1); err != nil {
|
|
return fmt.Errorf("generating anchor position 1: %w", err)
|
|
}
|
|
// Remove position 1 from remaining
|
|
remaining := positions[:0]
|
|
for _, p := range positions {
|
|
if p != 1 {
|
|
remaining = append(remaining, p)
|
|
}
|
|
}
|
|
positions = remaining
|
|
}
|
|
|
|
for _, pos := range positions {
|
|
if err := s.generatePosition(ctx, spec, pos); err != nil {
|
|
return fmt.Errorf("generating position %d: %w", pos, err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GenerateVideo generates a video for the given motion type and uploads it to storage.
|
|
// Requires SetAnchor() to have been called first (or GenerateImages() for position 1).
|
|
// Returns ErrAnchorNotSet if no anchor is available.
|
|
func (s *Service) GenerateVideo(ctx context.Context, spec *persona.PersonaSpec, motionType persona.MotionType) (*persona.VideoSpec, error) {
|
|
if s.anchor == nil {
|
|
return nil, ErrAnchorNotSet
|
|
}
|
|
|
|
videoSpec, videoData, err := generateVideo(ctx, s.mediagen, spec, motionType, s.anchor, s.logger)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
storagePath := fmt.Sprintf("personas/%s/videos/%s.mp4", spec.ID, string(motionType))
|
|
url, err := s.store.Upload(ctx, storagePath, videoData, "video/mp4")
|
|
if err != nil {
|
|
videoSpec.Status = persona.VideoStatusFailed
|
|
return nil, fmt.Errorf("storing video %s: %w", motionType, err)
|
|
}
|
|
videoSpec.URL = url
|
|
return videoSpec, nil
|
|
}
|
|
|
|
// GenerateAvatar generates a square profile picture (close-up face, 1:1 crop).
|
|
// Uses the anchor image for identity consistency if available.
|
|
func (s *Service) GenerateAvatar(ctx context.Context, spec *persona.PersonaSpec) ([]byte, error) {
|
|
if s.mediagen == nil {
|
|
return nil, fmt.Errorf("mediagen not configured")
|
|
}
|
|
prompt := buildAvatarPrompt(spec)
|
|
req := mediagen.ImageRequest{
|
|
Prompt: prompt,
|
|
AspectRatio: "1:1",
|
|
}
|
|
if s.anchor != nil {
|
|
req.ReferenceImage = s.anchor
|
|
req.ReferenceMime = "image/png"
|
|
}
|
|
resp, err := s.mediagen.GenerateImage(ctx, req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("generating avatar: %w", err)
|
|
}
|
|
if len(resp.Images) == 0 {
|
|
return nil, fmt.Errorf("no images returned from provider")
|
|
}
|
|
return resp.Images[0].Data, nil
|
|
}
|
|
|
|
// GenerateBanner generates a wide banner image (16:9, landscape).
|
|
// style hints at the backdrop mood (e.g., "lifestyle", "luxury", "outdoor").
|
|
func (s *Service) GenerateBanner(ctx context.Context, spec *persona.PersonaSpec, style string) ([]byte, error) {
|
|
if s.mediagen == nil {
|
|
return nil, fmt.Errorf("mediagen not configured")
|
|
}
|
|
prompt := buildBannerPrompt(spec, style)
|
|
req := mediagen.ImageRequest{
|
|
Prompt: prompt,
|
|
AspectRatio: "16:9",
|
|
}
|
|
if s.anchor != nil {
|
|
req.ReferenceImage = s.anchor
|
|
req.ReferenceMime = "image/png"
|
|
}
|
|
resp, err := s.mediagen.GenerateImage(ctx, req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("generating banner: %w", err)
|
|
}
|
|
if len(resp.Images) == 0 {
|
|
return nil, fmt.Errorf("no images returned from provider")
|
|
}
|
|
return resp.Images[0].Data, nil
|
|
}
|
|
|
|
// generatePosition generates and stores a single image position in the spec.
|
|
func (s *Service) generatePosition(ctx context.Context, spec *persona.PersonaSpec, pos int) error {
|
|
var imgSpec *persona.ImageSpec
|
|
for i := range spec.ImageMatrix {
|
|
if spec.ImageMatrix[i].Position == pos {
|
|
imgSpec = &spec.ImageMatrix[i]
|
|
break
|
|
}
|
|
}
|
|
if imgSpec == nil {
|
|
return fmt.Errorf("position %d not found in image matrix", pos)
|
|
}
|
|
|
|
imageBytes, err := generateImage(ctx, s.mediagen, spec, imgSpec, s.anchor, s.logger)
|
|
if err != nil {
|
|
imgSpec.Status = persona.ImageStatusFailed
|
|
return err
|
|
}
|
|
|
|
// Position 1 becomes the anchor for all subsequent generations.
|
|
if pos == 1 {
|
|
s.anchor = imageBytes
|
|
spec.AnchorImage = imageBytes
|
|
}
|
|
|
|
storagePath := fmt.Sprintf("personas/%s/images/%02d.png", spec.ID, pos)
|
|
url, err := s.store.Upload(ctx, storagePath, imageBytes, "image/png")
|
|
if err != nil {
|
|
imgSpec.Status = persona.ImageStatusFailed
|
|
return fmt.Errorf("storing position %d: %w", pos, err)
|
|
}
|
|
|
|
if imgSpec.Prompt != "" {
|
|
captionPath := fmt.Sprintf("personas/%s/images/%02d.caption", spec.ID, pos)
|
|
if _, captionErr := s.store.Upload(ctx, captionPath, []byte(imgSpec.Prompt), "text/plain"); captionErr != nil {
|
|
s.logger.Warn("failed to persist image caption", "error", captionErr, "position", pos)
|
|
}
|
|
}
|
|
|
|
imgSpec.URL = url
|
|
imgSpec.Status = persona.ImageStatusComplete
|
|
return nil
|
|
}
|
|
|
|
// QueueHandler returns a queue.Handler for processing persona_generate jobs.
|
|
// Creates a fresh Service per job to avoid shared anchor state between concurrent jobs.
|
|
// Publishes SSE events to the user's channel throughout generation.
|
|
func QueueHandler(tg *textgen.Manager, mg *mediagen.Manager, store storage.Store, pub realtime.EventPublisher, logger *slog.Logger) queue.Handler {
|
|
return func(ctx context.Context, job *queue.Job) error {
|
|
userID, _ := job.Payload["userID"].(string)
|
|
if userID == "" {
|
|
return fmt.Errorf("missing userID in persona_generate job payload")
|
|
}
|
|
|
|
description, _ := job.Payload["description"].(string)
|
|
gender, _ := job.Payload["gender"].(string)
|
|
name, _ := job.Payload["name"].(string)
|
|
|
|
sendEvent := func(event *realtime.SSEEvent) {
|
|
if err := pub.SendToUser(userID, event); err != nil {
|
|
logger.Warn("failed to send persona SSE event", "error", err, "type", event.Type)
|
|
}
|
|
}
|
|
|
|
svc := New(tg, mg, store, logger)
|
|
seed := SeedParams{Description: description, Gender: gender, Name: name}
|
|
|
|
sendEvent(&realtime.SSEEvent{
|
|
Type: "persona_spec_started",
|
|
JobID: job.ID,
|
|
Message: "Generating persona profile...",
|
|
})
|
|
|
|
spec, err := svc.GenerateSpec(ctx, seed)
|
|
if err != nil {
|
|
logger.Error("persona spec generation failed", "error", err, "job_id", job.ID)
|
|
sendEvent(&realtime.SSEEvent{
|
|
Type: "persona_failed",
|
|
JobID: job.ID,
|
|
Error: "Spec generation failed: " + err.Error(),
|
|
})
|
|
return err
|
|
}
|
|
|
|
sendEvent(&realtime.SSEEvent{
|
|
Type: "persona_spec_complete",
|
|
JobID: job.ID,
|
|
Message: "Profile complete, generating images...",
|
|
Result: map[string]any{"personaId": spec.ID},
|
|
})
|
|
|
|
// Build an ordered position list — position 1 (anchor) must always be generated first.
|
|
// generatePosition() mutates the spec.ImageMatrix entry in place (URL, Status),
|
|
// so we keep a pointer to each entry to read the URL after generation.
|
|
type posEntry struct {
|
|
pos int
|
|
imgSpec *persona.ImageSpec
|
|
}
|
|
orderedPositions := make([]posEntry, 0, len(spec.ImageMatrix))
|
|
for i := range spec.ImageMatrix {
|
|
orderedPositions = append(orderedPositions, posEntry{
|
|
pos: spec.ImageMatrix[i].Position,
|
|
imgSpec: &spec.ImageMatrix[i],
|
|
})
|
|
}
|
|
// Swap position 1 to front if it isn't already.
|
|
for i, e := range orderedPositions {
|
|
if e.pos == 1 && i != 0 {
|
|
orderedPositions[0], orderedPositions[i] = orderedPositions[i], orderedPositions[0]
|
|
break
|
|
}
|
|
}
|
|
|
|
// Generate all 20 image positions, publishing progress events.
|
|
for _, entry := range orderedPositions {
|
|
pos := entry.pos
|
|
sendEvent(&realtime.SSEEvent{
|
|
Type: "persona_image_started",
|
|
JobID: job.ID,
|
|
Message: fmt.Sprintf("Generating position %d...", pos),
|
|
Result: map[string]any{"position": pos},
|
|
})
|
|
|
|
if err := svc.generatePosition(ctx, spec, pos); err != nil {
|
|
logger.Error("persona image generation failed", "error", err, "position", pos, "job_id", job.ID)
|
|
sendEvent(&realtime.SSEEvent{
|
|
Type: "persona_failed",
|
|
JobID: job.ID,
|
|
Error: fmt.Sprintf("Image position %d failed: %s", pos, err.Error()),
|
|
})
|
|
return err
|
|
}
|
|
|
|
progress := (pos * 100) / 20
|
|
sendEvent(&realtime.SSEEvent{
|
|
Type: "persona_image_progress",
|
|
JobID: job.ID,
|
|
Progress: progress,
|
|
Result: map[string]any{"position": pos, "url": entry.imgSpec.URL},
|
|
})
|
|
}
|
|
|
|
sendEvent(&realtime.SSEEvent{
|
|
Type: "persona_image_complete",
|
|
JobID: job.ID,
|
|
Progress: 100,
|
|
Message: "All images generated",
|
|
Result: map[string]any{"personaId": spec.ID},
|
|
})
|
|
|
|
// Generate 4 videos. Videos are best-effort — a failed video does not abort the job,
|
|
// but a persona_video_failed event is sent so the frontend can reflect partial completion.
|
|
for _, vs := range spec.Videos {
|
|
sendEvent(&realtime.SSEEvent{
|
|
Type: "persona_video_started",
|
|
JobID: job.ID,
|
|
Message: fmt.Sprintf("Generating %s video...", vs.MotionType),
|
|
Result: map[string]any{"motionType": string(vs.MotionType)},
|
|
})
|
|
|
|
videoSpec, err := svc.GenerateVideo(ctx, spec, vs.MotionType)
|
|
if err != nil {
|
|
logger.Warn("persona video generation failed (non-fatal)", "error", err, "motion", vs.MotionType, "job_id", job.ID)
|
|
sendEvent(&realtime.SSEEvent{
|
|
Type: "persona_video_failed",
|
|
JobID: job.ID,
|
|
Error: fmt.Sprintf("%s video failed: %s", vs.MotionType, err.Error()),
|
|
Result: map[string]any{"motionType": string(vs.MotionType)},
|
|
})
|
|
continue
|
|
}
|
|
|
|
sendEvent(&realtime.SSEEvent{
|
|
Type: "persona_video_complete",
|
|
JobID: job.ID,
|
|
Message: "Video complete",
|
|
Result: map[string]any{"motionType": string(vs.MotionType), "url": videoSpec.URL},
|
|
})
|
|
}
|
|
|
|
logger.Info("persona generation complete", "job_id", job.ID, "persona_id", spec.ID)
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// generateID creates a random hex ID for a new persona spec.
|
|
func generateID() string {
|
|
b := make([]byte, 12)
|
|
_, _ = rand.Read(b)
|
|
return "ps_" + hex.EncodeToString(b)
|
|
}
|
|
|
|
// buildAvatarPrompt creates a close-up portrait prompt for avatar generation.
|
|
func buildAvatarPrompt(spec *persona.PersonaSpec) string {
|
|
return fmt.Sprintf(
|
|
"%s Close-up portrait, square 1:1 composition, face centered and sharp, soft bokeh background, professional headshot quality.",
|
|
buildIdentitySection(spec),
|
|
)
|
|
}
|
|
|
|
// buildBannerPrompt creates a wide landscape prompt for banner generation.
|
|
func buildBannerPrompt(spec *persona.PersonaSpec, style string) string {
|
|
if style == "" {
|
|
style = "lifestyle"
|
|
}
|
|
return fmt.Sprintf(
|
|
"%s Wide landscape banner, 16:9 cinematic composition, %s aesthetic, professional photography quality.",
|
|
buildIdentitySection(spec),
|
|
style,
|
|
)
|
|
}
|
|
|
|
// inferGenerationTier infers a generation tier from the description keywords (case-insensitive).
|
|
func inferGenerationTier(description string) persona.GenerationTier {
|
|
lower := strings.ToLower(description)
|
|
for _, kw := range []string{"supermodel", "model", "editorial", "high fashion"} {
|
|
if strings.Contains(lower, kw) {
|
|
return persona.GenerationTierSupermodel
|
|
}
|
|
}
|
|
for _, kw := range []string{"influencer", "content creator", "blogger", "social media"} {
|
|
if strings.Contains(lower, kw) {
|
|
return persona.GenerationTierInfluencer
|
|
}
|
|
}
|
|
return persona.GenerationTierEveryday
|
|
}
|
|
|
|
// inferAttractiveness infers an attractiveness tier from the generation tier.
|
|
func inferAttractiveness(tier persona.GenerationTier) persona.AttractivenessTier {
|
|
switch tier {
|
|
case persona.GenerationTierSupermodel:
|
|
return persona.AttractivenessTierStunning
|
|
case persona.GenerationTierInfluencer:
|
|
return persona.AttractivenessTierVery
|
|
default:
|
|
return persona.AttractivenessTierAttractive
|
|
}
|
|
}
|
|
|
|
// now returns the current time. Useful for overriding in tests.
|
|
var now = func() time.Time { return time.Now() }
|