package postgres import ( "context" "database/sql" "testing" "time" "github.com/orchard9/rdev/internal/domain" "github.com/orchard9/rdev/internal/testutil" ) func cleanupTestAuditLogs(t *testing.T, db *sql.DB) { t.Helper() _, err := db.Exec("DELETE FROM audit_log WHERE args LIKE 'test-%'") if err != nil { t.Logf("cleanup test audit logs: %v", err) } } func TestAuditLogger_LogCommandStart(t *testing.T) { db := testutil.TestDB(t) t.Cleanup(func() { cleanupTestAuditLogs(t, db) }) logger := NewAuditLogger(db) ctx := context.Background() t.Run("logs command start successfully", func(t *testing.T) { now := time.Now() entry := &domain.AuditLogEntry{ ID: "audit-test-1", APIKeyID: "key-test-1", CommandID: "cmd-test-1", ProjectID: "proj-test-1", CommandType: domain.CommandTypeClaude, Args: "test-args-1", ClientIP: "127.0.0.1", UserAgent: "test-agent", StartedAt: now, OutputSizeBytes: 0, } err := logger.LogCommandStart(ctx, entry) if err != nil { t.Fatalf("LogCommandStart() error = %v", err) } // Verify by retrieving retrieved, err := logger.Get(ctx, "cmd-test-1") if err != nil { t.Fatalf("Get() error = %v", err) } if retrieved.CommandID != "cmd-test-1" { t.Errorf("CommandID = %q, want %q", retrieved.CommandID, "cmd-test-1") } if retrieved.Status != domain.AuditStatusRunning { t.Errorf("Status = %q, want %q", retrieved.Status, domain.AuditStatusRunning) } }) } func TestAuditLogger_LogCommandEnd(t *testing.T) { db := testutil.TestDB(t) t.Cleanup(func() { cleanupTestAuditLogs(t, db) }) logger := NewAuditLogger(db) ctx := context.Background() t.Run("logs command end successfully", func(t *testing.T) { // First create a command start now := time.Now() entry := &domain.AuditLogEntry{ ID: "audit-test-end-1", APIKeyID: "key-test-2", CommandID: "cmd-test-end-1", ProjectID: "proj-test-2", CommandType: domain.CommandTypeShell, Args: "test-end-args", ClientIP: "127.0.0.1", UserAgent: "test-agent", StartedAt: now, } err := logger.LogCommandStart(ctx, entry) if err != nil { t.Fatalf("LogCommandStart() error = %v", err) } // Now log the end result := &domain.AuditResult{ ExitCode: 0, DurationMs: 1000, Status: domain.AuditStatusSuccess, ErrorMessage: "", OutputSizeBytes: 256, } err = logger.LogCommandEnd(ctx, "cmd-test-end-1", result) if err != nil { t.Fatalf("LogCommandEnd() error = %v", err) } // Verify retrieved, err := logger.Get(ctx, "cmd-test-end-1") if err != nil { t.Fatalf("Get() error = %v", err) } if retrieved.Status != domain.AuditStatusSuccess { t.Errorf("Status = %q, want %q", retrieved.Status, domain.AuditStatusSuccess) } if retrieved.ExitCode == nil || *retrieved.ExitCode != 0 { t.Errorf("ExitCode = %v, want 0", retrieved.ExitCode) } if retrieved.DurationMs == nil || *retrieved.DurationMs != 1000 { t.Errorf("DurationMs = %v, want 1000", retrieved.DurationMs) } if retrieved.CompletedAt == nil { t.Error("CompletedAt should be set") } }) t.Run("logs failed command", func(t *testing.T) { entry := &domain.AuditLogEntry{ ID: "audit-test-fail-1", APIKeyID: "key-test-3", CommandID: "cmd-test-fail-1", ProjectID: "proj-test-3", CommandType: domain.CommandTypeShell, Args: "test-fail-args", ClientIP: "127.0.0.1", UserAgent: "test-agent", StartedAt: time.Now(), } _ = logger.LogCommandStart(ctx, entry) result := &domain.AuditResult{ ExitCode: 1, DurationMs: 500, Status: domain.AuditStatusError, ErrorMessage: "command failed", } err := logger.LogCommandEnd(ctx, "cmd-test-fail-1", result) if err != nil { t.Fatalf("LogCommandEnd() error = %v", err) } retrieved, _ := logger.Get(ctx, "cmd-test-fail-1") if retrieved.Status != domain.AuditStatusError { t.Errorf("Status = %q, want %q", retrieved.Status, domain.AuditStatusError) } if retrieved.ErrorMessage != "command failed" { t.Errorf("ErrorMessage = %q, want %q", retrieved.ErrorMessage, "command failed") } }) } func TestAuditLogger_List(t *testing.T) { db := testutil.TestDB(t) t.Cleanup(func() { cleanupTestAuditLogs(t, db) }) logger := NewAuditLogger(db) ctx := context.Background() // Create test entries now := time.Now() for i := 0; i < 5; i++ { entry := &domain.AuditLogEntry{ ID: "audit-list-" + string(rune('a'+i)), APIKeyID: "key-list-1", CommandID: "cmd-list-" + string(rune('a'+i)), ProjectID: "proj-list-1", CommandType: domain.CommandTypeClaude, Args: "test-list-args", ClientIP: "127.0.0.1", UserAgent: "test-agent", StartedAt: now.Add(time.Duration(i) * time.Minute), } _ = logger.LogCommandStart(ctx, entry) } t.Run("lists all entries", func(t *testing.T) { entries, err := logger.List(ctx, domain.AuditFilters{ ProjectID: "proj-list-1", }) if err != nil { t.Fatalf("List() error = %v", err) } if len(entries) < 5 { t.Errorf("List() returned %d entries, want at least 5", len(entries)) } }) t.Run("filters by project", func(t *testing.T) { // Create entry in different project entry := &domain.AuditLogEntry{ ID: "audit-list-other", APIKeyID: "key-list-2", CommandID: "cmd-list-other", ProjectID: "proj-list-other", CommandType: domain.CommandTypeClaude, Args: "test-list-other", ClientIP: "127.0.0.1", UserAgent: "test-agent", StartedAt: now, } _ = logger.LogCommandStart(ctx, entry) entries, err := logger.List(ctx, domain.AuditFilters{ ProjectID: "proj-list-other", }) if err != nil { t.Fatalf("List() error = %v", err) } // Check all entries have the filtered project for _, e := range entries { if e.ProjectID != "proj-list-other" { t.Errorf("Entry has ProjectID = %q, want %q", e.ProjectID, "proj-list-other") } } }) t.Run("applies limit and offset", func(t *testing.T) { entries, err := logger.List(ctx, domain.AuditFilters{ ProjectID: "proj-list-1", Limit: 2, Offset: 0, }) if err != nil { t.Fatalf("List() error = %v", err) } if len(entries) != 2 { t.Errorf("List() returned %d entries, want 2", len(entries)) } }) t.Run("filters by command type", func(t *testing.T) { entries, err := logger.List(ctx, domain.AuditFilters{ ProjectID: "proj-list-1", CommandType: domain.CommandTypeClaude, }) if err != nil { t.Fatalf("List() error = %v", err) } for _, e := range entries { if e.CommandType != domain.CommandTypeClaude { t.Errorf("Entry has CommandType = %q, want %q", e.CommandType, domain.CommandTypeClaude) } } }) t.Run("filters by status", func(t *testing.T) { entries, err := logger.List(ctx, domain.AuditFilters{ ProjectID: "proj-list-1", Status: domain.AuditStatusRunning, }) if err != nil { t.Fatalf("List() error = %v", err) } for _, e := range entries { if e.Status != domain.AuditStatusRunning { t.Errorf("Entry has Status = %q, want %q", e.Status, domain.AuditStatusRunning) } } }) } func TestAuditLogger_Get(t *testing.T) { db := testutil.TestDB(t) t.Cleanup(func() { cleanupTestAuditLogs(t, db) }) logger := NewAuditLogger(db) ctx := context.Background() t.Run("gets existing entry", func(t *testing.T) { entry := &domain.AuditLogEntry{ ID: "audit-get-1", APIKeyID: "key-get-1", CommandID: "cmd-get-1", ProjectID: "proj-get-1", CommandType: domain.CommandTypeClaude, Args: "test-get-args", ClientIP: "10.0.0.1", UserAgent: "test-agent-get", StartedAt: time.Now(), } logger.LogCommandStart(ctx, entry) retrieved, err := logger.Get(ctx, "cmd-get-1") if err != nil { t.Fatalf("Get() error = %v", err) } if retrieved.CommandID != "cmd-get-1" { t.Errorf("CommandID = %q, want %q", retrieved.CommandID, "cmd-get-1") } if retrieved.ClientIP != "10.0.0.1" { t.Errorf("ClientIP = %q, want %q", retrieved.ClientIP, "10.0.0.1") } }) t.Run("returns error for non-existent entry", func(t *testing.T) { _, err := logger.Get(ctx, "cmd-nonexistent") if err != domain.ErrAuditNotFound { t.Errorf("Get() error = %v, want %v", err, domain.ErrAuditNotFound) } }) }