package middleware import ( "context" "errors" "net/http" "net/http/httptest" "testing" "time" "github.com/orchard9/rdev/internal/auth" "github.com/orchard9/rdev/internal/domain" ) // mockRateLimiter implements port.RateLimiter for testing. type mockRateLimiter struct { result *domain.RateLimitResult checkErr error recordErr error recordCalls int checkCalls int } func (m *mockRateLimiter) CheckLimit(ctx context.Context, apiKeyID string) (*domain.RateLimitResult, error) { m.checkCalls++ if m.checkErr != nil { return nil, m.checkErr } if m.result != nil { return m.result, nil } // Default: allowed return &domain.RateLimitResult{ Allowed: true, RemainingMinute: 50, RemainingHour: 900, LimitMinute: 60, LimitHour: 1000, ResetMinute: time.Now().Add(time.Minute), ResetHour: time.Now().Add(time.Hour), }, nil } func (m *mockRateLimiter) RecordRequest(ctx context.Context, apiKeyID string) error { m.recordCalls++ return m.recordErr } func (m *mockRateLimiter) GetLimits(ctx context.Context, apiKeyID string) (*domain.RateLimitConfig, error) { return &domain.RateLimitConfig{ PerMinute: 60, PerHour: 1000, }, nil } func (m *mockRateLimiter) Cleanup(ctx context.Context) error { return nil } func TestRateLimitMiddleware_AllowedRequest(t *testing.T) { limiter := &mockRateLimiter{} cfg := RateLimitConfig{ Limiter: limiter, SkipPaths: make(map[string]bool), } handler := RateLimitMiddleware(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) // Create request with API key context req := httptest.NewRequest(http.MethodGet, "/api/test", nil) ctx := auth.WithAPIKey(req.Context(), &auth.APIKey{ID: "test-key"}) req = req.WithContext(ctx) w := httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("expected status %d, got %d", http.StatusOK, w.Code) } // Verify rate limit headers are set if w.Header().Get("X-RateLimit-Limit") == "" { t.Error("expected X-RateLimit-Limit header to be set") } if w.Header().Get("X-RateLimit-Remaining") == "" { t.Error("expected X-RateLimit-Remaining header to be set") } if w.Header().Get("X-RateLimit-Reset") == "" { t.Error("expected X-RateLimit-Reset header to be set") } // Verify RecordRequest was called before CheckLimit if limiter.recordCalls != 1 { t.Errorf("expected RecordRequest to be called 1 time, got %d", limiter.recordCalls) } if limiter.checkCalls != 1 { t.Errorf("expected CheckLimit to be called 1 time, got %d", limiter.checkCalls) } } func TestRateLimitMiddleware_RateLimitExceeded(t *testing.T) { limiter := &mockRateLimiter{ result: &domain.RateLimitResult{ Allowed: false, RetryAfter: 5 * time.Second, RemainingMinute: 0, RemainingHour: 0, LimitMinute: 60, LimitHour: 1000, ResetMinute: time.Now().Add(time.Minute), ResetHour: time.Now().Add(time.Hour), }, } cfg := RateLimitConfig{ Limiter: limiter, SkipPaths: make(map[string]bool), } handler := RateLimitMiddleware(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Error("handler should not be called when rate limit exceeded") })) req := httptest.NewRequest(http.MethodGet, "/api/test", nil) ctx := auth.WithAPIKey(req.Context(), &auth.APIKey{ID: "test-key"}) req = req.WithContext(ctx) w := httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != http.StatusTooManyRequests { t.Errorf("expected status %d, got %d", http.StatusTooManyRequests, w.Code) } if w.Header().Get("Retry-After") == "" { t.Error("expected Retry-After header to be set") } } func TestRateLimitMiddleware_SkipPaths(t *testing.T) { limiter := &mockRateLimiter{} cfg := RateLimitConfig{ Limiter: limiter, SkipPaths: map[string]bool{ "/health": true, }, } handler := RateLimitMiddleware(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/health", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("expected status %d, got %d", http.StatusOK, w.Code) } // Rate limiter should not be called for skipped paths if limiter.recordCalls != 0 { t.Errorf("expected RecordRequest to not be called for skipped path, got %d calls", limiter.recordCalls) } } func TestRateLimitMiddleware_NoAPIKey(t *testing.T) { limiter := &mockRateLimiter{} cfg := RateLimitConfig{ Limiter: limiter, SkipPaths: make(map[string]bool), } handler := RateLimitMiddleware(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) // Request without API key context req := httptest.NewRequest(http.MethodGet, "/api/test", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) // Should pass through without rate limiting if w.Code != http.StatusOK { t.Errorf("expected status %d, got %d", http.StatusOK, w.Code) } // Rate limiter should not be called if limiter.recordCalls != 0 { t.Errorf("expected RecordRequest to not be called without API key, got %d calls", limiter.recordCalls) } } func TestRateLimitMiddleware_AdminKeyBypass(t *testing.T) { limiter := &mockRateLimiter{} cfg := RateLimitConfig{ Limiter: limiter, SkipPaths: make(map[string]bool), } handler := RateLimitMiddleware(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/api/test", nil) ctx := auth.WithAPIKey(req.Context(), &auth.APIKey{ID: "admin"}) req = req.WithContext(ctx) w := httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("expected status %d, got %d", http.StatusOK, w.Code) } // Rate limiter should not be called for admin if limiter.recordCalls != 0 { t.Errorf("expected RecordRequest to not be called for admin, got %d calls", limiter.recordCalls) } } func TestRateLimitMiddleware_RecordError(t *testing.T) { limiter := &mockRateLimiter{ recordErr: errors.New("record error"), } cfg := RateLimitConfig{ Limiter: limiter, SkipPaths: make(map[string]bool), } handler := RateLimitMiddleware(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/api/test", nil) ctx := auth.WithAPIKey(req.Context(), &auth.APIKey{ID: "test-key"}) req = req.WithContext(ctx) w := httptest.NewRecorder() handler.ServeHTTP(w, req) // Should fail open on error if w.Code != http.StatusOK { t.Errorf("expected fail-open with status %d, got %d", http.StatusOK, w.Code) } } func TestRateLimitMiddleware_CheckError(t *testing.T) { limiter := &mockRateLimiter{ checkErr: errors.New("check error"), } cfg := RateLimitConfig{ Limiter: limiter, SkipPaths: make(map[string]bool), } handler := RateLimitMiddleware(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/api/test", nil) ctx := auth.WithAPIKey(req.Context(), &auth.APIKey{ID: "test-key"}) req = req.WithContext(ctx) w := httptest.NewRecorder() handler.ServeHTTP(w, req) // Should fail open on error if w.Code != http.StatusOK { t.Errorf("expected fail-open with status %d, got %d", http.StatusOK, w.Code) } } func TestDefaultRateLimitConfig(t *testing.T) { cfg := DefaultRateLimitConfig() expectedPaths := []string{"/health", "/ready", "/docs", "/openapi.json", "/metrics"} for _, path := range expectedPaths { if !cfg.SkipPaths[path] { t.Errorf("expected %s to be in SkipPaths", path) } } } func TestSetRateLimitHeaders(t *testing.T) { w := httptest.NewRecorder() result := &domain.RateLimitResult{ Allowed: true, RemainingMinute: 50, RemainingHour: 900, LimitMinute: 60, LimitHour: 1000, ResetMinute: time.Now().Add(time.Minute), ResetHour: time.Now().Add(time.Hour), } setRateLimitHeaders(w, result) tests := []struct { header string want bool }{ {"X-RateLimit-Limit", true}, {"X-RateLimit-Remaining", true}, {"X-RateLimit-Reset", true}, {"X-RateLimit-Limit-Hour", true}, {"X-RateLimit-Remaining-Hour", true}, {"X-RateLimit-Reset-Hour", true}, } for _, tt := range tests { if (w.Header().Get(tt.header) != "") != tt.want { t.Errorf("header %s: got %q, want present=%v", tt.header, w.Header().Get(tt.header), tt.want) } } }