163 lines
4.8 KiB
Go
163 lines
4.8 KiB
Go
package postgres
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/jmoiron/sqlx"
|
|
|
|
"git.threesix.ai/jordan/persona-community-1/services/persona-api/internal/domain"
|
|
"git.threesix.ai/jordan/persona-community-1/services/persona-api/internal/port"
|
|
)
|
|
|
|
// Compile-time interface check.
|
|
var _ port.SessionRepository = (*SessionRepository)(nil)
|
|
|
|
// sessionRow maps to the sessions table.
|
|
type sessionRow struct {
|
|
ID string `db:"id"`
|
|
UserID string `db:"user_id"`
|
|
IPAddress string `db:"ip_address"`
|
|
UserAgent string `db:"user_agent"`
|
|
DeviceLabel string `db:"device_label"`
|
|
LastActiveAt time.Time `db:"last_active_at"`
|
|
ExpiresAt time.Time `db:"expires_at"`
|
|
RevokedAt *time.Time `db:"revoked_at"`
|
|
CreatedAt time.Time `db:"created_at"`
|
|
}
|
|
|
|
func (r *sessionRow) toDomain() *domain.Session {
|
|
return &domain.Session{
|
|
ID: domain.SessionID(r.ID),
|
|
UserID: domain.UserID(r.UserID),
|
|
IPAddress: r.IPAddress,
|
|
UserAgent: r.UserAgent,
|
|
DeviceLabel: r.DeviceLabel,
|
|
LastActiveAt: r.LastActiveAt,
|
|
ExpiresAt: r.ExpiresAt,
|
|
RevokedAt: r.RevokedAt,
|
|
CreatedAt: r.CreatedAt,
|
|
}
|
|
}
|
|
|
|
// SessionRepository implements port.SessionRepository with PostgreSQL/CockroachDB.
|
|
type SessionRepository struct {
|
|
db *sqlx.DB
|
|
}
|
|
|
|
// NewSessionRepository creates a new Postgres-backed session repository.
|
|
func NewSessionRepository(db *sqlx.DB) *SessionRepository {
|
|
return &SessionRepository{db: db}
|
|
}
|
|
|
|
func (r *SessionRepository) Create(ctx context.Context, session *domain.Session) error {
|
|
_, err := r.db.ExecContext(ctx, `
|
|
INSERT INTO sessions (id, user_id, ip_address, user_agent, device_label, last_active_at, expires_at, created_at)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
|
`, string(session.ID), string(session.UserID), session.IPAddress, session.UserAgent,
|
|
session.DeviceLabel, session.LastActiveAt, session.ExpiresAt, session.CreatedAt)
|
|
if err != nil {
|
|
return fmt.Errorf("insert session: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *SessionRepository) Get(ctx context.Context, id domain.SessionID) (*domain.Session, error) {
|
|
var row sessionRow
|
|
err := r.db.GetContext(ctx, &row, `
|
|
SELECT id, user_id, ip_address, user_agent, device_label, last_active_at, expires_at, revoked_at, created_at
|
|
FROM sessions WHERE id = $1
|
|
`, string(id))
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, domain.ErrSessionNotFound
|
|
}
|
|
return nil, fmt.Errorf("get session: %w", err)
|
|
}
|
|
return row.toDomain(), nil
|
|
}
|
|
|
|
func (r *SessionRepository) ListByUser(ctx context.Context, userID domain.UserID) ([]domain.Session, error) {
|
|
var rows []sessionRow
|
|
err := r.db.SelectContext(ctx, &rows, `
|
|
SELECT id, user_id, ip_address, user_agent, device_label, last_active_at, expires_at, revoked_at, created_at
|
|
FROM sessions
|
|
WHERE user_id = $1 AND revoked_at IS NULL AND expires_at > NOW()
|
|
ORDER BY last_active_at DESC
|
|
`, string(userID))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list sessions: %w", err)
|
|
}
|
|
|
|
sessions := make([]domain.Session, len(rows))
|
|
for i := range rows {
|
|
sessions[i] = *rows[i].toDomain()
|
|
}
|
|
return sessions, nil
|
|
}
|
|
|
|
func (r *SessionRepository) UpdateLastActive(ctx context.Context, id domain.SessionID) error {
|
|
_, err := r.db.ExecContext(ctx, `
|
|
UPDATE sessions SET last_active_at = NOW() WHERE id = $1
|
|
`, string(id))
|
|
if err != nil {
|
|
return fmt.Errorf("update last active: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *SessionRepository) Revoke(ctx context.Context, id domain.SessionID) error {
|
|
result, err := r.db.ExecContext(ctx, `
|
|
UPDATE sessions SET revoked_at = NOW() WHERE id = $1 AND revoked_at IS NULL
|
|
`, string(id))
|
|
if err != nil {
|
|
return fmt.Errorf("revoke session: %w", err)
|
|
}
|
|
|
|
rows, err := result.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("revoke session rows affected: %w", err)
|
|
}
|
|
if rows == 0 {
|
|
return domain.ErrSessionNotFound
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *SessionRepository) RevokeAllForUser(ctx context.Context, userID domain.UserID, exceptID *domain.SessionID) error {
|
|
if exceptID != nil {
|
|
_, err := r.db.ExecContext(ctx, `
|
|
UPDATE sessions SET revoked_at = NOW()
|
|
WHERE user_id = $1 AND revoked_at IS NULL AND id != $2
|
|
`, string(userID), string(*exceptID))
|
|
if err != nil {
|
|
return fmt.Errorf("revoke all sessions except: %w", err)
|
|
}
|
|
} else {
|
|
_, err := r.db.ExecContext(ctx, `
|
|
UPDATE sessions SET revoked_at = NOW()
|
|
WHERE user_id = $1 AND revoked_at IS NULL
|
|
`, string(userID))
|
|
if err != nil {
|
|
return fmt.Errorf("revoke all sessions: %w", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *SessionRepository) DeleteExpired(ctx context.Context) (int, error) {
|
|
result, err := r.db.ExecContext(ctx, `DELETE FROM sessions WHERE expires_at < NOW()`)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("delete expired sessions: %w", err)
|
|
}
|
|
rows, err := result.RowsAffected()
|
|
if err != nil {
|
|
return 0, fmt.Errorf("delete expired sessions rows: %w", err)
|
|
}
|
|
return int(rows), nil
|
|
}
|