package handlers import ( "context" "encoding/json" "net/http" "net/http/httptest" "testing" "time" "github.com/go-chi/chi/v5" "github.com/orchard9/rdev/internal/domain" ) // mockAuditLogger implements port.AuditLogger for testing. type mockAuditLogger struct { entries []domain.AuditLogEntry err error } func (m *mockAuditLogger) LogCommandStart(ctx context.Context, entry *domain.AuditLogEntry) error { return m.err } func (m *mockAuditLogger) LogCommandEnd(ctx context.Context, commandID string, result *domain.AuditResult) error { return m.err } func (m *mockAuditLogger) List(ctx context.Context, filters domain.AuditFilters) ([]domain.AuditLogEntry, error) { if m.err != nil { return nil, m.err } return m.entries, nil } func (m *mockAuditLogger) Get(ctx context.Context, commandID string) (*domain.AuditLogEntry, error) { if m.err != nil { return nil, m.err } for _, e := range m.entries { if e.CommandID == commandID { return &e, nil } } return nil, domain.ErrAuditNotFound } func TestAuditHandler_List(t *testing.T) { now := time.Now() entries := []domain.AuditLogEntry{ { ID: "audit-1", CommandID: "cmd-1", ProjectID: "proj-1", CommandType: domain.CommandTypeClaude, Status: domain.AuditStatusSuccess, StartedAt: now, CreatedAt: now, }, { ID: "audit-2", CommandID: "cmd-2", ProjectID: "proj-1", CommandType: domain.CommandTypeShell, Status: domain.AuditStatusRunning, StartedAt: now, CreatedAt: now, }, } tests := []struct { name string query string mock *mockAuditLogger wantStatus int wantCount int }{ { name: "list all entries", query: "", mock: &mockAuditLogger{entries: entries}, wantStatus: http.StatusOK, wantCount: 2, }, { name: "filter by project", query: "?project=proj-1", mock: &mockAuditLogger{entries: entries}, wantStatus: http.StatusOK, wantCount: 2, }, { name: "invalid command_type", query: "?command_type=invalid", mock: &mockAuditLogger{entries: entries}, wantStatus: http.StatusBadRequest, }, { name: "invalid status", query: "?status=invalid", mock: &mockAuditLogger{entries: entries}, wantStatus: http.StatusBadRequest, }, { name: "invalid start time", query: "?start=invalid", mock: &mockAuditLogger{entries: entries}, wantStatus: http.StatusBadRequest, }, { name: "invalid limit", query: "?limit=-1", mock: &mockAuditLogger{entries: entries}, wantStatus: http.StatusBadRequest, }, { name: "invalid offset", query: "?offset=-1", mock: &mockAuditLogger{entries: entries}, wantStatus: http.StatusBadRequest, }, { name: "valid limit and offset", query: "?limit=10&offset=0", mock: &mockAuditLogger{entries: entries}, wantStatus: http.StatusOK, wantCount: 2, }, { name: "empty result", query: "", mock: &mockAuditLogger{entries: nil}, wantStatus: http.StatusOK, wantCount: 0, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := NewAuditHandler(tt.mock) req := httptest.NewRequest(http.MethodGet, "/audit-log"+tt.query, nil) w := httptest.NewRecorder() h.List(w, req) if w.Code != tt.wantStatus { t.Errorf("List() status = %d, want %d", w.Code, tt.wantStatus) } if tt.wantStatus == http.StatusOK { var resp struct { Data ListAuditLogResponse `json:"data"` } if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { t.Fatalf("failed to unmarshal response: %v", err) } if len(resp.Data.Entries) != tt.wantCount { t.Errorf("List() count = %d, want %d", len(resp.Data.Entries), tt.wantCount) } } }) } } func TestAuditHandler_Get(t *testing.T) { now := time.Now() entries := []domain.AuditLogEntry{ { ID: "audit-1", CommandID: "cmd-123", ProjectID: "proj-1", CommandType: domain.CommandTypeClaude, Status: domain.AuditStatusSuccess, StartedAt: now, CreatedAt: now, }, } tests := []struct { name string commandID string mock *mockAuditLogger wantStatus int }{ { name: "existing entry", commandID: "cmd-123", mock: &mockAuditLogger{entries: entries}, wantStatus: http.StatusOK, }, { name: "non-existent entry", commandID: "cmd-unknown", mock: &mockAuditLogger{entries: entries}, wantStatus: http.StatusNotFound, }, { name: "empty command_id", commandID: "", mock: &mockAuditLogger{entries: entries}, wantStatus: http.StatusBadRequest, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := NewAuditHandler(tt.mock) r := chi.NewRouter() r.Get("/audit-log/{command_id}", h.Get) path := "/audit-log/" + tt.commandID if tt.commandID == "" { // Test with empty path param r.Get("/audit-log/", h.Get) path = "/audit-log/" } req := httptest.NewRequest(http.MethodGet, path, nil) w := httptest.NewRecorder() r.ServeHTTP(w, req) if w.Code != tt.wantStatus { t.Errorf("Get() status = %d, want %d", w.Code, tt.wantStatus) } }) } } func TestAuditLogToResponse(t *testing.T) { now := time.Now() completedAt := now.Add(time.Second) exitCode := 0 durationMs := int64(1000) entry := &domain.AuditLogEntry{ ID: "audit-1", APIKeyID: "key-1", CommandID: "cmd-1", ProjectID: "proj-1", CommandType: domain.CommandTypeClaude, Args: "some args", ClientIP: "127.0.0.1", UserAgent: "test-agent", StartedAt: now, CompletedAt: &completedAt, ExitCode: &exitCode, DurationMs: &durationMs, Status: domain.AuditStatusSuccess, ErrorMessage: "", OutputSizeBytes: 1024, CreatedAt: now, } resp := auditLogToResponse(entry) if resp.ID != entry.ID { t.Errorf("ID = %s, want %s", resp.ID, entry.ID) } if resp.CommandID != entry.CommandID { t.Errorf("CommandID = %s, want %s", resp.CommandID, entry.CommandID) } if resp.ExitCode == nil || *resp.ExitCode != exitCode { t.Errorf("ExitCode = %v, want %d", resp.ExitCode, exitCode) } if resp.DurationMs == nil || *resp.DurationMs != durationMs { t.Errorf("DurationMs = %v, want %d", resp.DurationMs, durationMs) } if resp.CompletedAt == nil { t.Error("CompletedAt should not be nil") } }