// Package postgres provides PostgreSQL-based implementations of port interfaces. package postgres import ( "context" "database/sql" "encoding/json" "fmt" "time" "github.com/orchard9/rdev/internal/domain" "github.com/orchard9/rdev/internal/port" ) // WorkerRegistryRepository implements port.WorkerRegistry using PostgreSQL. type WorkerRegistryRepository struct { db *sql.DB } // NewWorkerRegistryRepository creates a new PostgreSQL worker registry. func NewWorkerRegistryRepository(db *sql.DB) *WorkerRegistryRepository { return &WorkerRegistryRepository{db: db} } // Ensure WorkerRegistryRepository implements port.WorkerRegistry at compile time. var _ port.WorkerRegistry = (*WorkerRegistryRepository)(nil) // Register adds a worker to the pool. // If a worker with the same ID already exists, it is re-registered as idle. func (r *WorkerRegistryRepository) Register(ctx context.Context, worker *domain.Worker) error { capsJSON, err := json.Marshal(worker.Capabilities) if err != nil { return fmt.Errorf("marshal capabilities: %w", err) } _, err = r.db.ExecContext(ctx, ` INSERT INTO workers (id, hostname, status, capabilities, version, registered_at, last_heartbeat) VALUES ($1, $2, $3, $4, $5, $6, $6) ON CONFLICT (id) DO UPDATE SET hostname = EXCLUDED.hostname, status = 'idle', current_task = NULL, capabilities = EXCLUDED.capabilities, version = EXCLUDED.version, last_heartbeat = EXCLUDED.last_heartbeat `, worker.ID, worker.Hostname, domain.WorkerStatusIdle, capsJSON, nullString(worker.Version), time.Now()) if err != nil { return fmt.Errorf("register worker: %w", err) } return nil } // Heartbeat updates the worker's last_heartbeat timestamp. func (r *WorkerRegistryRepository) Heartbeat(ctx context.Context, workerID string) error { result, err := r.db.ExecContext(ctx, ` UPDATE workers SET last_heartbeat = NOW() WHERE id = $1 AND status != 'offline' `, workerID) if err != nil { return fmt.Errorf("heartbeat worker: %w", err) } rows, err := result.RowsAffected() if err != nil { return fmt.Errorf("rows affected: %w", err) } if rows == 0 { return domain.ErrWorkerNotFound } return nil } // UpdateStatus changes a worker's status and optionally assigns a task. func (r *WorkerRegistryRepository) UpdateStatus(ctx context.Context, workerID string, status domain.WorkerStatus, taskID string) error { var currentTask sql.NullString if taskID != "" { currentTask = sql.NullString{String: taskID, Valid: true} } result, err := r.db.ExecContext(ctx, ` UPDATE workers SET status = $2, current_task = $3, last_heartbeat = NOW() WHERE id = $1 `, workerID, status, currentTask) if err != nil { return fmt.Errorf("update worker status: %w", err) } rows, err := result.RowsAffected() if err != nil { return fmt.Errorf("rows affected: %w", err) } if rows == 0 { return domain.ErrWorkerNotFound } return nil } // Deregister removes a worker from the pool. func (r *WorkerRegistryRepository) Deregister(ctx context.Context, workerID string) error { result, err := r.db.ExecContext(ctx, `DELETE FROM workers WHERE id = $1`, workerID) if err != nil { return fmt.Errorf("deregister worker: %w", err) } rows, err := result.RowsAffected() if err != nil { return fmt.Errorf("rows affected: %w", err) } if rows == 0 { return domain.ErrWorkerNotFound } return nil } // Get retrieves a specific worker by ID. func (r *WorkerRegistryRepository) Get(ctx context.Context, workerID string) (*domain.Worker, error) { rows, err := r.db.QueryContext(ctx, ` SELECT id, hostname, status, current_task, capabilities, version, registered_at, last_heartbeat FROM workers WHERE id = $1 `, workerID) if err != nil { return nil, fmt.Errorf("get worker: %w", err) } defer func() { _ = rows.Close() }() if !rows.Next() { if err := rows.Err(); err != nil { return nil, fmt.Errorf("get worker: %w", err) } return nil, domain.ErrWorkerNotFound } return r.scanWorker(rows) } // List returns all workers matching the filter. func (r *WorkerRegistryRepository) List(ctx context.Context, filter port.WorkerFilter) ([]*domain.Worker, error) { query := ` SELECT id, hostname, status, current_task, capabilities, version, registered_at, last_heartbeat FROM workers WHERE 1=1` args := []any{} argNum := 1 if filter.Status != nil { query += fmt.Sprintf(" AND status = $%d", argNum) args = append(args, string(*filter.Status)) argNum++ } if filter.HasCapability != "" { query += fmt.Sprintf(" AND capabilities @> $%d::jsonb", argNum) capJSON, _ := json.Marshal([]string{filter.HasCapability}) args = append(args, string(capJSON)) argNum++ } query += " ORDER BY registered_at ASC" if filter.Limit > 0 { query += fmt.Sprintf(" LIMIT $%d", argNum) args = append(args, filter.Limit) } rows, err := r.db.QueryContext(ctx, query, args...) if err != nil { return nil, fmt.Errorf("list workers: %w", err) } defer func() { _ = rows.Close() }() var workers []*domain.Worker for rows.Next() { w, err := r.scanWorker(rows) if err != nil { return nil, err } workers = append(workers, w) } return workers, rows.Err() } // MarkStaleOffline marks workers without a recent heartbeat as offline. func (r *WorkerRegistryRepository) MarkStaleOffline(ctx context.Context, threshold time.Duration) (int, error) { cutoff := time.Now().Add(-threshold) result, err := r.db.ExecContext(ctx, ` UPDATE workers SET status = 'offline', current_task = NULL WHERE status != 'offline' AND last_heartbeat < $1 `, cutoff) if err != nil { return 0, fmt.Errorf("mark stale workers offline: %w", err) } rows, err := result.RowsAffected() if err != nil { return 0, fmt.Errorf("rows affected: %w", err) } return int(rows), nil } // scanWorker scans a single worker row from a query result. func (r *WorkerRegistryRepository) scanWorker(rows *sql.Rows) (*domain.Worker, error) { var w domain.Worker var currentTask sql.NullString var capsJSON []byte var version sql.NullString err := rows.Scan( &w.ID, &w.Hostname, &w.Status, ¤tTask, &capsJSON, &version, &w.RegisteredAt, &w.LastHeartbeat, ) if err != nil { return nil, fmt.Errorf("scan worker: %w", err) } if currentTask.Valid { w.CurrentTask = currentTask.String } if version.Valid { w.Version = version.String } if len(capsJSON) > 0 { if err := json.Unmarshal(capsJSON, &w.Capabilities); err != nil { return nil, fmt.Errorf("unmarshal capabilities: %w", err) } } return &w, nil }