// Package postgres provides PostgreSQL-based implementations of port interfaces. package postgres import ( "context" "database/sql" "errors" "fmt" "github.com/orchard9/rdev/internal/domain" "github.com/orchard9/rdev/internal/port" ) // SessionRepository implements port.SessionRepository using PostgreSQL. type SessionRepository struct { db *sql.DB } // NewSessionRepository creates a new PostgreSQL session repository. func NewSessionRepository(db *sql.DB) *SessionRepository { return &SessionRepository{db: db} } // Ensure SessionRepository implements port.SessionRepository at compile time. var _ port.SessionRepository = (*SessionRepository)(nil) // Create stores a new session record. func (r *SessionRepository) Create(ctx context.Context, session *domain.Session) error { var id string err := r.db.QueryRowContext(ctx, ` INSERT INTO sessions ( project_id, checkout_id, pod_name, preview_url, preview_host, created_by, created_at, expires_at, status, last_activity_at ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id `, string(session.ProjectID), string(session.CheckoutID), session.PodName, session.PreviewURL, session.PreviewHost, session.CreatedBy, session.CreatedAt, session.ExpiresAt, string(session.Status), session.LastActivityAt, ).Scan(&id) if err != nil { if isUniqueViolation(err) { return domain.ErrSessionExists } return fmt.Errorf("insert session: %w", err) } session.ID = domain.SessionID(id) return nil } // SetClaudeSessionID stores the Claude Code session ID and conversation record ID on a session. func (r *SessionRepository) SetClaudeSessionID(ctx context.Context, id domain.SessionID, claudeSessionID, conversationRecordID string) error { result, err := r.db.ExecContext(ctx, `UPDATE sessions SET claude_session_id = $1, conversation_record_id = $2 WHERE id = $3`, claudeSessionID, nullableUUID(conversationRecordID), string(id)) if err != nil { return fmt.Errorf("set claude session id: %w", err) } rows, err := result.RowsAffected() if err != nil { return fmt.Errorf("rows affected: %w", err) } if rows == 0 { return domain.ErrSessionNotFound } return nil } // nullableUUID returns nil for empty string (for nullable UUID columns). func nullableUUID(s string) any { if s == "" { return nil } return s } // Get retrieves a session by ID. func (r *SessionRepository) Get(ctx context.Context, id domain.SessionID) (*domain.Session, error) { session, err := r.scanSession(r.db.QueryRowContext(ctx, ` SELECT id, project_id, checkout_id, pod_name, preview_url, preview_host, created_by, created_at, expires_at, status, last_activity_at, ended_at, COALESCE(claude_session_id, ''), COALESCE(conversation_record_id::text, '') FROM sessions WHERE id = $1 `, string(id))) if errors.Is(err, sql.ErrNoRows) { return nil, domain.ErrSessionNotFound } if err != nil { return nil, fmt.Errorf("query session: %w", err) } return session, nil } // GetActiveByProject retrieves the active session for a project. func (r *SessionRepository) GetActiveByProject(ctx context.Context, projectID domain.ProjectID) (*domain.Session, error) { session, err := r.scanSession(r.db.QueryRowContext(ctx, ` SELECT id, project_id, checkout_id, pod_name, preview_url, preview_host, created_by, created_at, expires_at, status, last_activity_at, ended_at, COALESCE(claude_session_id, ''), COALESCE(conversation_record_id::text, '') FROM sessions WHERE project_id = $1 AND status = 'active' `, string(projectID))) if errors.Is(err, sql.ErrNoRows) { return nil, domain.ErrSessionNotFound } if err != nil { return nil, fmt.Errorf("query active session: %w", err) } return session, nil } // ListByProject returns all sessions for a project. func (r *SessionRepository) ListByProject(ctx context.Context, projectID domain.ProjectID) ([]*domain.Session, error) { rows, err := r.db.QueryContext(ctx, ` SELECT id, project_id, checkout_id, pod_name, preview_url, preview_host, created_by, created_at, expires_at, status, last_activity_at, ended_at, COALESCE(claude_session_id, ''), COALESCE(conversation_record_id::text, '') FROM sessions WHERE project_id = $1 ORDER BY created_at DESC `, string(projectID)) if err != nil { return nil, fmt.Errorf("query sessions by project: %w", err) } defer func() { _ = rows.Close() }() return r.scanSessions(rows) } // SetEnded marks a session as ended with a timestamp. func (r *SessionRepository) SetEnded(ctx context.Context, id domain.SessionID) error { result, err := r.db.ExecContext(ctx, ` UPDATE sessions SET status = 'ended', ended_at = NOW() WHERE id = $1 AND status = 'active' `, string(id)) if err != nil { return fmt.Errorf("set session ended: %w", err) } rows, err := result.RowsAffected() if err != nil { return fmt.Errorf("rows affected: %w", err) } if rows == 0 { return domain.ErrSessionNotActive } return nil } // TouchActivity updates the last_activity_at timestamp for an active session. func (r *SessionRepository) TouchActivity(ctx context.Context, id domain.SessionID) error { result, err := r.db.ExecContext(ctx, ` UPDATE sessions SET last_activity_at = NOW() WHERE id = $1 AND status = 'active' `, string(id)) if err != nil { return fmt.Errorf("touch session activity: %w", err) } rows, err := result.RowsAffected() if err != nil { return fmt.Errorf("rows affected: %w", err) } if rows == 0 { return domain.ErrSessionNotActive } return nil } // CleanupExpired marks expired sessions and returns them for preview teardown. func (r *SessionRepository) CleanupExpired(ctx context.Context) ([]*domain.Session, error) { rows, err := r.db.QueryContext(ctx, ` UPDATE sessions SET status = 'expired', ended_at = NOW() WHERE status = 'active' AND expires_at < NOW() AND last_activity_at < NOW() - INTERVAL '30 minutes' RETURNING id, project_id, checkout_id, pod_name, preview_url, preview_host, created_by, created_at, expires_at, status, last_activity_at, ended_at, COALESCE(claude_session_id, ''), COALESCE(conversation_record_id::text, '') `) if err != nil { return nil, fmt.Errorf("cleanup expired sessions: %w", err) } defer func() { _ = rows.Close() }() return r.scanSessions(rows) } // sessionScanner is an interface for scanning session rows. type sessionScanner interface { Scan(dest ...any) error } // scanSessionFields scans session fields from a scanner into a Session struct. func (r *SessionRepository) scanSessionFields(scanner sessionScanner) (*domain.Session, error) { var ( session domain.Session id string projectID string checkoutID string status string endedAt sql.NullTime claudeSessionID string conversationRecordID string ) err := scanner.Scan( &id, &projectID, &checkoutID, &session.PodName, &session.PreviewURL, &session.PreviewHost, &session.CreatedBy, &session.CreatedAt, &session.ExpiresAt, &status, &session.LastActivityAt, &endedAt, &claudeSessionID, &conversationRecordID, ) if err != nil { return nil, err } session.ID = domain.SessionID(id) session.ProjectID = domain.ProjectID(projectID) session.CheckoutID = domain.CheckoutID(checkoutID) session.Status = domain.SessionStatus(status) session.ClaudeSessionID = claudeSessionID session.ConversationRecordID = conversationRecordID if endedAt.Valid { session.EndedAt = &endedAt.Time } return &session, nil } // scanSession scans a single row into a Session struct. func (r *SessionRepository) scanSession(row *sql.Row) (*domain.Session, error) { return r.scanSessionFields(row) } // scanSessions scans multiple rows into Session structs. func (r *SessionRepository) scanSessions(rows *sql.Rows) ([]*domain.Session, error) { var sessions []*domain.Session for rows.Next() { session, err := r.scanSessionFields(rows) if err != nil { return nil, fmt.Errorf("scan session: %w", err) } sessions = append(sessions, session) } return sessions, rows.Err() }