rdev/internal/adapter/postgres/rate_limiter_test.go
jordan 72d16929ca feat: Implement hexagonal architecture with services, webhooks, queue, and telemetry
Major refactoring to hexagonal (ports & adapters) architecture:

- Add service layer (apikey_service, project_service) for business logic
- Add webhook system with dispatcher and delivery tracking
- Add command queue with priority-based processing
- Add rate limiting with sliding window algorithm
- Add audit logging for command execution
- Add OpenTelemetry integration (traces, metrics, spans)
- Add circuit breaker for fault tolerance
- Add cached repository wrapper for performance
- Add comprehensive validation package
- Add Kubernetes client integration for pod management
- Add database migrations (allowed_ips, audit_log, rate_limiting, queue, webhooks)
- Add network policy and PodDisruptionBudget for k8s
- Remove legacy executor and projects/registry packages
- Untrack secrets.yaml (now managed via envault)
- Add coverage.out to .gitignore
- Add e2e test infrastructure with docker-compose
- Add comprehensive documentation (API, architecture, operations, plans)
- Add golangci-lint config and pre-commit hook

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-25 19:57:46 -07:00

313 lines
8.4 KiB
Go

package postgres
import (
"context"
"crypto/sha256"
"database/sql"
"encoding/hex"
"testing"
"time"
"github.com/orchard9/rdev/internal/domain"
"github.com/orchard9/rdev/internal/testutil"
)
// createTestAPIKey creates a test API key and returns its ID.
func createTestAPIKey(t *testing.T, db *sql.DB, name string) string {
t.Helper()
repo := NewAPIKeyRepository(db)
ctx := context.Background()
key := &domain.APIKey{
Name: "test-ratelimit-" + name,
KeyPrefix: "rl123456",
Scopes: []domain.Scope{domain.ScopeProjectsRead},
CreatedBy: "test",
}
h := sha256.Sum256([]byte("ratelimit-key-" + name))
keyHash := hex.EncodeToString(h[:])
err := repo.Create(ctx, key, keyHash)
if err != nil {
t.Fatalf("create test API key: %v", err)
}
return string(key.ID)
}
func cleanupTestRateLimits(t *testing.T, db *sql.DB) {
t.Helper()
// Clean rate limit state for test keys
_, err := db.Exec(`
DELETE FROM rate_limit_state
WHERE api_key_id IN (SELECT id FROM api_keys WHERE name LIKE 'test-ratelimit-%')
`)
if err != nil {
t.Logf("cleanup test rate limits: %v", err)
}
// Clean up test API keys
testutil.CleanupTestKeys(t, db)
}
func TestRateLimiter_RecordRequest(t *testing.T) {
db := testutil.TestDB(t)
t.Cleanup(func() { cleanupTestRateLimits(t, db) })
limiter := NewRateLimiter(db)
ctx := context.Background()
t.Run("records first request", func(t *testing.T) {
keyID := createTestAPIKey(t, db, "record-first")
err := limiter.RecordRequest(ctx, keyID)
if err != nil {
t.Fatalf("RecordRequest() error = %v", err)
}
// Verify by checking limits
result, err := limiter.CheckLimit(ctx, keyID)
if err != nil {
t.Fatalf("CheckLimit() error = %v", err)
}
// Should have recorded one request
if result.RemainingMinute >= result.LimitMinute {
t.Error("RemainingMinute should be less than LimitMinute after recording a request")
}
})
t.Run("increments existing request count", func(t *testing.T) {
keyID := createTestAPIKey(t, db, "record-increment")
// Record multiple requests
for i := 0; i < 3; i++ {
err := limiter.RecordRequest(ctx, keyID)
if err != nil {
t.Fatalf("RecordRequest() iteration %d error = %v", i, err)
}
}
result, _ := limiter.CheckLimit(ctx, keyID)
expectedRemaining := result.LimitMinute - 3
if result.RemainingMinute != expectedRemaining {
t.Errorf("RemainingMinute = %d, want %d", result.RemainingMinute, expectedRemaining)
}
})
}
func TestRateLimiter_CheckLimit(t *testing.T) {
db := testutil.TestDB(t)
t.Cleanup(func() { cleanupTestRateLimits(t, db) })
limiter := NewRateLimiter(db)
ctx := context.Background()
t.Run("allows request when under limit", func(t *testing.T) {
keyID := createTestAPIKey(t, db, "check-under")
result, err := limiter.CheckLimit(ctx, keyID)
if err != nil {
t.Fatalf("CheckLimit() error = %v", err)
}
if !result.Allowed {
t.Error("CheckLimit() should allow request when under limit")
}
if result.RemainingMinute <= 0 {
t.Error("RemainingMinute should be positive")
}
if result.RemainingHour <= 0 {
t.Error("RemainingHour should be positive")
}
})
t.Run("denies request when minute limit exceeded", func(t *testing.T) {
keyID := createTestAPIKey(t, db, "check-minute-exceeded")
// Get the limit
limits, _ := limiter.GetLimits(ctx, keyID)
// Record enough requests to exceed minute limit
for i := 0; i < limits.PerMinute; i++ {
_ = limiter.RecordRequest(ctx, keyID)
}
result, err := limiter.CheckLimit(ctx, keyID)
if err != nil {
t.Fatalf("CheckLimit() error = %v", err)
}
if result.Allowed {
t.Error("CheckLimit() should deny request when minute limit exceeded")
}
if result.RetryAfter <= 0 {
t.Error("RetryAfter should be positive when denied")
}
if result.RemainingMinute != 0 {
t.Errorf("RemainingMinute = %d, want 0", result.RemainingMinute)
}
})
t.Run("returns correct reset times", func(t *testing.T) {
keyID := createTestAPIKey(t, db, "check-reset")
result, err := limiter.CheckLimit(ctx, keyID)
if err != nil {
t.Fatalf("CheckLimit() error = %v", err)
}
now := time.Now()
// ResetMinute should be within the next minute
if result.ResetMinute.Before(now) {
t.Error("ResetMinute should be in the future")
}
if result.ResetMinute.After(now.Add(time.Minute + time.Second)) {
t.Error("ResetMinute should be within ~1 minute from now")
}
// ResetHour should be within the next hour
if result.ResetHour.Before(now) {
t.Error("ResetHour should be in the future")
}
if result.ResetHour.After(now.Add(time.Hour + time.Second)) {
t.Error("ResetHour should be within ~1 hour from now")
}
})
}
func TestRateLimiter_GetLimits(t *testing.T) {
db := testutil.TestDB(t)
t.Cleanup(func() { cleanupTestRateLimits(t, db) })
limiter := NewRateLimiter(db)
ctx := context.Background()
t.Run("returns default limits for unknown key", func(t *testing.T) {
// Use a UUID that doesn't exist
limits, err := limiter.GetLimits(ctx, "00000000-0000-0000-0000-000000000000")
if err != nil {
t.Fatalf("GetLimits() error = %v", err)
}
defaults := domain.DefaultRateLimitConfig()
if limits.PerMinute != defaults.PerMinute {
t.Errorf("PerMinute = %d, want %d", limits.PerMinute, defaults.PerMinute)
}
if limits.PerHour != defaults.PerHour {
t.Errorf("PerHour = %d, want %d", limits.PerHour, defaults.PerHour)
}
})
t.Run("returns limits from existing key", func(t *testing.T) {
keyID := createTestAPIKey(t, db, "get-limits")
limits, err := limiter.GetLimits(ctx, keyID)
if err != nil {
t.Fatalf("GetLimits() error = %v", err)
}
// Should return defaults since we didn't set custom limits
defaults := domain.DefaultRateLimitConfig()
if limits.PerMinute != defaults.PerMinute {
t.Errorf("PerMinute = %d, want %d", limits.PerMinute, defaults.PerMinute)
}
})
}
func TestRateLimiter_Cleanup(t *testing.T) {
db := testutil.TestDB(t)
t.Cleanup(func() { cleanupTestRateLimits(t, db) })
limiter := NewRateLimiter(db)
ctx := context.Background()
t.Run("cleanup runs without error", func(t *testing.T) {
// Create some rate limit entries
keyID := createTestAPIKey(t, db, "cleanup-entry")
_ = limiter.RecordRequest(ctx, keyID)
err := limiter.Cleanup(ctx)
if err != nil {
t.Fatalf("Cleanup() error = %v", err)
}
// Recent entries should not be deleted
result, _ := limiter.CheckLimit(ctx, keyID)
if result.RemainingMinute >= result.LimitMinute {
// If the entry was cleaned up, remaining would equal limit
t.Log("Note: Recent rate limit entry was not cleaned up (expected behavior)")
}
})
}
func TestRateLimiter_WindowHandling(t *testing.T) {
db := testutil.TestDB(t)
t.Cleanup(func() { cleanupTestRateLimits(t, db) })
limiter := NewRateLimiter(db)
ctx := context.Background()
t.Run("minute and hour windows are tracked separately", func(t *testing.T) {
keyID := createTestAPIKey(t, db, "windows")
// Record a request
_ = limiter.RecordRequest(ctx, keyID)
result, err := limiter.CheckLimit(ctx, keyID)
if err != nil {
t.Fatalf("CheckLimit() error = %v", err)
}
// Both counters should reflect one request
expectedMinuteRemaining := result.LimitMinute - 1
expectedHourRemaining := result.LimitHour - 1
if result.RemainingMinute != expectedMinuteRemaining {
t.Errorf("RemainingMinute = %d, want %d", result.RemainingMinute, expectedMinuteRemaining)
}
if result.RemainingHour != expectedHourRemaining {
t.Errorf("RemainingHour = %d, want %d", result.RemainingHour, expectedHourRemaining)
}
})
}
func TestRateLimiter_ConcurrentRequests(t *testing.T) {
db := testutil.TestDB(t)
t.Cleanup(func() { cleanupTestRateLimits(t, db) })
limiter := NewRateLimiter(db)
ctx := context.Background()
keyID := createTestAPIKey(t, db, "concurrent")
// Run concurrent requests
const numRequests = 10
done := make(chan error, numRequests)
for i := 0; i < numRequests; i++ {
go func() {
done <- limiter.RecordRequest(ctx, keyID)
}()
}
// Wait for all requests to complete
for i := 0; i < numRequests; i++ {
if err := <-done; err != nil {
t.Errorf("Concurrent RecordRequest() error = %v", err)
}
}
// Verify the count
result, err := limiter.CheckLimit(ctx, keyID)
if err != nil {
t.Fatalf("CheckLimit() error = %v", err)
}
expectedRemaining := result.LimitMinute - numRequests
if result.RemainingMinute != expectedRemaining {
t.Errorf("RemainingMinute = %d, want %d (all concurrent requests should be counted)", result.RemainingMinute, expectedRemaining)
}
}