rdev/internal/middleware/rate_limit_test.go
jordan 72d16929ca feat: Implement hexagonal architecture with services, webhooks, queue, and telemetry
Major refactoring to hexagonal (ports & adapters) architecture:

- Add service layer (apikey_service, project_service) for business logic
- Add webhook system with dispatcher and delivery tracking
- Add command queue with priority-based processing
- Add rate limiting with sliding window algorithm
- Add audit logging for command execution
- Add OpenTelemetry integration (traces, metrics, spans)
- Add circuit breaker for fault tolerance
- Add cached repository wrapper for performance
- Add comprehensive validation package
- Add Kubernetes client integration for pod management
- Add database migrations (allowed_ips, audit_log, rate_limiting, queue, webhooks)
- Add network policy and PodDisruptionBudget for k8s
- Remove legacy executor and projects/registry packages
- Untrack secrets.yaml (now managed via envault)
- Add coverage.out to .gitignore
- Add e2e test infrastructure with docker-compose
- Add comprehensive documentation (API, architecture, operations, plans)
- Add golangci-lint config and pre-commit hook

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-25 19:57:46 -07:00

320 lines
8.4 KiB
Go

package middleware
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/orchard9/rdev/internal/auth"
"github.com/orchard9/rdev/internal/domain"
)
// mockRateLimiter implements port.RateLimiter for testing.
type mockRateLimiter struct {
result *domain.RateLimitResult
checkErr error
recordErr error
recordCalls int
checkCalls int
}
func (m *mockRateLimiter) CheckLimit(ctx context.Context, apiKeyID string) (*domain.RateLimitResult, error) {
m.checkCalls++
if m.checkErr != nil {
return nil, m.checkErr
}
if m.result != nil {
return m.result, nil
}
// Default: allowed
return &domain.RateLimitResult{
Allowed: true,
RemainingMinute: 50,
RemainingHour: 900,
LimitMinute: 60,
LimitHour: 1000,
ResetMinute: time.Now().Add(time.Minute),
ResetHour: time.Now().Add(time.Hour),
}, nil
}
func (m *mockRateLimiter) RecordRequest(ctx context.Context, apiKeyID string) error {
m.recordCalls++
return m.recordErr
}
func (m *mockRateLimiter) GetLimits(ctx context.Context, apiKeyID string) (*domain.RateLimitConfig, error) {
return &domain.RateLimitConfig{
PerMinute: 60,
PerHour: 1000,
}, nil
}
func (m *mockRateLimiter) Cleanup(ctx context.Context) error {
return nil
}
func TestRateLimitMiddleware_AllowedRequest(t *testing.T) {
limiter := &mockRateLimiter{}
cfg := RateLimitConfig{
Limiter: limiter,
SkipPaths: make(map[string]bool),
}
handler := RateLimitMiddleware(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Create request with API key context
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
ctx := auth.WithAPIKey(req.Context(), &auth.APIKey{ID: "test-key"})
req = req.WithContext(ctx)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
// Verify rate limit headers are set
if w.Header().Get("X-RateLimit-Limit") == "" {
t.Error("expected X-RateLimit-Limit header to be set")
}
if w.Header().Get("X-RateLimit-Remaining") == "" {
t.Error("expected X-RateLimit-Remaining header to be set")
}
if w.Header().Get("X-RateLimit-Reset") == "" {
t.Error("expected X-RateLimit-Reset header to be set")
}
// Verify RecordRequest was called before CheckLimit
if limiter.recordCalls != 1 {
t.Errorf("expected RecordRequest to be called 1 time, got %d", limiter.recordCalls)
}
if limiter.checkCalls != 1 {
t.Errorf("expected CheckLimit to be called 1 time, got %d", limiter.checkCalls)
}
}
func TestRateLimitMiddleware_RateLimitExceeded(t *testing.T) {
limiter := &mockRateLimiter{
result: &domain.RateLimitResult{
Allowed: false,
RetryAfter: 5 * time.Second,
RemainingMinute: 0,
RemainingHour: 0,
LimitMinute: 60,
LimitHour: 1000,
ResetMinute: time.Now().Add(time.Minute),
ResetHour: time.Now().Add(time.Hour),
},
}
cfg := RateLimitConfig{
Limiter: limiter,
SkipPaths: make(map[string]bool),
}
handler := RateLimitMiddleware(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("handler should not be called when rate limit exceeded")
}))
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
ctx := auth.WithAPIKey(req.Context(), &auth.APIKey{ID: "test-key"})
req = req.WithContext(ctx)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusTooManyRequests {
t.Errorf("expected status %d, got %d", http.StatusTooManyRequests, w.Code)
}
if w.Header().Get("Retry-After") == "" {
t.Error("expected Retry-After header to be set")
}
}
func TestRateLimitMiddleware_SkipPaths(t *testing.T) {
limiter := &mockRateLimiter{}
cfg := RateLimitConfig{
Limiter: limiter,
SkipPaths: map[string]bool{
"/health": true,
},
}
handler := RateLimitMiddleware(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/health", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
// Rate limiter should not be called for skipped paths
if limiter.recordCalls != 0 {
t.Errorf("expected RecordRequest to not be called for skipped path, got %d calls", limiter.recordCalls)
}
}
func TestRateLimitMiddleware_NoAPIKey(t *testing.T) {
limiter := &mockRateLimiter{}
cfg := RateLimitConfig{
Limiter: limiter,
SkipPaths: make(map[string]bool),
}
handler := RateLimitMiddleware(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Request without API key context
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
// Should pass through without rate limiting
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
// Rate limiter should not be called
if limiter.recordCalls != 0 {
t.Errorf("expected RecordRequest to not be called without API key, got %d calls", limiter.recordCalls)
}
}
func TestRateLimitMiddleware_AdminKeyBypass(t *testing.T) {
limiter := &mockRateLimiter{}
cfg := RateLimitConfig{
Limiter: limiter,
SkipPaths: make(map[string]bool),
}
handler := RateLimitMiddleware(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
ctx := auth.WithAPIKey(req.Context(), &auth.APIKey{ID: "admin"})
req = req.WithContext(ctx)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
// Rate limiter should not be called for admin
if limiter.recordCalls != 0 {
t.Errorf("expected RecordRequest to not be called for admin, got %d calls", limiter.recordCalls)
}
}
func TestRateLimitMiddleware_RecordError(t *testing.T) {
limiter := &mockRateLimiter{
recordErr: errors.New("record error"),
}
cfg := RateLimitConfig{
Limiter: limiter,
SkipPaths: make(map[string]bool),
}
handler := RateLimitMiddleware(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
ctx := auth.WithAPIKey(req.Context(), &auth.APIKey{ID: "test-key"})
req = req.WithContext(ctx)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
// Should fail open on error
if w.Code != http.StatusOK {
t.Errorf("expected fail-open with status %d, got %d", http.StatusOK, w.Code)
}
}
func TestRateLimitMiddleware_CheckError(t *testing.T) {
limiter := &mockRateLimiter{
checkErr: errors.New("check error"),
}
cfg := RateLimitConfig{
Limiter: limiter,
SkipPaths: make(map[string]bool),
}
handler := RateLimitMiddleware(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
ctx := auth.WithAPIKey(req.Context(), &auth.APIKey{ID: "test-key"})
req = req.WithContext(ctx)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
// Should fail open on error
if w.Code != http.StatusOK {
t.Errorf("expected fail-open with status %d, got %d", http.StatusOK, w.Code)
}
}
func TestDefaultRateLimitConfig(t *testing.T) {
cfg := DefaultRateLimitConfig()
expectedPaths := []string{"/health", "/ready", "/docs", "/openapi.json", "/metrics"}
for _, path := range expectedPaths {
if !cfg.SkipPaths[path] {
t.Errorf("expected %s to be in SkipPaths", path)
}
}
}
func TestSetRateLimitHeaders(t *testing.T) {
w := httptest.NewRecorder()
result := &domain.RateLimitResult{
Allowed: true,
RemainingMinute: 50,
RemainingHour: 900,
LimitMinute: 60,
LimitHour: 1000,
ResetMinute: time.Now().Add(time.Minute),
ResetHour: time.Now().Add(time.Hour),
}
setRateLimitHeaders(w, result)
tests := []struct {
header string
want bool
}{
{"X-RateLimit-Limit", true},
{"X-RateLimit-Remaining", true},
{"X-RateLimit-Reset", true},
{"X-RateLimit-Limit-Hour", true},
{"X-RateLimit-Remaining-Hour", true},
{"X-RateLimit-Reset-Hour", true},
}
for _, tt := range tests {
if (w.Header().Get(tt.header) != "") != tt.want {
t.Errorf("header %s: got %q, want present=%v", tt.header, w.Header().Get(tt.header), tt.want)
}
}
}