package handlers import ( "bytes" "context" "encoding/json" "net/http" "net/http/httptest" "strings" "testing" "github.com/go-chi/chi/v5" "github.com/orchard9/rdev/internal/adapter/kubernetes" "github.com/orchard9/rdev/internal/domain" ) // newTestProjectsHandler creates a ProjectsHandler for testing. // It registers a test project "test-project" for use in tests. func newTestProjectsHandler() *ProjectsHandler { repo := kubernetes.NewProjectRepository("test-namespace") // Register a test project for tests to use _ = repo.Register(context.Background(), &domain.Project{ ID: "test-project", Name: "Test Project", Description: "Test project for unit tests", PodName: "test-project-pod-0", Status: domain.ProjectStatusRunning, Workspace: "/workspace", }) exec := kubernetes.NewExecutor("test-namespace") return NewProjectsHandler(repo, exec) } // TestProjectsHandler_List tests the List endpoint. func TestProjectsHandler_List(t *testing.T) { h := newTestProjectsHandler() router := chi.NewRouter() router.Use(testAdminAuth) h.Mount(router) req := httptest.NewRequest("GET", "/projects", nil) rec := httptest.NewRecorder() router.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("Status = %d, want 200", rec.Code) } var resp map[string]any if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { t.Fatalf("Failed to decode response: %v", err) } if _, ok := resp["data"]; !ok { t.Error("Response missing 'data' field") } if _, ok := resp["meta"]; !ok { t.Error("Response missing 'meta' field") } } // TestProjectsHandler_Get tests the Get endpoint. func TestProjectsHandler_Get(t *testing.T) { h := newTestProjectsHandler() router := chi.NewRouter() router.Use(testAdminAuth) h.Mount(router) tests := []struct { name string projectID string wantStatus int }{ {"existing project", "test-project", http.StatusOK}, {"non-existent project", "nonexistent", http.StatusNotFound}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest("GET", "/projects/"+tt.projectID, nil) rec := httptest.NewRecorder() router.ServeHTTP(rec, req) if rec.Code != tt.wantStatus { t.Errorf("Status = %d, want %d", rec.Code, tt.wantStatus) } }) } } // TestProjectsHandler_RunClaude tests the RunClaude endpoint. func TestProjectsHandler_RunClaude(t *testing.T) { h := newTestProjectsHandler() router := chi.NewRouter() router.Use(testAdminAuth) h.Mount(router) tests := []struct { name string projectID string body any wantStatus int wantErr string }{ { name: "valid request", projectID: "test-project", body: ClaudeRequest{ Prompt: "Hello, world!", }, wantStatus: http.StatusCreated, }, { name: "missing prompt", projectID: "test-project", body: ClaudeRequest{ Prompt: "", }, wantStatus: http.StatusBadRequest, wantErr: "prompt: is required", }, { name: "project not found", projectID: "nonexistent", body: ClaudeRequest{Prompt: "test"}, wantStatus: http.StatusNotFound, }, { name: "null byte in prompt", projectID: "test-project", body: ClaudeRequest{ Prompt: "Hello\x00World", }, wantStatus: http.StatusBadRequest, wantErr: "null byte", }, { name: "invalid stream ID", projectID: "test-project", body: ClaudeRequest{ Prompt: "Hello", StreamID: "invalid stream id with spaces", }, wantStatus: http.StatusBadRequest, wantErr: "alphanumeric", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { body, _ := json.Marshal(tt.body) req := httptest.NewRequest("POST", "/projects/"+tt.projectID+"/claude", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") rec := httptest.NewRecorder() router.ServeHTTP(rec, req) if rec.Code != tt.wantStatus { t.Errorf("Status = %d, want %d. Body: %s", rec.Code, tt.wantStatus, rec.Body.String()) } if tt.wantErr != "" { if !strings.Contains(rec.Body.String(), tt.wantErr) { t.Errorf("Body = %q, want to contain %q", rec.Body.String(), tt.wantErr) } } }) } } // TestProjectsHandler_RunShell tests the RunShell endpoint. func TestProjectsHandler_RunShell(t *testing.T) { h := newTestProjectsHandler() router := chi.NewRouter() router.Use(testAdminAuth) h.Mount(router) tests := []struct { name string projectID string body any wantStatus int wantErr string }{ { name: "valid command", projectID: "test-project", body: ShellRequest{ Command: "ls -la", }, wantStatus: http.StatusCreated, }, { name: "missing command", projectID: "test-project", body: ShellRequest{ Command: "", }, wantStatus: http.StatusBadRequest, wantErr: "command: is required", }, { name: "dangerous command with semicolon", projectID: "test-project", body: ShellRequest{ Command: "ls; rm -rf /", }, wantStatus: http.StatusBadRequest, wantErr: "command chaining", }, { name: "dangerous command with pipe", projectID: "test-project", body: ShellRequest{ Command: "cat /etc/passwd | grep root", }, wantStatus: http.StatusBadRequest, wantErr: "command chaining", }, { name: "command substitution", projectID: "test-project", body: ShellRequest{ Command: "echo $(whoami)", }, wantStatus: http.StatusBadRequest, wantErr: "command chaining", }, { name: "redirect", projectID: "test-project", body: ShellRequest{ Command: "ls > /tmp/out.txt", }, wantStatus: http.StatusBadRequest, wantErr: "redirect", }, { name: "rm rf root", projectID: "test-project", body: ShellRequest{ Command: "rm -rf /", }, wantStatus: http.StatusBadRequest, wantErr: "destructive rm", }, { name: "project not found", projectID: "nonexistent", body: ShellRequest{Command: "ls"}, wantStatus: http.StatusNotFound, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { body, _ := json.Marshal(tt.body) req := httptest.NewRequest("POST", "/projects/"+tt.projectID+"/shell", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") rec := httptest.NewRecorder() router.ServeHTTP(rec, req) if rec.Code != tt.wantStatus { t.Errorf("Status = %d, want %d. Body: %s", rec.Code, tt.wantStatus, rec.Body.String()) } if tt.wantErr != "" { if !strings.Contains(rec.Body.String(), tt.wantErr) { t.Errorf("Body = %q, want to contain %q", rec.Body.String(), tt.wantErr) } } }) } } // TestProjectsHandler_RunGit tests the RunGit endpoint. func TestProjectsHandler_RunGit(t *testing.T) { h := newTestProjectsHandler() router := chi.NewRouter() router.Use(testAdminAuth) h.Mount(router) tests := []struct { name string projectID string body any wantStatus int wantErr string }{ { name: "valid git status", projectID: "test-project", body: GitRequest{ Args: []string{"status"}, }, wantStatus: http.StatusCreated, }, { name: "valid git log", projectID: "test-project", body: GitRequest{ Args: []string{"log", "--oneline", "-10"}, }, wantStatus: http.StatusCreated, }, { name: "missing args", projectID: "test-project", body: GitRequest{ Args: []string{}, }, wantStatus: http.StatusBadRequest, wantErr: "args: is required", }, { name: "git config blocked", projectID: "test-project", body: GitRequest{ Args: []string{"config", "--global", "user.name", "attacker"}, }, wantStatus: http.StatusBadRequest, wantErr: "git config", }, { name: "git remote blocked", projectID: "test-project", body: GitRequest{ Args: []string{"remote", "add", "evil", "https://evil.com/repo"}, }, wantStatus: http.StatusBadRequest, wantErr: "git remote", }, { name: "force push blocked", projectID: "test-project", body: GitRequest{ Args: []string{"push", "-f", "origin", "main"}, }, wantStatus: http.StatusBadRequest, wantErr: "force push", }, { name: "project not found", projectID: "nonexistent", body: GitRequest{Args: []string{"status"}}, wantStatus: http.StatusNotFound, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { body, _ := json.Marshal(tt.body) req := httptest.NewRequest("POST", "/projects/"+tt.projectID+"/git", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") rec := httptest.NewRecorder() router.ServeHTTP(rec, req) if rec.Code != tt.wantStatus { t.Errorf("Status = %d, want %d. Body: %s", rec.Code, tt.wantStatus, rec.Body.String()) } if tt.wantErr != "" { if !strings.Contains(rec.Body.String(), tt.wantErr) { t.Errorf("Body = %q, want to contain %q", rec.Body.String(), tt.wantErr) } } }) } } // TestProjectsHandler_Events tests the Events SSE endpoint. func TestProjectsHandler_Events(t *testing.T) { h := newTestProjectsHandler() router := chi.NewRouter() router.Use(testAdminAuth) h.Mount(router) // Note: SSE tests with headers are difficult in httptest because the // handler blocks waiting for events. We test what we can without blocking. t.Run("project not found", func(t *testing.T) { req := httptest.NewRequest("GET", "/projects/nonexistent/events", nil) rec := httptest.NewRecorder() router.ServeHTTP(rec, req) if rec.Code != http.StatusNotFound { t.Errorf("Status = %d, want 404", rec.Code) } }) } // TestProjectsHandler_InvalidJSON tests handling of invalid JSON bodies. func TestProjectsHandler_InvalidJSON(t *testing.T) { h := newTestProjectsHandler() router := chi.NewRouter() router.Use(testAdminAuth) h.Mount(router) endpoints := []struct { method string path string }{ {"POST", "/projects/test-project/claude"}, {"POST", "/projects/test-project/shell"}, {"POST", "/projects/test-project/git"}, } for _, ep := range endpoints { t.Run(ep.path, func(t *testing.T) { req := httptest.NewRequest(ep.method, ep.path, strings.NewReader("invalid json{")) req.Header.Set("Content-Type", "application/json") rec := httptest.NewRecorder() router.ServeHTTP(rec, req) if rec.Code != http.StatusBadRequest { t.Errorf("Status = %d, want 400. Body: %s", rec.Code, rec.Body.String()) } if !strings.Contains(rec.Body.String(), "invalid") { t.Errorf("Body = %q, want to contain 'invalid'", rec.Body.String()) } }) } } // TestCommandIDGeneration tests that command IDs are generated correctly. func TestCommandIDGeneration(t *testing.T) { h := newTestProjectsHandler() router := chi.NewRouter() router.Use(testAdminAuth) h.Mount(router) // Send two requests and verify they get different command IDs body := ClaudeRequest{Prompt: "test"} bodyBytes, _ := json.Marshal(body) req1 := httptest.NewRequest("POST", "/projects/test-project/claude", bytes.NewReader(bodyBytes)) req1.Header.Set("Content-Type", "application/json") rec1 := httptest.NewRecorder() router.ServeHTTP(rec1, req1) req2 := httptest.NewRequest("POST", "/projects/test-project/claude", bytes.NewReader(bodyBytes)) req2.Header.Set("Content-Type", "application/json") rec2 := httptest.NewRecorder() router.ServeHTTP(rec2, req2) // Parse both responses var resp1, resp2 map[string]any json.NewDecoder(bytes.NewReader(rec1.Body.Bytes())).Decode(&resp1) json.NewDecoder(bytes.NewReader(rec2.Body.Bytes())).Decode(&resp2) data1, _ := resp1["data"].(map[string]any) data2, _ := resp2["data"].(map[string]any) if data1["id"] == data2["id"] { t.Error("Two requests should have different command IDs") } } // TestCustomStreamID tests that custom stream IDs are used when provided. func TestCustomStreamID(t *testing.T) { h := newTestProjectsHandler() router := chi.NewRouter() router.Use(testAdminAuth) h.Mount(router) body := ClaudeRequest{ Prompt: "test", StreamID: "my-custom-stream-id", } bodyBytes, _ := json.Marshal(body) req := httptest.NewRequest("POST", "/projects/test-project/claude", bytes.NewReader(bodyBytes)) req.Header.Set("Content-Type", "application/json") rec := httptest.NewRecorder() router.ServeHTTP(rec, req) var resp map[string]any json.NewDecoder(rec.Body).Decode(&resp) data, _ := resp["data"].(map[string]any) if data["id"] != "my-custom-stream-id" { t.Errorf("Command ID = %v, want my-custom-stream-id", data["id"]) } }