rdev/internal/adapter/postgres/rate_limiter.go
jordan bc47e426b0 feat: Add CI pipeline proxy, DNS alias management, and worker executor system
- Add ListPipelines/GetPipeline to CIProvider port with Woodpecker adapter
- Add DNS alias endpoints: GET/POST/DELETE /projects/{id}/domains
- Implement worker executor daemon, build executor, and git operations
- Add build service, worker service, and build audit tracking
- Add worker registry with PostgreSQL adapter and migration
- Add multi-provider code agent interface (Claude Code + OpenCode)
- Add create-and-build combo endpoint
- Update landing-page cookbook to reflect all gaps closed
- Fix tech debt: unified validation, auth scopes, error wrapping, slog patterns

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 21:05:28 -07:00

246 lines
6.5 KiB
Go

// Package postgres provides PostgreSQL-based implementations of port interfaces.
package postgres
import (
"context"
"database/sql"
"errors"
"fmt"
"log/slog"
"time"
"github.com/orchard9/rdev/internal/domain"
"github.com/orchard9/rdev/internal/port"
)
// RateLimiter implements port.RateLimiter using PostgreSQL.
type RateLimiter struct {
db *sql.DB
logger *slog.Logger
}
// NewRateLimiter creates a new PostgreSQL rate limiter.
func NewRateLimiter(db *sql.DB) *RateLimiter {
return &RateLimiter{db: db, logger: slog.Default()}
}
// WithLogger sets a custom logger for the rate limiter.
func (r *RateLimiter) WithLogger(logger *slog.Logger) *RateLimiter {
if logger != nil {
r.logger = logger
}
return r
}
// Ensure RateLimiter implements port.RateLimiter at compile time.
var _ port.RateLimiter = (*RateLimiter)(nil)
// CheckLimit checks if a request is allowed under the rate limit.
func (r *RateLimiter) CheckLimit(ctx context.Context, keyID string) (*domain.RateLimitResult, error) {
now := time.Now()
minuteWindow := domain.TruncateToMinute(now)
hourWindow := domain.TruncateToHour(now)
// Get rate limits for this key
limits, err := r.GetLimits(ctx, keyID)
if err != nil {
return nil, fmt.Errorf("get limits: %w", err)
}
// Get current usage for minute window
minuteCount, err := r.getWindowCount(ctx, keyID, minuteWindow, domain.WindowTypeMinute)
if err != nil {
return nil, fmt.Errorf("get minute count: %w", err)
}
// Get current usage for hour window
hourCount, err := r.getWindowCount(ctx, keyID, hourWindow, domain.WindowTypeHour)
if err != nil {
return nil, fmt.Errorf("get hour count: %w", err)
}
result := &domain.RateLimitResult{
LimitMinute: limits.PerMinute,
LimitHour: limits.PerHour,
RemainingMinute: limits.PerMinute - minuteCount,
RemainingHour: limits.PerHour - hourCount,
ResetMinute: minuteWindow.Add(time.Minute),
ResetHour: hourWindow.Add(time.Hour),
}
// Ensure remaining doesn't go negative
if result.RemainingMinute < 0 {
result.RemainingMinute = 0
}
if result.RemainingHour < 0 {
result.RemainingHour = 0
}
// Check if either limit is exceeded
if minuteCount >= limits.PerMinute {
result.Allowed = false
result.RetryAfter = time.Until(result.ResetMinute)
if result.RetryAfter < 0 {
result.RetryAfter = time.Second
}
return result, nil
}
if hourCount >= limits.PerHour {
result.Allowed = false
result.RetryAfter = time.Until(result.ResetHour)
if result.RetryAfter < 0 {
result.RetryAfter = time.Second
}
return result, nil
}
result.Allowed = true
return result, nil
}
// RecordRequest records that a request was made for the given API key.
func (r *RateLimiter) RecordRequest(ctx context.Context, keyID string) error {
now := time.Now()
minuteWindow := domain.TruncateToMinute(now)
hourWindow := domain.TruncateToHour(now)
// Use a transaction to update both windows atomically
tx, err := r.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("begin tx: %w", err)
}
defer func() { _ = tx.Rollback() }()
// Upsert minute window
if err := r.upsertWindow(ctx, tx, keyID, minuteWindow, domain.WindowTypeMinute); err != nil {
return fmt.Errorf("upsert minute window: %w", err)
}
// Upsert hour window
if err := r.upsertWindow(ctx, tx, keyID, hourWindow, domain.WindowTypeHour); err != nil {
return fmt.Errorf("upsert hour window: %w", err)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("commit: %w", err)
}
return nil
}
// GetLimits retrieves the rate limit configuration for an API key.
func (r *RateLimiter) GetLimits(ctx context.Context, keyID string) (*domain.RateLimitConfig, error) {
var perMinute, perHour sql.NullInt64
err := r.db.QueryRowContext(ctx, `
SELECT rate_limit_per_minute, rate_limit_per_hour
FROM api_keys
WHERE id = $1
`, keyID).Scan(&perMinute, &perHour)
if errors.Is(err, sql.ErrNoRows) {
// Key not found, return defaults
defaults := domain.DefaultRateLimitConfig()
return &defaults, nil
}
if err != nil {
return nil, fmt.Errorf("query limits: %w", err)
}
config := domain.DefaultRateLimitConfig()
if perMinute.Valid {
config.PerMinute = int(perMinute.Int64)
}
if perHour.Valid {
config.PerHour = int(perHour.Int64)
}
return &config, nil
}
// Cleanup removes expired rate limit state entries.
func (r *RateLimiter) Cleanup(ctx context.Context) error {
// Remove entries older than 2 hours (well past any active window)
cutoff := time.Now().Add(-2 * time.Hour)
result, err := r.db.ExecContext(ctx, `
DELETE FROM rate_limit_state
WHERE window_start < $1
`, cutoff)
if err != nil {
return fmt.Errorf("delete old entries: %w", err)
}
rows, _ := result.RowsAffected()
if rows > 0 {
// Log cleanup (optional, could use structured logging)
_ = rows
}
return nil
}
// getWindowCount returns the request count for a specific window.
func (r *RateLimiter) getWindowCount(ctx context.Context, keyID string, windowStart time.Time, windowType string) (int, error) {
var count int
err := r.db.QueryRowContext(ctx, `
SELECT COALESCE(request_count, 0)
FROM rate_limit_state
WHERE api_key_id = $1 AND window_start = $2 AND window_type = $3
`, keyID, windowStart, windowType).Scan(&count)
if errors.Is(err, sql.ErrNoRows) {
return 0, nil
}
if err != nil {
return 0, fmt.Errorf("query count: %w", err)
}
return count, nil
}
// upsertWindow inserts or updates a rate limit window.
func (r *RateLimiter) upsertWindow(ctx context.Context, tx *sql.Tx, keyID string, windowStart time.Time, windowType string) error {
_, err := tx.ExecContext(ctx, `
INSERT INTO rate_limit_state (api_key_id, window_start, window_type, request_count, updated_at)
VALUES ($1, $2, $3, 1, NOW())
ON CONFLICT (api_key_id, window_start, window_type)
DO UPDATE SET request_count = rate_limit_state.request_count + 1, updated_at = NOW()
`, keyID, windowStart, windowType)
if err != nil {
return fmt.Errorf("upsert: %w", err)
}
return nil
}
// StartCleanupWorker starts a background goroutine that periodically cleans up expired entries.
// Returns a stop function to terminate the worker.
func (r *RateLimiter) StartCleanupWorker(ctx context.Context, interval time.Duration) func() {
stopCh := make(chan struct{})
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-stopCh:
return
case <-ticker.C:
if err := r.Cleanup(ctx); err != nil {
r.logger.Error("rate limit cleanup failed", "error", err)
}
}
}
}()
return func() {
close(stopCh)
}
}