- Add ListPipelines/GetPipeline to CIProvider port with Woodpecker adapter
- Add DNS alias endpoints: GET/POST/DELETE /projects/{id}/domains
- Implement worker executor daemon, build executor, and git operations
- Add build service, worker service, and build audit tracking
- Add worker registry with PostgreSQL adapter and migration
- Add multi-provider code agent interface (Claude Code + OpenCode)
- Add create-and-build combo endpoint
- Update landing-page cookbook to reflect all gaps closed
- Fix tech debt: unified validation, auth scopes, error wrapping, slog patterns
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
273 lines
6.0 KiB
Go
273 lines
6.0 KiB
Go
// Package ratelimit provides rate limiting middleware for HTTP handlers.
|
|
package ratelimit
|
|
|
|
import (
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/orchard9/rdev/internal/auth"
|
|
"github.com/orchard9/rdev/pkg/api"
|
|
)
|
|
|
|
// Config defines rate limit parameters.
|
|
type Config struct {
|
|
// RequestsPerMinute is the average rate limit.
|
|
RequestsPerMinute int
|
|
|
|
// BurstSize is the maximum number of requests allowed in a burst.
|
|
// Defaults to RequestsPerMinute / 2 if not set.
|
|
BurstSize int
|
|
|
|
// CleanupInterval is how often to clean up stale entries.
|
|
// Defaults to 5 minutes.
|
|
CleanupInterval time.Duration
|
|
|
|
// KeyFunc extracts the rate limit key from a request.
|
|
// Defaults to using the API key ID from context.
|
|
KeyFunc func(*http.Request) string
|
|
}
|
|
|
|
// DefaultConfig returns a sensible default configuration.
|
|
func DefaultConfig() Config {
|
|
return Config{
|
|
RequestsPerMinute: 100,
|
|
BurstSize: 50,
|
|
CleanupInterval: 5 * time.Minute,
|
|
}
|
|
}
|
|
|
|
// Limiter implements token bucket rate limiting.
|
|
type Limiter struct {
|
|
cfg Config
|
|
buckets map[string]*bucket
|
|
mu sync.RWMutex
|
|
stopCh chan struct{}
|
|
}
|
|
|
|
type bucket struct {
|
|
tokens float64
|
|
lastUpdate time.Time
|
|
}
|
|
|
|
// New creates a new rate limiter with the given configuration.
|
|
func New(cfg Config) *Limiter {
|
|
if cfg.RequestsPerMinute <= 0 {
|
|
cfg.RequestsPerMinute = 100
|
|
}
|
|
if cfg.BurstSize <= 0 {
|
|
cfg.BurstSize = cfg.RequestsPerMinute / 2
|
|
if cfg.BurstSize < 1 {
|
|
cfg.BurstSize = 1
|
|
}
|
|
}
|
|
if cfg.CleanupInterval <= 0 {
|
|
cfg.CleanupInterval = 5 * time.Minute
|
|
}
|
|
|
|
l := &Limiter{
|
|
cfg: cfg,
|
|
buckets: make(map[string]*bucket),
|
|
stopCh: make(chan struct{}),
|
|
}
|
|
|
|
// Start cleanup goroutine
|
|
go l.cleanup()
|
|
|
|
return l
|
|
}
|
|
|
|
// Stop stops the background cleanup goroutine.
|
|
func (l *Limiter) Stop() {
|
|
close(l.stopCh)
|
|
}
|
|
|
|
// Allow checks if a request is allowed under the rate limit.
|
|
// Returns remaining tokens and whether the request is allowed.
|
|
func (l *Limiter) Allow(key string) (remaining int, allowed bool) {
|
|
now := time.Now()
|
|
|
|
l.mu.Lock()
|
|
defer l.mu.Unlock()
|
|
|
|
b, exists := l.buckets[key]
|
|
if !exists {
|
|
b = &bucket{
|
|
tokens: float64(l.cfg.BurstSize),
|
|
lastUpdate: now,
|
|
}
|
|
l.buckets[key] = b
|
|
}
|
|
|
|
// Refill tokens based on time elapsed
|
|
elapsed := now.Sub(b.lastUpdate).Seconds()
|
|
rate := float64(l.cfg.RequestsPerMinute) / 60.0 // tokens per second
|
|
b.tokens += elapsed * rate
|
|
if b.tokens > float64(l.cfg.BurstSize) {
|
|
b.tokens = float64(l.cfg.BurstSize)
|
|
}
|
|
b.lastUpdate = now
|
|
|
|
// Try to consume a token
|
|
if b.tokens >= 1 {
|
|
b.tokens--
|
|
return int(b.tokens), true
|
|
}
|
|
|
|
return 0, false
|
|
}
|
|
|
|
// Middleware returns an HTTP middleware that enforces rate limits.
|
|
func (l *Limiter) Middleware() func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Get the rate limit key
|
|
key := l.getKey(r)
|
|
if key == "" {
|
|
// No key means no rate limiting (e.g., health checks)
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
remaining, allowed := l.Allow(key)
|
|
|
|
// Set rate limit headers
|
|
w.Header().Set("X-RateLimit-Limit", itoa(l.cfg.RequestsPerMinute))
|
|
w.Header().Set("X-RateLimit-Remaining", itoa(remaining))
|
|
|
|
if !allowed {
|
|
// Calculate retry time
|
|
retryAfter := 60.0 / float64(l.cfg.RequestsPerMinute)
|
|
w.Header().Set("Retry-After", itoa(int(retryAfter)+1))
|
|
api.WriteError(w, r, http.StatusTooManyRequests, "RATE_LIMITED",
|
|
"Rate limit exceeded. Please retry later.")
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
func (l *Limiter) getKey(r *http.Request) string {
|
|
// Use custom key function if provided
|
|
if l.cfg.KeyFunc != nil {
|
|
return l.cfg.KeyFunc(r)
|
|
}
|
|
|
|
// Default: use API key ID from context
|
|
// This requires the auth middleware to run first
|
|
if apiKey := auth.GetAPIKey(r.Context()); apiKey != nil {
|
|
return string(apiKey.ID)
|
|
}
|
|
|
|
// Fallback: use client IP
|
|
return getClientIP(r)
|
|
}
|
|
|
|
func (l *Limiter) cleanup() {
|
|
ticker := time.NewTicker(l.cfg.CleanupInterval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-l.stopCh:
|
|
return
|
|
case <-ticker.C:
|
|
l.doCleanup()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (l *Limiter) doCleanup() {
|
|
l.mu.Lock()
|
|
defer l.mu.Unlock()
|
|
|
|
// Remove buckets that haven't been used in 2x cleanup interval
|
|
threshold := time.Now().Add(-2 * l.cfg.CleanupInterval)
|
|
|
|
for key, b := range l.buckets {
|
|
if b.lastUpdate.Before(threshold) {
|
|
delete(l.buckets, key)
|
|
}
|
|
}
|
|
}
|
|
|
|
// getClientIP extracts the client IP from the request.
|
|
func getClientIP(r *http.Request) string {
|
|
// Check X-Forwarded-For header (set by proxies/load balancers)
|
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
|
// Take the first IP in the chain
|
|
for i := 0; i < len(xff); i++ {
|
|
if xff[i] == ',' {
|
|
return xff[:i]
|
|
}
|
|
}
|
|
return xff
|
|
}
|
|
|
|
// Check X-Real-IP header
|
|
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
|
return xri
|
|
}
|
|
|
|
// Fall back to RemoteAddr
|
|
// RemoteAddr is "IP:port", so strip the port
|
|
addr := r.RemoteAddr
|
|
for i := len(addr) - 1; i >= 0; i-- {
|
|
if addr[i] == ':' {
|
|
return addr[:i]
|
|
}
|
|
}
|
|
return addr
|
|
}
|
|
|
|
// itoa converts an integer to a string without importing strconv.
|
|
func itoa(i int) string {
|
|
if i == 0 {
|
|
return "0"
|
|
}
|
|
|
|
negative := i < 0
|
|
if negative {
|
|
i = -i
|
|
}
|
|
|
|
// Max int64 is 19 digits
|
|
buf := make([]byte, 0, 20)
|
|
|
|
for i > 0 {
|
|
buf = append(buf, byte('0'+i%10))
|
|
i /= 10
|
|
}
|
|
|
|
if negative {
|
|
buf = append(buf, '-')
|
|
}
|
|
|
|
// Reverse
|
|
for left, right := 0, len(buf)-1; left < right; left, right = left+1, right-1 {
|
|
buf[left], buf[right] = buf[right], buf[left]
|
|
}
|
|
|
|
return string(buf)
|
|
}
|
|
|
|
// KeyFromAPIKey creates a KeyFunc that extracts the API key ID for rate limiting.
|
|
// This is useful when you want to rate limit by API key rather than IP.
|
|
func KeyFromAPIKey() func(*http.Request) string {
|
|
return func(r *http.Request) string {
|
|
if apiKey := auth.GetAPIKey(r.Context()); apiKey != nil {
|
|
return string(apiKey.ID)
|
|
}
|
|
return getClientIP(r)
|
|
}
|
|
}
|
|
|
|
// KeyFromIP creates a KeyFunc that uses client IP for rate limiting.
|
|
func KeyFromIP() func(*http.Request) string {
|
|
return func(r *http.Request) string {
|
|
return getClientIP(r)
|
|
}
|
|
}
|