// Package postgres provides PostgreSQL-based implementations of port interfaces. package postgres import ( "context" "database/sql" "encoding/json" "errors" "fmt" "strings" "github.com/lib/pq" "github.com/orchard9/rdev/internal/domain" "github.com/orchard9/rdev/internal/port" ) // SagaRepository implements port.SagaRepository using PostgreSQL. type SagaRepository struct { db *sql.DB } // NewSagaRepository creates a new PostgreSQL saga repository. func NewSagaRepository(db *sql.DB) *SagaRepository { return &SagaRepository{db: db} } // Ensure SagaRepository implements port.SagaRepository at compile time. var _ port.SagaRepository = (*SagaRepository)(nil) // ErrSagaNotFound is returned when a saga is not found. var ErrSagaNotFound = errors.New("saga not found") // Create creates a new saga with its steps. func (r *SagaRepository) Create(ctx context.Context, saga *domain.Saga) error { tx, err := r.db.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("begin transaction: %w", err) } defer func() { _ = tx.Rollback() }() varsJSON, err := json.Marshal(saga.Vars) if err != nil { return fmt.Errorf("marshal vars: %w", err) } outputsJSON, err := json.Marshal(saga.Outputs) if err != nil { return fmt.Errorf("marshal outputs: %w", err) } // Insert saga err = tx.QueryRowContext(ctx, ` INSERT INTO sagas ( name, status, definition, vars, outputs, current_step, retry_count, max_retries, error ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING id, created_at, updated_at `, saga.Name, string(saga.Status), nullString(saga.Definition), varsJSON, outputsJSON, nullString(saga.CurrentStep), saga.RetryCount, saga.MaxRetries, nullString(saga.Error), ).Scan(&saga.ID, &saga.CreatedAt, &saga.UpdatedAt) if err != nil { return fmt.Errorf("insert saga: %w", err) } // Insert steps for i := range saga.Steps { step := &saga.Steps[i] step.SagaID = saga.ID retryPolicyJSON, err := json.Marshal(step.RetryPolicy) if err != nil { return fmt.Errorf("marshal retry policy: %w", err) } configJSON, err := json.Marshal(step.Config) if err != nil { return fmt.Errorf("marshal config: %w", err) } err = tx.QueryRowContext(ctx, ` INSERT INTO saga_steps ( saga_id, name, status, action, depends_on, retry_policy, compensate, config, retry_count ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING id `, saga.ID, step.Name, string(step.Status), step.Action, pq.Array(step.DependsOn), retryPolicyJSON, nullString(step.Compensate), configJSON, step.RetryCount, ).Scan(&step.ID) if err != nil { return fmt.Errorf("insert step %s: %w", step.Name, err) } } if err := tx.Commit(); err != nil { return fmt.Errorf("commit transaction: %w", err) } return nil } // Get returns a saga by ID, including all steps. func (r *SagaRepository) Get(ctx context.Context, id string) (*domain.Saga, error) { saga, err := r.getSaga(ctx, id) if err != nil { return nil, err } steps, err := r.getSteps(ctx, id) if err != nil { return nil, err } saga.Steps = steps return saga, nil } // getSaga retrieves just the saga record (no steps). func (r *SagaRepository) getSaga(ctx context.Context, id string) (*domain.Saga, error) { row := r.db.QueryRowContext(ctx, ` SELECT id, name, status, definition, vars, outputs, current_step, retry_count, max_retries, error, created_at, updated_at, completed_at FROM sagas WHERE id = $1 `, id) return r.scanSaga(row) } // getSteps retrieves all steps for a saga. func (r *SagaRepository) getSteps(ctx context.Context, sagaID string) ([]domain.SagaStep, error) { rows, err := r.db.QueryContext(ctx, ` SELECT id, saga_id, name, status, action, depends_on, retry_policy, compensate, config, output, error, retry_count, started_at, completed_at FROM saga_steps WHERE saga_id = $1 ORDER BY id `, sagaID) if err != nil { return nil, fmt.Errorf("query steps: %w", err) } defer func() { _ = rows.Close() }() var steps []domain.SagaStep for rows.Next() { step, err := r.scanStep(rows) if err != nil { return nil, err } steps = append(steps, *step) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("iterate steps: %w", err) } return steps, nil } // Update updates a saga's status and metadata (not steps). func (r *SagaRepository) Update(ctx context.Context, saga *domain.Saga) error { varsJSON, err := json.Marshal(saga.Vars) if err != nil { return fmt.Errorf("marshal vars: %w", err) } outputsJSON, err := json.Marshal(saga.Outputs) if err != nil { return fmt.Errorf("marshal outputs: %w", err) } res, err := r.db.ExecContext(ctx, ` UPDATE sagas SET status = $2, vars = $3, outputs = $4, current_step = $5, retry_count = $6, error = $7, completed_at = $8 WHERE id = $1 `, saga.ID, string(saga.Status), varsJSON, outputsJSON, nullString(saga.CurrentStep), saga.RetryCount, nullString(saga.Error), nullTime(saga.CompletedAt), ) if err != nil { return fmt.Errorf("update saga: %w", err) } rows, err := res.RowsAffected() if err != nil { return fmt.Errorf("rows affected: %w", err) } if rows == 0 { return ErrSagaNotFound } return nil } // UpdateStep updates a single step's status and output. func (r *SagaRepository) UpdateStep(ctx context.Context, step *domain.SagaStep) error { var outputJSON []byte var err error if step.Output != nil { outputJSON, err = json.Marshal(step.Output) if err != nil { return fmt.Errorf("marshal output: %w", err) } } res, err := r.db.ExecContext(ctx, ` UPDATE saga_steps SET status = $2, output = $3, error = $4, retry_count = $5, started_at = $6, completed_at = $7 WHERE id = $1 `, step.ID, string(step.Status), outputJSON, nullString(step.Error), step.RetryCount, nullTime(step.StartedAt), nullTime(step.CompletedAt), ) if err != nil { return fmt.Errorf("update step: %w", err) } rows, err := res.RowsAffected() if err != nil { return fmt.Errorf("rows affected: %w", err) } if rows == 0 { return fmt.Errorf("step %s not found", step.ID) } return nil } // List returns sagas matching the given filters. func (r *SagaRepository) List(ctx context.Context, filters domain.SagaFilters) ([]*domain.Saga, error) { filters.Normalize() query := strings.Builder{} query.WriteString(` SELECT id, name, status, definition, vars, outputs, current_step, retry_count, max_retries, error, created_at, updated_at, completed_at FROM sagas WHERE 1=1 `) args := []any{} argNum := 1 if filters.Name != "" { fmt.Fprintf(&query, " AND name = $%d", argNum) args = append(args, filters.Name) argNum++ } if filters.Status != "" { fmt.Fprintf(&query, " AND status = $%d", argNum) args = append(args, string(filters.Status)) argNum++ } if !filters.Since.IsZero() { fmt.Fprintf(&query, " AND created_at >= $%d", argNum) args = append(args, filters.Since) argNum++ } query.WriteString(" ORDER BY created_at DESC") fmt.Fprintf(&query, " LIMIT $%d", argNum) args = append(args, filters.Limit) rows, err := r.db.QueryContext(ctx, query.String(), args...) if err != nil { return nil, fmt.Errorf("query sagas: %w", err) } defer func() { _ = rows.Close() }() var sagas []*domain.Saga for rows.Next() { saga, err := r.scanSagaRows(rows) if err != nil { return nil, err } sagas = append(sagas, saga) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("iterate sagas: %w", err) } return sagas, nil } // Delete removes a saga and its steps (cascade). func (r *SagaRepository) Delete(ctx context.Context, id string) error { res, err := r.db.ExecContext(ctx, `DELETE FROM sagas WHERE id = $1`, id) if err != nil { return fmt.Errorf("delete saga: %w", err) } rows, err := res.RowsAffected() if err != nil { return fmt.Errorf("rows affected: %w", err) } if rows == 0 { return ErrSagaNotFound } return nil } // GetPendingSteps returns steps ready to execute (no unmet dependencies). func (r *SagaRepository) GetPendingSteps(ctx context.Context, sagaID string) ([]domain.SagaStep, error) { // Get all steps steps, err := r.getSteps(ctx, sagaID) if err != nil { return nil, err } // Build completed steps map (includes skipped steps for dependency resolution) completed := make(map[string]bool) for _, step := range steps { if step.Status == domain.StepStatusCompleted || step.Status == domain.StepStatusSkipped { completed[step.Name] = true } } // Find runnable steps var runnable []domain.SagaStep for _, step := range steps { if step.CanRun(completed) { runnable = append(runnable, step) } } return runnable, nil } // sagaScanner interface abstracts sql.Row and sql.Rows for scanning. type sagaScanner interface { Scan(dest ...any) error } // scanSaga scans a saga from a QueryRow result. func (r *SagaRepository) scanSaga(row *sql.Row) (*domain.Saga, error) { saga, err := r.scanSagaFrom(row) if errors.Is(err, sql.ErrNoRows) { return nil, ErrSagaNotFound } return saga, err } // scanSagaRows scans a saga from a Rows result. func (r *SagaRepository) scanSagaRows(rows *sql.Rows) (*domain.Saga, error) { return r.scanSagaFrom(rows) } // scanSagaFrom scans a saga from any scanner (Row or Rows). func (r *SagaRepository) scanSagaFrom(scanner sagaScanner) (*domain.Saga, error) { var saga domain.Saga var status string var definition, currentStep, sagaError sql.NullString var completedAt sql.NullTime var varsJSON, outputsJSON []byte err := scanner.Scan( &saga.ID, &saga.Name, &status, &definition, &varsJSON, &outputsJSON, ¤tStep, &saga.RetryCount, &saga.MaxRetries, &sagaError, &saga.CreatedAt, &saga.UpdatedAt, &completedAt, ) if err != nil { return nil, fmt.Errorf("scan saga: %w", err) } saga.Status = domain.SagaStatus(status) if definition.Valid { saga.Definition = definition.String } if currentStep.Valid { saga.CurrentStep = currentStep.String } if sagaError.Valid { saga.Error = sagaError.String } if completedAt.Valid { saga.CompletedAt = &completedAt.Time } if len(varsJSON) > 0 { if err := json.Unmarshal(varsJSON, &saga.Vars); err != nil { return nil, fmt.Errorf("unmarshal vars: %w", err) } } if len(outputsJSON) > 0 { if err := json.Unmarshal(outputsJSON, &saga.Outputs); err != nil { return nil, fmt.Errorf("unmarshal outputs: %w", err) } } return &saga, nil } // scanStep scans a step from a Rows result. func (r *SagaRepository) scanStep(rows *sql.Rows) (*domain.SagaStep, error) { var step domain.SagaStep var status string var compensate, stepError sql.NullString var startedAt, completedAt sql.NullTime var retryPolicyJSON, configJSON, outputJSON []byte var dependsOn pq.StringArray err := rows.Scan( &step.ID, &step.SagaID, &step.Name, &status, &step.Action, &dependsOn, &retryPolicyJSON, &compensate, &configJSON, &outputJSON, &stepError, &step.RetryCount, &startedAt, &completedAt, ) if err != nil { return nil, fmt.Errorf("scan step: %w", err) } step.Status = domain.StepStatus(status) step.DependsOn = []string(dependsOn) if compensate.Valid { step.Compensate = compensate.String } if stepError.Valid { step.Error = stepError.String } if startedAt.Valid { step.StartedAt = &startedAt.Time } if completedAt.Valid { step.CompletedAt = &completedAt.Time } if len(retryPolicyJSON) > 0 { if err := json.Unmarshal(retryPolicyJSON, &step.RetryPolicy); err != nil { return nil, fmt.Errorf("unmarshal retry policy: %w", err) } } if len(configJSON) > 0 { if err := json.Unmarshal(configJSON, &step.Config); err != nil { return nil, fmt.Errorf("unmarshal config: %w", err) } } if len(outputJSON) > 0 { if err := json.Unmarshal(outputJSON, &step.Output); err != nil { return nil, fmt.Errorf("unmarshal output: %w", err) } } return &step, nil }