// Package postgres provides PostgreSQL-based implementations of port interfaces. package postgres import ( "context" "database/sql" "errors" "fmt" "time" "github.com/lib/pq" "github.com/orchard9/rdev/internal/domain" "github.com/orchard9/rdev/internal/port" ) // APIKeyRepository implements port.APIKeyRepository using PostgreSQL. type APIKeyRepository struct { db *sql.DB } // NewAPIKeyRepository creates a new PostgreSQL API key repository. func NewAPIKeyRepository(db *sql.DB) *APIKeyRepository { return &APIKeyRepository{db: db} } // Ensure APIKeyRepository implements port.APIKeyRepository at compile time. var _ port.APIKeyRepository = (*APIKeyRepository)(nil) // Create stores a new API key. func (r *APIKeyRepository) Create(ctx context.Context, key *domain.APIKey, keyHash string) error { scopeStrings := scopesToStrings(key.Scopes) projectIDStrings := projectIDsToStrings(key.ProjectIDs) var id string err := r.db.QueryRowContext(ctx, ` INSERT INTO api_keys (name, key_hash, key_prefix, scopes, project_ids, allowed_ips, expires_at, created_by) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING id `, key.Name, keyHash, key.KeyPrefix, pq.Array(scopeStrings), pq.Array(projectIDStrings), pq.Array(key.AllowedIPs), key.ExpiresAt, key.CreatedBy).Scan(&id) if err != nil { return fmt.Errorf("insert key: %w", err) } key.ID = domain.APIKeyID(id) key.CreatedAt = time.Now() return nil } // GetByHash retrieves an API key by its hash. func (r *APIKeyRepository) GetByHash(ctx context.Context, keyHash string) (*domain.APIKey, error) { var ( key domain.APIKey id string scopeStrings []string projectIDs []string ) err := r.db.QueryRowContext(ctx, ` SELECT id, name, key_prefix, scopes, project_ids, allowed_ips, created_at, expires_at, last_used_at, revoked_at, created_by FROM api_keys WHERE key_hash = $1 `, keyHash).Scan( &id, &key.Name, &key.KeyPrefix, pq.Array(&scopeStrings), pq.Array(&projectIDs), pq.Array(&key.AllowedIPs), &key.CreatedAt, &key.ExpiresAt, &key.LastUsedAt, &key.RevokedAt, &key.CreatedBy, ) if errors.Is(err, sql.ErrNoRows) { return nil, domain.ErrKeyNotFound } if err != nil { return nil, fmt.Errorf("query key: %w", err) } key.ID = domain.APIKeyID(id) key.Scopes = scopesFromStrings(scopeStrings) key.ProjectIDs = projectIDsFromStrings(projectIDs) return &key, nil } // Get retrieves an API key by ID. func (r *APIKeyRepository) Get(ctx context.Context, id domain.APIKeyID) (*domain.APIKey, error) { var ( key domain.APIKey keyID string scopeStrings []string projectIDs []string ) err := r.db.QueryRowContext(ctx, ` SELECT id, name, key_prefix, scopes, project_ids, allowed_ips, created_at, expires_at, last_used_at, revoked_at, created_by FROM api_keys WHERE id = $1 `, string(id)).Scan( &keyID, &key.Name, &key.KeyPrefix, pq.Array(&scopeStrings), pq.Array(&projectIDs), pq.Array(&key.AllowedIPs), &key.CreatedAt, &key.ExpiresAt, &key.LastUsedAt, &key.RevokedAt, &key.CreatedBy, ) if errors.Is(err, sql.ErrNoRows) { return nil, domain.ErrKeyNotFound } if err != nil { return nil, fmt.Errorf("query key: %w", err) } key.ID = domain.APIKeyID(keyID) key.Scopes = scopesFromStrings(scopeStrings) key.ProjectIDs = projectIDsFromStrings(projectIDs) return &key, nil } // List returns all API keys (without secrets). func (r *APIKeyRepository) List(ctx context.Context) ([]*domain.APIKey, error) { rows, err := r.db.QueryContext(ctx, ` SELECT id, name, key_prefix, scopes, project_ids, allowed_ips, created_at, expires_at, last_used_at, revoked_at, created_by FROM api_keys ORDER BY created_at DESC `) if err != nil { return nil, fmt.Errorf("query keys: %w", err) } defer func() { _ = rows.Close() }() var keys []*domain.APIKey for rows.Next() { var ( key domain.APIKey id string scopeStrings []string projectIDs []string ) if err := rows.Scan( &id, &key.Name, &key.KeyPrefix, pq.Array(&scopeStrings), pq.Array(&projectIDs), pq.Array(&key.AllowedIPs), &key.CreatedAt, &key.ExpiresAt, &key.LastUsedAt, &key.RevokedAt, &key.CreatedBy, ); err != nil { return nil, fmt.Errorf("scan key: %w", err) } key.ID = domain.APIKeyID(id) key.Scopes = scopesFromStrings(scopeStrings) key.ProjectIDs = projectIDsFromStrings(projectIDs) keys = append(keys, &key) } return keys, nil } // Revoke marks an API key as revoked. func (r *APIKeyRepository) Revoke(ctx context.Context, id domain.APIKeyID) error { result, err := r.db.ExecContext(ctx, ` UPDATE api_keys SET revoked_at = NOW() WHERE id = $1 AND revoked_at IS NULL `, string(id)) if err != nil { return fmt.Errorf("revoke key: %w", err) } rows, err := result.RowsAffected() if err != nil { return fmt.Errorf("rows affected: %w", err) } if rows == 0 { return domain.ErrKeyNotFound } return nil } // UpdateLastUsed updates the last used timestamp for a key. func (r *APIKeyRepository) UpdateLastUsed(ctx context.Context, id domain.APIKeyID) error { _, err := r.db.ExecContext(ctx, ` UPDATE api_keys SET last_used_at = NOW() WHERE id = $1 `, string(id)) return err } // Update applies a partial update to an API key. func (r *APIKeyRepository) Update(ctx context.Context, id domain.APIKeyID, update port.APIKeyUpdate) error { // Build SET clauses dynamically based on non-nil fields setClauses := []string{} args := []any{} argIdx := 1 if update.Name != nil { setClauses = append(setClauses, fmt.Sprintf("name = $%d", argIdx)) args = append(args, *update.Name) argIdx++ } if update.Scopes != nil { setClauses = append(setClauses, fmt.Sprintf("scopes = $%d", argIdx)) args = append(args, pq.Array(scopesToStrings(update.Scopes))) argIdx++ } if update.ProjectIDs != nil { setClauses = append(setClauses, fmt.Sprintf("project_ids = $%d", argIdx)) args = append(args, pq.Array(projectIDsToStrings(*update.ProjectIDs))) argIdx++ } if update.AllowedIPs != nil { setClauses = append(setClauses, fmt.Sprintf("allowed_ips = $%d", argIdx)) args = append(args, pq.Array(*update.AllowedIPs)) argIdx++ } if update.ExpiresAt != nil { setClauses = append(setClauses, fmt.Sprintf("expires_at = $%d", argIdx)) args = append(args, *update.ExpiresAt) argIdx++ } if len(setClauses) == 0 { return nil // nothing to update } args = append(args, string(id)) query := fmt.Sprintf("UPDATE api_keys SET %s WHERE id = $%d AND revoked_at IS NULL", joinStrings(setClauses, ", "), argIdx) result, err := r.db.ExecContext(ctx, query, args...) if err != nil { return fmt.Errorf("update key: %w", err) } rows, err := result.RowsAffected() if err != nil { return fmt.Errorf("rows affected: %w", err) } if rows == 0 { return domain.ErrKeyNotFound } return nil } // ListByProjectID returns all active keys that have the given project ID in their project_ids. func (r *APIKeyRepository) ListByProjectID(ctx context.Context, projectID domain.ProjectID) ([]*domain.APIKey, error) { rows, err := r.db.QueryContext(ctx, ` SELECT id, name, key_prefix, scopes, project_ids, allowed_ips, created_at, expires_at, last_used_at, revoked_at, created_by FROM api_keys WHERE $1 = ANY(project_ids) AND revoked_at IS NULL ORDER BY created_at DESC `, string(projectID)) if err != nil { return nil, fmt.Errorf("query keys by project: %w", err) } defer func() { _ = rows.Close() }() var keys []*domain.APIKey for rows.Next() { var ( key domain.APIKey id string scopeStrings []string projectIDs []string ) if err := rows.Scan( &id, &key.Name, &key.KeyPrefix, pq.Array(&scopeStrings), pq.Array(&projectIDs), pq.Array(&key.AllowedIPs), &key.CreatedAt, &key.ExpiresAt, &key.LastUsedAt, &key.RevokedAt, &key.CreatedBy, ); err != nil { return nil, fmt.Errorf("scan key: %w", err) } key.ID = domain.APIKeyID(id) key.Scopes = scopesFromStrings(scopeStrings) key.ProjectIDs = projectIDsFromStrings(projectIDs) keys = append(keys, &key) } return keys, nil } // joinStrings joins string slices with a separator (avoids importing strings in this file). func joinStrings(ss []string, sep string) string { result := "" for i, s := range ss { if i > 0 { result += sep } result += s } return result } // Helper functions for scope conversion func scopesToStrings(scopes []domain.Scope) []string { ss := make([]string, len(scopes)) for i, s := range scopes { ss[i] = string(s) } return ss } func scopesFromStrings(ss []string) []domain.Scope { scopes := make([]domain.Scope, len(ss)) for i, s := range ss { scopes[i] = domain.Scope(s) } return scopes } func projectIDsToStrings(ids []domain.ProjectID) []string { if ids == nil { return nil } ss := make([]string, len(ids)) for i, id := range ids { ss[i] = string(id) } return ss } func projectIDsFromStrings(ss []string) []domain.ProjectID { if ss == nil { return nil } ids := make([]domain.ProjectID, len(ss)) for i, s := range ss { ids[i] = domain.ProjectID(s) } return ids }