persona-community-2/pkg/middleware/ratelimit.go
jordan cb3d4d5786
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed
ci/woodpecker/manual/woodpecker Pipeline was successful
Initialize project from skeleton template
2026-02-23 10:53:55 +00:00

133 lines
3.0 KiB
Go

package middleware
import (
"net"
"net/http"
"strings"
"sync"
"time"
)
// RateLimitConfig configures per-IP rate limiting.
type RateLimitConfig struct {
// Requests is the maximum number of requests allowed per window.
Requests int
// Window is the time window for the rate limit.
Window time.Duration
}
// ipEntry tracks request timestamps for a single IP.
type ipEntry struct {
timestamps []time.Time
}
// rateLimiter implements a sliding window rate limiter.
type rateLimiter struct {
mu sync.Mutex
entries map[string]*ipEntry
config RateLimitConfig
}
// RateLimit returns middleware that limits requests per IP using a sliding window.
// When the limit is exceeded, it responds with 429 Too Many Requests.
func RateLimit(cfg RateLimitConfig) func(http.Handler) http.Handler {
rl := &rateLimiter{
entries: make(map[string]*ipEntry),
config: cfg,
}
// Periodically evict stale entries to prevent unbounded growth.
go rl.cleanup()
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := rateLimitClientIP(r)
if !rl.allow(ip) {
http.Error(w, `{"error":{"message":"too many requests, please try again later"}}`, http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
}
// allow checks if the IP is within its rate limit and records the request.
func (rl *rateLimiter) allow(ip string) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
cutoff := now.Add(-rl.config.Window)
entry, ok := rl.entries[ip]
if !ok {
entry = &ipEntry{}
rl.entries[ip] = entry
}
// Remove timestamps outside the window.
valid := entry.timestamps[:0]
for _, t := range entry.timestamps {
if t.After(cutoff) {
valid = append(valid, t)
}
}
entry.timestamps = valid
if len(entry.timestamps) >= rl.config.Requests {
return false
}
entry.timestamps = append(entry.timestamps, now)
return true
}
// cleanup removes stale entries every 5 minutes.
func (rl *rateLimiter) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
rl.mu.Lock()
now := time.Now()
cutoff := now.Add(-rl.config.Window)
for ip, entry := range rl.entries {
valid := entry.timestamps[:0]
for _, t := range entry.timestamps {
if t.After(cutoff) {
valid = append(valid, t)
}
}
if len(valid) == 0 {
delete(rl.entries, ip)
} else {
entry.timestamps = valid
}
}
rl.mu.Unlock()
}
}
// rateLimitClientIP extracts the client IP, trusting proxy headers only from private IPs.
func rateLimitClientIP(r *http.Request) string {
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
host = r.RemoteAddr
}
ip := net.ParseIP(host)
if ip != nil && (ip.IsLoopback() || ip.IsPrivate()) {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
parts := strings.SplitN(xff, ",", 2)
if fwd := strings.TrimSpace(parts[0]); fwd != "" {
return fwd
}
}
if xri := r.Header.Get("X-Real-Ip"); xri != "" {
return xri
}
}
return host
}