// Package middleware provides HTTP middleware components for the rdev API. package middleware import ( "log/slog" "net/http" "strconv" "github.com/orchard9/rdev/internal/auth" "github.com/orchard9/rdev/internal/domain" "github.com/orchard9/rdev/internal/port" "github.com/orchard9/rdev/pkg/api" ) // RateLimitConfig holds configuration for the rate limit middleware. type RateLimitConfig struct { // SkipPaths are paths that should not be rate limited. SkipPaths map[string]bool // Limiter is the rate limiter implementation to use. Limiter port.RateLimiter // Logger for rate limit events (optional). Logger *slog.Logger } // DefaultRateLimitConfig returns a sensible default configuration. func DefaultRateLimitConfig() RateLimitConfig { return RateLimitConfig{ SkipPaths: map[string]bool{ "/health": true, "/ready": true, "/docs": true, "/openapi.json": true, "/metrics": true, }, } } // RateLimitMiddleware returns an HTTP middleware that enforces rate limits. // It requires the auth middleware to run first to set the API key context. func RateLimitMiddleware(cfg RateLimitConfig) func(http.Handler) http.Handler { logger := cfg.Logger if logger == nil { logger = slog.Default() } return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Skip rate limiting for configured paths if cfg.SkipPaths[r.URL.Path] { next.ServeHTTP(w, r) return } // Get API key from context (set by auth middleware) apiKey := auth.GetAPIKey(r.Context()) if apiKey == nil { // No API key means auth middleware hasn't run or request is unauthenticated // Let the auth middleware handle this next.ServeHTTP(w, r) return } // Skip rate limiting for admin keys if apiKey.ID == "admin" { next.ServeHTTP(w, r) return } // Check rate limit and record atomically to prevent race conditions // RecordRequest is called first to ensure the count is incremented before // we check, preventing burst bypass under high concurrency if err := cfg.Limiter.RecordRequest(r.Context(), apiKey.ID); err != nil { logger.Error("failed to record rate limit request", "error", err, "key_id", apiKey.ID) // On error, allow the request (fail open) next.ServeHTTP(w, r) return } // Now check the limit (which includes the just-recorded request) result, err := cfg.Limiter.CheckLimit(r.Context(), apiKey.ID) if err != nil { logger.Error("failed to check rate limit", "error", err, "key_id", apiKey.ID) // On error, allow the request (fail open) next.ServeHTTP(w, r) return } // Set rate limit headers on all responses setRateLimitHeaders(w, result) if !result.Allowed { // Rate limit exceeded retryAfterSeconds := int(result.RetryAfter.Seconds()) if retryAfterSeconds < 1 { retryAfterSeconds = 1 } w.Header().Set("Retry-After", strconv.Itoa(retryAfterSeconds)) api.WriteError(w, r, http.StatusTooManyRequests, "RATE_LIMITED", "Rate limit exceeded. Please retry after "+strconv.Itoa(retryAfterSeconds)+" seconds.") return } next.ServeHTTP(w, r) }) } } // setRateLimitHeaders sets the standard rate limit headers on the response. func setRateLimitHeaders(w http.ResponseWriter, result *domain.RateLimitResult) { // Use the minute limit as the primary limit in headers (more commonly hit) w.Header().Set("X-RateLimit-Limit", strconv.Itoa(result.LimitMinute)) w.Header().Set("X-RateLimit-Remaining", strconv.Itoa(result.RemainingMinute)) w.Header().Set("X-RateLimit-Reset", strconv.FormatInt(result.ResetMinute.Unix(), 10)) // Also include hourly limits in extended headers w.Header().Set("X-RateLimit-Limit-Hour", strconv.Itoa(result.LimitHour)) w.Header().Set("X-RateLimit-Remaining-Hour", strconv.Itoa(result.RemainingHour)) w.Header().Set("X-RateLimit-Reset-Hour", strconv.FormatInt(result.ResetHour.Unix(), 10)) }