Major refactoring to hexagonal (ports & adapters) architecture: - Add service layer (apikey_service, project_service) for business logic - Add webhook system with dispatcher and delivery tracking - Add command queue with priority-based processing - Add rate limiting with sliding window algorithm - Add audit logging for command execution - Add OpenTelemetry integration (traces, metrics, spans) - Add circuit breaker for fault tolerance - Add cached repository wrapper for performance - Add comprehensive validation package - Add Kubernetes client integration for pod management - Add database migrations (allowed_ips, audit_log, rate_limiting, queue, webhooks) - Add network policy and PodDisruptionBudget for k8s - Remove legacy executor and projects/registry packages - Untrack secrets.yaml (now managed via envault) - Add coverage.out to .gitignore - Add e2e test infrastructure with docker-compose - Add comprehensive documentation (API, architecture, operations, plans) - Add golangci-lint config and pre-commit hook Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
317 lines
8.3 KiB
Go
317 lines
8.3 KiB
Go
package postgres
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/orchard9/rdev/internal/domain"
|
|
"github.com/orchard9/rdev/internal/testutil"
|
|
)
|
|
|
|
func cleanupTestAuditLogs(t *testing.T, db *sql.DB) {
|
|
t.Helper()
|
|
_, err := db.Exec("DELETE FROM audit_log WHERE args LIKE 'test-%'")
|
|
if err != nil {
|
|
t.Logf("cleanup test audit logs: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestAuditLogger_LogCommandStart(t *testing.T) {
|
|
db := testutil.TestDB(t)
|
|
t.Cleanup(func() { cleanupTestAuditLogs(t, db) })
|
|
|
|
logger := NewAuditLogger(db)
|
|
ctx := context.Background()
|
|
|
|
t.Run("logs command start successfully", func(t *testing.T) {
|
|
now := time.Now()
|
|
entry := &domain.AuditLogEntry{
|
|
ID: "audit-test-1",
|
|
APIKeyID: "key-test-1",
|
|
CommandID: "cmd-test-1",
|
|
ProjectID: "proj-test-1",
|
|
CommandType: domain.CommandTypeClaude,
|
|
Args: "test-args-1",
|
|
ClientIP: "127.0.0.1",
|
|
UserAgent: "test-agent",
|
|
StartedAt: now,
|
|
OutputSizeBytes: 0,
|
|
}
|
|
|
|
err := logger.LogCommandStart(ctx, entry)
|
|
if err != nil {
|
|
t.Fatalf("LogCommandStart() error = %v", err)
|
|
}
|
|
|
|
// Verify by retrieving
|
|
retrieved, err := logger.Get(ctx, "cmd-test-1")
|
|
if err != nil {
|
|
t.Fatalf("Get() error = %v", err)
|
|
}
|
|
|
|
if retrieved.CommandID != "cmd-test-1" {
|
|
t.Errorf("CommandID = %q, want %q", retrieved.CommandID, "cmd-test-1")
|
|
}
|
|
if retrieved.Status != domain.AuditStatusRunning {
|
|
t.Errorf("Status = %q, want %q", retrieved.Status, domain.AuditStatusRunning)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestAuditLogger_LogCommandEnd(t *testing.T) {
|
|
db := testutil.TestDB(t)
|
|
t.Cleanup(func() { cleanupTestAuditLogs(t, db) })
|
|
|
|
logger := NewAuditLogger(db)
|
|
ctx := context.Background()
|
|
|
|
t.Run("logs command end successfully", func(t *testing.T) {
|
|
// First create a command start
|
|
now := time.Now()
|
|
entry := &domain.AuditLogEntry{
|
|
ID: "audit-test-end-1",
|
|
APIKeyID: "key-test-2",
|
|
CommandID: "cmd-test-end-1",
|
|
ProjectID: "proj-test-2",
|
|
CommandType: domain.CommandTypeShell,
|
|
Args: "test-end-args",
|
|
ClientIP: "127.0.0.1",
|
|
UserAgent: "test-agent",
|
|
StartedAt: now,
|
|
}
|
|
|
|
err := logger.LogCommandStart(ctx, entry)
|
|
if err != nil {
|
|
t.Fatalf("LogCommandStart() error = %v", err)
|
|
}
|
|
|
|
// Now log the end
|
|
result := &domain.AuditResult{
|
|
ExitCode: 0,
|
|
DurationMs: 1000,
|
|
Status: domain.AuditStatusSuccess,
|
|
ErrorMessage: "",
|
|
OutputSizeBytes: 256,
|
|
}
|
|
|
|
err = logger.LogCommandEnd(ctx, "cmd-test-end-1", result)
|
|
if err != nil {
|
|
t.Fatalf("LogCommandEnd() error = %v", err)
|
|
}
|
|
|
|
// Verify
|
|
retrieved, err := logger.Get(ctx, "cmd-test-end-1")
|
|
if err != nil {
|
|
t.Fatalf("Get() error = %v", err)
|
|
}
|
|
|
|
if retrieved.Status != domain.AuditStatusSuccess {
|
|
t.Errorf("Status = %q, want %q", retrieved.Status, domain.AuditStatusSuccess)
|
|
}
|
|
if retrieved.ExitCode == nil || *retrieved.ExitCode != 0 {
|
|
t.Errorf("ExitCode = %v, want 0", retrieved.ExitCode)
|
|
}
|
|
if retrieved.DurationMs == nil || *retrieved.DurationMs != 1000 {
|
|
t.Errorf("DurationMs = %v, want 1000", retrieved.DurationMs)
|
|
}
|
|
if retrieved.CompletedAt == nil {
|
|
t.Error("CompletedAt should be set")
|
|
}
|
|
})
|
|
|
|
t.Run("logs failed command", func(t *testing.T) {
|
|
entry := &domain.AuditLogEntry{
|
|
ID: "audit-test-fail-1",
|
|
APIKeyID: "key-test-3",
|
|
CommandID: "cmd-test-fail-1",
|
|
ProjectID: "proj-test-3",
|
|
CommandType: domain.CommandTypeShell,
|
|
Args: "test-fail-args",
|
|
ClientIP: "127.0.0.1",
|
|
UserAgent: "test-agent",
|
|
StartedAt: time.Now(),
|
|
}
|
|
|
|
_ = logger.LogCommandStart(ctx, entry)
|
|
|
|
result := &domain.AuditResult{
|
|
ExitCode: 1,
|
|
DurationMs: 500,
|
|
Status: domain.AuditStatusError,
|
|
ErrorMessage: "command failed",
|
|
}
|
|
|
|
err := logger.LogCommandEnd(ctx, "cmd-test-fail-1", result)
|
|
if err != nil {
|
|
t.Fatalf("LogCommandEnd() error = %v", err)
|
|
}
|
|
|
|
retrieved, _ := logger.Get(ctx, "cmd-test-fail-1")
|
|
if retrieved.Status != domain.AuditStatusError {
|
|
t.Errorf("Status = %q, want %q", retrieved.Status, domain.AuditStatusError)
|
|
}
|
|
if retrieved.ErrorMessage != "command failed" {
|
|
t.Errorf("ErrorMessage = %q, want %q", retrieved.ErrorMessage, "command failed")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestAuditLogger_List(t *testing.T) {
|
|
db := testutil.TestDB(t)
|
|
t.Cleanup(func() { cleanupTestAuditLogs(t, db) })
|
|
|
|
logger := NewAuditLogger(db)
|
|
ctx := context.Background()
|
|
|
|
// Create test entries
|
|
now := time.Now()
|
|
for i := 0; i < 5; i++ {
|
|
entry := &domain.AuditLogEntry{
|
|
ID: "audit-list-" + string(rune('a'+i)),
|
|
APIKeyID: "key-list-1",
|
|
CommandID: "cmd-list-" + string(rune('a'+i)),
|
|
ProjectID: "proj-list-1",
|
|
CommandType: domain.CommandTypeClaude,
|
|
Args: "test-list-args",
|
|
ClientIP: "127.0.0.1",
|
|
UserAgent: "test-agent",
|
|
StartedAt: now.Add(time.Duration(i) * time.Minute),
|
|
}
|
|
_ = logger.LogCommandStart(ctx, entry)
|
|
}
|
|
|
|
t.Run("lists all entries", func(t *testing.T) {
|
|
entries, err := logger.List(ctx, domain.AuditFilters{
|
|
ProjectID: "proj-list-1",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("List() error = %v", err)
|
|
}
|
|
|
|
if len(entries) < 5 {
|
|
t.Errorf("List() returned %d entries, want at least 5", len(entries))
|
|
}
|
|
})
|
|
|
|
t.Run("filters by project", func(t *testing.T) {
|
|
// Create entry in different project
|
|
entry := &domain.AuditLogEntry{
|
|
ID: "audit-list-other",
|
|
APIKeyID: "key-list-2",
|
|
CommandID: "cmd-list-other",
|
|
ProjectID: "proj-list-other",
|
|
CommandType: domain.CommandTypeClaude,
|
|
Args: "test-list-other",
|
|
ClientIP: "127.0.0.1",
|
|
UserAgent: "test-agent",
|
|
StartedAt: now,
|
|
}
|
|
_ = logger.LogCommandStart(ctx, entry)
|
|
|
|
entries, err := logger.List(ctx, domain.AuditFilters{
|
|
ProjectID: "proj-list-other",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("List() error = %v", err)
|
|
}
|
|
|
|
// Check all entries have the filtered project
|
|
for _, e := range entries {
|
|
if e.ProjectID != "proj-list-other" {
|
|
t.Errorf("Entry has ProjectID = %q, want %q", e.ProjectID, "proj-list-other")
|
|
}
|
|
}
|
|
})
|
|
|
|
t.Run("applies limit and offset", func(t *testing.T) {
|
|
entries, err := logger.List(ctx, domain.AuditFilters{
|
|
ProjectID: "proj-list-1",
|
|
Limit: 2,
|
|
Offset: 0,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("List() error = %v", err)
|
|
}
|
|
|
|
if len(entries) != 2 {
|
|
t.Errorf("List() returned %d entries, want 2", len(entries))
|
|
}
|
|
})
|
|
|
|
t.Run("filters by command type", func(t *testing.T) {
|
|
entries, err := logger.List(ctx, domain.AuditFilters{
|
|
ProjectID: "proj-list-1",
|
|
CommandType: domain.CommandTypeClaude,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("List() error = %v", err)
|
|
}
|
|
|
|
for _, e := range entries {
|
|
if e.CommandType != domain.CommandTypeClaude {
|
|
t.Errorf("Entry has CommandType = %q, want %q", e.CommandType, domain.CommandTypeClaude)
|
|
}
|
|
}
|
|
})
|
|
|
|
t.Run("filters by status", func(t *testing.T) {
|
|
entries, err := logger.List(ctx, domain.AuditFilters{
|
|
ProjectID: "proj-list-1",
|
|
Status: domain.AuditStatusRunning,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("List() error = %v", err)
|
|
}
|
|
|
|
for _, e := range entries {
|
|
if e.Status != domain.AuditStatusRunning {
|
|
t.Errorf("Entry has Status = %q, want %q", e.Status, domain.AuditStatusRunning)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestAuditLogger_Get(t *testing.T) {
|
|
db := testutil.TestDB(t)
|
|
t.Cleanup(func() { cleanupTestAuditLogs(t, db) })
|
|
|
|
logger := NewAuditLogger(db)
|
|
ctx := context.Background()
|
|
|
|
t.Run("gets existing entry", func(t *testing.T) {
|
|
entry := &domain.AuditLogEntry{
|
|
ID: "audit-get-1",
|
|
APIKeyID: "key-get-1",
|
|
CommandID: "cmd-get-1",
|
|
ProjectID: "proj-get-1",
|
|
CommandType: domain.CommandTypeClaude,
|
|
Args: "test-get-args",
|
|
ClientIP: "10.0.0.1",
|
|
UserAgent: "test-agent-get",
|
|
StartedAt: time.Now(),
|
|
}
|
|
logger.LogCommandStart(ctx, entry)
|
|
|
|
retrieved, err := logger.Get(ctx, "cmd-get-1")
|
|
if err != nil {
|
|
t.Fatalf("Get() error = %v", err)
|
|
}
|
|
|
|
if retrieved.CommandID != "cmd-get-1" {
|
|
t.Errorf("CommandID = %q, want %q", retrieved.CommandID, "cmd-get-1")
|
|
}
|
|
if retrieved.ClientIP != "10.0.0.1" {
|
|
t.Errorf("ClientIP = %q, want %q", retrieved.ClientIP, "10.0.0.1")
|
|
}
|
|
})
|
|
|
|
t.Run("returns error for non-existent entry", func(t *testing.T) {
|
|
_, err := logger.Get(ctx, "cmd-nonexistent")
|
|
if err != domain.ErrAuditNotFound {
|
|
t.Errorf("Get() error = %v, want %v", err, domain.ErrAuditNotFound)
|
|
}
|
|
})
|
|
}
|