385 lines
9.7 KiB
Go
385 lines
9.7 KiB
Go
package personagen
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log/slog"
|
|
"testing"
|
|
"time"
|
|
|
|
"git.threesix.ai/jordan/persona-community-5/pkg/persona"
|
|
"git.threesix.ai/jordan/persona-community-5/pkg/queue"
|
|
"git.threesix.ai/jordan/persona-community-5/pkg/realtime"
|
|
)
|
|
|
|
// ── Mock implementations ─────────────────────────────────────────────────────
|
|
|
|
// mockPersonaStore implements PersonaStore for testing.
|
|
type mockPersonaStore struct {
|
|
records map[string]*PersonaRecord
|
|
}
|
|
|
|
func newMockPersonaStore() *mockPersonaStore {
|
|
return &mockPersonaStore{records: make(map[string]*PersonaRecord)}
|
|
}
|
|
|
|
func (m *mockPersonaStore) GetByID(_ context.Context, id string) (*PersonaRecord, error) {
|
|
rec, ok := m.records[id]
|
|
if !ok {
|
|
return nil, ErrPersonaNotFound
|
|
}
|
|
// Return a copy to avoid shared mutation.
|
|
cp := *rec
|
|
if rec.Tags != nil {
|
|
cp.Tags = make([]string, len(rec.Tags))
|
|
copy(cp.Tags, rec.Tags)
|
|
}
|
|
if rec.ImageURLs != nil {
|
|
cp.ImageURLs = make([]string, len(rec.ImageURLs))
|
|
copy(cp.ImageURLs, rec.ImageURLs)
|
|
}
|
|
if rec.VideoURLs != nil {
|
|
cp.VideoURLs = make([]string, len(rec.VideoURLs))
|
|
copy(cp.VideoURLs, rec.VideoURLs)
|
|
}
|
|
if rec.SpecJSON != nil {
|
|
cp.SpecJSON = make(json.RawMessage, len(rec.SpecJSON))
|
|
copy(cp.SpecJSON, rec.SpecJSON)
|
|
}
|
|
return &cp, nil
|
|
}
|
|
|
|
func (m *mockPersonaStore) Update(_ context.Context, p *PersonaRecord) error {
|
|
if _, ok := m.records[p.ID]; !ok {
|
|
return ErrPersonaNotFound
|
|
}
|
|
m.records[p.ID] = p
|
|
return nil
|
|
}
|
|
|
|
// mockProducer records enqueued jobs.
|
|
type mockProducer struct {
|
|
jobs []enqueuedJob
|
|
}
|
|
|
|
type enqueuedJob struct {
|
|
jobType string
|
|
payload map[string]any
|
|
}
|
|
|
|
var _ queue.Producer = (*mockProducer)(nil)
|
|
|
|
func (m *mockProducer) Enqueue(_ context.Context, jobType string, payload map[string]any) (string, error) {
|
|
m.jobs = append(m.jobs, enqueuedJob{jobType: jobType, payload: payload})
|
|
return fmt.Sprintf("job-%d", len(m.jobs)), nil
|
|
}
|
|
|
|
func (m *mockProducer) EnqueueWithOptions(_ context.Context, job queue.Job) (string, error) {
|
|
m.jobs = append(m.jobs, enqueuedJob{jobType: job.Type, payload: job.Payload})
|
|
return fmt.Sprintf("job-%d", len(m.jobs)), nil
|
|
}
|
|
|
|
// mockPublisher records published events.
|
|
type mockPublisher struct {
|
|
userEvents []publishedEvent
|
|
channelEvents []publishedEvent
|
|
}
|
|
|
|
type publishedEvent struct {
|
|
target string
|
|
event *realtime.SSEEvent
|
|
}
|
|
|
|
var _ realtime.EventPublisher = (*mockPublisher)(nil)
|
|
|
|
func (m *mockPublisher) SendToUser(userID string, event *realtime.SSEEvent) error {
|
|
m.userEvents = append(m.userEvents, publishedEvent{target: userID, event: event})
|
|
return nil
|
|
}
|
|
|
|
func (m *mockPublisher) SendToChannel(channel string, event *realtime.SSEEvent) error {
|
|
m.channelEvents = append(m.channelEvents, publishedEvent{target: channel, event: event})
|
|
return nil
|
|
}
|
|
|
|
// ── Tests ────────────────────────────────────────────────────────────────────
|
|
|
|
func TestStagedQueueHandler_MissingPersonaID(t *testing.T) {
|
|
handler := StagedQueueHandler(PipelineDeps{})
|
|
|
|
job := &queue.Job{
|
|
ID: "test-job",
|
|
Type: "persona_generate",
|
|
Payload: map[string]any{"stage": "spec"},
|
|
}
|
|
|
|
err := handler(context.Background(), job)
|
|
if err == nil {
|
|
t.Fatal("expected error for missing persona_id")
|
|
}
|
|
if err.Error() != "missing persona_id in persona_generate job payload" {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestStagedQueueHandler_MissingStage(t *testing.T) {
|
|
handler := StagedQueueHandler(PipelineDeps{})
|
|
|
|
job := &queue.Job{
|
|
ID: "test-job",
|
|
Type: "persona_generate",
|
|
Payload: map[string]any{"persona_id": "ps_test"},
|
|
}
|
|
|
|
err := handler(context.Background(), job)
|
|
if err == nil {
|
|
t.Fatal("expected error for missing stage")
|
|
}
|
|
if err.Error() != "missing stage in persona_generate job payload" {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestStagedQueueHandler_UnknownStage(t *testing.T) {
|
|
store := newMockPersonaStore()
|
|
pub := &mockPublisher{}
|
|
handler := StagedQueueHandler(PipelineDeps{
|
|
Personas: store,
|
|
Pub: pub,
|
|
Logger: noopLogger(),
|
|
})
|
|
|
|
job := &queue.Job{
|
|
ID: "test-job",
|
|
Type: "persona_generate",
|
|
Payload: map[string]any{
|
|
"persona_id": "ps_test",
|
|
"stage": "unknown_stage",
|
|
},
|
|
}
|
|
|
|
err := handler(context.Background(), job)
|
|
if err == nil {
|
|
t.Fatal("expected error for unknown stage")
|
|
}
|
|
}
|
|
|
|
func TestStagedQueueHandler_SpecStage_PersonaNotFound(t *testing.T) {
|
|
store := newMockPersonaStore()
|
|
pub := &mockPublisher{}
|
|
handler := StagedQueueHandler(PipelineDeps{
|
|
Personas: store,
|
|
Pub: pub,
|
|
Logger: noopLogger(),
|
|
})
|
|
|
|
job := &queue.Job{
|
|
ID: "test-job",
|
|
Type: "persona_generate",
|
|
Payload: map[string]any{
|
|
"persona_id": "ps_nonexistent",
|
|
"stage": StageSpec,
|
|
},
|
|
}
|
|
|
|
err := handler(context.Background(), job)
|
|
if err == nil {
|
|
t.Fatal("expected error for persona not found")
|
|
}
|
|
}
|
|
|
|
func TestExtractTags(t *testing.T) {
|
|
spec := &persona.PersonaSpec{
|
|
DNA: &persona.DNA{
|
|
Identity: persona.IdentityDNA{
|
|
Gender: persona.GenderWoman,
|
|
Ethnicity: persona.EthnicityEastAsian,
|
|
},
|
|
},
|
|
Name: persona.NameSpec{First: "Luna"},
|
|
Lifestyle: persona.Lifestyle{
|
|
FashionSense: persona.FashionSense{
|
|
Primary: persona.FashionClassicMinimalist,
|
|
},
|
|
Interests: persona.Interests{
|
|
Creative: []string{"photography", "painting"},
|
|
Active: []string{"yoga", "hiking"},
|
|
},
|
|
},
|
|
}
|
|
|
|
tags := extractTags(spec)
|
|
if len(tags) > 8 {
|
|
t.Errorf("expected at most 8 tags, got %d", len(tags))
|
|
}
|
|
if len(tags) == 0 {
|
|
t.Error("expected at least some tags")
|
|
}
|
|
|
|
// Should contain gender.
|
|
found := false
|
|
for _, tag := range tags {
|
|
if tag == "woman" {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
t.Error("expected 'woman' tag")
|
|
}
|
|
}
|
|
|
|
func TestCountCompletedImages(t *testing.T) {
|
|
spec := &persona.PersonaSpec{
|
|
ImageMatrix: []persona.ImageSpec{
|
|
{Position: 1, Status: persona.ImageStatusComplete},
|
|
{Position: 2, Status: persona.ImageStatusComplete},
|
|
{Position: 3, Status: persona.ImageStatusPending},
|
|
{Position: 4, Status: persona.ImageStatusFailed},
|
|
},
|
|
}
|
|
|
|
count := countCompletedImages(spec)
|
|
if count != 2 {
|
|
t.Errorf("expected 2 completed, got %d", count)
|
|
}
|
|
}
|
|
|
|
func TestAppendIfMissing(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
slice []string
|
|
s string
|
|
expected int
|
|
}{
|
|
{"adds new", []string{"a"}, "b", 2},
|
|
{"skips duplicate", []string{"a", "b"}, "a", 2},
|
|
{"skips empty", []string{"a"}, "", 1},
|
|
{"adds to nil", nil, "a", 1},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := appendIfMissing(tt.slice, tt.s)
|
|
if len(result) != tt.expected {
|
|
t.Errorf("expected len %d, got %d", tt.expected, len(result))
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGenerateHandle(t *testing.T) {
|
|
// Override time for deterministic output.
|
|
origNow := now
|
|
now = func() time.Time { return time.UnixMilli(1234567890) }
|
|
defer func() { now = origNow }()
|
|
|
|
tests := []struct {
|
|
name string
|
|
input string
|
|
}{
|
|
{"simple name", "Luna Shadow"},
|
|
{"special chars", "DJ Beats!@#"},
|
|
{"unicode", "Café Noir"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
handle := generateHandle(tt.input)
|
|
if handle == "" {
|
|
t.Error("expected non-empty handle")
|
|
}
|
|
if len(handle) > 50 {
|
|
t.Errorf("handle too long: %d chars", len(handle))
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestPublishJobUpdate(t *testing.T) {
|
|
pub := &mockPublisher{}
|
|
|
|
publishJobUpdate(pub, "ps_123", StageSpec, "complete", 100, "")
|
|
|
|
if len(pub.channelEvents) != 1 {
|
|
t.Fatalf("expected 1 channel event, got %d", len(pub.channelEvents))
|
|
}
|
|
|
|
event := pub.channelEvents[0]
|
|
if event.target != "channel:personas" {
|
|
t.Errorf("expected channel 'channel:personas', got '%s'", event.target)
|
|
}
|
|
if event.event.Type != "job_update" {
|
|
t.Errorf("expected event type 'job_update', got '%s'", event.event.Type)
|
|
}
|
|
if event.event.Progress != 100 {
|
|
t.Errorf("expected progress 100, got %d", event.event.Progress)
|
|
}
|
|
|
|
result, ok := event.event.Result.(map[string]any)
|
|
if !ok {
|
|
t.Fatal("expected Result to be map")
|
|
}
|
|
if result["persona_id"] != "ps_123" {
|
|
t.Errorf("expected persona_id 'ps_123', got %v", result["persona_id"])
|
|
}
|
|
if result["stage"] != StageSpec {
|
|
t.Errorf("expected stage 'spec', got %v", result["stage"])
|
|
}
|
|
if result["status"] != "complete" {
|
|
t.Errorf("expected status 'complete', got %v", result["status"])
|
|
}
|
|
}
|
|
|
|
func TestPublishJobUpdate_WithError(t *testing.T) {
|
|
pub := &mockPublisher{}
|
|
|
|
publishJobUpdate(pub, "ps_123", StageSpec, "error", 0, "something went wrong")
|
|
|
|
if len(pub.channelEvents) != 1 {
|
|
t.Fatalf("expected 1 channel event, got %d", len(pub.channelEvents))
|
|
}
|
|
|
|
event := pub.channelEvents[0]
|
|
if event.event.Error != "something went wrong" {
|
|
t.Errorf("expected error message, got '%s'", event.event.Error)
|
|
}
|
|
}
|
|
|
|
func TestUnmarshalSpec_EmptyJSON(t *testing.T) {
|
|
_, err := unmarshalSpec(nil)
|
|
if err == nil {
|
|
t.Fatal("expected error for nil spec JSON")
|
|
}
|
|
|
|
_, err = unmarshalSpec(json.RawMessage{})
|
|
if err == nil {
|
|
t.Fatal("expected error for empty spec JSON")
|
|
}
|
|
}
|
|
|
|
func TestUnmarshalSpec_ValidJSON(t *testing.T) {
|
|
spec := &persona.PersonaSpec{
|
|
ID: "ps_test",
|
|
Name: persona.NameSpec{First: "Luna", Last: "Shadow"},
|
|
}
|
|
data, _ := json.Marshal(spec)
|
|
|
|
result, err := unmarshalSpec(data)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if result.ID != "ps_test" {
|
|
t.Errorf("expected ID 'ps_test', got '%s'", result.ID)
|
|
}
|
|
if result.Name.First != "Luna" {
|
|
t.Errorf("expected first name 'Luna', got '%s'", result.Name.First)
|
|
}
|
|
}
|
|
|
|
// noopLogger returns a discard logger for tests.
|
|
func noopLogger() *slog.Logger {
|
|
return slog.Default() // In test context, logs go to stderr which is fine
|
|
}
|