rdev/internal/ratelimit/ratelimit.go
jordan bc47e426b0 feat: Add CI pipeline proxy, DNS alias management, and worker executor system
- Add ListPipelines/GetPipeline to CIProvider port with Woodpecker adapter
- Add DNS alias endpoints: GET/POST/DELETE /projects/{id}/domains
- Implement worker executor daemon, build executor, and git operations
- Add build service, worker service, and build audit tracking
- Add worker registry with PostgreSQL adapter and migration
- Add multi-provider code agent interface (Claude Code + OpenCode)
- Add create-and-build combo endpoint
- Update landing-page cookbook to reflect all gaps closed
- Fix tech debt: unified validation, auth scopes, error wrapping, slog patterns

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 21:05:28 -07:00

273 lines
6.0 KiB
Go

// Package ratelimit provides rate limiting middleware for HTTP handlers.
package ratelimit
import (
"net/http"
"sync"
"time"
"github.com/orchard9/rdev/internal/auth"
"github.com/orchard9/rdev/pkg/api"
)
// Config defines rate limit parameters.
type Config struct {
// RequestsPerMinute is the average rate limit.
RequestsPerMinute int
// BurstSize is the maximum number of requests allowed in a burst.
// Defaults to RequestsPerMinute / 2 if not set.
BurstSize int
// CleanupInterval is how often to clean up stale entries.
// Defaults to 5 minutes.
CleanupInterval time.Duration
// KeyFunc extracts the rate limit key from a request.
// Defaults to using the API key ID from context.
KeyFunc func(*http.Request) string
}
// DefaultConfig returns a sensible default configuration.
func DefaultConfig() Config {
return Config{
RequestsPerMinute: 100,
BurstSize: 50,
CleanupInterval: 5 * time.Minute,
}
}
// Limiter implements token bucket rate limiting.
type Limiter struct {
cfg Config
buckets map[string]*bucket
mu sync.RWMutex
stopCh chan struct{}
}
type bucket struct {
tokens float64
lastUpdate time.Time
}
// New creates a new rate limiter with the given configuration.
func New(cfg Config) *Limiter {
if cfg.RequestsPerMinute <= 0 {
cfg.RequestsPerMinute = 100
}
if cfg.BurstSize <= 0 {
cfg.BurstSize = cfg.RequestsPerMinute / 2
if cfg.BurstSize < 1 {
cfg.BurstSize = 1
}
}
if cfg.CleanupInterval <= 0 {
cfg.CleanupInterval = 5 * time.Minute
}
l := &Limiter{
cfg: cfg,
buckets: make(map[string]*bucket),
stopCh: make(chan struct{}),
}
// Start cleanup goroutine
go l.cleanup()
return l
}
// Stop stops the background cleanup goroutine.
func (l *Limiter) Stop() {
close(l.stopCh)
}
// Allow checks if a request is allowed under the rate limit.
// Returns remaining tokens and whether the request is allowed.
func (l *Limiter) Allow(key string) (remaining int, allowed bool) {
now := time.Now()
l.mu.Lock()
defer l.mu.Unlock()
b, exists := l.buckets[key]
if !exists {
b = &bucket{
tokens: float64(l.cfg.BurstSize),
lastUpdate: now,
}
l.buckets[key] = b
}
// Refill tokens based on time elapsed
elapsed := now.Sub(b.lastUpdate).Seconds()
rate := float64(l.cfg.RequestsPerMinute) / 60.0 // tokens per second
b.tokens += elapsed * rate
if b.tokens > float64(l.cfg.BurstSize) {
b.tokens = float64(l.cfg.BurstSize)
}
b.lastUpdate = now
// Try to consume a token
if b.tokens >= 1 {
b.tokens--
return int(b.tokens), true
}
return 0, false
}
// Middleware returns an HTTP middleware that enforces rate limits.
func (l *Limiter) Middleware() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get the rate limit key
key := l.getKey(r)
if key == "" {
// No key means no rate limiting (e.g., health checks)
next.ServeHTTP(w, r)
return
}
remaining, allowed := l.Allow(key)
// Set rate limit headers
w.Header().Set("X-RateLimit-Limit", itoa(l.cfg.RequestsPerMinute))
w.Header().Set("X-RateLimit-Remaining", itoa(remaining))
if !allowed {
// Calculate retry time
retryAfter := 60.0 / float64(l.cfg.RequestsPerMinute)
w.Header().Set("Retry-After", itoa(int(retryAfter)+1))
api.WriteError(w, r, http.StatusTooManyRequests, "RATE_LIMITED",
"Rate limit exceeded. Please retry later.")
return
}
next.ServeHTTP(w, r)
})
}
}
func (l *Limiter) getKey(r *http.Request) string {
// Use custom key function if provided
if l.cfg.KeyFunc != nil {
return l.cfg.KeyFunc(r)
}
// Default: use API key ID from context
// This requires the auth middleware to run first
if apiKey := auth.GetAPIKey(r.Context()); apiKey != nil {
return string(apiKey.ID)
}
// Fallback: use client IP
return getClientIP(r)
}
func (l *Limiter) cleanup() {
ticker := time.NewTicker(l.cfg.CleanupInterval)
defer ticker.Stop()
for {
select {
case <-l.stopCh:
return
case <-ticker.C:
l.doCleanup()
}
}
}
func (l *Limiter) doCleanup() {
l.mu.Lock()
defer l.mu.Unlock()
// Remove buckets that haven't been used in 2x cleanup interval
threshold := time.Now().Add(-2 * l.cfg.CleanupInterval)
for key, b := range l.buckets {
if b.lastUpdate.Before(threshold) {
delete(l.buckets, key)
}
}
}
// getClientIP extracts the client IP from the request.
func getClientIP(r *http.Request) string {
// Check X-Forwarded-For header (set by proxies/load balancers)
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// Take the first IP in the chain
for i := 0; i < len(xff); i++ {
if xff[i] == ',' {
return xff[:i]
}
}
return xff
}
// Check X-Real-IP header
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return xri
}
// Fall back to RemoteAddr
// RemoteAddr is "IP:port", so strip the port
addr := r.RemoteAddr
for i := len(addr) - 1; i >= 0; i-- {
if addr[i] == ':' {
return addr[:i]
}
}
return addr
}
// itoa converts an integer to a string without importing strconv.
func itoa(i int) string {
if i == 0 {
return "0"
}
negative := i < 0
if negative {
i = -i
}
// Max int64 is 19 digits
buf := make([]byte, 0, 20)
for i > 0 {
buf = append(buf, byte('0'+i%10))
i /= 10
}
if negative {
buf = append(buf, '-')
}
// Reverse
for left, right := 0, len(buf)-1; left < right; left, right = left+1, right-1 {
buf[left], buf[right] = buf[right], buf[left]
}
return string(buf)
}
// KeyFromAPIKey creates a KeyFunc that extracts the API key ID for rate limiting.
// This is useful when you want to rate limit by API key rather than IP.
func KeyFromAPIKey() func(*http.Request) string {
return func(r *http.Request) string {
if apiKey := auth.GetAPIKey(r.Context()); apiKey != nil {
return string(apiKey.ID)
}
return getClientIP(r)
}
}
// KeyFromIP creates a KeyFunc that uses client IP for rate limiting.
func KeyFromIP() func(*http.Request) string {
return func(r *http.Request) string {
return getClientIP(r)
}
}