121 lines
3.3 KiB
Go
121 lines
3.3 KiB
Go
package postgres
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/jmoiron/sqlx"
|
|
|
|
"git.threesix.ai/jordan/persona-community-3/services/persona-api/internal/domain"
|
|
"git.threesix.ai/jordan/persona-community-3/services/persona-api/internal/port"
|
|
)
|
|
|
|
// Compile-time interface check.
|
|
var _ port.AuthCodeRepository = (*AuthCodeRepository)(nil)
|
|
|
|
// authCodeRow maps to the auth_codes table.
|
|
type authCodeRow struct {
|
|
ID string `db:"id"`
|
|
UserID *string `db:"user_id"`
|
|
Email string `db:"email"`
|
|
Code string `db:"code"`
|
|
Purpose string `db:"purpose"`
|
|
ExpiresAt time.Time `db:"expires_at"`
|
|
UsedAt *time.Time `db:"used_at"`
|
|
IPAddress string `db:"ip_address"`
|
|
CreatedAt time.Time `db:"created_at"`
|
|
}
|
|
|
|
func (r *authCodeRow) toDomain() *domain.AuthCode {
|
|
ac := &domain.AuthCode{
|
|
ID: r.ID,
|
|
Email: r.Email,
|
|
Code: r.Code,
|
|
Purpose: domain.AuthCodePurpose(r.Purpose),
|
|
ExpiresAt: r.ExpiresAt,
|
|
UsedAt: r.UsedAt,
|
|
IPAddress: r.IPAddress,
|
|
CreatedAt: r.CreatedAt,
|
|
}
|
|
if r.UserID != nil {
|
|
uid := domain.UserID(*r.UserID)
|
|
ac.UserID = &uid
|
|
}
|
|
return ac
|
|
}
|
|
|
|
// AuthCodeRepository implements port.AuthCodeRepository with PostgreSQL/CockroachDB.
|
|
type AuthCodeRepository struct {
|
|
db *sqlx.DB
|
|
}
|
|
|
|
// NewAuthCodeRepository creates a new Postgres-backed auth code repository.
|
|
func NewAuthCodeRepository(db *sqlx.DB) *AuthCodeRepository {
|
|
return &AuthCodeRepository{db: db}
|
|
}
|
|
|
|
func (r *AuthCodeRepository) Create(ctx context.Context, code *domain.AuthCode) error {
|
|
var userID *string
|
|
if code.UserID != nil {
|
|
s := string(*code.UserID)
|
|
userID = &s
|
|
}
|
|
|
|
_, err := r.db.ExecContext(ctx, `
|
|
INSERT INTO auth_codes (id, user_id, email, code, purpose, expires_at, ip_address, created_at)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
|
`, code.ID, userID, code.Email, code.Code, string(code.Purpose),
|
|
code.ExpiresAt, code.IPAddress, code.CreatedAt)
|
|
if err != nil {
|
|
return fmt.Errorf("insert auth code: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *AuthCodeRepository) FindValid(ctx context.Context, email string, code string, purpose domain.AuthCodePurpose) (*domain.AuthCode, error) {
|
|
var row authCodeRow
|
|
err := r.db.GetContext(ctx, &row, `
|
|
SELECT id, user_id, email, code, purpose, expires_at, used_at, ip_address, created_at
|
|
FROM auth_codes
|
|
WHERE email = $1 AND code = $2 AND purpose = $3
|
|
AND used_at IS NULL AND expires_at > NOW()
|
|
ORDER BY created_at DESC
|
|
LIMIT 1
|
|
`, email, code, string(purpose))
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, domain.ErrInvalidAuthCode
|
|
}
|
|
return nil, fmt.Errorf("find valid auth code: %w", err)
|
|
}
|
|
return row.toDomain(), nil
|
|
}
|
|
|
|
func (r *AuthCodeRepository) MarkUsed(ctx context.Context, id string) error {
|
|
_, err := r.db.ExecContext(ctx, `
|
|
UPDATE auth_codes SET used_at = NOW() WHERE id = $1
|
|
`, id)
|
|
if err != nil {
|
|
return fmt.Errorf("mark auth code used: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *AuthCodeRepository) DeleteExpired(ctx context.Context) (int, error) {
|
|
result, err := r.db.ExecContext(ctx, `
|
|
DELETE FROM auth_codes WHERE expires_at < NOW()
|
|
`)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("delete expired auth codes: %w", err)
|
|
}
|
|
|
|
rows, err := result.RowsAffected()
|
|
if err != nil {
|
|
return 0, fmt.Errorf("delete expired rows affected: %w", err)
|
|
}
|
|
return int(rows), nil
|
|
}
|