261 lines
7.1 KiB
Go
261 lines
7.1 KiB
Go
package postgres
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/jmoiron/sqlx"
|
|
"github.com/lib/pq"
|
|
|
|
"git.threesix.ai/jordan/persona-community-3/services/persona-api/internal/domain"
|
|
"git.threesix.ai/jordan/persona-community-3/services/persona-api/internal/port"
|
|
)
|
|
|
|
// Compile-time interface check.
|
|
var _ port.UserRepository = (*UserRepository)(nil)
|
|
|
|
// userRow maps to the users table.
|
|
type userRow struct {
|
|
ID string `db:"id"`
|
|
Email string `db:"email"`
|
|
EmailVerified bool `db:"email_verified"`
|
|
Name string `db:"name"`
|
|
AvatarURL string `db:"avatar_url"`
|
|
Status string `db:"status"`
|
|
LastLoginAt *time.Time `db:"last_login_at"`
|
|
CreatedAt time.Time `db:"created_at"`
|
|
UpdatedAt time.Time `db:"updated_at"`
|
|
}
|
|
|
|
func (r *userRow) toDomain(roles []string) *domain.User {
|
|
return &domain.User{
|
|
ID: domain.UserID(r.ID),
|
|
Email: r.Email,
|
|
EmailVerified: r.EmailVerified,
|
|
Name: r.Name,
|
|
AvatarURL: r.AvatarURL,
|
|
Status: domain.UserStatus(r.Status),
|
|
Roles: roles,
|
|
LastLoginAt: r.LastLoginAt,
|
|
CreatedAt: r.CreatedAt,
|
|
UpdatedAt: r.UpdatedAt,
|
|
}
|
|
}
|
|
|
|
// UserRepository implements port.UserRepository with PostgreSQL/CockroachDB.
|
|
type UserRepository struct {
|
|
db *sqlx.DB
|
|
}
|
|
|
|
// NewUserRepository creates a new Postgres-backed user repository.
|
|
func NewUserRepository(db *sqlx.DB) *UserRepository {
|
|
return &UserRepository{db: db}
|
|
}
|
|
|
|
func (r *UserRepository) Create(ctx context.Context, user *domain.User) error {
|
|
_, err := r.db.ExecContext(ctx, `
|
|
INSERT INTO users (id, email, email_verified, name, avatar_url, status, created_at, updated_at)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
|
`, string(user.ID), user.Email, user.EmailVerified, user.Name, user.AvatarURL,
|
|
string(user.Status), user.CreatedAt, user.UpdatedAt)
|
|
if err != nil {
|
|
if isUniqueViolation(err) {
|
|
return domain.ErrDuplicateEmail
|
|
}
|
|
return fmt.Errorf("insert user: %w", err)
|
|
}
|
|
|
|
// Insert roles
|
|
for _, role := range user.Roles {
|
|
if _, err := r.db.ExecContext(ctx, `
|
|
INSERT INTO user_roles (user_id, role) VALUES ($1, $2)
|
|
ON CONFLICT (user_id, role) DO NOTHING
|
|
`, string(user.ID), role); err != nil {
|
|
return fmt.Errorf("insert role: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *UserRepository) Get(ctx context.Context, id domain.UserID) (*domain.User, error) {
|
|
var row userRow
|
|
err := r.db.GetContext(ctx, &row, `
|
|
SELECT id, email, email_verified, name, avatar_url, status, last_login_at, created_at, updated_at
|
|
FROM users WHERE id = $1
|
|
`, string(id))
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, domain.ErrUserNotFound
|
|
}
|
|
return nil, fmt.Errorf("get user: %w", err)
|
|
}
|
|
|
|
roles, err := r.GetRoles(ctx, id)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return row.toDomain(roles), nil
|
|
}
|
|
|
|
func (r *UserRepository) GetByEmail(ctx context.Context, email string) (*domain.User, error) {
|
|
var row userRow
|
|
err := r.db.GetContext(ctx, &row, `
|
|
SELECT id, email, email_verified, name, avatar_url, status, last_login_at, created_at, updated_at
|
|
FROM users WHERE email = $1
|
|
`, email)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, domain.ErrUserNotFound
|
|
}
|
|
return nil, fmt.Errorf("get user by email: %w", err)
|
|
}
|
|
|
|
roles, err := r.GetRoles(ctx, domain.UserID(row.ID))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return row.toDomain(roles), nil
|
|
}
|
|
|
|
func (r *UserRepository) Update(ctx context.Context, user *domain.User) error {
|
|
result, err := r.db.ExecContext(ctx, `
|
|
UPDATE users
|
|
SET email = $2, email_verified = $3, name = $4, avatar_url = $5,
|
|
status = $6, updated_at = $7
|
|
WHERE id = $1
|
|
`, string(user.ID), user.Email, user.EmailVerified, user.Name,
|
|
user.AvatarURL, string(user.Status), time.Now())
|
|
if err != nil {
|
|
if isUniqueViolation(err) {
|
|
return domain.ErrDuplicateEmail
|
|
}
|
|
return fmt.Errorf("update user: %w", err)
|
|
}
|
|
|
|
rows, err := result.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("update user rows affected: %w", err)
|
|
}
|
|
if rows == 0 {
|
|
return domain.ErrUserNotFound
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *UserRepository) UpdateLastLogin(ctx context.Context, id domain.UserID) error {
|
|
result, err := r.db.ExecContext(ctx, `
|
|
UPDATE users SET last_login_at = NOW() WHERE id = $1
|
|
`, string(id))
|
|
if err != nil {
|
|
return fmt.Errorf("update last login: %w", err)
|
|
}
|
|
|
|
rows, err := result.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("update last login rows affected: %w", err)
|
|
}
|
|
if rows == 0 {
|
|
return domain.ErrUserNotFound
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *UserRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
|
var exists bool
|
|
err := r.db.GetContext(ctx, &exists, `SELECT EXISTS(SELECT 1 FROM users WHERE email = $1)`, email)
|
|
if err != nil {
|
|
return false, fmt.Errorf("exists by email: %w", err)
|
|
}
|
|
return exists, nil
|
|
}
|
|
|
|
func (r *UserRepository) SetPassword(ctx context.Context, userID domain.UserID, hash string) error {
|
|
_, err := r.db.ExecContext(ctx, `
|
|
INSERT INTO user_passwords (user_id, password_hash, updated_at)
|
|
VALUES ($1, $2, NOW())
|
|
ON CONFLICT (user_id) DO UPDATE SET password_hash = $2, updated_at = NOW()
|
|
`, string(userID), hash)
|
|
if err != nil {
|
|
return fmt.Errorf("set password: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *UserRepository) GetPasswordHash(ctx context.Context, userID domain.UserID) (string, error) {
|
|
var hash string
|
|
err := r.db.GetContext(ctx, &hash, `
|
|
SELECT password_hash FROM user_passwords WHERE user_id = $1
|
|
`, string(userID))
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return "", nil
|
|
}
|
|
return "", fmt.Errorf("get password hash: %w", err)
|
|
}
|
|
return hash, nil
|
|
}
|
|
|
|
func (r *UserRepository) HasPassword(ctx context.Context, userID domain.UserID) (bool, error) {
|
|
var exists bool
|
|
err := r.db.GetContext(ctx, &exists, `
|
|
SELECT EXISTS(SELECT 1 FROM user_passwords WHERE user_id = $1)
|
|
`, string(userID))
|
|
if err != nil {
|
|
return false, fmt.Errorf("has password: %w", err)
|
|
}
|
|
return exists, nil
|
|
}
|
|
|
|
func (r *UserRepository) AddRole(ctx context.Context, userID domain.UserID, role string) error {
|
|
_, err := r.db.ExecContext(ctx, `
|
|
INSERT INTO user_roles (user_id, role) VALUES ($1, $2)
|
|
ON CONFLICT (user_id, role) DO NOTHING
|
|
`, string(userID), role)
|
|
if err != nil {
|
|
return fmt.Errorf("add role: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *UserRepository) RemoveRole(ctx context.Context, userID domain.UserID, role string) error {
|
|
_, err := r.db.ExecContext(ctx, `
|
|
DELETE FROM user_roles WHERE user_id = $1 AND role = $2
|
|
`, string(userID), role)
|
|
if err != nil {
|
|
return fmt.Errorf("remove role: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *UserRepository) GetRoles(ctx context.Context, userID domain.UserID) ([]string, error) {
|
|
var roles []string
|
|
err := r.db.SelectContext(ctx, &roles, `
|
|
SELECT role FROM user_roles WHERE user_id = $1 ORDER BY role
|
|
`, string(userID))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get roles: %w", err)
|
|
}
|
|
if roles == nil {
|
|
roles = []string{}
|
|
}
|
|
return roles, nil
|
|
}
|
|
|
|
// isUniqueViolation checks if a database error is a unique constraint violation.
|
|
// Works with both PostgreSQL (23505) and CockroachDB.
|
|
func isUniqueViolation(err error) bool {
|
|
var pqErr *pq.Error
|
|
if errors.As(err, &pqErr) {
|
|
return pqErr.Code == "23505"
|
|
}
|
|
return false
|
|
}
|