Major refactoring to hexagonal (ports & adapters) architecture: - Add service layer (apikey_service, project_service) for business logic - Add webhook system with dispatcher and delivery tracking - Add command queue with priority-based processing - Add rate limiting with sliding window algorithm - Add audit logging for command execution - Add OpenTelemetry integration (traces, metrics, spans) - Add circuit breaker for fault tolerance - Add cached repository wrapper for performance - Add comprehensive validation package - Add Kubernetes client integration for pod management - Add database migrations (allowed_ips, audit_log, rate_limiting, queue, webhooks) - Add network policy and PodDisruptionBudget for k8s - Remove legacy executor and projects/registry packages - Untrack secrets.yaml (now managed via envault) - Add coverage.out to .gitignore - Add e2e test infrastructure with docker-compose - Add comprehensive documentation (API, architecture, operations, plans) - Add golangci-lint config and pre-commit hook Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
320 lines
8.4 KiB
Go
320 lines
8.4 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/orchard9/rdev/internal/auth"
|
|
"github.com/orchard9/rdev/internal/domain"
|
|
)
|
|
|
|
// mockRateLimiter implements port.RateLimiter for testing.
|
|
type mockRateLimiter struct {
|
|
result *domain.RateLimitResult
|
|
checkErr error
|
|
recordErr error
|
|
recordCalls int
|
|
checkCalls int
|
|
}
|
|
|
|
func (m *mockRateLimiter) CheckLimit(ctx context.Context, apiKeyID string) (*domain.RateLimitResult, error) {
|
|
m.checkCalls++
|
|
if m.checkErr != nil {
|
|
return nil, m.checkErr
|
|
}
|
|
if m.result != nil {
|
|
return m.result, nil
|
|
}
|
|
// Default: allowed
|
|
return &domain.RateLimitResult{
|
|
Allowed: true,
|
|
RemainingMinute: 50,
|
|
RemainingHour: 900,
|
|
LimitMinute: 60,
|
|
LimitHour: 1000,
|
|
ResetMinute: time.Now().Add(time.Minute),
|
|
ResetHour: time.Now().Add(time.Hour),
|
|
}, nil
|
|
}
|
|
|
|
func (m *mockRateLimiter) RecordRequest(ctx context.Context, apiKeyID string) error {
|
|
m.recordCalls++
|
|
return m.recordErr
|
|
}
|
|
|
|
func (m *mockRateLimiter) GetLimits(ctx context.Context, apiKeyID string) (*domain.RateLimitConfig, error) {
|
|
return &domain.RateLimitConfig{
|
|
PerMinute: 60,
|
|
PerHour: 1000,
|
|
}, nil
|
|
}
|
|
|
|
func (m *mockRateLimiter) Cleanup(ctx context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
func TestRateLimitMiddleware_AllowedRequest(t *testing.T) {
|
|
limiter := &mockRateLimiter{}
|
|
cfg := RateLimitConfig{
|
|
Limiter: limiter,
|
|
SkipPaths: make(map[string]bool),
|
|
}
|
|
|
|
handler := RateLimitMiddleware(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
// Create request with API key context
|
|
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
|
ctx := auth.WithAPIKey(req.Context(), &auth.APIKey{ID: "test-key"})
|
|
req = req.WithContext(ctx)
|
|
w := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
|
}
|
|
|
|
// Verify rate limit headers are set
|
|
if w.Header().Get("X-RateLimit-Limit") == "" {
|
|
t.Error("expected X-RateLimit-Limit header to be set")
|
|
}
|
|
if w.Header().Get("X-RateLimit-Remaining") == "" {
|
|
t.Error("expected X-RateLimit-Remaining header to be set")
|
|
}
|
|
if w.Header().Get("X-RateLimit-Reset") == "" {
|
|
t.Error("expected X-RateLimit-Reset header to be set")
|
|
}
|
|
|
|
// Verify RecordRequest was called before CheckLimit
|
|
if limiter.recordCalls != 1 {
|
|
t.Errorf("expected RecordRequest to be called 1 time, got %d", limiter.recordCalls)
|
|
}
|
|
if limiter.checkCalls != 1 {
|
|
t.Errorf("expected CheckLimit to be called 1 time, got %d", limiter.checkCalls)
|
|
}
|
|
}
|
|
|
|
func TestRateLimitMiddleware_RateLimitExceeded(t *testing.T) {
|
|
limiter := &mockRateLimiter{
|
|
result: &domain.RateLimitResult{
|
|
Allowed: false,
|
|
RetryAfter: 5 * time.Second,
|
|
RemainingMinute: 0,
|
|
RemainingHour: 0,
|
|
LimitMinute: 60,
|
|
LimitHour: 1000,
|
|
ResetMinute: time.Now().Add(time.Minute),
|
|
ResetHour: time.Now().Add(time.Hour),
|
|
},
|
|
}
|
|
cfg := RateLimitConfig{
|
|
Limiter: limiter,
|
|
SkipPaths: make(map[string]bool),
|
|
}
|
|
|
|
handler := RateLimitMiddleware(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
t.Error("handler should not be called when rate limit exceeded")
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
|
ctx := auth.WithAPIKey(req.Context(), &auth.APIKey{ID: "test-key"})
|
|
req = req.WithContext(ctx)
|
|
w := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusTooManyRequests {
|
|
t.Errorf("expected status %d, got %d", http.StatusTooManyRequests, w.Code)
|
|
}
|
|
|
|
if w.Header().Get("Retry-After") == "" {
|
|
t.Error("expected Retry-After header to be set")
|
|
}
|
|
}
|
|
|
|
func TestRateLimitMiddleware_SkipPaths(t *testing.T) {
|
|
limiter := &mockRateLimiter{}
|
|
cfg := RateLimitConfig{
|
|
Limiter: limiter,
|
|
SkipPaths: map[string]bool{
|
|
"/health": true,
|
|
},
|
|
}
|
|
|
|
handler := RateLimitMiddleware(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/health", nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
|
}
|
|
|
|
// Rate limiter should not be called for skipped paths
|
|
if limiter.recordCalls != 0 {
|
|
t.Errorf("expected RecordRequest to not be called for skipped path, got %d calls", limiter.recordCalls)
|
|
}
|
|
}
|
|
|
|
func TestRateLimitMiddleware_NoAPIKey(t *testing.T) {
|
|
limiter := &mockRateLimiter{}
|
|
cfg := RateLimitConfig{
|
|
Limiter: limiter,
|
|
SkipPaths: make(map[string]bool),
|
|
}
|
|
|
|
handler := RateLimitMiddleware(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
// Request without API key context
|
|
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(w, req)
|
|
|
|
// Should pass through without rate limiting
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
|
}
|
|
|
|
// Rate limiter should not be called
|
|
if limiter.recordCalls != 0 {
|
|
t.Errorf("expected RecordRequest to not be called without API key, got %d calls", limiter.recordCalls)
|
|
}
|
|
}
|
|
|
|
func TestRateLimitMiddleware_AdminKeyBypass(t *testing.T) {
|
|
limiter := &mockRateLimiter{}
|
|
cfg := RateLimitConfig{
|
|
Limiter: limiter,
|
|
SkipPaths: make(map[string]bool),
|
|
}
|
|
|
|
handler := RateLimitMiddleware(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
|
ctx := auth.WithAPIKey(req.Context(), &auth.APIKey{ID: "admin"})
|
|
req = req.WithContext(ctx)
|
|
w := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
|
}
|
|
|
|
// Rate limiter should not be called for admin
|
|
if limiter.recordCalls != 0 {
|
|
t.Errorf("expected RecordRequest to not be called for admin, got %d calls", limiter.recordCalls)
|
|
}
|
|
}
|
|
|
|
func TestRateLimitMiddleware_RecordError(t *testing.T) {
|
|
limiter := &mockRateLimiter{
|
|
recordErr: errors.New("record error"),
|
|
}
|
|
cfg := RateLimitConfig{
|
|
Limiter: limiter,
|
|
SkipPaths: make(map[string]bool),
|
|
}
|
|
|
|
handler := RateLimitMiddleware(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
|
ctx := auth.WithAPIKey(req.Context(), &auth.APIKey{ID: "test-key"})
|
|
req = req.WithContext(ctx)
|
|
w := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(w, req)
|
|
|
|
// Should fail open on error
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected fail-open with status %d, got %d", http.StatusOK, w.Code)
|
|
}
|
|
}
|
|
|
|
func TestRateLimitMiddleware_CheckError(t *testing.T) {
|
|
limiter := &mockRateLimiter{
|
|
checkErr: errors.New("check error"),
|
|
}
|
|
cfg := RateLimitConfig{
|
|
Limiter: limiter,
|
|
SkipPaths: make(map[string]bool),
|
|
}
|
|
|
|
handler := RateLimitMiddleware(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
|
ctx := auth.WithAPIKey(req.Context(), &auth.APIKey{ID: "test-key"})
|
|
req = req.WithContext(ctx)
|
|
w := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(w, req)
|
|
|
|
// Should fail open on error
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected fail-open with status %d, got %d", http.StatusOK, w.Code)
|
|
}
|
|
}
|
|
|
|
func TestDefaultRateLimitConfig(t *testing.T) {
|
|
cfg := DefaultRateLimitConfig()
|
|
|
|
expectedPaths := []string{"/health", "/ready", "/docs", "/openapi.json", "/metrics"}
|
|
for _, path := range expectedPaths {
|
|
if !cfg.SkipPaths[path] {
|
|
t.Errorf("expected %s to be in SkipPaths", path)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestSetRateLimitHeaders(t *testing.T) {
|
|
w := httptest.NewRecorder()
|
|
result := &domain.RateLimitResult{
|
|
Allowed: true,
|
|
RemainingMinute: 50,
|
|
RemainingHour: 900,
|
|
LimitMinute: 60,
|
|
LimitHour: 1000,
|
|
ResetMinute: time.Now().Add(time.Minute),
|
|
ResetHour: time.Now().Add(time.Hour),
|
|
}
|
|
|
|
setRateLimitHeaders(w, result)
|
|
|
|
tests := []struct {
|
|
header string
|
|
want bool
|
|
}{
|
|
{"X-RateLimit-Limit", true},
|
|
{"X-RateLimit-Remaining", true},
|
|
{"X-RateLimit-Reset", true},
|
|
{"X-RateLimit-Limit-Hour", true},
|
|
{"X-RateLimit-Remaining-Hour", true},
|
|
{"X-RateLimit-Reset-Hour", true},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
if (w.Header().Get(tt.header) != "") != tt.want {
|
|
t.Errorf("header %s: got %q, want present=%v", tt.header, w.Header().Get(tt.header), tt.want)
|
|
}
|
|
}
|
|
}
|