diff --git a/internal/adapter/postgres/build_audit.go b/internal/adapter/postgres/build_audit.go index 6dcfcdc..2f974af 100644 --- a/internal/adapter/postgres/build_audit.go +++ b/internal/adapter/postgres/build_audit.go @@ -83,6 +83,28 @@ func (r *BuildAuditRepository) Update(ctx context.Context, taskID string, result return nil } +// UpdateStatus updates the status and worker assignment when a task is claimed. +func (r *BuildAuditRepository) UpdateStatus(ctx context.Context, taskID string, status domain.BuildStatus, workerID string) error { + res, err := r.db.ExecContext(ctx, ` + UPDATE build_audit + SET status = $2, worker_id = $3 + WHERE task_id = $1 + `, taskID, status, nullString(workerID)) + if err != nil { + return fmt.Errorf("update build audit status: %w", err) + } + + rows, err := res.RowsAffected() + if err != nil { + return fmt.Errorf("rows affected: %w", err) + } + if rows == 0 { + return domain.ErrBuildNotFound + } + + return nil +} + // Get retrieves a specific audit entry by task ID. func (r *BuildAuditRepository) Get(ctx context.Context, taskID string) (*domain.BuildAuditEntry, error) { rows, err := r.db.QueryContext(ctx, ` diff --git a/internal/adapter/postgres/build_audit_test.go b/internal/adapter/postgres/build_audit_test.go index cb2818e..19866ad 100644 --- a/internal/adapter/postgres/build_audit_test.go +++ b/internal/adapter/postgres/build_audit_test.go @@ -181,6 +181,83 @@ func TestBuildAuditRepository_Update(t *testing.T) { }) } +func TestBuildAuditRepository_UpdateStatus(t *testing.T) { + db := testutil.TestDB(t) + t.Cleanup(func() { cleanupTestBuildAudit(t, db) }) + + repo := NewBuildAuditRepository(db) + ctx := context.Background() + + // Create initial entry + entry := &domain.BuildAuditEntry{ + TaskID: "test-task-status-1", + ProjectID: "test-project-1", + Spec: domain.BuildSpec{Prompt: "Build"}, + Status: domain.BuildStatusPending, + StartedAt: time.Now(), + } + if err := repo.Record(ctx, entry); err != nil { + t.Fatalf("Record() error = %v", err) + } + + t.Run("updates status and worker ID", func(t *testing.T) { + err := repo.UpdateStatus(ctx, "test-task-status-1", domain.BuildStatusRunning, "worker-123") + if err != nil { + t.Fatalf("UpdateStatus() error = %v", err) + } + + got, err := repo.Get(ctx, "test-task-status-1") + if err != nil { + t.Fatalf("Get() error = %v", err) + } + if got.Status != domain.BuildStatusRunning { + t.Errorf("got status %q, want %q", got.Status, domain.BuildStatusRunning) + } + if got.WorkerID != "worker-123" { + t.Errorf("got worker_id %q, want %q", got.WorkerID, "worker-123") + } + }) + + t.Run("updates status with empty worker ID", func(t *testing.T) { + // Create another entry + entry := &domain.BuildAuditEntry{ + TaskID: "test-task-status-2", + ProjectID: "test-project-1", + WorkerID: "old-worker", + Spec: domain.BuildSpec{Prompt: "Build"}, + Status: domain.BuildStatusRunning, + StartedAt: time.Now(), + } + if err := repo.Record(ctx, entry); err != nil { + t.Fatalf("Record() error = %v", err) + } + + err := repo.UpdateStatus(ctx, "test-task-status-2", domain.BuildStatusCompleted, "") + if err != nil { + t.Fatalf("UpdateStatus() error = %v", err) + } + + got, err := repo.Get(ctx, "test-task-status-2") + if err != nil { + t.Fatalf("Get() error = %v", err) + } + if got.Status != domain.BuildStatusCompleted { + t.Errorf("got status %q, want %q", got.Status, domain.BuildStatusCompleted) + } + // WorkerID should be cleared when empty string is passed + if got.WorkerID != "" { + t.Errorf("got worker_id %q, want empty", got.WorkerID) + } + }) + + t.Run("returns error for nonexistent task", func(t *testing.T) { + err := repo.UpdateStatus(ctx, "test-task-nonexistent", domain.BuildStatusRunning, "worker-1") + if err == nil { + t.Error("expected error for nonexistent task") + } + }) +} + func TestBuildAuditRepository_Get(t *testing.T) { db := testutil.TestDB(t) t.Cleanup(func() { cleanupTestBuildAudit(t, db) }) diff --git a/internal/handlers/builds_test.go b/internal/handlers/builds_test.go index 8a2073c..241a698 100644 --- a/internal/handlers/builds_test.go +++ b/internal/handlers/builds_test.go @@ -66,6 +66,19 @@ func (m *mockBuildAudit) Update(_ context.Context, taskID string, result *domain return nil } +func (m *mockBuildAudit) UpdateStatus(_ context.Context, taskID string, status domain.BuildStatus, workerID string) error { + if m.err != nil { + return m.err + } + entry, ok := m.entries[taskID] + if !ok { + return domain.ErrBuildNotFound + } + entry.Status = status + entry.WorkerID = workerID + return nil +} + func (m *mockBuildAudit) Get(_ context.Context, taskID string) (*domain.BuildAuditEntry, error) { if m.err != nil { return nil, m.err diff --git a/internal/port/build_audit.go b/internal/port/build_audit.go index c248861..9211b17 100644 --- a/internal/port/build_audit.go +++ b/internal/port/build_audit.go @@ -19,6 +19,10 @@ type BuildAudit interface { // Update modifies an existing entry when a build completes. Update(ctx context.Context, taskID string, result *domain.BuildResult) error + // UpdateStatus updates the status and worker assignment when a task is claimed. + // This is called when a worker picks up a task to mark it as running. + UpdateStatus(ctx context.Context, taskID string, status domain.BuildStatus, workerID string) error + // Get retrieves a specific audit entry by task ID. // Returns ErrBuildNotFound if the entry does not exist. Get(ctx context.Context, taskID string) (*domain.BuildAuditEntry, error) diff --git a/internal/service/mock_test.go b/internal/service/mock_test.go index ce44967..ddf7363 100644 --- a/internal/service/mock_test.go +++ b/internal/service/mock_test.go @@ -130,6 +130,19 @@ func (m *mockBuildAudit) Update(ctx context.Context, taskID string, result *doma return nil } +func (m *mockBuildAudit) UpdateStatus(ctx context.Context, taskID string, status domain.BuildStatus, workerID string) error { + if m.err != nil { + return m.err + } + entry, ok := m.entries[taskID] + if !ok { + return domain.ErrBuildNotFound + } + entry.Status = status + entry.WorkerID = workerID + return nil +} + func (m *mockBuildAudit) Get(ctx context.Context, taskID string) (*domain.BuildAuditEntry, error) { if m.err != nil { return nil, m.err diff --git a/internal/service/worker_service.go b/internal/service/worker_service.go index 450a027..7316ae2 100644 --- a/internal/service/worker_service.go +++ b/internal/service/worker_service.go @@ -118,12 +118,14 @@ func (s *WorkerService) ClaimTask(ctx context.Context, workerID string) (*domain ) } - // Update audit entry if available + // Update audit entry if available - persist status change to database if s.audit != nil { - entry, _ := s.audit.Get(ctx, task.ID) - if entry != nil { - entry.WorkerID = workerID - entry.Status = domain.BuildStatusRunning + if err := s.audit.UpdateStatus(ctx, task.ID, domain.BuildStatusRunning, workerID); err != nil { + s.logger.Warn("failed to update audit status after claim", + "task_id", task.ID, + "worker_id", workerID, + "error", err, + ) } } diff --git a/internal/service/worker_service_test.go b/internal/service/worker_service_test.go index 5ae8bad..1e728fb 100644 --- a/internal/service/worker_service_test.go +++ b/internal/service/worker_service_test.go @@ -174,6 +174,50 @@ func TestWorkerService_ClaimTask(t *testing.T) { t.Error("expected nil task when queue is empty") } }) + + t.Run("updates audit status when claiming task", func(t *testing.T) { + registry := newMockWorkerRegistry() + registry.workers["worker-1"] = &domain.Worker{ + ID: "worker-1", + Hostname: "host-1", + Status: domain.WorkerStatusIdle, + } + + queue := newMockWorkQueue() + queue.tasks["task-1"] = &domain.WorkTask{ + ID: "task-1", + ProjectID: "project-1", + Type: domain.WorkTaskTypeBuild, + Status: domain.WorkTaskStatusPending, + CreatedAt: time.Now(), + } + + audit := newMockBuildAudit() + audit.entries["task-1"] = &domain.BuildAuditEntry{ + TaskID: "task-1", + ProjectID: "project-1", + Status: domain.BuildStatusPending, + } + + svc := NewWorkerService(registry, queue, nil).WithBuildAudit(audit) + + task, err := svc.ClaimTask(ctx, "worker-1") + if err != nil { + t.Fatalf("ClaimTask() error = %v", err) + } + if task == nil { + t.Fatal("expected task to be returned") + } + + // Verify audit was updated + entry := audit.entries["task-1"] + if entry.Status != domain.BuildStatusRunning { + t.Errorf("got audit status %q, want %q", entry.Status, domain.BuildStatusRunning) + } + if entry.WorkerID != "worker-1" { + t.Errorf("got audit worker_id %q, want %q", entry.WorkerID, "worker-1") + } + }) } func TestWorkerService_CompleteTask(t *testing.T) { diff --git a/internal/worker/mock_test.go b/internal/worker/mock_test.go index 5e074c8..be1c35d 100644 --- a/internal/worker/mock_test.go +++ b/internal/worker/mock_test.go @@ -221,6 +221,18 @@ func (m *mockBuildAudit) Update(_ context.Context, taskID string, result *domain return nil } +func (m *mockBuildAudit) UpdateStatus(_ context.Context, taskID string, status domain.BuildStatus, workerID string) error { + m.mu.Lock() + defer m.mu.Unlock() + entry, ok := m.entries[taskID] + if !ok { + return domain.ErrBuildNotFound + } + entry.Status = status + entry.WorkerID = workerID + return nil +} + func (m *mockBuildAudit) Get(_ context.Context, taskID string) (*domain.BuildAuditEntry, error) { m.mu.Lock() defer m.mu.Unlock() @@ -308,7 +320,7 @@ func newTestDeps() *testDeps { WithBuildAudit(audit) workSvc := service.NewWorkService(queue, service.WorkServiceConfig{}) - buildExec := NewBuildExecutor(agentRegistry, nil, nil) + buildExec := NewBuildExecutor(agentRegistry, nil, nil, nil) return &testDeps{ queue: queue,