package postgres import ( "context" "database/sql" "fmt" "github.com/orchard9/rdev/internal/domain" "github.com/orchard9/rdev/internal/port" ) // ConversationRepository implements port.ConversationRepository using PostgreSQL. type ConversationRepository struct { db *sql.DB } // NewConversationRepository creates a new PostgreSQL conversation repository. func NewConversationRepository(db *sql.DB) *ConversationRepository { return &ConversationRepository{db: db} } // Ensure ConversationRepository implements port.ConversationRepository at compile time. var _ port.ConversationRepository = (*ConversationRepository)(nil) // CreateConversation creates a new conversation. func (r *ConversationRepository) CreateConversation(ctx context.Context, projectID, title string) (*domain.Conversation, error) { var conv domain.Conversation var lastMessage sql.NullTime err := r.db.QueryRowContext(ctx, ` INSERT INTO conversations (project_id, title) VALUES ($1, $2) RETURNING id, project_id, title, created_at, updated_at, last_message_at `, projectID, title).Scan( &conv.ID, &conv.ProjectID, &conv.Title, &conv.CreatedAt, &conv.UpdatedAt, &lastMessage, ) if err != nil { return nil, fmt.Errorf("create conversation: %w", err) } if lastMessage.Valid { conv.LastMessage = &lastMessage.Time } return &conv, nil } // GetConversation retrieves a conversation by ID. func (r *ConversationRepository) GetConversation(ctx context.Context, id domain.ConversationID) (*domain.Conversation, error) { var conv domain.Conversation var lastMessage sql.NullTime err := r.db.QueryRowContext(ctx, ` SELECT id, project_id, title, created_at, updated_at, last_message_at FROM conversations WHERE id = $1 `, id).Scan( &conv.ID, &conv.ProjectID, &conv.Title, &conv.CreatedAt, &conv.UpdatedAt, &lastMessage, ) if err == sql.ErrNoRows { return nil, domain.ErrConversationNotFound } if err != nil { return nil, fmt.Errorf("get conversation: %w", err) } if lastMessage.Valid { conv.LastMessage = &lastMessage.Time } return &conv, nil } // ListConversations returns all conversations for a project. func (r *ConversationRepository) ListConversations(ctx context.Context, projectID string) ([]*domain.Conversation, error) { rows, err := r.db.QueryContext(ctx, ` SELECT id, project_id, title, created_at, updated_at, last_message_at FROM conversations WHERE project_id = $1 ORDER BY last_message_at DESC NULLS LAST, created_at DESC `, projectID) if err != nil { return nil, fmt.Errorf("list conversations: %w", err) } defer rows.Close() var convs []*domain.Conversation for rows.Next() { var conv domain.Conversation var lastMessage sql.NullTime if err := rows.Scan( &conv.ID, &conv.ProjectID, &conv.Title, &conv.CreatedAt, &conv.UpdatedAt, &lastMessage, ); err != nil { return nil, fmt.Errorf("scan conversation: %w", err) } if lastMessage.Valid { conv.LastMessage = &lastMessage.Time } convs = append(convs, &conv) } return convs, rows.Err() } // UpdateConversationTitle updates the conversation title. func (r *ConversationRepository) UpdateConversationTitle(ctx context.Context, id domain.ConversationID, title string) error { result, err := r.db.ExecContext(ctx, ` UPDATE conversations SET title = $1 WHERE id = $2 `, title, id) if err != nil { return fmt.Errorf("update conversation title: %w", err) } rows, err := result.RowsAffected() if err != nil { return fmt.Errorf("rows affected: %w", err) } if rows == 0 { return domain.ErrConversationNotFound } return nil } // DeleteConversation deletes a conversation and all its messages. func (r *ConversationRepository) DeleteConversation(ctx context.Context, id domain.ConversationID) error { result, err := r.db.ExecContext(ctx, ` DELETE FROM conversations WHERE id = $1 `, id) if err != nil { return fmt.Errorf("delete conversation: %w", err) } rows, err := result.RowsAffected() if err != nil { return fmt.Errorf("rows affected: %w", err) } if rows == 0 { return domain.ErrConversationNotFound } return nil } // AddMessage adds a message to a conversation. func (r *ConversationRepository) AddMessage(ctx context.Context, conversationID domain.ConversationID, role domain.MessageRole, content string) (*domain.Message, error) { var msg domain.Message err := r.db.QueryRowContext(ctx, ` INSERT INTO messages (conversation_id, role, content) VALUES ($1, $2, $3) RETURNING id, conversation_id, role, content, created_at `, conversationID, role, content).Scan( &msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &msg.CreatedAt, ) if err != nil { return nil, fmt.Errorf("add message: %w", err) } return &msg, nil } // GetMessages retrieves all messages for a conversation. func (r *ConversationRepository) GetMessages(ctx context.Context, conversationID domain.ConversationID) ([]*domain.Message, error) { rows, err := r.db.QueryContext(ctx, ` SELECT id, conversation_id, role, content, created_at FROM messages WHERE conversation_id = $1 ORDER BY created_at ASC `, conversationID) if err != nil { return nil, fmt.Errorf("get messages: %w", err) } defer rows.Close() var messages []*domain.Message for rows.Next() { var msg domain.Message if err := rows.Scan( &msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &msg.CreatedAt, ); err != nil { return nil, fmt.Errorf("scan message: %w", err) } messages = append(messages, &msg) } return messages, rows.Err() } // GetConversationWithMessages retrieves a conversation with all messages. func (r *ConversationRepository) GetConversationWithMessages(ctx context.Context, id domain.ConversationID) (*domain.ConversationWithMessages, error) { conv, err := r.GetConversation(ctx, id) if err != nil { return nil, err } messages, err := r.GetMessages(ctx, id) if err != nil { return nil, err } return &domain.ConversationWithMessages{ Conversation: *conv, Messages: messages, }, nil }