package ratelimit import ( "net/http" "net/http/httptest" "sync" "testing" "time" "github.com/orchard9/rdev/internal/auth" ) func TestNew(t *testing.T) { t.Run("default config", func(t *testing.T) { l := New(Config{}) defer l.Stop() if l.cfg.RequestsPerMinute != 100 { t.Errorf("RequestsPerMinute = %d, want 100", l.cfg.RequestsPerMinute) } if l.cfg.BurstSize != 50 { t.Errorf("BurstSize = %d, want 50", l.cfg.BurstSize) } }) t.Run("custom config", func(t *testing.T) { l := New(Config{ RequestsPerMinute: 200, BurstSize: 100, CleanupInterval: time.Minute, }) defer l.Stop() if l.cfg.RequestsPerMinute != 200 { t.Errorf("RequestsPerMinute = %d, want 200", l.cfg.RequestsPerMinute) } if l.cfg.BurstSize != 100 { t.Errorf("BurstSize = %d, want 100", l.cfg.BurstSize) } }) } func TestDefaultConfig(t *testing.T) { cfg := DefaultConfig() if cfg.RequestsPerMinute != 100 { t.Errorf("RequestsPerMinute = %d, want 100", cfg.RequestsPerMinute) } if cfg.BurstSize != 50 { t.Errorf("BurstSize = %d, want 50", cfg.BurstSize) } if cfg.CleanupInterval != 5*time.Minute { t.Errorf("CleanupInterval = %v, want 5m", cfg.CleanupInterval) } } func TestAllow(t *testing.T) { t.Run("allows requests within limit", func(t *testing.T) { l := New(Config{ RequestsPerMinute: 60, BurstSize: 10, }) defer l.Stop() // Should allow burst requests for i := 0; i < 10; i++ { remaining, allowed := l.Allow("test-key") if !allowed { t.Errorf("Request %d was denied, want allowed", i) } if remaining != 10-i-1 { t.Errorf("Request %d: remaining = %d, want %d", i, remaining, 10-i-1) } } }) t.Run("denies requests exceeding limit", func(t *testing.T) { l := New(Config{ RequestsPerMinute: 60, BurstSize: 5, }) defer l.Stop() // Exhaust the bucket for i := 0; i < 5; i++ { l.Allow("test-key") } // Next request should be denied _, allowed := l.Allow("test-key") if allowed { t.Error("Request was allowed, want denied") } }) t.Run("refills over time", func(t *testing.T) { l := New(Config{ RequestsPerMinute: 60, // 1 per second BurstSize: 1, }) defer l.Stop() // Use the one token l.Allow("test-key") // Should be denied immediately _, allowed := l.Allow("test-key") if allowed { t.Error("Request was allowed immediately, want denied") } // Wait for refill (1 token per second) time.Sleep(1100 * time.Millisecond) // Should be allowed now _, allowed = l.Allow("test-key") if !allowed { t.Error("Request was denied after refill, want allowed") } }) t.Run("separate buckets per key", func(t *testing.T) { l := New(Config{ RequestsPerMinute: 60, BurstSize: 1, }) defer l.Stop() // Exhaust key1 l.Allow("key1") _, allowed1 := l.Allow("key1") if allowed1 { t.Error("key1 was allowed, want denied") } // key2 should still have tokens _, allowed2 := l.Allow("key2") if !allowed2 { t.Error("key2 was denied, want allowed") } }) } func TestMiddleware(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("OK")) }) t.Run("allows requests within limit", func(t *testing.T) { l := New(Config{ RequestsPerMinute: 100, BurstSize: 10, KeyFunc: KeyFromIP(), }) defer l.Stop() middleware := l.Middleware() wrapped := middleware(handler) req := httptest.NewRequest("GET", "/test", nil) req.RemoteAddr = "192.168.1.1:12345" rec := httptest.NewRecorder() wrapped.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("Status = %d, want 200", rec.Code) } if rec.Header().Get("X-RateLimit-Limit") != "100" { t.Errorf("X-RateLimit-Limit = %q, want 100", rec.Header().Get("X-RateLimit-Limit")) } }) t.Run("returns 429 when rate limited", func(t *testing.T) { l := New(Config{ RequestsPerMinute: 60, BurstSize: 1, KeyFunc: KeyFromIP(), }) defer l.Stop() middleware := l.Middleware() wrapped := middleware(handler) // First request should succeed req1 := httptest.NewRequest("GET", "/test", nil) req1.RemoteAddr = "192.168.1.1:12345" rec1 := httptest.NewRecorder() wrapped.ServeHTTP(rec1, req1) if rec1.Code != http.StatusOK { t.Errorf("First request status = %d, want 200", rec1.Code) } // Second request should be rate limited req2 := httptest.NewRequest("GET", "/test", nil) req2.RemoteAddr = "192.168.1.1:12345" rec2 := httptest.NewRecorder() wrapped.ServeHTTP(rec2, req2) if rec2.Code != http.StatusTooManyRequests { t.Errorf("Second request status = %d, want 429", rec2.Code) } if rec2.Header().Get("Retry-After") == "" { t.Error("Retry-After header not set") } }) t.Run("no key means no rate limiting", func(t *testing.T) { l := New(Config{ RequestsPerMinute: 60, BurstSize: 1, KeyFunc: func(r *http.Request) string { return "" // No key }, }) defer l.Stop() middleware := l.Middleware() wrapped := middleware(handler) // Multiple requests should all succeed for i := 0; i < 5; i++ { req := httptest.NewRequest("GET", "/test", nil) rec := httptest.NewRecorder() wrapped.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("Request %d status = %d, want 200", i, rec.Code) } } }) } func TestGetClientIP(t *testing.T) { tests := []struct { name string remoteAddr string xff string xri string want string }{ {"from RemoteAddr", "192.168.1.1:12345", "", "", "192.168.1.1"}, {"from X-Forwarded-For single", "127.0.0.1:8080", "10.0.0.1", "", "10.0.0.1"}, {"from X-Forwarded-For multiple", "127.0.0.1:8080", "10.0.0.1, 10.0.0.2", "", "10.0.0.1"}, {"from X-Real-IP", "127.0.0.1:8080", "", "10.0.0.5", "10.0.0.5"}, {"X-Forwarded-For takes precedence", "127.0.0.1:8080", "10.0.0.1", "10.0.0.5", "10.0.0.1"}, {"no port in RemoteAddr", "192.168.1.1", "", "", "192.168.1.1"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) req.RemoteAddr = tt.remoteAddr if tt.xff != "" { req.Header.Set("X-Forwarded-For", tt.xff) } if tt.xri != "" { req.Header.Set("X-Real-IP", tt.xri) } got := getClientIP(req) if got != tt.want { t.Errorf("getClientIP() = %q, want %q", got, tt.want) } }) } } func TestKeyFromAPIKey(t *testing.T) { keyFunc := KeyFromAPIKey() t.Run("extracts from context", func(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) apiKey := &auth.APIKey{ID: "test-key-123"} ctx := auth.WithAPIKey(req.Context(), apiKey) req = req.WithContext(ctx) got := keyFunc(req) if got != "test-key-123" { t.Errorf("KeyFromAPIKey() = %q, want test-key-123", got) } }) t.Run("falls back to IP", func(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) req.RemoteAddr = "192.168.1.100:12345" got := keyFunc(req) if got != "192.168.1.100" { t.Errorf("KeyFromAPIKey() = %q, want 192.168.1.100", got) } }) } func TestKeyFromIP(t *testing.T) { keyFunc := KeyFromIP() req := httptest.NewRequest("GET", "/", nil) req.RemoteAddr = "10.0.0.50:12345" got := keyFunc(req) if got != "10.0.0.50" { t.Errorf("KeyFromIP() = %q, want 10.0.0.50", got) } } func TestItoa(t *testing.T) { tests := []struct { input int want string }{ {0, "0"}, {1, "1"}, {10, "10"}, {100, "100"}, {12345, "12345"}, {-1, "-1"}, {-12345, "-12345"}, } for _, tt := range tests { t.Run(tt.want, func(t *testing.T) { got := itoa(tt.input) if got != tt.want { t.Errorf("itoa(%d) = %q, want %q", tt.input, got, tt.want) } }) } } func TestConcurrentAccess(t *testing.T) { l := New(Config{ RequestsPerMinute: 1000, BurstSize: 100, }) defer l.Stop() var wg sync.WaitGroup var allowedCount, deniedCount int64 var mu sync.Mutex // Spawn many goroutines making requests for i := 0; i < 200; i++ { wg.Add(1) go func() { defer wg.Done() _, allowed := l.Allow("concurrent-test") mu.Lock() if allowed { allowedCount++ } else { deniedCount++ } mu.Unlock() }() } wg.Wait() // Should have allowed approximately BurstSize requests // and denied the rest if allowedCount < 90 || allowedCount > 110 { t.Errorf("allowedCount = %d, want ~100", allowedCount) } if deniedCount < 90 || deniedCount > 110 { t.Errorf("deniedCount = %d, want ~100", deniedCount) } } func TestCleanup(t *testing.T) { l := New(Config{ RequestsPerMinute: 60, BurstSize: 10, CleanupInterval: 50 * time.Millisecond, }) defer l.Stop() // Make some requests to create buckets l.Allow("key1") l.Allow("key2") l.mu.RLock() bucketCount := len(l.buckets) l.mu.RUnlock() if bucketCount != 2 { t.Errorf("bucketCount = %d, want 2", bucketCount) } // Wait for cleanup (2x cleanup interval) time.Sleep(150 * time.Millisecond) l.mu.RLock() bucketCount = len(l.buckets) l.mu.RUnlock() // Buckets should be cleaned up if bucketCount != 0 { t.Errorf("bucketCount after cleanup = %d, want 0", bucketCount) } } func TestStop(t *testing.T) { l := New(Config{}) // Should not panic when stopping l.Stop() // Should not panic if stopped multiple times // (but this is technically undefined behavior - just testing it doesn't crash) }