package postgres import ( "context" "database/sql" "encoding/json" "fmt" "time" "github.com/lib/pq" "github.com/orchard9/rdev/internal/domain" "github.com/orchard9/rdev/internal/port" ) // QuestionRepository implements port.QuestionRepository using PostgreSQL. type QuestionRepository struct { db *sql.DB } // NewQuestionRepository creates a new PostgreSQL question repository. func NewQuestionRepository(db *sql.DB) *QuestionRepository { return &QuestionRepository{db: db} } // Ensure QuestionRepository implements port.QuestionRepository at compile time. var _ port.QuestionRepository = (*QuestionRepository)(nil) // CreateQuestion creates a new question. func (r *QuestionRepository) CreateQuestion(ctx context.Context, question *domain.Question) error { metadataJSON, err := json.Marshal(question.Metadata) if err != nil { return fmt.Errorf("marshal metadata: %w", err) } err = r.db.QueryRowContext(ctx, ` INSERT INTO questions (conversation_id, project_id, question_type, text, choices, metadata) VALUES ($1, $2, $3, $4, $5, $6) RETURNING id, created_at `, question.ConversationID, question.ProjectID, question.Type, question.Text, pq.Array(question.Choices), metadataJSON).Scan( &question.ID, &question.CreatedAt, ) if err != nil { return fmt.Errorf("create question: %w", err) } return nil } // GetQuestion retrieves a question by ID. func (r *QuestionRepository) GetQuestion(ctx context.Context, id domain.QuestionID) (*domain.Question, error) { var question domain.Question var answer sql.NullString var answeredAt sql.NullTime var metadataJSON []byte err := r.db.QueryRowContext(ctx, ` SELECT id, conversation_id, project_id, question_type, text, choices, answer, answer_choices, metadata, created_at, answered_at FROM questions WHERE id = $1 `, id).Scan( &question.ID, &question.ConversationID, &question.ProjectID, &question.Type, &question.Text, pq.Array(&question.Choices), &answer, pq.Array(&question.AnswerChoices), &metadataJSON, &question.CreatedAt, &answeredAt, ) if err == sql.ErrNoRows { return nil, domain.ErrQuestionNotFound } if err != nil { return nil, fmt.Errorf("get question: %w", err) } if answer.Valid { question.Answer = &answer.String } if answeredAt.Valid { question.AnsweredAt = &answeredAt.Time } if len(metadataJSON) > 0 { if err := json.Unmarshal(metadataJSON, &question.Metadata); err != nil { return nil, fmt.Errorf("unmarshal metadata: %w", err) } } return &question, nil } // ListUnansweredQuestions returns all unanswered questions for a project. func (r *QuestionRepository) ListUnansweredQuestions(ctx context.Context, projectID string) ([]*domain.Question, error) { rows, err := r.db.QueryContext(ctx, ` SELECT id, conversation_id, project_id, question_type, text, choices, answer, answer_choices, metadata, created_at, answered_at FROM questions WHERE project_id = $1 AND answered_at IS NULL ORDER BY created_at DESC `, projectID) if err != nil { return nil, fmt.Errorf("list unanswered questions: %w", err) } defer rows.Close() var questions []*domain.Question for rows.Next() { var question domain.Question var answer sql.NullString var answeredAt sql.NullTime var metadataJSON []byte if err := rows.Scan( &question.ID, &question.ConversationID, &question.ProjectID, &question.Type, &question.Text, pq.Array(&question.Choices), &answer, pq.Array(&question.AnswerChoices), &metadataJSON, &question.CreatedAt, &answeredAt, ); err != nil { return nil, fmt.Errorf("scan question: %w", err) } if answer.Valid { question.Answer = &answer.String } if answeredAt.Valid { question.AnsweredAt = &answeredAt.Time } if len(metadataJSON) > 0 { if err := json.Unmarshal(metadataJSON, &question.Metadata); err != nil { return nil, fmt.Errorf("unmarshal metadata: %w", err) } } questions = append(questions, &question) } return questions, rows.Err() } // ListQuestionsByConversation returns all questions for a conversation. func (r *QuestionRepository) ListQuestionsByConversation(ctx context.Context, conversationID domain.ConversationID) ([]*domain.Question, error) { rows, err := r.db.QueryContext(ctx, ` SELECT id, conversation_id, project_id, question_type, text, choices, answer, answer_choices, metadata, created_at, answered_at FROM questions WHERE conversation_id = $1 ORDER BY created_at DESC `, conversationID) if err != nil { return nil, fmt.Errorf("list questions by conversation: %w", err) } defer rows.Close() var questions []*domain.Question for rows.Next() { var question domain.Question var answer sql.NullString var answeredAt sql.NullTime var metadataJSON []byte if err := rows.Scan( &question.ID, &question.ConversationID, &question.ProjectID, &question.Type, &question.Text, pq.Array(&question.Choices), &answer, pq.Array(&question.AnswerChoices), &metadataJSON, &question.CreatedAt, &answeredAt, ); err != nil { return nil, fmt.Errorf("scan question: %w", err) } if answer.Valid { question.Answer = &answer.String } if answeredAt.Valid { question.AnsweredAt = &answeredAt.Time } if len(metadataJSON) > 0 { if err := json.Unmarshal(metadataJSON, &question.Metadata); err != nil { return nil, fmt.Errorf("unmarshal metadata: %w", err) } } questions = append(questions, &question) } return questions, rows.Err() } // AnswerQuestion records an answer to a question. func (r *QuestionRepository) AnswerQuestion(ctx context.Context, id domain.QuestionID, answer *string, answerChoices []string) error { now := time.Now() result, err := r.db.ExecContext(ctx, ` UPDATE questions SET answer = $1, answer_choices = $2, answered_at = $3 WHERE id = $4 `, answer, pq.Array(answerChoices), now, id) if err != nil { return fmt.Errorf("answer question: %w", err) } rows, err := result.RowsAffected() if err != nil { return fmt.Errorf("rows affected: %w", err) } if rows == 0 { return domain.ErrQuestionNotFound } return nil } // DeleteQuestion deletes a question. func (r *QuestionRepository) DeleteQuestion(ctx context.Context, id domain.QuestionID) error { result, err := r.db.ExecContext(ctx, ` DELETE FROM questions WHERE id = $1 `, id) if err != nil { return fmt.Errorf("delete question: %w", err) } rows, err := result.RowsAffected() if err != nil { return fmt.Errorf("rows affected: %w", err) } if rows == 0 { return domain.ErrQuestionNotFound } return nil }