Major refactoring to hexagonal (ports & adapters) architecture: - Add service layer (apikey_service, project_service) for business logic - Add webhook system with dispatcher and delivery tracking - Add command queue with priority-based processing - Add rate limiting with sliding window algorithm - Add audit logging for command execution - Add OpenTelemetry integration (traces, metrics, spans) - Add circuit breaker for fault tolerance - Add cached repository wrapper for performance - Add comprehensive validation package - Add Kubernetes client integration for pod management - Add database migrations (allowed_ips, audit_log, rate_limiting, queue, webhooks) - Add network policy and PodDisruptionBudget for k8s - Remove legacy executor and projects/registry packages - Untrack secrets.yaml (now managed via envault) - Add coverage.out to .gitignore - Add e2e test infrastructure with docker-compose - Add comprehensive documentation (API, architecture, operations, plans) - Add golangci-lint config and pre-commit hook Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
237 lines
6.2 KiB
Go
237 lines
6.2 KiB
Go
// Package postgres provides PostgreSQL-based implementations of port interfaces.
|
|
package postgres
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"time"
|
|
|
|
"github.com/orchard9/rdev/internal/domain"
|
|
"github.com/orchard9/rdev/internal/port"
|
|
)
|
|
|
|
// RateLimiter implements port.RateLimiter using PostgreSQL.
|
|
type RateLimiter struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
// NewRateLimiter creates a new PostgreSQL rate limiter.
|
|
func NewRateLimiter(db *sql.DB) *RateLimiter {
|
|
return &RateLimiter{db: db}
|
|
}
|
|
|
|
// Ensure RateLimiter implements port.RateLimiter at compile time.
|
|
var _ port.RateLimiter = (*RateLimiter)(nil)
|
|
|
|
// CheckLimit checks if a request is allowed under the rate limit.
|
|
func (r *RateLimiter) CheckLimit(ctx context.Context, keyID string) (*domain.RateLimitResult, error) {
|
|
now := time.Now()
|
|
minuteWindow := domain.TruncateToMinute(now)
|
|
hourWindow := domain.TruncateToHour(now)
|
|
|
|
// Get rate limits for this key
|
|
limits, err := r.GetLimits(ctx, keyID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get limits: %w", err)
|
|
}
|
|
|
|
// Get current usage for minute window
|
|
minuteCount, err := r.getWindowCount(ctx, keyID, minuteWindow, domain.WindowTypeMinute)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get minute count: %w", err)
|
|
}
|
|
|
|
// Get current usage for hour window
|
|
hourCount, err := r.getWindowCount(ctx, keyID, hourWindow, domain.WindowTypeHour)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get hour count: %w", err)
|
|
}
|
|
|
|
result := &domain.RateLimitResult{
|
|
LimitMinute: limits.PerMinute,
|
|
LimitHour: limits.PerHour,
|
|
RemainingMinute: limits.PerMinute - minuteCount,
|
|
RemainingHour: limits.PerHour - hourCount,
|
|
ResetMinute: minuteWindow.Add(time.Minute),
|
|
ResetHour: hourWindow.Add(time.Hour),
|
|
}
|
|
|
|
// Ensure remaining doesn't go negative
|
|
if result.RemainingMinute < 0 {
|
|
result.RemainingMinute = 0
|
|
}
|
|
if result.RemainingHour < 0 {
|
|
result.RemainingHour = 0
|
|
}
|
|
|
|
// Check if either limit is exceeded
|
|
if minuteCount >= limits.PerMinute {
|
|
result.Allowed = false
|
|
result.RetryAfter = time.Until(result.ResetMinute)
|
|
if result.RetryAfter < 0 {
|
|
result.RetryAfter = time.Second
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
if hourCount >= limits.PerHour {
|
|
result.Allowed = false
|
|
result.RetryAfter = time.Until(result.ResetHour)
|
|
if result.RetryAfter < 0 {
|
|
result.RetryAfter = time.Second
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
result.Allowed = true
|
|
return result, nil
|
|
}
|
|
|
|
// RecordRequest records that a request was made for the given API key.
|
|
func (r *RateLimiter) RecordRequest(ctx context.Context, keyID string) error {
|
|
now := time.Now()
|
|
minuteWindow := domain.TruncateToMinute(now)
|
|
hourWindow := domain.TruncateToHour(now)
|
|
|
|
// Use a transaction to update both windows atomically
|
|
tx, err := r.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("begin tx: %w", err)
|
|
}
|
|
defer func() { _ = tx.Rollback() }()
|
|
|
|
// Upsert minute window
|
|
if err := r.upsertWindow(ctx, tx, keyID, minuteWindow, domain.WindowTypeMinute); err != nil {
|
|
return fmt.Errorf("upsert minute window: %w", err)
|
|
}
|
|
|
|
// Upsert hour window
|
|
if err := r.upsertWindow(ctx, tx, keyID, hourWindow, domain.WindowTypeHour); err != nil {
|
|
return fmt.Errorf("upsert hour window: %w", err)
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return fmt.Errorf("commit: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetLimits retrieves the rate limit configuration for an API key.
|
|
func (r *RateLimiter) GetLimits(ctx context.Context, keyID string) (*domain.RateLimitConfig, error) {
|
|
var perMinute, perHour sql.NullInt64
|
|
|
|
err := r.db.QueryRowContext(ctx, `
|
|
SELECT rate_limit_per_minute, rate_limit_per_hour
|
|
FROM api_keys
|
|
WHERE id = $1
|
|
`, keyID).Scan(&perMinute, &perHour)
|
|
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
// Key not found, return defaults
|
|
defaults := domain.DefaultRateLimitConfig()
|
|
return &defaults, nil
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("query limits: %w", err)
|
|
}
|
|
|
|
config := domain.DefaultRateLimitConfig()
|
|
if perMinute.Valid {
|
|
config.PerMinute = int(perMinute.Int64)
|
|
}
|
|
if perHour.Valid {
|
|
config.PerHour = int(perHour.Int64)
|
|
}
|
|
|
|
return &config, nil
|
|
}
|
|
|
|
// Cleanup removes expired rate limit state entries.
|
|
func (r *RateLimiter) Cleanup(ctx context.Context) error {
|
|
// Remove entries older than 2 hours (well past any active window)
|
|
cutoff := time.Now().Add(-2 * time.Hour)
|
|
|
|
result, err := r.db.ExecContext(ctx, `
|
|
DELETE FROM rate_limit_state
|
|
WHERE window_start < $1
|
|
`, cutoff)
|
|
if err != nil {
|
|
return fmt.Errorf("delete old entries: %w", err)
|
|
}
|
|
|
|
rows, _ := result.RowsAffected()
|
|
if rows > 0 {
|
|
// Log cleanup (optional, could use structured logging)
|
|
_ = rows
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// getWindowCount returns the request count for a specific window.
|
|
func (r *RateLimiter) getWindowCount(ctx context.Context, keyID string, windowStart time.Time, windowType string) (int, error) {
|
|
var count int
|
|
err := r.db.QueryRowContext(ctx, `
|
|
SELECT COALESCE(request_count, 0)
|
|
FROM rate_limit_state
|
|
WHERE api_key_id = $1 AND window_start = $2 AND window_type = $3
|
|
`, keyID, windowStart, windowType).Scan(&count)
|
|
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return 0, nil
|
|
}
|
|
if err != nil {
|
|
return 0, fmt.Errorf("query count: %w", err)
|
|
}
|
|
|
|
return count, nil
|
|
}
|
|
|
|
// upsertWindow inserts or updates a rate limit window.
|
|
func (r *RateLimiter) upsertWindow(ctx context.Context, tx *sql.Tx, keyID string, windowStart time.Time, windowType string) error {
|
|
_, err := tx.ExecContext(ctx, `
|
|
INSERT INTO rate_limit_state (api_key_id, window_start, window_type, request_count, updated_at)
|
|
VALUES ($1, $2, $3, 1, NOW())
|
|
ON CONFLICT (api_key_id, window_start, window_type)
|
|
DO UPDATE SET request_count = rate_limit_state.request_count + 1, updated_at = NOW()
|
|
`, keyID, windowStart, windowType)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("upsert: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// StartCleanupWorker starts a background goroutine that periodically cleans up expired entries.
|
|
// Returns a stop function to terminate the worker.
|
|
func (r *RateLimiter) StartCleanupWorker(ctx context.Context, interval time.Duration) func() {
|
|
stopCh := make(chan struct{})
|
|
|
|
go func() {
|
|
ticker := time.NewTicker(interval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-stopCh:
|
|
return
|
|
case <-ticker.C:
|
|
if err := r.Cleanup(ctx); err != nil {
|
|
slog.Error("rate limit cleanup failed", "error", err)
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
return func() {
|
|
close(stopCh)
|
|
}
|
|
}
|