rdev/internal/service/project_service_test.go
jordan 4f01015132
All checks were successful
ci/woodpecker/push/woodpecker Pipeline was successful
feat: implement project access enforcement and management API
- Fix no-op RequireProjectAccess middleware to enforce project_ids
- Apply project access middleware to all project-scoped routes
- Filter GET /projects by allowed project IDs for restricted keys
- Add GET /me endpoint with key identity, scopes, and project access info
- Add PATCH /keys/{id} for partial key updates (name, scopes, project_ids, allowed_ips, expires_in)
- Add GET/POST/DELETE /projects/{id}/access for project-centric access management
- Auto-grant creating key access when using POST /project/create-and-build
- Accept grant_to_key_ids in create-and-build to grant multiple keys on project creation
- Move newProvisionerWithDeps test helper from production code to test file

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-02-21 15:38:37 -07:00

436 lines
11 KiB
Go

package service
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/orchard9/rdev/internal/domain"
"github.com/orchard9/rdev/internal/port"
)
// MockProjectRepository implements port.ProjectRepository for testing.
type MockProjectRepository struct {
projects map[domain.ProjectID]*domain.Project
refreshCalls int
refreshErr error
}
func NewMockProjectRepository() *MockProjectRepository {
return &MockProjectRepository{
projects: make(map[domain.ProjectID]*domain.Project),
}
}
func (m *MockProjectRepository) List(ctx context.Context) ([]domain.Project, error) {
result := make([]domain.Project, 0, len(m.projects))
for _, p := range m.projects {
result = append(result, *p)
}
return result, nil
}
func (m *MockProjectRepository) Get(ctx context.Context, id domain.ProjectID) (*domain.Project, error) {
p, ok := m.projects[id]
if !ok {
return nil, domain.ErrProjectNotFound
}
return p, nil
}
func (m *MockProjectRepository) Exists(ctx context.Context, id domain.ProjectID) (bool, error) {
_, ok := m.projects[id]
return ok, nil
}
func (m *MockProjectRepository) RefreshStatus(ctx context.Context) error {
m.refreshCalls++
return m.refreshErr
}
func (m *MockProjectRepository) Register(ctx context.Context, p *domain.Project) error {
m.projects[p.ID] = p
return nil
}
func (m *MockProjectRepository) Unregister(ctx context.Context, id domain.ProjectID) error {
delete(m.projects, id)
return nil
}
// MockCommandExecutor implements port.CommandExecutor for testing.
// Uses atomic counters to safely track calls from concurrent goroutines.
type MockCommandExecutor struct {
executeCalls atomic.Int32
cancelCalls atomic.Int32
mu sync.RWMutex // protects result and err
result *domain.CommandResult
err error
}
func (m *MockCommandExecutor) Execute(ctx context.Context, cmd *domain.Command, podName string, handler domain.OutputHandler) (*domain.CommandResult, error) {
m.executeCalls.Add(1)
m.mu.RLock()
defer m.mu.RUnlock()
if m.result != nil {
return m.result, m.err
}
return &domain.CommandResult{
CommandID: cmd.ID,
ExitCode: 0,
DurationMs: 100,
}, m.err
}
func (m *MockCommandExecutor) Cancel(ctx context.Context, cmdID domain.CommandID) error {
m.cancelCalls.Add(1)
return nil
}
func (m *MockCommandExecutor) PodExists(ctx context.Context, podName string) (bool, error) {
return true, nil
}
func (m *MockCommandExecutor) CheckConnection(ctx context.Context) error {
return nil
}
// ExecuteCallCount returns the number of Execute calls (thread-safe).
func (m *MockCommandExecutor) ExecuteCallCount() int {
return int(m.executeCalls.Load())
}
// CancelCallCount returns the number of Cancel calls (thread-safe).
func (m *MockCommandExecutor) CancelCallCount() int {
return int(m.cancelCalls.Load())
}
// SetResult sets the mock result (thread-safe).
func (m *MockCommandExecutor) SetResult(result *domain.CommandResult, err error) {
m.mu.Lock()
defer m.mu.Unlock()
m.result = result
m.err = err
}
// MockStreamPublisher implements port.StreamPublisher for testing.
// Uses mutex to safely handle concurrent publishes from background goroutines.
type MockStreamPublisher struct {
mu sync.RWMutex
streams map[string][]port.StreamEvent
}
func NewMockStreamPublisher() *MockStreamPublisher {
return &MockStreamPublisher{
streams: make(map[string][]port.StreamEvent),
}
}
func (m *MockStreamPublisher) Subscribe(streamID string) (<-chan port.StreamEvent, func()) {
ch := make(chan port.StreamEvent, 100)
return ch, func() { close(ch) }
}
func (m *MockStreamPublisher) SubscribeFromID(streamID, lastEventID string) (<-chan port.StreamEvent, func()) {
return m.Subscribe(streamID)
}
func (m *MockStreamPublisher) Publish(streamID string, event port.StreamEvent) string {
m.mu.Lock()
defer m.mu.Unlock()
m.streams[streamID] = append(m.streams[streamID], event)
return "event-1"
}
func (m *MockStreamPublisher) Close(streamID string) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.streams, streamID)
}
// GetEvents returns events for a stream (thread-safe).
func (m *MockStreamPublisher) GetEvents(streamID string) []port.StreamEvent {
m.mu.RLock()
defer m.mu.RUnlock()
events := make([]port.StreamEvent, len(m.streams[streamID]))
copy(events, m.streams[streamID])
return events
}
func TestProjectService_List(t *testing.T) {
repo := NewMockProjectRepository()
repo.Register(context.Background(), &domain.Project{ID: "proj-a", Name: "Project A"})
repo.Register(context.Background(), &domain.Project{ID: "proj-b", Name: "Project B"})
svc := NewProjectService(repo, nil, nil)
projects, err := svc.List(context.Background(), nil)
if err != nil {
t.Fatalf("List() error = %v", err)
}
if len(projects) != 2 {
t.Errorf("List() returned %d projects, want 2", len(projects))
}
// Should call RefreshStatus
if repo.refreshCalls != 1 {
t.Errorf("RefreshStatus() called %d times, want 1", repo.refreshCalls)
}
}
func TestProjectService_List_RefreshError(t *testing.T) {
repo := NewMockProjectRepository()
repo.refreshErr = errors.New("refresh failed")
repo.Register(context.Background(), &domain.Project{ID: "proj-a", Name: "Project A"})
svc := NewProjectService(repo, nil, nil)
// Should still return projects even if refresh fails
projects, err := svc.List(context.Background(), nil)
if err != nil {
t.Fatalf("List() error = %v", err)
}
if len(projects) != 1 {
t.Errorf("List() returned %d projects, want 1", len(projects))
}
}
func TestProjectService_Get(t *testing.T) {
repo := NewMockProjectRepository()
repo.Register(context.Background(), &domain.Project{
ID: "my-project",
Name: "My Project",
PodName: "pod-0",
})
svc := NewProjectService(repo, nil, nil)
t.Run("existing project", func(t *testing.T) {
project, err := svc.Get(context.Background(), "my-project")
if err != nil {
t.Fatalf("Get() error = %v", err)
}
if project.Name != "My Project" {
t.Errorf("Name = %q, want %q", project.Name, "My Project")
}
})
t.Run("non-existent project", func(t *testing.T) {
_, err := svc.Get(context.Background(), "unknown")
if !errors.Is(err, domain.ErrProjectNotFound) {
t.Errorf("Get() error = %v, want %v", err, domain.ErrProjectNotFound)
}
})
}
func TestProjectService_Exists(t *testing.T) {
repo := NewMockProjectRepository()
repo.Register(context.Background(), &domain.Project{ID: "existing"})
svc := NewProjectService(repo, nil, nil)
tests := []struct {
id domain.ProjectID
want bool
}{
{"existing", true},
{"unknown", false},
}
for _, tt := range tests {
exists, err := svc.Exists(context.Background(), tt.id)
if err != nil {
t.Errorf("Exists(%q) error = %v", tt.id, err)
}
if exists != tt.want {
t.Errorf("Exists(%q) = %v, want %v", tt.id, exists, tt.want)
}
}
}
func TestProjectService_ExecuteClaude(t *testing.T) {
repo := NewMockProjectRepository()
repo.Register(context.Background(), &domain.Project{
ID: "my-project",
PodName: "pod-0",
})
executor := &MockCommandExecutor{}
streams := NewMockStreamPublisher()
svc := NewProjectService(repo, executor, streams)
t.Run("valid request", func(t *testing.T) {
result, err := svc.ExecuteClaude(context.Background(), ExecuteClaudeRequest{
ProjectID: "my-project",
Prompt: "Hello Claude",
})
if err != nil {
t.Fatalf("ExecuteClaude() error = %v", err)
}
if result.CommandID == "" {
t.Error("CommandID should not be empty")
}
if result.StreamURL == "" {
t.Error("StreamURL should not be empty")
}
// Wait a bit for background goroutine
time.Sleep(50 * time.Millisecond)
if executor.ExecuteCallCount() != 1 {
t.Errorf("Execute() called %d times, want 1", executor.ExecuteCallCount())
}
})
t.Run("empty prompt", func(t *testing.T) {
_, err := svc.ExecuteClaude(context.Background(), ExecuteClaudeRequest{
ProjectID: "my-project",
Prompt: "",
})
if !errors.Is(err, domain.ErrInvalidCommand) {
t.Errorf("ExecuteClaude() error = %v, want %v", err, domain.ErrInvalidCommand)
}
})
t.Run("non-existent project", func(t *testing.T) {
_, err := svc.ExecuteClaude(context.Background(), ExecuteClaudeRequest{
ProjectID: "unknown",
Prompt: "Hello",
})
if !errors.Is(err, domain.ErrProjectNotFound) {
t.Errorf("ExecuteClaude() error = %v, want %v", err, domain.ErrProjectNotFound)
}
})
t.Run("custom stream ID", func(t *testing.T) {
result, err := svc.ExecuteClaude(context.Background(), ExecuteClaudeRequest{
ProjectID: "my-project",
Prompt: "Hello",
StreamID: "custom-stream-123",
})
if err != nil {
t.Fatalf("ExecuteClaude() error = %v", err)
}
if result.CommandID != "custom-stream-123" {
t.Errorf("CommandID = %q, want %q", result.CommandID, "custom-stream-123")
}
})
}
func TestProjectService_ExecuteShell(t *testing.T) {
repo := NewMockProjectRepository()
repo.Register(context.Background(), &domain.Project{
ID: "my-project",
PodName: "pod-0",
})
executor := &MockCommandExecutor{}
streams := NewMockStreamPublisher()
svc := NewProjectService(repo, executor, streams)
t.Run("valid request", func(t *testing.T) {
result, err := svc.ExecuteShell(context.Background(), ExecuteShellRequest{
ProjectID: "my-project",
Command: "ls -la",
})
if err != nil {
t.Fatalf("ExecuteShell() error = %v", err)
}
if result.CommandID == "" {
t.Error("CommandID should not be empty")
}
})
t.Run("empty command", func(t *testing.T) {
_, err := svc.ExecuteShell(context.Background(), ExecuteShellRequest{
ProjectID: "my-project",
Command: "",
})
if !errors.Is(err, domain.ErrInvalidCommand) {
t.Errorf("ExecuteShell() error = %v, want %v", err, domain.ErrInvalidCommand)
}
})
t.Run("dangerous command rejected", func(t *testing.T) {
_, err := svc.ExecuteShell(context.Background(), ExecuteShellRequest{
ProjectID: "my-project",
Command: "rm -rf /",
})
if !errors.Is(err, domain.ErrCommandSanitization) {
t.Errorf("ExecuteShell() error = %v, want %v", err, domain.ErrCommandSanitization)
}
})
}
func TestProjectService_ExecuteGit(t *testing.T) {
repo := NewMockProjectRepository()
repo.Register(context.Background(), &domain.Project{
ID: "my-project",
PodName: "pod-0",
})
executor := &MockCommandExecutor{}
streams := NewMockStreamPublisher()
svc := NewProjectService(repo, executor, streams)
t.Run("valid request", func(t *testing.T) {
result, err := svc.ExecuteGit(context.Background(), ExecuteGitRequest{
ProjectID: "my-project",
Args: []string{"status"},
})
if err != nil {
t.Fatalf("ExecuteGit() error = %v", err)
}
if result.CommandID == "" {
t.Error("CommandID should not be empty")
}
})
t.Run("empty args", func(t *testing.T) {
_, err := svc.ExecuteGit(context.Background(), ExecuteGitRequest{
ProjectID: "my-project",
Args: []string{},
})
if !errors.Is(err, domain.ErrInvalidCommand) {
t.Errorf("ExecuteGit() error = %v, want %v", err, domain.ErrInvalidCommand)
}
})
}
func TestProjectService_Subscribe(t *testing.T) {
streams := NewMockStreamPublisher()
svc := NewProjectService(nil, nil, streams)
ch, cleanup := svc.Subscribe("test-stream")
defer cleanup()
if ch == nil {
t.Error("Subscribe() returned nil channel")
}
}
func TestProjectService_SubscribeFromID(t *testing.T) {
streams := NewMockStreamPublisher()
svc := NewProjectService(nil, nil, streams)
ch, cleanup := svc.SubscribeFromID("test-stream", "last-event-123")
defer cleanup()
if ch == nil {
t.Error("SubscribeFromID() returned nil channel")
}
}