// Package postgres provides PostgreSQL-based implementations of port interfaces. package postgres import ( "context" "database/sql" "errors" "fmt" "log/slog" "time" "github.com/orchard9/rdev/internal/domain" "github.com/orchard9/rdev/internal/port" ) // RateLimiter implements port.RateLimiter using PostgreSQL. type RateLimiter struct { db *sql.DB logger *slog.Logger } // NewRateLimiter creates a new PostgreSQL rate limiter. func NewRateLimiter(db *sql.DB) *RateLimiter { return &RateLimiter{db: db, logger: slog.Default()} } // WithLogger sets a custom logger for the rate limiter. func (r *RateLimiter) WithLogger(logger *slog.Logger) *RateLimiter { if logger != nil { r.logger = logger } return r } // Ensure RateLimiter implements port.RateLimiter at compile time. var _ port.RateLimiter = (*RateLimiter)(nil) // CheckLimit checks if a request is allowed under the rate limit. func (r *RateLimiter) CheckLimit(ctx context.Context, keyID string) (*domain.RateLimitResult, error) { now := time.Now() minuteWindow := domain.TruncateToMinute(now) hourWindow := domain.TruncateToHour(now) // Get rate limits for this key limits, err := r.GetLimits(ctx, keyID) if err != nil { return nil, fmt.Errorf("get limits: %w", err) } // Get current usage for minute window minuteCount, err := r.getWindowCount(ctx, keyID, minuteWindow, domain.WindowTypeMinute) if err != nil { return nil, fmt.Errorf("get minute count: %w", err) } // Get current usage for hour window hourCount, err := r.getWindowCount(ctx, keyID, hourWindow, domain.WindowTypeHour) if err != nil { return nil, fmt.Errorf("get hour count: %w", err) } result := &domain.RateLimitResult{ LimitMinute: limits.PerMinute, LimitHour: limits.PerHour, RemainingMinute: limits.PerMinute - minuteCount, RemainingHour: limits.PerHour - hourCount, ResetMinute: minuteWindow.Add(time.Minute), ResetHour: hourWindow.Add(time.Hour), } // Ensure remaining doesn't go negative if result.RemainingMinute < 0 { result.RemainingMinute = 0 } if result.RemainingHour < 0 { result.RemainingHour = 0 } // Check if either limit is exceeded if minuteCount >= limits.PerMinute { result.Allowed = false result.RetryAfter = time.Until(result.ResetMinute) if result.RetryAfter < 0 { result.RetryAfter = time.Second } return result, nil } if hourCount >= limits.PerHour { result.Allowed = false result.RetryAfter = time.Until(result.ResetHour) if result.RetryAfter < 0 { result.RetryAfter = time.Second } return result, nil } result.Allowed = true return result, nil } // RecordRequest records that a request was made for the given API key. func (r *RateLimiter) RecordRequest(ctx context.Context, keyID string) error { now := time.Now() minuteWindow := domain.TruncateToMinute(now) hourWindow := domain.TruncateToHour(now) // Use a transaction to update both windows atomically tx, err := r.db.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("begin tx: %w", err) } defer func() { _ = tx.Rollback() }() // Upsert minute window if err := r.upsertWindow(ctx, tx, keyID, minuteWindow, domain.WindowTypeMinute); err != nil { return fmt.Errorf("upsert minute window: %w", err) } // Upsert hour window if err := r.upsertWindow(ctx, tx, keyID, hourWindow, domain.WindowTypeHour); err != nil { return fmt.Errorf("upsert hour window: %w", err) } if err := tx.Commit(); err != nil { return fmt.Errorf("commit: %w", err) } return nil } // GetLimits retrieves the rate limit configuration for an API key. func (r *RateLimiter) GetLimits(ctx context.Context, keyID string) (*domain.RateLimitConfig, error) { var perMinute, perHour sql.NullInt64 err := r.db.QueryRowContext(ctx, ` SELECT rate_limit_per_minute, rate_limit_per_hour FROM api_keys WHERE id = $1 `, keyID).Scan(&perMinute, &perHour) if errors.Is(err, sql.ErrNoRows) { // Key not found, return defaults defaults := domain.DefaultRateLimitConfig() return &defaults, nil } if err != nil { return nil, fmt.Errorf("query limits: %w", err) } config := domain.DefaultRateLimitConfig() if perMinute.Valid { config.PerMinute = int(perMinute.Int64) } if perHour.Valid { config.PerHour = int(perHour.Int64) } return &config, nil } // Cleanup removes expired rate limit state entries. func (r *RateLimiter) Cleanup(ctx context.Context) error { // Remove entries older than 2 hours (well past any active window) cutoff := time.Now().Add(-2 * time.Hour) result, err := r.db.ExecContext(ctx, ` DELETE FROM rate_limit_state WHERE window_start < $1 `, cutoff) if err != nil { return fmt.Errorf("delete old entries: %w", err) } rows, _ := result.RowsAffected() if rows > 0 { // Log cleanup (optional, could use structured logging) _ = rows } return nil } // getWindowCount returns the request count for a specific window. func (r *RateLimiter) getWindowCount(ctx context.Context, keyID string, windowStart time.Time, windowType string) (int, error) { var count int err := r.db.QueryRowContext(ctx, ` SELECT COALESCE(request_count, 0) FROM rate_limit_state WHERE api_key_id = $1 AND window_start = $2 AND window_type = $3 `, keyID, windowStart, windowType).Scan(&count) if errors.Is(err, sql.ErrNoRows) { return 0, nil } if err != nil { return 0, fmt.Errorf("query count: %w", err) } return count, nil } // upsertWindow inserts or updates a rate limit window. func (r *RateLimiter) upsertWindow(ctx context.Context, tx *sql.Tx, keyID string, windowStart time.Time, windowType string) error { _, err := tx.ExecContext(ctx, ` INSERT INTO rate_limit_state (api_key_id, window_start, window_type, request_count, updated_at) VALUES ($1, $2, $3, 1, NOW()) ON CONFLICT (api_key_id, window_start, window_type) DO UPDATE SET request_count = rate_limit_state.request_count + 1, updated_at = NOW() `, keyID, windowStart, windowType) if err != nil { return fmt.Errorf("upsert: %w", err) } return nil } // StartCleanupWorker starts a background goroutine that periodically cleans up expired entries. // Returns a stop function to terminate the worker. func (r *RateLimiter) StartCleanupWorker(ctx context.Context, interval time.Duration) func() { stopCh := make(chan struct{}) go func() { ticker := time.NewTicker(interval) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-stopCh: return case <-ticker.C: if err := r.Cleanup(ctx); err != nil { r.logger.Error("rate limit cleanup failed", "error", err) } } } }() return func() { close(stopCh) } }