// Package ratelimit provides rate limiting middleware for HTTP handlers. package ratelimit import ( "net/http" "sync" "time" "github.com/orchard9/rdev/internal/auth" "github.com/orchard9/rdev/pkg/api" ) // Config defines rate limit parameters. type Config struct { // RequestsPerMinute is the average rate limit. RequestsPerMinute int // BurstSize is the maximum number of requests allowed in a burst. // Defaults to RequestsPerMinute / 2 if not set. BurstSize int // CleanupInterval is how often to clean up stale entries. // Defaults to 5 minutes. CleanupInterval time.Duration // KeyFunc extracts the rate limit key from a request. // Defaults to using the API key ID from context. KeyFunc func(*http.Request) string } // DefaultConfig returns a sensible default configuration. func DefaultConfig() Config { return Config{ RequestsPerMinute: 100, BurstSize: 50, CleanupInterval: 5 * time.Minute, } } // Limiter implements token bucket rate limiting. type Limiter struct { cfg Config buckets map[string]*bucket mu sync.RWMutex stopCh chan struct{} } type bucket struct { tokens float64 lastUpdate time.Time } // New creates a new rate limiter with the given configuration. func New(cfg Config) *Limiter { if cfg.RequestsPerMinute <= 0 { cfg.RequestsPerMinute = 100 } if cfg.BurstSize <= 0 { cfg.BurstSize = cfg.RequestsPerMinute / 2 if cfg.BurstSize < 1 { cfg.BurstSize = 1 } } if cfg.CleanupInterval <= 0 { cfg.CleanupInterval = 5 * time.Minute } l := &Limiter{ cfg: cfg, buckets: make(map[string]*bucket), stopCh: make(chan struct{}), } // Start cleanup goroutine go l.cleanup() return l } // Stop stops the background cleanup goroutine. func (l *Limiter) Stop() { close(l.stopCh) } // Allow checks if a request is allowed under the rate limit. // Returns remaining tokens and whether the request is allowed. func (l *Limiter) Allow(key string) (remaining int, allowed bool) { now := time.Now() l.mu.Lock() defer l.mu.Unlock() b, exists := l.buckets[key] if !exists { b = &bucket{ tokens: float64(l.cfg.BurstSize), lastUpdate: now, } l.buckets[key] = b } // Refill tokens based on time elapsed elapsed := now.Sub(b.lastUpdate).Seconds() rate := float64(l.cfg.RequestsPerMinute) / 60.0 // tokens per second b.tokens += elapsed * rate if b.tokens > float64(l.cfg.BurstSize) { b.tokens = float64(l.cfg.BurstSize) } b.lastUpdate = now // Try to consume a token if b.tokens >= 1 { b.tokens-- return int(b.tokens), true } return 0, false } // Middleware returns an HTTP middleware that enforces rate limits. func (l *Limiter) Middleware() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Get the rate limit key key := l.getKey(r) if key == "" { // No key means no rate limiting (e.g., health checks) next.ServeHTTP(w, r) return } remaining, allowed := l.Allow(key) // Set rate limit headers w.Header().Set("X-RateLimit-Limit", itoa(l.cfg.RequestsPerMinute)) w.Header().Set("X-RateLimit-Remaining", itoa(remaining)) if !allowed { // Calculate retry time retryAfter := 60.0 / float64(l.cfg.RequestsPerMinute) w.Header().Set("Retry-After", itoa(int(retryAfter)+1)) api.WriteError(w, r, http.StatusTooManyRequests, "RATE_LIMITED", "Rate limit exceeded. Please retry later.") return } next.ServeHTTP(w, r) }) } } func (l *Limiter) getKey(r *http.Request) string { // Use custom key function if provided if l.cfg.KeyFunc != nil { return l.cfg.KeyFunc(r) } // Default: use API key ID from context // This requires the auth middleware to run first if apiKey := auth.GetAPIKey(r.Context()); apiKey != nil { return string(apiKey.ID) } // Fallback: use client IP return getClientIP(r) } func (l *Limiter) cleanup() { ticker := time.NewTicker(l.cfg.CleanupInterval) defer ticker.Stop() for { select { case <-l.stopCh: return case <-ticker.C: l.doCleanup() } } } func (l *Limiter) doCleanup() { l.mu.Lock() defer l.mu.Unlock() // Remove buckets that haven't been used in 2x cleanup interval threshold := time.Now().Add(-2 * l.cfg.CleanupInterval) for key, b := range l.buckets { if b.lastUpdate.Before(threshold) { delete(l.buckets, key) } } } // getClientIP extracts the client IP from the request. func getClientIP(r *http.Request) string { // Check X-Forwarded-For header (set by proxies/load balancers) if xff := r.Header.Get("X-Forwarded-For"); xff != "" { // Take the first IP in the chain for i := 0; i < len(xff); i++ { if xff[i] == ',' { return xff[:i] } } return xff } // Check X-Real-IP header if xri := r.Header.Get("X-Real-IP"); xri != "" { return xri } // Fall back to RemoteAddr // RemoteAddr is "IP:port", so strip the port addr := r.RemoteAddr for i := len(addr) - 1; i >= 0; i-- { if addr[i] == ':' { return addr[:i] } } return addr } // itoa converts an integer to a string without importing strconv. func itoa(i int) string { if i == 0 { return "0" } negative := i < 0 if negative { i = -i } // Max int64 is 19 digits buf := make([]byte, 0, 20) for i > 0 { buf = append(buf, byte('0'+i%10)) i /= 10 } if negative { buf = append(buf, '-') } // Reverse for left, right := 0, len(buf)-1; left < right; left, right = left+1, right-1 { buf[left], buf[right] = buf[right], buf[left] } return string(buf) } // KeyFromAPIKey creates a KeyFunc that extracts the API key ID for rate limiting. // This is useful when you want to rate limit by API key rather than IP. func KeyFromAPIKey() func(*http.Request) string { return func(r *http.Request) string { if apiKey := auth.GetAPIKey(r.Context()); apiKey != nil { return string(apiKey.ID) } return getClientIP(r) } } // KeyFromIP creates a KeyFunc that uses client IP for rate limiting. func KeyFromIP() func(*http.Request) string { return func(r *http.Request) string { return getClientIP(r) } }