// Package postgres provides PostgreSQL-based implementations of port interfaces. package postgres import ( "context" "database/sql" "encoding/json" "errors" "fmt" "time" "github.com/orchard9/rdev/internal/domain" "github.com/orchard9/rdev/internal/port" ) // WorkQueueRepository implements port.WorkQueue using PostgreSQL. type WorkQueueRepository struct { db *sql.DB } // NewWorkQueueRepository creates a new PostgreSQL work queue repository. func NewWorkQueueRepository(db *sql.DB) *WorkQueueRepository { return &WorkQueueRepository{db: db} } // Ensure WorkQueueRepository implements port.WorkQueue at compile time. var _ port.WorkQueue = (*WorkQueueRepository)(nil) // Enqueue adds a task to the queue. func (r *WorkQueueRepository) Enqueue(ctx context.Context, task *port.WorkTask) (string, error) { specJSON, err := json.Marshal(task.Spec) if err != nil { return "", fmt.Errorf("marshal task spec: %w", err) } var id string err = r.db.QueryRowContext(ctx, ` INSERT INTO work_queue (project_id, task_type, task_spec, priority, callback_url, max_retries) VALUES ($1, $2, $3, $4, $5, $6) RETURNING id `, task.ProjectID, string(task.Type), specJSON, task.Priority, nullString(task.CallbackURL), task.MaxRetries).Scan(&id) if err != nil { return "", fmt.Errorf("enqueue work task: %w", err) } return id, nil } // Dequeue atomically claims the next available task for a worker. func (r *WorkQueueRepository) Dequeue(ctx context.Context, workerID string) (*port.WorkTask, error) { // Use a single UPDATE ... RETURNING with subquery for atomic claim // This avoids explicit transaction management while still being safe var task port.WorkTask var taskType string var specJSON []byte var status string var callbackURL sql.NullString var startedAt sql.NullTime var completedAt sql.NullTime var resultJSON []byte var errorMsg sql.NullString err := r.db.QueryRowContext(ctx, ` UPDATE work_queue SET status = 'running', worker_id = $1, started_at = NOW() WHERE id = ( SELECT id FROM work_queue WHERE status = 'pending' ORDER BY priority DESC, created_at ASC LIMIT 1 FOR UPDATE SKIP LOCKED ) RETURNING id, project_id, task_type, task_spec, status, priority, worker_id, callback_url, created_at, started_at, completed_at, result, error, retry_count, max_retries `, workerID).Scan( &task.ID, &task.ProjectID, &taskType, &specJSON, &status, &task.Priority, &task.WorkerID, &callbackURL, &task.CreatedAt, &startedAt, &completedAt, &resultJSON, &errorMsg, &task.RetryCount, &task.MaxRetries, ) if errors.Is(err, sql.ErrNoRows) { return nil, nil // No pending tasks } if err != nil { return nil, fmt.Errorf("dequeue work task: %w", err) } task.Type = port.WorkTaskType(taskType) task.Status = port.WorkTaskStatus(status) if callbackURL.Valid { task.CallbackURL = callbackURL.String } if startedAt.Valid { task.StartedAt = &startedAt.Time } if completedAt.Valid { task.CompletedAt = &completedAt.Time } if errorMsg.Valid { task.Error = errorMsg.String } // Parse task spec if len(specJSON) > 0 { if err := json.Unmarshal(specJSON, &task.Spec); err != nil { return nil, fmt.Errorf("unmarshal task spec: %w", err) } } // Parse result if len(resultJSON) > 0 { task.Result = &port.WorkResult{} if err := json.Unmarshal(resultJSON, task.Result); err != nil { return nil, fmt.Errorf("unmarshal task result: %w", err) } } return &task, nil } // Complete marks a task as successfully completed with results. func (r *WorkQueueRepository) Complete(ctx context.Context, taskID string, result *port.WorkResult) error { var resultJSON []byte var err error if result != nil { resultJSON, err = json.Marshal(result) if err != nil { return fmt.Errorf("marshal result: %w", err) } } res, err := r.db.ExecContext(ctx, ` UPDATE work_queue SET status = 'completed', completed_at = NOW(), result = $1 WHERE id = $2 AND status = 'running' `, resultJSON, taskID) if err != nil { return fmt.Errorf("complete work task: %w", err) } rows, err := res.RowsAffected() if err != nil { return fmt.Errorf("rows affected: %w", err) } if rows == 0 { return domain.ErrWorkTaskNotFound } return nil } // Fail marks a task as failed with an error message. // Uses a single atomic UPDATE to avoid race conditions between SELECT and UPDATE. func (r *WorkQueueRepository) Fail(ctx context.Context, taskID string, errMsg string) error { // Use a single atomic query that handles both retry and permanent failure cases result, err := r.db.ExecContext(ctx, ` UPDATE work_queue SET status = CASE WHEN retry_count < max_retries THEN 'pending' ELSE 'failed' END, worker_id = CASE WHEN retry_count < max_retries THEN NULL ELSE worker_id END, started_at = CASE WHEN retry_count < max_retries THEN NULL ELSE started_at END, completed_at = CASE WHEN retry_count >= max_retries THEN NOW() ELSE completed_at END, retry_count = CASE WHEN retry_count < max_retries THEN retry_count + 1 ELSE retry_count END, error = $1 WHERE id = $2 `, errMsg, taskID) if err != nil { return fmt.Errorf("fail work task: %w", err) } rows, err := result.RowsAffected() if err != nil { return fmt.Errorf("rows affected: %w", err) } if rows == 0 { return domain.ErrWorkTaskNotFound } return nil } // Cancel marks a pending task as cancelled. func (r *WorkQueueRepository) Cancel(ctx context.Context, taskID string) error { result, err := r.db.ExecContext(ctx, ` UPDATE work_queue SET status = 'cancelled', completed_at = NOW() WHERE id = $1 AND status = 'pending' `, taskID) if err != nil { return fmt.Errorf("cancel work task: %w", err) } rows, err := result.RowsAffected() if err != nil { return fmt.Errorf("rows affected: %w", err) } if rows == 0 { // Check if task exists var exists bool err := r.db.QueryRowContext(ctx, `SELECT EXISTS(SELECT 1 FROM work_queue WHERE id = $1)`, taskID).Scan(&exists) if err != nil { return fmt.Errorf("check exists: %w", err) } if !exists { return domain.ErrWorkTaskNotFound } return fmt.Errorf("task is not in pending state") } return nil } // GetTask retrieves a task by ID. func (r *WorkQueueRepository) GetTask(ctx context.Context, taskID string) (*port.WorkTask, error) { var task port.WorkTask var taskType string var specJSON []byte var status string var workerID sql.NullString var callbackURL sql.NullString var startedAt sql.NullTime var completedAt sql.NullTime var resultJSON []byte var errorMsg sql.NullString err := r.db.QueryRowContext(ctx, ` SELECT id, project_id, task_type, task_spec, status, priority, worker_id, callback_url, created_at, started_at, completed_at, result, error, retry_count, max_retries FROM work_queue WHERE id = $1 `, taskID).Scan( &task.ID, &task.ProjectID, &taskType, &specJSON, &status, &task.Priority, &workerID, &callbackURL, &task.CreatedAt, &startedAt, &completedAt, &resultJSON, &errorMsg, &task.RetryCount, &task.MaxRetries, ) if errors.Is(err, sql.ErrNoRows) { return nil, domain.ErrWorkTaskNotFound } if err != nil { return nil, fmt.Errorf("get work task: %w", err) } task.Type = port.WorkTaskType(taskType) task.Status = port.WorkTaskStatus(status) if workerID.Valid { task.WorkerID = workerID.String } if callbackURL.Valid { task.CallbackURL = callbackURL.String } if startedAt.Valid { task.StartedAt = &startedAt.Time } if completedAt.Valid { task.CompletedAt = &completedAt.Time } if errorMsg.Valid { task.Error = errorMsg.String } // Parse task spec if len(specJSON) > 0 { if err := json.Unmarshal(specJSON, &task.Spec); err != nil { return nil, fmt.Errorf("unmarshal task spec: %w", err) } } // Parse result if len(resultJSON) > 0 { task.Result = &port.WorkResult{} if err := json.Unmarshal(resultJSON, task.Result); err != nil { return nil, fmt.Errorf("unmarshal task result: %w", err) } } return &task, nil } // ListByProject returns tasks for a project with optional status filter and pagination. func (r *WorkQueueRepository) ListByProject(ctx context.Context, projectID string, status *port.WorkTaskStatus, opts port.WorkListOptions) (*port.WorkListResult, error) { // Normalize pagination options opts.Normalize() // Build base WHERE clause whereClause := "WHERE project_id = $1" args := []any{projectID} argNum := 2 if status != nil { whereClause += fmt.Sprintf(" AND status = $%d", argNum) args = append(args, string(*status)) argNum++ } // Get total count for pagination metadata countQuery := "SELECT COUNT(*) FROM work_queue " + whereClause var total int64 if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil { return nil, fmt.Errorf("count work tasks: %w", err) } // Build paginated query query := fmt.Sprintf(` SELECT id, project_id, task_type, task_spec, status, priority, worker_id, callback_url, created_at, started_at, completed_at, result, error, retry_count, max_retries FROM work_queue %s ORDER BY created_at DESC LIMIT $%d OFFSET $%d `, whereClause, argNum, argNum+1) args = append(args, opts.Limit, opts.Offset) rows, err := r.db.QueryContext(ctx, query, args...) if err != nil { return nil, fmt.Errorf("list work tasks: %w", err) } defer func() { _ = rows.Close() }() var tasks []*port.WorkTask for rows.Next() { task, err := r.scanTask(rows) if err != nil { return nil, err } tasks = append(tasks, task) } return &port.WorkListResult{ Tasks: tasks, Total: total, Limit: opts.Limit, Offset: opts.Offset, }, nil } // GetStats returns queue statistics. func (r *WorkQueueRepository) GetStats(ctx context.Context) (*port.WorkQueueStats, error) { var stats port.WorkQueueStats err := r.db.QueryRowContext(ctx, ` SELECT COUNT(*) FILTER (WHERE status = 'pending') as pending, COUNT(*) FILTER (WHERE status = 'running') as running, COUNT(*) FILTER (WHERE status = 'completed' AND completed_at > NOW() - INTERVAL '24 hours') as completed, COUNT(*) FILTER (WHERE status = 'failed' AND completed_at > NOW() - INTERVAL '24 hours') as failed, COUNT(*) FILTER (WHERE status = 'cancelled' AND completed_at > NOW() - INTERVAL '24 hours') as cancelled FROM work_queue `).Scan( &stats.Pending, &stats.Running, &stats.Completed, &stats.Failed, &stats.Cancelled, ) if err != nil { return nil, fmt.Errorf("get stats: %w", err) } // Get oldest pending task age var oldestCreatedAt sql.NullTime err = r.db.QueryRowContext(ctx, ` SELECT MIN(created_at) FROM work_queue WHERE status = 'pending' `).Scan(&oldestCreatedAt) if err != nil && !errors.Is(err, sql.ErrNoRows) { return nil, fmt.Errorf("get oldest pending: %w", err) } if oldestCreatedAt.Valid { age := time.Since(oldestCreatedAt.Time) stats.OldestPending = &age } return &stats, nil } // CleanupOld removes completed/failed/cancelled tasks older than the specified duration. func (r *WorkQueueRepository) CleanupOld(ctx context.Context, olderThan time.Duration) (int64, error) { cutoff := time.Now().Add(-olderThan) result, err := r.db.ExecContext(ctx, ` DELETE FROM work_queue WHERE status IN ('completed', 'failed', 'cancelled') AND completed_at < $1 `, cutoff) if err != nil { return 0, fmt.Errorf("cleanup old tasks: %w", err) } return result.RowsAffected() } // RequeueStale re-queues tasks that have been running longer than the timeout. func (r *WorkQueueRepository) RequeueStale(ctx context.Context, timeout time.Duration) (int64, error) { cutoff := time.Now().Add(-timeout) result, err := r.db.ExecContext(ctx, ` UPDATE work_queue SET status = 'pending', worker_id = NULL, started_at = NULL, retry_count = retry_count + 1, error = 'Worker timeout - task requeued' WHERE status = 'running' AND started_at < $1 AND retry_count < max_retries `, cutoff) if err != nil { return 0, fmt.Errorf("requeue stale tasks: %w", err) } return result.RowsAffected() } // scanTask scans a single task row. func (r *WorkQueueRepository) scanTask(rows *sql.Rows) (*port.WorkTask, error) { var task port.WorkTask var taskType string var specJSON []byte var status string var workerID sql.NullString var callbackURL sql.NullString var startedAt sql.NullTime var completedAt sql.NullTime var resultJSON []byte var errorMsg sql.NullString err := rows.Scan( &task.ID, &task.ProjectID, &taskType, &specJSON, &status, &task.Priority, &workerID, &callbackURL, &task.CreatedAt, &startedAt, &completedAt, &resultJSON, &errorMsg, &task.RetryCount, &task.MaxRetries, ) if err != nil { return nil, fmt.Errorf("scan task: %w", err) } task.Type = port.WorkTaskType(taskType) task.Status = port.WorkTaskStatus(status) if workerID.Valid { task.WorkerID = workerID.String } if callbackURL.Valid { task.CallbackURL = callbackURL.String } if startedAt.Valid { task.StartedAt = &startedAt.Time } if completedAt.Valid { task.CompletedAt = &completedAt.Time } if errorMsg.Valid { task.Error = errorMsg.String } // Parse task spec if len(specJSON) > 0 { if err := json.Unmarshal(specJSON, &task.Spec); err != nil { return nil, fmt.Errorf("unmarshal task spec: %w", err) } } // Parse result if len(resultJSON) > 0 { task.Result = &port.WorkResult{} if err := json.Unmarshal(resultJSON, task.Result); err != nil { return nil, fmt.Errorf("unmarshal task result: %w", err) } } return &task, nil }