package postgres import ( "context" "crypto/sha256" "database/sql" "encoding/hex" "testing" "time" "github.com/orchard9/rdev/internal/domain" "github.com/orchard9/rdev/internal/testutil" ) // createTestAPIKey creates a test API key and returns its ID. func createTestAPIKey(t *testing.T, db *sql.DB, name string) string { t.Helper() repo := NewAPIKeyRepository(db) ctx := context.Background() key := &domain.APIKey{ Name: "test-ratelimit-" + name, KeyPrefix: "rl123456", Scopes: []domain.Scope{domain.ScopeProjectsRead}, CreatedBy: "test", } h := sha256.Sum256([]byte("ratelimit-key-" + name)) keyHash := hex.EncodeToString(h[:]) err := repo.Create(ctx, key, keyHash) if err != nil { t.Fatalf("create test API key: %v", err) } return string(key.ID) } func cleanupTestRateLimits(t *testing.T, db *sql.DB) { t.Helper() // Clean rate limit state for test keys _, err := db.Exec(` DELETE FROM rate_limit_state WHERE api_key_id IN (SELECT id FROM api_keys WHERE name LIKE 'test-ratelimit-%') `) if err != nil { t.Logf("cleanup test rate limits: %v", err) } // Clean up test API keys testutil.CleanupTestKeys(t, db) } func TestRateLimiter_RecordRequest(t *testing.T) { db := testutil.TestDB(t) t.Cleanup(func() { cleanupTestRateLimits(t, db) }) limiter := NewRateLimiter(db) ctx := context.Background() t.Run("records first request", func(t *testing.T) { keyID := createTestAPIKey(t, db, "record-first") err := limiter.RecordRequest(ctx, keyID) if err != nil { t.Fatalf("RecordRequest() error = %v", err) } // Verify by checking limits result, err := limiter.CheckLimit(ctx, keyID) if err != nil { t.Fatalf("CheckLimit() error = %v", err) } // Should have recorded one request if result.RemainingMinute >= result.LimitMinute { t.Error("RemainingMinute should be less than LimitMinute after recording a request") } }) t.Run("increments existing request count", func(t *testing.T) { keyID := createTestAPIKey(t, db, "record-increment") // Record multiple requests for i := 0; i < 3; i++ { err := limiter.RecordRequest(ctx, keyID) if err != nil { t.Fatalf("RecordRequest() iteration %d error = %v", i, err) } } result, _ := limiter.CheckLimit(ctx, keyID) expectedRemaining := result.LimitMinute - 3 if result.RemainingMinute != expectedRemaining { t.Errorf("RemainingMinute = %d, want %d", result.RemainingMinute, expectedRemaining) } }) } func TestRateLimiter_CheckLimit(t *testing.T) { db := testutil.TestDB(t) t.Cleanup(func() { cleanupTestRateLimits(t, db) }) limiter := NewRateLimiter(db) ctx := context.Background() t.Run("allows request when under limit", func(t *testing.T) { keyID := createTestAPIKey(t, db, "check-under") result, err := limiter.CheckLimit(ctx, keyID) if err != nil { t.Fatalf("CheckLimit() error = %v", err) } if !result.Allowed { t.Error("CheckLimit() should allow request when under limit") } if result.RemainingMinute <= 0 { t.Error("RemainingMinute should be positive") } if result.RemainingHour <= 0 { t.Error("RemainingHour should be positive") } }) t.Run("denies request when minute limit exceeded", func(t *testing.T) { keyID := createTestAPIKey(t, db, "check-minute-exceeded") // Get the limit limits, _ := limiter.GetLimits(ctx, keyID) // Record enough requests to exceed minute limit for i := 0; i < limits.PerMinute; i++ { _ = limiter.RecordRequest(ctx, keyID) } result, err := limiter.CheckLimit(ctx, keyID) if err != nil { t.Fatalf("CheckLimit() error = %v", err) } if result.Allowed { t.Error("CheckLimit() should deny request when minute limit exceeded") } if result.RetryAfter <= 0 { t.Error("RetryAfter should be positive when denied") } if result.RemainingMinute != 0 { t.Errorf("RemainingMinute = %d, want 0", result.RemainingMinute) } }) t.Run("returns correct reset times", func(t *testing.T) { keyID := createTestAPIKey(t, db, "check-reset") result, err := limiter.CheckLimit(ctx, keyID) if err != nil { t.Fatalf("CheckLimit() error = %v", err) } now := time.Now() // ResetMinute should be within the next minute if result.ResetMinute.Before(now) { t.Error("ResetMinute should be in the future") } if result.ResetMinute.After(now.Add(time.Minute + time.Second)) { t.Error("ResetMinute should be within ~1 minute from now") } // ResetHour should be within the next hour if result.ResetHour.Before(now) { t.Error("ResetHour should be in the future") } if result.ResetHour.After(now.Add(time.Hour + time.Second)) { t.Error("ResetHour should be within ~1 hour from now") } }) } func TestRateLimiter_GetLimits(t *testing.T) { db := testutil.TestDB(t) t.Cleanup(func() { cleanupTestRateLimits(t, db) }) limiter := NewRateLimiter(db) ctx := context.Background() t.Run("returns default limits for unknown key", func(t *testing.T) { // Use a UUID that doesn't exist limits, err := limiter.GetLimits(ctx, "00000000-0000-0000-0000-000000000000") if err != nil { t.Fatalf("GetLimits() error = %v", err) } defaults := domain.DefaultRateLimitConfig() if limits.PerMinute != defaults.PerMinute { t.Errorf("PerMinute = %d, want %d", limits.PerMinute, defaults.PerMinute) } if limits.PerHour != defaults.PerHour { t.Errorf("PerHour = %d, want %d", limits.PerHour, defaults.PerHour) } }) t.Run("returns limits from existing key", func(t *testing.T) { keyID := createTestAPIKey(t, db, "get-limits") limits, err := limiter.GetLimits(ctx, keyID) if err != nil { t.Fatalf("GetLimits() error = %v", err) } // Should return defaults since we didn't set custom limits defaults := domain.DefaultRateLimitConfig() if limits.PerMinute != defaults.PerMinute { t.Errorf("PerMinute = %d, want %d", limits.PerMinute, defaults.PerMinute) } }) } func TestRateLimiter_Cleanup(t *testing.T) { db := testutil.TestDB(t) t.Cleanup(func() { cleanupTestRateLimits(t, db) }) limiter := NewRateLimiter(db) ctx := context.Background() t.Run("cleanup runs without error", func(t *testing.T) { // Create some rate limit entries keyID := createTestAPIKey(t, db, "cleanup-entry") _ = limiter.RecordRequest(ctx, keyID) err := limiter.Cleanup(ctx) if err != nil { t.Fatalf("Cleanup() error = %v", err) } // Recent entries should not be deleted result, _ := limiter.CheckLimit(ctx, keyID) if result.RemainingMinute >= result.LimitMinute { // If the entry was cleaned up, remaining would equal limit t.Log("Note: Recent rate limit entry was not cleaned up (expected behavior)") } }) } func TestRateLimiter_WindowHandling(t *testing.T) { db := testutil.TestDB(t) t.Cleanup(func() { cleanupTestRateLimits(t, db) }) limiter := NewRateLimiter(db) ctx := context.Background() t.Run("minute and hour windows are tracked separately", func(t *testing.T) { keyID := createTestAPIKey(t, db, "windows") // Record a request _ = limiter.RecordRequest(ctx, keyID) result, err := limiter.CheckLimit(ctx, keyID) if err != nil { t.Fatalf("CheckLimit() error = %v", err) } // Both counters should reflect one request expectedMinuteRemaining := result.LimitMinute - 1 expectedHourRemaining := result.LimitHour - 1 if result.RemainingMinute != expectedMinuteRemaining { t.Errorf("RemainingMinute = %d, want %d", result.RemainingMinute, expectedMinuteRemaining) } if result.RemainingHour != expectedHourRemaining { t.Errorf("RemainingHour = %d, want %d", result.RemainingHour, expectedHourRemaining) } }) } func TestRateLimiter_ConcurrentRequests(t *testing.T) { db := testutil.TestDB(t) t.Cleanup(func() { cleanupTestRateLimits(t, db) }) limiter := NewRateLimiter(db) ctx := context.Background() keyID := createTestAPIKey(t, db, "concurrent") // Run concurrent requests const numRequests = 10 done := make(chan error, numRequests) for i := 0; i < numRequests; i++ { go func() { done <- limiter.RecordRequest(ctx, keyID) }() } // Wait for all requests to complete for i := 0; i < numRequests; i++ { if err := <-done; err != nil { t.Errorf("Concurrent RecordRequest() error = %v", err) } } // Verify the count result, err := limiter.CheckLimit(ctx, keyID) if err != nil { t.Fatalf("CheckLimit() error = %v", err) } expectedRemaining := result.LimitMinute - numRequests if result.RemainingMinute != expectedRemaining { t.Errorf("RemainingMinute = %d, want %d (all concurrent requests should be counted)", result.RemainingMinute, expectedRemaining) } }