rdev/internal/adapter/postgres/audit_logger_test.go
jordan 72d16929ca feat: Implement hexagonal architecture with services, webhooks, queue, and telemetry
Major refactoring to hexagonal (ports & adapters) architecture:

- Add service layer (apikey_service, project_service) for business logic
- Add webhook system with dispatcher and delivery tracking
- Add command queue with priority-based processing
- Add rate limiting with sliding window algorithm
- Add audit logging for command execution
- Add OpenTelemetry integration (traces, metrics, spans)
- Add circuit breaker for fault tolerance
- Add cached repository wrapper for performance
- Add comprehensive validation package
- Add Kubernetes client integration for pod management
- Add database migrations (allowed_ips, audit_log, rate_limiting, queue, webhooks)
- Add network policy and PodDisruptionBudget for k8s
- Remove legacy executor and projects/registry packages
- Untrack secrets.yaml (now managed via envault)
- Add coverage.out to .gitignore
- Add e2e test infrastructure with docker-compose
- Add comprehensive documentation (API, architecture, operations, plans)
- Add golangci-lint config and pre-commit hook

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-25 19:57:46 -07:00

317 lines
8.3 KiB
Go

package postgres
import (
"context"
"database/sql"
"testing"
"time"
"github.com/orchard9/rdev/internal/domain"
"github.com/orchard9/rdev/internal/testutil"
)
func cleanupTestAuditLogs(t *testing.T, db *sql.DB) {
t.Helper()
_, err := db.Exec("DELETE FROM audit_log WHERE args LIKE 'test-%'")
if err != nil {
t.Logf("cleanup test audit logs: %v", err)
}
}
func TestAuditLogger_LogCommandStart(t *testing.T) {
db := testutil.TestDB(t)
t.Cleanup(func() { cleanupTestAuditLogs(t, db) })
logger := NewAuditLogger(db)
ctx := context.Background()
t.Run("logs command start successfully", func(t *testing.T) {
now := time.Now()
entry := &domain.AuditLogEntry{
ID: "audit-test-1",
APIKeyID: "key-test-1",
CommandID: "cmd-test-1",
ProjectID: "proj-test-1",
CommandType: domain.CommandTypeClaude,
Args: "test-args-1",
ClientIP: "127.0.0.1",
UserAgent: "test-agent",
StartedAt: now,
OutputSizeBytes: 0,
}
err := logger.LogCommandStart(ctx, entry)
if err != nil {
t.Fatalf("LogCommandStart() error = %v", err)
}
// Verify by retrieving
retrieved, err := logger.Get(ctx, "cmd-test-1")
if err != nil {
t.Fatalf("Get() error = %v", err)
}
if retrieved.CommandID != "cmd-test-1" {
t.Errorf("CommandID = %q, want %q", retrieved.CommandID, "cmd-test-1")
}
if retrieved.Status != domain.AuditStatusRunning {
t.Errorf("Status = %q, want %q", retrieved.Status, domain.AuditStatusRunning)
}
})
}
func TestAuditLogger_LogCommandEnd(t *testing.T) {
db := testutil.TestDB(t)
t.Cleanup(func() { cleanupTestAuditLogs(t, db) })
logger := NewAuditLogger(db)
ctx := context.Background()
t.Run("logs command end successfully", func(t *testing.T) {
// First create a command start
now := time.Now()
entry := &domain.AuditLogEntry{
ID: "audit-test-end-1",
APIKeyID: "key-test-2",
CommandID: "cmd-test-end-1",
ProjectID: "proj-test-2",
CommandType: domain.CommandTypeShell,
Args: "test-end-args",
ClientIP: "127.0.0.1",
UserAgent: "test-agent",
StartedAt: now,
}
err := logger.LogCommandStart(ctx, entry)
if err != nil {
t.Fatalf("LogCommandStart() error = %v", err)
}
// Now log the end
result := &domain.AuditResult{
ExitCode: 0,
DurationMs: 1000,
Status: domain.AuditStatusSuccess,
ErrorMessage: "",
OutputSizeBytes: 256,
}
err = logger.LogCommandEnd(ctx, "cmd-test-end-1", result)
if err != nil {
t.Fatalf("LogCommandEnd() error = %v", err)
}
// Verify
retrieved, err := logger.Get(ctx, "cmd-test-end-1")
if err != nil {
t.Fatalf("Get() error = %v", err)
}
if retrieved.Status != domain.AuditStatusSuccess {
t.Errorf("Status = %q, want %q", retrieved.Status, domain.AuditStatusSuccess)
}
if retrieved.ExitCode == nil || *retrieved.ExitCode != 0 {
t.Errorf("ExitCode = %v, want 0", retrieved.ExitCode)
}
if retrieved.DurationMs == nil || *retrieved.DurationMs != 1000 {
t.Errorf("DurationMs = %v, want 1000", retrieved.DurationMs)
}
if retrieved.CompletedAt == nil {
t.Error("CompletedAt should be set")
}
})
t.Run("logs failed command", func(t *testing.T) {
entry := &domain.AuditLogEntry{
ID: "audit-test-fail-1",
APIKeyID: "key-test-3",
CommandID: "cmd-test-fail-1",
ProjectID: "proj-test-3",
CommandType: domain.CommandTypeShell,
Args: "test-fail-args",
ClientIP: "127.0.0.1",
UserAgent: "test-agent",
StartedAt: time.Now(),
}
_ = logger.LogCommandStart(ctx, entry)
result := &domain.AuditResult{
ExitCode: 1,
DurationMs: 500,
Status: domain.AuditStatusError,
ErrorMessage: "command failed",
}
err := logger.LogCommandEnd(ctx, "cmd-test-fail-1", result)
if err != nil {
t.Fatalf("LogCommandEnd() error = %v", err)
}
retrieved, _ := logger.Get(ctx, "cmd-test-fail-1")
if retrieved.Status != domain.AuditStatusError {
t.Errorf("Status = %q, want %q", retrieved.Status, domain.AuditStatusError)
}
if retrieved.ErrorMessage != "command failed" {
t.Errorf("ErrorMessage = %q, want %q", retrieved.ErrorMessage, "command failed")
}
})
}
func TestAuditLogger_List(t *testing.T) {
db := testutil.TestDB(t)
t.Cleanup(func() { cleanupTestAuditLogs(t, db) })
logger := NewAuditLogger(db)
ctx := context.Background()
// Create test entries
now := time.Now()
for i := 0; i < 5; i++ {
entry := &domain.AuditLogEntry{
ID: "audit-list-" + string(rune('a'+i)),
APIKeyID: "key-list-1",
CommandID: "cmd-list-" + string(rune('a'+i)),
ProjectID: "proj-list-1",
CommandType: domain.CommandTypeClaude,
Args: "test-list-args",
ClientIP: "127.0.0.1",
UserAgent: "test-agent",
StartedAt: now.Add(time.Duration(i) * time.Minute),
}
_ = logger.LogCommandStart(ctx, entry)
}
t.Run("lists all entries", func(t *testing.T) {
entries, err := logger.List(ctx, domain.AuditFilters{
ProjectID: "proj-list-1",
})
if err != nil {
t.Fatalf("List() error = %v", err)
}
if len(entries) < 5 {
t.Errorf("List() returned %d entries, want at least 5", len(entries))
}
})
t.Run("filters by project", func(t *testing.T) {
// Create entry in different project
entry := &domain.AuditLogEntry{
ID: "audit-list-other",
APIKeyID: "key-list-2",
CommandID: "cmd-list-other",
ProjectID: "proj-list-other",
CommandType: domain.CommandTypeClaude,
Args: "test-list-other",
ClientIP: "127.0.0.1",
UserAgent: "test-agent",
StartedAt: now,
}
_ = logger.LogCommandStart(ctx, entry)
entries, err := logger.List(ctx, domain.AuditFilters{
ProjectID: "proj-list-other",
})
if err != nil {
t.Fatalf("List() error = %v", err)
}
// Check all entries have the filtered project
for _, e := range entries {
if e.ProjectID != "proj-list-other" {
t.Errorf("Entry has ProjectID = %q, want %q", e.ProjectID, "proj-list-other")
}
}
})
t.Run("applies limit and offset", func(t *testing.T) {
entries, err := logger.List(ctx, domain.AuditFilters{
ProjectID: "proj-list-1",
Limit: 2,
Offset: 0,
})
if err != nil {
t.Fatalf("List() error = %v", err)
}
if len(entries) != 2 {
t.Errorf("List() returned %d entries, want 2", len(entries))
}
})
t.Run("filters by command type", func(t *testing.T) {
entries, err := logger.List(ctx, domain.AuditFilters{
ProjectID: "proj-list-1",
CommandType: domain.CommandTypeClaude,
})
if err != nil {
t.Fatalf("List() error = %v", err)
}
for _, e := range entries {
if e.CommandType != domain.CommandTypeClaude {
t.Errorf("Entry has CommandType = %q, want %q", e.CommandType, domain.CommandTypeClaude)
}
}
})
t.Run("filters by status", func(t *testing.T) {
entries, err := logger.List(ctx, domain.AuditFilters{
ProjectID: "proj-list-1",
Status: domain.AuditStatusRunning,
})
if err != nil {
t.Fatalf("List() error = %v", err)
}
for _, e := range entries {
if e.Status != domain.AuditStatusRunning {
t.Errorf("Entry has Status = %q, want %q", e.Status, domain.AuditStatusRunning)
}
}
})
}
func TestAuditLogger_Get(t *testing.T) {
db := testutil.TestDB(t)
t.Cleanup(func() { cleanupTestAuditLogs(t, db) })
logger := NewAuditLogger(db)
ctx := context.Background()
t.Run("gets existing entry", func(t *testing.T) {
entry := &domain.AuditLogEntry{
ID: "audit-get-1",
APIKeyID: "key-get-1",
CommandID: "cmd-get-1",
ProjectID: "proj-get-1",
CommandType: domain.CommandTypeClaude,
Args: "test-get-args",
ClientIP: "10.0.0.1",
UserAgent: "test-agent-get",
StartedAt: time.Now(),
}
logger.LogCommandStart(ctx, entry)
retrieved, err := logger.Get(ctx, "cmd-get-1")
if err != nil {
t.Fatalf("Get() error = %v", err)
}
if retrieved.CommandID != "cmd-get-1" {
t.Errorf("CommandID = %q, want %q", retrieved.CommandID, "cmd-get-1")
}
if retrieved.ClientIP != "10.0.0.1" {
t.Errorf("ClientIP = %q, want %q", retrieved.ClientIP, "10.0.0.1")
}
})
t.Run("returns error for non-existent entry", func(t *testing.T) {
_, err := logger.Get(ctx, "cmd-nonexistent")
if err != domain.ErrAuditNotFound {
t.Errorf("Get() error = %v, want %v", err, domain.ErrAuditNotFound)
}
})
}