package postgres import ( "context" "database/sql" "encoding/json" "errors" "fmt" "time" "github.com/orchard9/rdev/internal/domain" ) // GetTask retrieves a task by ID. func (r *WorkQueueRepository) GetTask(ctx context.Context, taskID string) (*domain.WorkTask, error) { var task domain.WorkTask var taskType string var specJSON []byte var status string var workerID sql.NullString var callbackURL sql.NullString var startedAt sql.NullTime var completedAt sql.NullTime var resultJSON []byte var errorMsg sql.NullString var errorCode sql.NullString err := r.db.QueryRowContext(ctx, ` SELECT id, project_id, task_type, task_spec, status, priority, worker_id, callback_url, created_at, started_at, completed_at, result, error, retry_count, max_retries, error_code FROM work_queue WHERE id = $1 `, taskID).Scan( &task.ID, &task.ProjectID, &taskType, &specJSON, &status, &task.Priority, &workerID, &callbackURL, &task.CreatedAt, &startedAt, &completedAt, &resultJSON, &errorMsg, &task.RetryCount, &task.MaxRetries, &errorCode, ) if errors.Is(err, sql.ErrNoRows) { return nil, domain.ErrWorkTaskNotFound } if err != nil { return nil, fmt.Errorf("get work task: %w", err) } task.Type = domain.WorkTaskType(taskType) task.Status = domain.WorkTaskStatus(status) if workerID.Valid { task.WorkerID = workerID.String } if callbackURL.Valid { task.CallbackURL = callbackURL.String } if startedAt.Valid { task.StartedAt = &startedAt.Time } if completedAt.Valid { task.CompletedAt = &completedAt.Time } if errorMsg.Valid { task.Error = errorMsg.String } if errorCode.Valid { task.ErrorCode = domain.WorkErrorCode(errorCode.String) } // Parse task spec if len(specJSON) > 0 { if err := json.Unmarshal(specJSON, &task.Spec); err != nil { return nil, fmt.Errorf("unmarshal task spec: %w", err) } } // Parse result if len(resultJSON) > 0 { task.Result = &domain.WorkResult{} if err := json.Unmarshal(resultJSON, task.Result); err != nil { return nil, fmt.Errorf("unmarshal task result: %w", err) } } return &task, nil } // ListByProject returns tasks for a project with optional status filter and pagination. func (r *WorkQueueRepository) ListByProject(ctx context.Context, projectID string, status *domain.WorkTaskStatus, opts domain.WorkListOptions) (*domain.WorkListResult, error) { // Normalize pagination options opts.Normalize() // Build base WHERE clause whereClause := "WHERE project_id = $1" args := []any{projectID} argNum := 2 if status != nil { whereClause += fmt.Sprintf(" AND status = $%d", argNum) args = append(args, string(*status)) argNum++ } // Get total count for pagination metadata countQuery := "SELECT COUNT(*) FROM work_queue " + whereClause var total int64 if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil { return nil, fmt.Errorf("count work tasks: %w", err) } // Build paginated query query := fmt.Sprintf(` SELECT id, project_id, task_type, task_spec, status, priority, worker_id, callback_url, created_at, started_at, completed_at, result, error, retry_count, max_retries, error_code FROM work_queue %s ORDER BY created_at DESC LIMIT $%d OFFSET $%d `, whereClause, argNum, argNum+1) args = append(args, opts.Limit, opts.Offset) rows, err := r.db.QueryContext(ctx, query, args...) if err != nil { return nil, fmt.Errorf("list work tasks: %w", err) } defer func() { _ = rows.Close() }() var tasks []*domain.WorkTask for rows.Next() { task, err := r.scanTask(rows) if err != nil { return nil, err } tasks = append(tasks, task) } return &domain.WorkListResult{ Tasks: tasks, Total: total, Limit: opts.Limit, Offset: opts.Offset, }, nil } // GetStats returns queue statistics. func (r *WorkQueueRepository) GetStats(ctx context.Context) (*domain.WorkQueueStats, error) { var stats domain.WorkQueueStats err := r.db.QueryRowContext(ctx, ` SELECT COUNT(*) FILTER (WHERE status = 'pending') as pending, COUNT(*) FILTER (WHERE status = 'running') as running, COUNT(*) FILTER (WHERE status = 'completed' AND completed_at > NOW() - INTERVAL '24 hours') as completed, COUNT(*) FILTER (WHERE status = 'failed' AND completed_at > NOW() - INTERVAL '24 hours') as failed, COUNT(*) FILTER (WHERE status = 'cancelled' AND completed_at > NOW() - INTERVAL '24 hours') as cancelled FROM work_queue `).Scan( &stats.Pending, &stats.Running, &stats.Completed, &stats.Failed, &stats.Cancelled, ) if err != nil { return nil, fmt.Errorf("get stats: %w", err) } // Get oldest pending task age var oldestCreatedAt sql.NullTime err = r.db.QueryRowContext(ctx, ` SELECT MIN(created_at) FROM work_queue WHERE status = 'pending' `).Scan(&oldestCreatedAt) if err != nil && !errors.Is(err, sql.ErrNoRows) { return nil, fmt.Errorf("get oldest pending: %w", err) } if oldestCreatedAt.Valid { age := time.Since(oldestCreatedAt.Time) stats.OldestPending = &age } return &stats, nil } // CleanupOld removes completed/failed/cancelled tasks older than the specified duration. func (r *WorkQueueRepository) CleanupOld(ctx context.Context, olderThan time.Duration) (int64, error) { cutoff := time.Now().Add(-olderThan) result, err := r.db.ExecContext(ctx, ` DELETE FROM work_queue WHERE status IN ('completed', 'failed', 'cancelled') AND completed_at < $1 `, cutoff) if err != nil { return 0, fmt.Errorf("cleanup old tasks: %w", err) } return result.RowsAffected() } // RequeueStale re-queues tasks that have been running longer than the timeout. func (r *WorkQueueRepository) RequeueStale(ctx context.Context, timeout time.Duration) (int64, error) { ids, err := r.RequeueStaleWithIDs(ctx, timeout) if err != nil { return 0, err } return int64(len(ids)), nil } // RequeueStaleWithIDs re-queues stale tasks and returns their IDs. // Tasks that have exceeded max_retries are marked as failed with STALE_WORKER error code. func (r *WorkQueueRepository) RequeueStaleWithIDs(ctx context.Context, timeout time.Duration) ([]string, error) { cutoff := time.Now().Add(-timeout) // First, mark tasks that have exceeded max_retries as permanently failed _, err := r.db.ExecContext(ctx, ` UPDATE work_queue SET status = 'failed', completed_at = NOW(), error = 'Worker timeout - max retries exceeded', error_code = 'STALE_WORKER' WHERE status = 'running' AND started_at < $1 AND retry_count >= max_retries `, cutoff) if err != nil { return nil, fmt.Errorf("fail stale tasks: %w", err) } // Then, requeue tasks that can still be retried rows, err := r.db.QueryContext(ctx, ` UPDATE work_queue SET status = 'pending', worker_id = NULL, started_at = NULL, retry_count = retry_count + 1, error = 'Worker timeout - task requeued', error_code = NULL WHERE status = 'running' AND started_at < $1 AND retry_count < max_retries RETURNING id `, cutoff) if err != nil { return nil, fmt.Errorf("requeue stale tasks: %w", err) } defer func() { _ = rows.Close() }() var ids []string for rows.Next() { var id string if err := rows.Scan(&id); err != nil { return nil, fmt.Errorf("scan requeued task id: %w", err) } ids = append(ids, id) } return ids, rows.Err() } // scanTask scans a single task row. func (r *WorkQueueRepository) scanTask(rows *sql.Rows) (*domain.WorkTask, error) { var task domain.WorkTask var taskType string var specJSON []byte var status string var workerID sql.NullString var callbackURL sql.NullString var startedAt sql.NullTime var completedAt sql.NullTime var resultJSON []byte var errorMsg sql.NullString var errorCode sql.NullString err := rows.Scan( &task.ID, &task.ProjectID, &taskType, &specJSON, &status, &task.Priority, &workerID, &callbackURL, &task.CreatedAt, &startedAt, &completedAt, &resultJSON, &errorMsg, &task.RetryCount, &task.MaxRetries, &errorCode, ) if err != nil { return nil, fmt.Errorf("scan task: %w", err) } task.Type = domain.WorkTaskType(taskType) task.Status = domain.WorkTaskStatus(status) if workerID.Valid { task.WorkerID = workerID.String } if callbackURL.Valid { task.CallbackURL = callbackURL.String } if startedAt.Valid { task.StartedAt = &startedAt.Time } if completedAt.Valid { task.CompletedAt = &completedAt.Time } if errorMsg.Valid { task.Error = errorMsg.String } if errorCode.Valid { task.ErrorCode = domain.WorkErrorCode(errorCode.String) } // Parse task spec if len(specJSON) > 0 { if err := json.Unmarshal(specJSON, &task.Spec); err != nil { return nil, fmt.Errorf("unmarshal task spec: %w", err) } } // Parse result if len(resultJSON) > 0 { task.Result = &domain.WorkResult{} if err := json.Unmarshal(resultJSON, task.Result); err != nil { return nil, fmt.Errorf("unmarshal task result: %w", err) } } return &task, nil }