package middleware import ( "net" "net/http" "strings" "sync" "time" ) // RateLimitConfig configures per-IP rate limiting. type RateLimitConfig struct { // Requests is the maximum number of requests allowed per window. Requests int // Window is the time window for the rate limit. Window time.Duration } // ipEntry tracks request timestamps for a single IP. type ipEntry struct { timestamps []time.Time } // rateLimiter implements a sliding window rate limiter. type rateLimiter struct { mu sync.Mutex entries map[string]*ipEntry config RateLimitConfig } // RateLimit returns middleware that limits requests per IP using a sliding window. // When the limit is exceeded, it responds with 429 Too Many Requests. func RateLimit(cfg RateLimitConfig) func(http.Handler) http.Handler { rl := &rateLimiter{ entries: make(map[string]*ipEntry), config: cfg, } // Periodically evict stale entries to prevent unbounded growth. go rl.cleanup() return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ip := rateLimitClientIP(r) if !rl.allow(ip) { http.Error(w, `{"error":{"message":"too many requests, please try again later"}}`, http.StatusTooManyRequests) return } next.ServeHTTP(w, r) }) } } // allow checks if the IP is within its rate limit and records the request. func (rl *rateLimiter) allow(ip string) bool { rl.mu.Lock() defer rl.mu.Unlock() now := time.Now() cutoff := now.Add(-rl.config.Window) entry, ok := rl.entries[ip] if !ok { entry = &ipEntry{} rl.entries[ip] = entry } // Remove timestamps outside the window. valid := entry.timestamps[:0] for _, t := range entry.timestamps { if t.After(cutoff) { valid = append(valid, t) } } entry.timestamps = valid if len(entry.timestamps) >= rl.config.Requests { return false } entry.timestamps = append(entry.timestamps, now) return true } // cleanup removes stale entries every 5 minutes. func (rl *rateLimiter) cleanup() { ticker := time.NewTicker(5 * time.Minute) defer ticker.Stop() for range ticker.C { rl.mu.Lock() now := time.Now() cutoff := now.Add(-rl.config.Window) for ip, entry := range rl.entries { valid := entry.timestamps[:0] for _, t := range entry.timestamps { if t.After(cutoff) { valid = append(valid, t) } } if len(valid) == 0 { delete(rl.entries, ip) } else { entry.timestamps = valid } } rl.mu.Unlock() } } // rateLimitClientIP extracts the client IP, trusting proxy headers only from private IPs. func rateLimitClientIP(r *http.Request) string { host, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { host = r.RemoteAddr } ip := net.ParseIP(host) if ip != nil && (ip.IsLoopback() || ip.IsPrivate()) { if xff := r.Header.Get("X-Forwarded-For"); xff != "" { parts := strings.SplitN(xff, ",", 2) if fwd := strings.TrimSpace(parts[0]); fwd != "" { return fwd } } if xri := r.Header.Get("X-Real-Ip"); xri != "" { return xri } } return host }