133 lines
3.0 KiB
Go
133 lines
3.0 KiB
Go
package middleware
|
|
|
|
import (
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// RateLimitConfig configures per-IP rate limiting.
|
|
type RateLimitConfig struct {
|
|
// Requests is the maximum number of requests allowed per window.
|
|
Requests int
|
|
// Window is the time window for the rate limit.
|
|
Window time.Duration
|
|
}
|
|
|
|
// ipEntry tracks request timestamps for a single IP.
|
|
type ipEntry struct {
|
|
timestamps []time.Time
|
|
}
|
|
|
|
// rateLimiter implements a sliding window rate limiter.
|
|
type rateLimiter struct {
|
|
mu sync.Mutex
|
|
entries map[string]*ipEntry
|
|
config RateLimitConfig
|
|
}
|
|
|
|
// RateLimit returns middleware that limits requests per IP using a sliding window.
|
|
// When the limit is exceeded, it responds with 429 Too Many Requests.
|
|
func RateLimit(cfg RateLimitConfig) func(http.Handler) http.Handler {
|
|
rl := &rateLimiter{
|
|
entries: make(map[string]*ipEntry),
|
|
config: cfg,
|
|
}
|
|
|
|
// Periodically evict stale entries to prevent unbounded growth.
|
|
go rl.cleanup()
|
|
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
ip := rateLimitClientIP(r)
|
|
if !rl.allow(ip) {
|
|
http.Error(w, `{"error":{"message":"too many requests, please try again later"}}`, http.StatusTooManyRequests)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
// allow checks if the IP is within its rate limit and records the request.
|
|
func (rl *rateLimiter) allow(ip string) bool {
|
|
rl.mu.Lock()
|
|
defer rl.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
cutoff := now.Add(-rl.config.Window)
|
|
|
|
entry, ok := rl.entries[ip]
|
|
if !ok {
|
|
entry = &ipEntry{}
|
|
rl.entries[ip] = entry
|
|
}
|
|
|
|
// Remove timestamps outside the window.
|
|
valid := entry.timestamps[:0]
|
|
for _, t := range entry.timestamps {
|
|
if t.After(cutoff) {
|
|
valid = append(valid, t)
|
|
}
|
|
}
|
|
entry.timestamps = valid
|
|
|
|
if len(entry.timestamps) >= rl.config.Requests {
|
|
return false
|
|
}
|
|
|
|
entry.timestamps = append(entry.timestamps, now)
|
|
return true
|
|
}
|
|
|
|
// cleanup removes stale entries every 5 minutes.
|
|
func (rl *rateLimiter) cleanup() {
|
|
ticker := time.NewTicker(5 * time.Minute)
|
|
defer ticker.Stop()
|
|
|
|
for range ticker.C {
|
|
rl.mu.Lock()
|
|
now := time.Now()
|
|
cutoff := now.Add(-rl.config.Window)
|
|
for ip, entry := range rl.entries {
|
|
valid := entry.timestamps[:0]
|
|
for _, t := range entry.timestamps {
|
|
if t.After(cutoff) {
|
|
valid = append(valid, t)
|
|
}
|
|
}
|
|
if len(valid) == 0 {
|
|
delete(rl.entries, ip)
|
|
} else {
|
|
entry.timestamps = valid
|
|
}
|
|
}
|
|
rl.mu.Unlock()
|
|
}
|
|
}
|
|
|
|
// rateLimitClientIP extracts the client IP, trusting proxy headers only from private IPs.
|
|
func rateLimitClientIP(r *http.Request) string {
|
|
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
|
if err != nil {
|
|
host = r.RemoteAddr
|
|
}
|
|
|
|
ip := net.ParseIP(host)
|
|
if ip != nil && (ip.IsLoopback() || ip.IsPrivate()) {
|
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
|
parts := strings.SplitN(xff, ",", 2)
|
|
if fwd := strings.TrimSpace(parts[0]); fwd != "" {
|
|
return fwd
|
|
}
|
|
}
|
|
if xri := r.Header.Get("X-Real-Ip"); xri != "" {
|
|
return xri
|
|
}
|
|
}
|
|
|
|
return host
|
|
}
|