rdev/internal/adapter/postgres/work_queue_queries.go
jordan cfba724f8a feat: add work task error classification and user-facing error codes
- Add WorkErrorCode type with RATE_LIMITED, AUTH_FAILED, TIMEOUT, STALE_WORKER, AGENT_ERROR, INVALID_SPEC
- Add ClassifyAgentError function to detect error patterns from stderr
- Add error_code column to work_queue table (migration 016)
- Add FailWithCode method to WorkQueue interface and implementations
- Update RequeueStaleWithIDs to mark permanently failed tasks with STALE_WORKER
- Add ErrorCode to BuildResult for API responses
- Update work executor to classify errors before failing tasks

This enables users to see actual failure reasons (e.g., "RATE_LIMITED") instead of
builds stuck in "running" state forever when Claude hits rate limits.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 00:07:34 -07:00

342 lines
8.8 KiB
Go

package postgres
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"time"
"github.com/orchard9/rdev/internal/domain"
)
// GetTask retrieves a task by ID.
func (r *WorkQueueRepository) GetTask(ctx context.Context, taskID string) (*domain.WorkTask, error) {
var task domain.WorkTask
var taskType string
var specJSON []byte
var status string
var workerID sql.NullString
var callbackURL sql.NullString
var startedAt sql.NullTime
var completedAt sql.NullTime
var resultJSON []byte
var errorMsg sql.NullString
var errorCode sql.NullString
err := r.db.QueryRowContext(ctx, `
SELECT id, project_id, task_type, task_spec, status, priority, worker_id,
callback_url, created_at, started_at, completed_at, result, error,
retry_count, max_retries, error_code
FROM work_queue
WHERE id = $1
`, taskID).Scan(
&task.ID,
&task.ProjectID,
&taskType,
&specJSON,
&status,
&task.Priority,
&workerID,
&callbackURL,
&task.CreatedAt,
&startedAt,
&completedAt,
&resultJSON,
&errorMsg,
&task.RetryCount,
&task.MaxRetries,
&errorCode,
)
if errors.Is(err, sql.ErrNoRows) {
return nil, domain.ErrWorkTaskNotFound
}
if err != nil {
return nil, fmt.Errorf("get work task: %w", err)
}
task.Type = domain.WorkTaskType(taskType)
task.Status = domain.WorkTaskStatus(status)
if workerID.Valid {
task.WorkerID = workerID.String
}
if callbackURL.Valid {
task.CallbackURL = callbackURL.String
}
if startedAt.Valid {
task.StartedAt = &startedAt.Time
}
if completedAt.Valid {
task.CompletedAt = &completedAt.Time
}
if errorMsg.Valid {
task.Error = errorMsg.String
}
if errorCode.Valid {
task.ErrorCode = domain.WorkErrorCode(errorCode.String)
}
// Parse task spec
if len(specJSON) > 0 {
if err := json.Unmarshal(specJSON, &task.Spec); err != nil {
return nil, fmt.Errorf("unmarshal task spec: %w", err)
}
}
// Parse result
if len(resultJSON) > 0 {
task.Result = &domain.WorkResult{}
if err := json.Unmarshal(resultJSON, task.Result); err != nil {
return nil, fmt.Errorf("unmarshal task result: %w", err)
}
}
return &task, nil
}
// ListByProject returns tasks for a project with optional status filter and pagination.
func (r *WorkQueueRepository) ListByProject(ctx context.Context, projectID string, status *domain.WorkTaskStatus, opts domain.WorkListOptions) (*domain.WorkListResult, error) {
// Normalize pagination options
opts.Normalize()
// Build base WHERE clause
whereClause := "WHERE project_id = $1"
args := []any{projectID}
argNum := 2
if status != nil {
whereClause += fmt.Sprintf(" AND status = $%d", argNum)
args = append(args, string(*status))
argNum++
}
// Get total count for pagination metadata
countQuery := "SELECT COUNT(*) FROM work_queue " + whereClause
var total int64
if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
return nil, fmt.Errorf("count work tasks: %w", err)
}
// Build paginated query
query := fmt.Sprintf(`
SELECT id, project_id, task_type, task_spec, status, priority, worker_id,
callback_url, created_at, started_at, completed_at, result, error,
retry_count, max_retries, error_code
FROM work_queue
%s
ORDER BY created_at DESC
LIMIT $%d OFFSET $%d
`, whereClause, argNum, argNum+1)
args = append(args, opts.Limit, opts.Offset)
rows, err := r.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("list work tasks: %w", err)
}
defer func() { _ = rows.Close() }()
var tasks []*domain.WorkTask
for rows.Next() {
task, err := r.scanTask(rows)
if err != nil {
return nil, err
}
tasks = append(tasks, task)
}
return &domain.WorkListResult{
Tasks: tasks,
Total: total,
Limit: opts.Limit,
Offset: opts.Offset,
}, nil
}
// GetStats returns queue statistics.
func (r *WorkQueueRepository) GetStats(ctx context.Context) (*domain.WorkQueueStats, error) {
var stats domain.WorkQueueStats
err := r.db.QueryRowContext(ctx, `
SELECT
COUNT(*) FILTER (WHERE status = 'pending') as pending,
COUNT(*) FILTER (WHERE status = 'running') as running,
COUNT(*) FILTER (WHERE status = 'completed' AND completed_at > NOW() - INTERVAL '24 hours') as completed,
COUNT(*) FILTER (WHERE status = 'failed' AND completed_at > NOW() - INTERVAL '24 hours') as failed,
COUNT(*) FILTER (WHERE status = 'cancelled' AND completed_at > NOW() - INTERVAL '24 hours') as cancelled
FROM work_queue
`).Scan(
&stats.Pending,
&stats.Running,
&stats.Completed,
&stats.Failed,
&stats.Cancelled,
)
if err != nil {
return nil, fmt.Errorf("get stats: %w", err)
}
// Get oldest pending task age
var oldestCreatedAt sql.NullTime
err = r.db.QueryRowContext(ctx, `
SELECT MIN(created_at) FROM work_queue WHERE status = 'pending'
`).Scan(&oldestCreatedAt)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return nil, fmt.Errorf("get oldest pending: %w", err)
}
if oldestCreatedAt.Valid {
age := time.Since(oldestCreatedAt.Time)
stats.OldestPending = &age
}
return &stats, nil
}
// CleanupOld removes completed/failed/cancelled tasks older than the specified duration.
func (r *WorkQueueRepository) CleanupOld(ctx context.Context, olderThan time.Duration) (int64, error) {
cutoff := time.Now().Add(-olderThan)
result, err := r.db.ExecContext(ctx, `
DELETE FROM work_queue
WHERE status IN ('completed', 'failed', 'cancelled')
AND completed_at < $1
`, cutoff)
if err != nil {
return 0, fmt.Errorf("cleanup old tasks: %w", err)
}
return result.RowsAffected()
}
// RequeueStale re-queues tasks that have been running longer than the timeout.
func (r *WorkQueueRepository) RequeueStale(ctx context.Context, timeout time.Duration) (int64, error) {
ids, err := r.RequeueStaleWithIDs(ctx, timeout)
if err != nil {
return 0, err
}
return int64(len(ids)), nil
}
// RequeueStaleWithIDs re-queues stale tasks and returns their IDs.
// Tasks that have exceeded max_retries are marked as failed with STALE_WORKER error code.
func (r *WorkQueueRepository) RequeueStaleWithIDs(ctx context.Context, timeout time.Duration) ([]string, error) {
cutoff := time.Now().Add(-timeout)
// First, mark tasks that have exceeded max_retries as permanently failed
_, err := r.db.ExecContext(ctx, `
UPDATE work_queue
SET status = 'failed', completed_at = NOW(),
error = 'Worker timeout - max retries exceeded',
error_code = 'STALE_WORKER'
WHERE status = 'running'
AND started_at < $1
AND retry_count >= max_retries
`, cutoff)
if err != nil {
return nil, fmt.Errorf("fail stale tasks: %w", err)
}
// Then, requeue tasks that can still be retried
rows, err := r.db.QueryContext(ctx, `
UPDATE work_queue
SET status = 'pending', worker_id = NULL, started_at = NULL,
retry_count = retry_count + 1, error = 'Worker timeout - task requeued',
error_code = NULL
WHERE status = 'running'
AND started_at < $1
AND retry_count < max_retries
RETURNING id
`, cutoff)
if err != nil {
return nil, fmt.Errorf("requeue stale tasks: %w", err)
}
defer func() { _ = rows.Close() }()
var ids []string
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return nil, fmt.Errorf("scan requeued task id: %w", err)
}
ids = append(ids, id)
}
return ids, rows.Err()
}
// scanTask scans a single task row.
func (r *WorkQueueRepository) scanTask(rows *sql.Rows) (*domain.WorkTask, error) {
var task domain.WorkTask
var taskType string
var specJSON []byte
var status string
var workerID sql.NullString
var callbackURL sql.NullString
var startedAt sql.NullTime
var completedAt sql.NullTime
var resultJSON []byte
var errorMsg sql.NullString
var errorCode sql.NullString
err := rows.Scan(
&task.ID,
&task.ProjectID,
&taskType,
&specJSON,
&status,
&task.Priority,
&workerID,
&callbackURL,
&task.CreatedAt,
&startedAt,
&completedAt,
&resultJSON,
&errorMsg,
&task.RetryCount,
&task.MaxRetries,
&errorCode,
)
if err != nil {
return nil, fmt.Errorf("scan task: %w", err)
}
task.Type = domain.WorkTaskType(taskType)
task.Status = domain.WorkTaskStatus(status)
if workerID.Valid {
task.WorkerID = workerID.String
}
if callbackURL.Valid {
task.CallbackURL = callbackURL.String
}
if startedAt.Valid {
task.StartedAt = &startedAt.Time
}
if completedAt.Valid {
task.CompletedAt = &completedAt.Time
}
if errorMsg.Valid {
task.Error = errorMsg.String
}
if errorCode.Valid {
task.ErrorCode = domain.WorkErrorCode(errorCode.String)
}
// Parse task spec
if len(specJSON) > 0 {
if err := json.Unmarshal(specJSON, &task.Spec); err != nil {
return nil, fmt.Errorf("unmarshal task spec: %w", err)
}
}
// Parse result
if len(resultJSON) > 0 {
task.Result = &domain.WorkResult{}
if err := json.Unmarshal(resultJSON, task.Result); err != nil {
return nil, fmt.Errorf("unmarshal task result: %w", err)
}
}
return &task, nil
}