// Package postgres provides PostgreSQL implementations of port interfaces. package postgres import ( "context" "database/sql" "errors" "fmt" "strings" "time" "github.com/orchard9/rdev/internal/domain" "github.com/orchard9/rdev/internal/port" ) // Ensure CredentialStore implements port.CredentialStore. var _ port.CredentialStore = (*CredentialStore)(nil) // CredentialStore implements credential storage with encryption. type CredentialStore struct { db *sql.DB encryptionKey string } // NewCredentialStore creates a new credential store. // The encryptionKey is used for pgcrypto symmetric encryption. func NewCredentialStore(db *sql.DB, encryptionKey string) *CredentialStore { return &CredentialStore{ db: db, encryptionKey: encryptionKey, } } // Get retrieves a credential by key. Returns empty string if not found. func (s *CredentialStore) Get(ctx context.Context, key string) (string, error) { var value string err := s.db.QueryRowContext(ctx, ` SELECT pgp_sym_decrypt(value, $1) FROM credentials WHERE key = $2 `, s.encryptionKey, key).Scan(&value) if errors.Is(err, sql.ErrNoRows) { return "", nil } if err != nil { return "", fmt.Errorf("get credential %s: %w", key, err) } return value, nil } // GetRequired retrieves a credential by key. Returns error if not found. func (s *CredentialStore) GetRequired(ctx context.Context, key string) (string, error) { value, err := s.Get(ctx, key) if err != nil { return "", err } if value == "" { return "", fmt.Errorf("credential %s not found", key) } return value, nil } // Set stores or updates a credential. func (s *CredentialStore) Set(ctx context.Context, cred domain.Credential) error { _, err := s.db.ExecContext(ctx, ` INSERT INTO credentials (key, value, description, category, updated_by) VALUES ($1, pgp_sym_encrypt($2, $3), $4, $5, $6) ON CONFLICT (key) DO UPDATE SET value = pgp_sym_encrypt($2, $3), description = COALESCE(NULLIF($4, ''), credentials.description), category = COALESCE(NULLIF($5, ''), credentials.category), updated_by = $6 `, cred.Key, cred.Value, s.encryptionKey, cred.Description, cred.Category, cred.UpdatedBy) if err != nil { return fmt.Errorf("set credential %s: %w", cred.Key, err) } return nil } // Delete removes a credential by key. func (s *CredentialStore) Delete(ctx context.Context, key string) error { result, err := s.db.ExecContext(ctx, `DELETE FROM credentials WHERE key = $1`, key) if err != nil { return fmt.Errorf("delete credential %s: %w", key, err) } rows, _ := result.RowsAffected() if rows == 0 { return fmt.Errorf("credential %s not found", key) } return nil } // List returns all credentials (with values masked). func (s *CredentialStore) List(ctx context.Context) ([]domain.Credential, error) { rows, err := s.db.QueryContext(ctx, ` SELECT key, description, category, created_at, updated_at, COALESCE(updated_by, '') FROM credentials ORDER BY category, key `) if err != nil { return nil, fmt.Errorf("list credentials: %w", err) } defer func() { _ = rows.Close() }() var creds []domain.Credential for rows.Next() { var c domain.Credential var desc, cat sql.NullString if err := rows.Scan(&c.Key, &desc, &cat, &c.CreatedAt, &c.UpdatedAt, &c.UpdatedBy); err != nil { return nil, fmt.Errorf("scan credential: %w", err) } c.Description = desc.String c.Category = cat.String c.Value = "********" // Masked creds = append(creds, c) } return creds, rows.Err() } // ListByCategory returns credentials in a category (with values masked). func (s *CredentialStore) ListByCategory(ctx context.Context, category string) ([]domain.Credential, error) { rows, err := s.db.QueryContext(ctx, ` SELECT key, description, category, created_at, updated_at, COALESCE(updated_by, '') FROM credentials WHERE category = $1 ORDER BY key `, category) if err != nil { return nil, fmt.Errorf("list credentials by category: %w", err) } defer func() { _ = rows.Close() }() var creds []domain.Credential for rows.Next() { var c domain.Credential var desc, cat sql.NullString if err := rows.Scan(&c.Key, &desc, &cat, &c.CreatedAt, &c.UpdatedAt, &c.UpdatedBy); err != nil { return nil, fmt.Errorf("scan credential: %w", err) } c.Description = desc.String c.Category = cat.String c.Value = "********" // Masked creds = append(creds, c) } return creds, rows.Err() } // GetMultiple retrieves multiple credentials by keys. func (s *CredentialStore) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { if len(keys) == 0 { return make(map[string]string), nil } // Build placeholders for IN clause placeholders := make([]string, len(keys)) args := make([]any, len(keys)+1) args[0] = s.encryptionKey for i, key := range keys { placeholders[i] = fmt.Sprintf("$%d", i+2) args[i+1] = key } query := fmt.Sprintf(` SELECT key, pgp_sym_decrypt(value, $1) FROM credentials WHERE key IN (%s) `, strings.Join(placeholders, ",")) rows, err := s.db.QueryContext(ctx, query, args...) if err != nil { return nil, fmt.Errorf("get multiple credentials: %w", err) } defer func() { _ = rows.Close() }() result := make(map[string]string) for rows.Next() { var key, value string if err := rows.Scan(&key, &value); err != nil { return nil, fmt.Errorf("scan credential: %w", err) } result[key] = value } return result, rows.Err() } // SetMultiple stores multiple credentials in a single transaction. func (s *CredentialStore) SetMultiple(ctx context.Context, creds []domain.Credential) error { if len(creds) == 0 { return nil } tx, err := s.db.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("begin transaction: %w", err) } defer func() { _ = tx.Rollback() }() stmt, err := tx.PrepareContext(ctx, ` INSERT INTO credentials (key, value, description, category, updated_by) VALUES ($1, pgp_sym_encrypt($2, $3), $4, $5, $6) ON CONFLICT (key) DO UPDATE SET value = pgp_sym_encrypt($2, $3), description = COALESCE(NULLIF($4, ''), credentials.description), category = COALESCE(NULLIF($5, ''), credentials.category), updated_by = $6 `) if err != nil { return fmt.Errorf("prepare statement: %w", err) } defer func() { _ = stmt.Close() }() now := time.Now() for _, cred := range creds { updatedBy := cred.UpdatedBy if updatedBy == "" { updatedBy = "system" } _, err := stmt.ExecContext(ctx, cred.Key, cred.Value, s.encryptionKey, cred.Description, cred.Category, updatedBy) if err != nil { return fmt.Errorf("set credential %s: %w", cred.Key, err) } _ = now // silence unused } if err := tx.Commit(); err != nil { return fmt.Errorf("commit transaction: %w", err) } return nil }