persona-community-5/pkg/personagen/pipeline_test.go
rdev-worker 66ceb7e55f
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed
build: /implement-feature persona-generation --requirements 'Implement the g...
2026-02-24 08:13:52 +00:00

385 lines
9.7 KiB
Go

package personagen
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"testing"
"time"
"git.threesix.ai/jordan/persona-community-5/pkg/persona"
"git.threesix.ai/jordan/persona-community-5/pkg/queue"
"git.threesix.ai/jordan/persona-community-5/pkg/realtime"
)
// ── Mock implementations ─────────────────────────────────────────────────────
// mockPersonaStore implements PersonaStore for testing.
type mockPersonaStore struct {
records map[string]*PersonaRecord
}
func newMockPersonaStore() *mockPersonaStore {
return &mockPersonaStore{records: make(map[string]*PersonaRecord)}
}
func (m *mockPersonaStore) GetByID(_ context.Context, id string) (*PersonaRecord, error) {
rec, ok := m.records[id]
if !ok {
return nil, ErrPersonaNotFound
}
// Return a copy to avoid shared mutation.
cp := *rec
if rec.Tags != nil {
cp.Tags = make([]string, len(rec.Tags))
copy(cp.Tags, rec.Tags)
}
if rec.ImageURLs != nil {
cp.ImageURLs = make([]string, len(rec.ImageURLs))
copy(cp.ImageURLs, rec.ImageURLs)
}
if rec.VideoURLs != nil {
cp.VideoURLs = make([]string, len(rec.VideoURLs))
copy(cp.VideoURLs, rec.VideoURLs)
}
if rec.SpecJSON != nil {
cp.SpecJSON = make(json.RawMessage, len(rec.SpecJSON))
copy(cp.SpecJSON, rec.SpecJSON)
}
return &cp, nil
}
func (m *mockPersonaStore) Update(_ context.Context, p *PersonaRecord) error {
if _, ok := m.records[p.ID]; !ok {
return ErrPersonaNotFound
}
m.records[p.ID] = p
return nil
}
// mockProducer records enqueued jobs.
type mockProducer struct {
jobs []enqueuedJob
}
type enqueuedJob struct {
jobType string
payload map[string]any
}
var _ queue.Producer = (*mockProducer)(nil)
func (m *mockProducer) Enqueue(_ context.Context, jobType string, payload map[string]any) (string, error) {
m.jobs = append(m.jobs, enqueuedJob{jobType: jobType, payload: payload})
return fmt.Sprintf("job-%d", len(m.jobs)), nil
}
func (m *mockProducer) EnqueueWithOptions(_ context.Context, job queue.Job) (string, error) {
m.jobs = append(m.jobs, enqueuedJob{jobType: job.Type, payload: job.Payload})
return fmt.Sprintf("job-%d", len(m.jobs)), nil
}
// mockPublisher records published events.
type mockPublisher struct {
userEvents []publishedEvent
channelEvents []publishedEvent
}
type publishedEvent struct {
target string
event *realtime.SSEEvent
}
var _ realtime.EventPublisher = (*mockPublisher)(nil)
func (m *mockPublisher) SendToUser(userID string, event *realtime.SSEEvent) error {
m.userEvents = append(m.userEvents, publishedEvent{target: userID, event: event})
return nil
}
func (m *mockPublisher) SendToChannel(channel string, event *realtime.SSEEvent) error {
m.channelEvents = append(m.channelEvents, publishedEvent{target: channel, event: event})
return nil
}
// ── Tests ────────────────────────────────────────────────────────────────────
func TestStagedQueueHandler_MissingPersonaID(t *testing.T) {
handler := StagedQueueHandler(PipelineDeps{})
job := &queue.Job{
ID: "test-job",
Type: "persona_generate",
Payload: map[string]any{"stage": "spec"},
}
err := handler(context.Background(), job)
if err == nil {
t.Fatal("expected error for missing persona_id")
}
if err.Error() != "missing persona_id in persona_generate job payload" {
t.Errorf("unexpected error: %v", err)
}
}
func TestStagedQueueHandler_MissingStage(t *testing.T) {
handler := StagedQueueHandler(PipelineDeps{})
job := &queue.Job{
ID: "test-job",
Type: "persona_generate",
Payload: map[string]any{"persona_id": "ps_test"},
}
err := handler(context.Background(), job)
if err == nil {
t.Fatal("expected error for missing stage")
}
if err.Error() != "missing stage in persona_generate job payload" {
t.Errorf("unexpected error: %v", err)
}
}
func TestStagedQueueHandler_UnknownStage(t *testing.T) {
store := newMockPersonaStore()
pub := &mockPublisher{}
handler := StagedQueueHandler(PipelineDeps{
Personas: store,
Pub: pub,
Logger: noopLogger(),
})
job := &queue.Job{
ID: "test-job",
Type: "persona_generate",
Payload: map[string]any{
"persona_id": "ps_test",
"stage": "unknown_stage",
},
}
err := handler(context.Background(), job)
if err == nil {
t.Fatal("expected error for unknown stage")
}
}
func TestStagedQueueHandler_SpecStage_PersonaNotFound(t *testing.T) {
store := newMockPersonaStore()
pub := &mockPublisher{}
handler := StagedQueueHandler(PipelineDeps{
Personas: store,
Pub: pub,
Logger: noopLogger(),
})
job := &queue.Job{
ID: "test-job",
Type: "persona_generate",
Payload: map[string]any{
"persona_id": "ps_nonexistent",
"stage": StageSpec,
},
}
err := handler(context.Background(), job)
if err == nil {
t.Fatal("expected error for persona not found")
}
}
func TestExtractTags(t *testing.T) {
spec := &persona.PersonaSpec{
DNA: &persona.DNA{
Identity: persona.IdentityDNA{
Gender: persona.GenderWoman,
Ethnicity: persona.EthnicityEastAsian,
},
},
Name: persona.NameSpec{First: "Luna"},
Lifestyle: persona.Lifestyle{
FashionSense: persona.FashionSense{
Primary: persona.FashionClassicMinimalist,
},
Interests: persona.Interests{
Creative: []string{"photography", "painting"},
Active: []string{"yoga", "hiking"},
},
},
}
tags := extractTags(spec)
if len(tags) > 8 {
t.Errorf("expected at most 8 tags, got %d", len(tags))
}
if len(tags) == 0 {
t.Error("expected at least some tags")
}
// Should contain gender.
found := false
for _, tag := range tags {
if tag == "woman" {
found = true
break
}
}
if !found {
t.Error("expected 'woman' tag")
}
}
func TestCountCompletedImages(t *testing.T) {
spec := &persona.PersonaSpec{
ImageMatrix: []persona.ImageSpec{
{Position: 1, Status: persona.ImageStatusComplete},
{Position: 2, Status: persona.ImageStatusComplete},
{Position: 3, Status: persona.ImageStatusPending},
{Position: 4, Status: persona.ImageStatusFailed},
},
}
count := countCompletedImages(spec)
if count != 2 {
t.Errorf("expected 2 completed, got %d", count)
}
}
func TestAppendIfMissing(t *testing.T) {
tests := []struct {
name string
slice []string
s string
expected int
}{
{"adds new", []string{"a"}, "b", 2},
{"skips duplicate", []string{"a", "b"}, "a", 2},
{"skips empty", []string{"a"}, "", 1},
{"adds to nil", nil, "a", 1},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := appendIfMissing(tt.slice, tt.s)
if len(result) != tt.expected {
t.Errorf("expected len %d, got %d", tt.expected, len(result))
}
})
}
}
func TestGenerateHandle(t *testing.T) {
// Override time for deterministic output.
origNow := now
now = func() time.Time { return time.UnixMilli(1234567890) }
defer func() { now = origNow }()
tests := []struct {
name string
input string
}{
{"simple name", "Luna Shadow"},
{"special chars", "DJ Beats!@#"},
{"unicode", "Café Noir"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
handle := generateHandle(tt.input)
if handle == "" {
t.Error("expected non-empty handle")
}
if len(handle) > 50 {
t.Errorf("handle too long: %d chars", len(handle))
}
})
}
}
func TestPublishJobUpdate(t *testing.T) {
pub := &mockPublisher{}
publishJobUpdate(pub, "ps_123", StageSpec, "complete", 100, "")
if len(pub.channelEvents) != 1 {
t.Fatalf("expected 1 channel event, got %d", len(pub.channelEvents))
}
event := pub.channelEvents[0]
if event.target != "channel:personas" {
t.Errorf("expected channel 'channel:personas', got '%s'", event.target)
}
if event.event.Type != "job_update" {
t.Errorf("expected event type 'job_update', got '%s'", event.event.Type)
}
if event.event.Progress != 100 {
t.Errorf("expected progress 100, got %d", event.event.Progress)
}
result, ok := event.event.Result.(map[string]any)
if !ok {
t.Fatal("expected Result to be map")
}
if result["persona_id"] != "ps_123" {
t.Errorf("expected persona_id 'ps_123', got %v", result["persona_id"])
}
if result["stage"] != StageSpec {
t.Errorf("expected stage 'spec', got %v", result["stage"])
}
if result["status"] != "complete" {
t.Errorf("expected status 'complete', got %v", result["status"])
}
}
func TestPublishJobUpdate_WithError(t *testing.T) {
pub := &mockPublisher{}
publishJobUpdate(pub, "ps_123", StageSpec, "error", 0, "something went wrong")
if len(pub.channelEvents) != 1 {
t.Fatalf("expected 1 channel event, got %d", len(pub.channelEvents))
}
event := pub.channelEvents[0]
if event.event.Error != "something went wrong" {
t.Errorf("expected error message, got '%s'", event.event.Error)
}
}
func TestUnmarshalSpec_EmptyJSON(t *testing.T) {
_, err := unmarshalSpec(nil)
if err == nil {
t.Fatal("expected error for nil spec JSON")
}
_, err = unmarshalSpec(json.RawMessage{})
if err == nil {
t.Fatal("expected error for empty spec JSON")
}
}
func TestUnmarshalSpec_ValidJSON(t *testing.T) {
spec := &persona.PersonaSpec{
ID: "ps_test",
Name: persona.NameSpec{First: "Luna", Last: "Shadow"},
}
data, _ := json.Marshal(spec)
result, err := unmarshalSpec(data)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.ID != "ps_test" {
t.Errorf("expected ID 'ps_test', got '%s'", result.ID)
}
if result.Name.First != "Luna" {
t.Errorf("expected first name 'Luna', got '%s'", result.Name.First)
}
}
// noopLogger returns a discard logger for tests.
func noopLogger() *slog.Logger {
return slog.Default() // In test context, logs go to stderr which is fine
}