diff --git a/ai-lookup/features/build-orchestration.md b/ai-lookup/features/build-orchestration.md index c837110..36c5da7 100644 --- a/ai-lookup/features/build-orchestration.md +++ b/ai-lookup/features/build-orchestration.md @@ -1,77 +1,59 @@ # Build Orchestration -**Last Updated:** 2025-01 -**Confidence:** High (Planned - see address-the-gaps.md) +**Last Updated:** 2026-01-27 +**Confidence:** High ## Summary -Build orchestration enables structured build specs for bot-driven development. Bots submit build requests with prompts, workers execute, and callbacks notify completion. +Build orchestration enables structured build specs for bot-driven development. Bots submit build requests with prompts and templates via `POST /project/{name}/build`, workers execute Claude Code, and callbacks notify completion. All builds are recorded in the `build_audit` table for observability. **Key Facts:** -- Build spec includes template, prompt, variables, auto_deploy flag -- Enqueues as work task for worker pool -- Auto-deploy commits, pushes, triggers Woodpecker CI -- Callback URL notified on completion with artifacts +- BuildSpec: prompt (required), template, variables, auto_commit, auto_push, callback_url +- BuildResult: success, output, error, commit_sha, files_changed, duration_ms, artifacts +- Builds enqueued as work tasks for the worker pool +- Auto-commit/push triggers Woodpecker CI pipeline +- Callback URL receives completion notification with full BuildResult +- Complete audit trail in `build_audit` PostgreSQL table **File Pointers:** +- Domain: `internal/domain/build.go` (BuildSpec, BuildResult, BuildAuditEntry) +- Port: `internal/port/build_audit.go` (BuildAudit interface) +- Adapter: `internal/adapter/postgres/build_audit.go` - Service: `internal/service/build_service.go` -- Handler: `internal/handlers/build.go` -- Work queue: `internal/port/work_queue.go` +- Handler: `internal/handlers/builds.go` (StartBuild, ListBuilds, GetBuild) +- Handler: `internal/handlers/create_and_build.go` (CreateAndBuild) +- Executor: `internal/worker/build_executor.go` (BuildSpec→AgentRequest translation) +- Git: `internal/worker/git_operations.go` (clone, commit, push with token injection) +- Migration: `internal/db/migrations/012_worker_registry.sql` (build_audit table) -## Build Spec Schema +## API Endpoints -```go -type BuildSpec struct { - Template string `json:"template"` - Prompt string `json:"prompt"` - Variables map[string]string `json:"variables"` - AutoDeploy bool `json:"auto_deploy"` - CallbackURL string `json:"callback_url"` -} -``` - -## API Endpoint - -``` -POST /project/{name}/build -{ - "template": "astro-landing", - "prompt": "Create a coming soon page with dark theme and threesix.ai branding", - "auto_deploy": true, - "callback_url": "https://pantheon.orchard9.ai/webhooks/build-complete" -} -``` +| Method | Path | Description | +|--------|------|-------------| +| POST | `/projects/{id}/builds` | Start a build, returns task_id | +| GET | `/projects/{id}/builds` | List builds for project | +| GET | `/builds/{taskId}` | Get build status and result | +| POST | `/project/create-and-build` | Create project + start build in one call | ## Orchestration Flow -1. Bot calls `POST /project/{name}/build` -2. BuildService validates project exists -3. Creates WorkTask with build spec -4. Enqueues to work queue -5. Returns task ID immediately -6. Worker picks up task: - - Clones repo - - Runs Claude with prompt - - Commits and pushes (if auto_deploy) -7. Woodpecker builds and deploys -8. Callback notified with result +1. Bot calls `POST /projects/{id}/builds` with BuildSpec (prompt, template, auto_commit, auto_push) +2. BuildService validates spec (prompt required), creates WorkTask with build spec, enqueues +3. Creates BuildAuditEntry with status "pending" +4. Returns task ID immediately +5. WorkExecutor poll loop claims task from queue +6. BuildExecutor translates spec: clones repo, builds AgentRequest, calls CodeAgent.Execute() +7. On success with auto_commit: GitOperations commits and pushes changes +8. WorkExecutor reports completion with BuildResult +9. Audit entry updated, callback URL notified -## Callback Payload +## Build Audit Statuses -```json -{ - "task_id": "uuid", - "project_id": "myapp", - "status": "completed", - "result": { - "output": "...", - "artifacts": { - "commit_sha": "abc123", - "deploy_url": "https://myapp.threesix.ai" - } - } -} -``` +- `pending` - enqueued, waiting for worker +- `running` - worker executing +- `completed` - finished successfully +- `failed` - execution failed +- `cancelled` - cancelled before completion ## Related Topics diff --git a/ai-lookup/index.md b/ai-lookup/index.md index ff4833a..0743f5f 100644 --- a/ai-lookup/index.md +++ b/ai-lookup/index.md @@ -10,16 +10,16 @@ Quick reference for rdev concepts and facts. | Project Service | [services/project-service.md](./services/project-service.md) | High | 2025-01 | Business logic for project operations | | API Keys | [services/api-keys.md](./services/api-keys.md) | High | 2025-01 | Authentication, scopes, restrictions | | Webhooks | [services/webhooks.md](./services/webhooks.md) | High | 2025-01 | Event subscriptions and delivery | -| **Worker Infrastructure** (Planned) | +| **Worker Infrastructure** | | Work Queue | [services/work-queue.md](./services/work-queue.md) | High | 2025-01 | Task queue for worker pool | -| Worker Pool | [services/worker-pool.md](./services/worker-pool.md) | High | 2025-01 | Shared claudebox workers | +| Worker Pool | [services/worker-pool.md](./services/worker-pool.md) | High | 2026-01 | Embedded work executor with queue maintenance and metrics | | CI Provider | [services/ci-provider.md](./services/ci-provider.md) | High | 2025-01 | Woodpecker auto-activation | | Template Provider | [services/template-provider.md](./services/template-provider.md) | High | 2025-01 | Project template seeding | | **Features** | | Command Execution | [features/command-execution.md](./features/command-execution.md) | High | 2025-01 | Claude/shell/git command flow | | SSE Streaming | [features/sse-streaming.md](./features/sse-streaming.md) | High | 2025-01 | Real-time output streaming | | Infrastructure Management | [features/infrastructure.md](./features/infrastructure.md) | High | 2025-01 | Gitea, Cloudflare, deployment | -| Build Orchestration | [features/build-orchestration.md](./features/build-orchestration.md) | High | 2025-01 | Bot-driven build specs | +| Build Orchestration | [features/build-orchestration.md](./features/build-orchestration.md) | High | 2026-01 | Bot-driven build specs with audit trail | ## Roadmap Reference diff --git a/ai-lookup/services/worker-pool.md b/ai-lookup/services/worker-pool.md index 918f487..2c5d12b 100644 --- a/ai-lookup/services/worker-pool.md +++ b/ai-lookup/services/worker-pool.md @@ -1,59 +1,74 @@ # Worker Pool -**Last Updated:** 2025-01 -**Confidence:** High (Planned - see address-the-gaps.md) +**Last Updated:** 2026-01-27 +**Confidence:** High ## Summary -Shared pool of claudebox workers (3-5 pods) that can build any project. Workers register, send heartbeats, and poll for tasks. Scales horizontally by adding workers, not projects. +Shared worker pool that executes build tasks for any project. Currently runs as an embedded WorkExecutor daemon inside rdev-api. Workers register with the worker registry, poll the work queue for tasks, execute Claude Code in cloned repos via GitOperations, and report results with audit trails. **Key Facts:** -- Workers labeled `rdev.orchard9.ai/role=worker` -- StatefulSet: `claudebox-worker` with 3+ replicas -- Each worker has dedicated PVC for workspace -- Workers poll rdev-api for tasks every 5 seconds -- Health tracked via heartbeat endpoint +- Embedded WorkExecutor daemon runs inside rdev-api process +- Workers poll work queue every 5 seconds, heartbeat every 30 seconds +- Stale workers (no heartbeat for 2 minutes) automatically marked offline by QueueMaintenance +- Stale tasks (running >30 min without completion) automatically requeued +- Old tasks (>7 days) automatically cleaned up +- Queue depth and worker counts exported as Prometheus metrics +- Future: external worker binary for separate pod deployment **File Pointers:** -- Port: `internal/port/worker_registry.go` +- Domain: `internal/domain/worker.go` (Worker, WorkerStatus) +- Domain: `internal/domain/build.go` (BuildSpec, BuildResult) +- Port: `internal/port/worker_registry.go` (WorkerRegistry interface) +- Port: `internal/port/build_audit.go` (BuildAudit interface) - Adapter: `internal/adapter/postgres/worker_registry.go` -- Handler: `internal/handlers/workers.go` -- K8s manifest: `deployments/k8s/base/claudebox-worker.yaml` +- Adapter: `internal/adapter/postgres/build_audit.go` +- Service: `internal/service/worker_service.go` +- Service: `internal/service/build_service.go` +- Executor: `internal/worker/work_executor.go` (poll loop, heartbeat, task routing) +- Executor: `internal/worker/build_executor.go` (BuildSpec→AgentRequest) +- Git: `internal/worker/git_operations.go` (clone, commit, push) +- Maintenance: `internal/worker/queue_maintenance.go` (stale recovery, cleanup, metrics) +- Handler: `internal/handlers/workers.go` (REST API for workers) +- Handler: `internal/handlers/builds.go` (REST API for builds) +- Handler: `internal/handlers/create_and_build.go` (combined create+build) +- Migration: `internal/db/migrations/012_worker_registry.sql` -## Port Interface +## Worker Lifecycle (Embedded) -```go -type WorkerRegistry interface { - Register(ctx context.Context, worker WorkerInfo) error - Heartbeat(ctx context.Context, workerID string) error - Deregister(ctx context.Context, workerID string) error - ListActive(ctx context.Context) ([]WorkerInfo, error) -} +1. rdev-api starts → WorkExecutor registers as worker in registry +2. Heartbeat loop: every 30s sends heartbeat via WorkerService +3. Poll loop: every 5s dequeues next task from work queue +4. BuildExecutor: clones repo, executes CodeAgent, commits/pushes if auto_commit +5. Reports completion with BuildResult via WorkerService +6. Graceful shutdown: deregisters worker on rdev-api stop -type WorkerInfo struct { - ID string - PodName string - Namespace string - Status string // "idle", "busy", "unhealthy" - LastSeen time.Time - CurrentTask string -} -``` +## Worker Statuses -## Worker Lifecycle +- `idle` - available for new tasks +- `busy` - currently executing a task +- `draining` - not accepting new tasks (pre-shutdown) +- `offline` - missed heartbeat threshold -1. Pod starts → calls `POST /workers` to register -2. Main loop: heartbeat every 5s, poll for tasks -3. Task received → clone repo, run Claude, commit, report -4. Pod shutdown → `DELETE /workers/{id}` to deregister +## API Endpoints -## Environment Variables +| Method | Path | Description | +|--------|------|-------------| +| GET | `/workers` | List all workers with status summary | +| GET | `/workers/{workerId}` | Get worker details | +| POST | `/workers/{workerId}/drain` | Set worker to draining | +| POST | `/projects/{id}/builds` | Start build for project | +| GET | `/projects/{id}/builds` | List builds for project | +| GET | `/builds/{taskId}` | Get build status | +| POST | `/project/create-and-build` | Create project + start build | -``` -WORKER_ID=$(hostname) -RDEV_API_URL=http://rdev-api.rdev.svc:8080 -RDEV_API_KEY= -``` +## Queue Maintenance + +The QueueMaintenance worker runs inside rdev-api alongside the WorkExecutor: +- **Stale task recovery** (every 1m): Requeues tasks running >30m without completion +- **Stale worker marking** (every 1m): Marks workers offline after 2m without heartbeat +- **Old task cleanup** (every 1m): Removes completed/failed/cancelled tasks >7 days old +- **Metrics refresh** (every 15s): Updates Prometheus gauges for queue depth and worker counts ## Related Topics diff --git a/cmd/rdev-api/config.go b/cmd/rdev-api/config.go new file mode 100644 index 0000000..92855e1 --- /dev/null +++ b/cmd/rdev-api/config.go @@ -0,0 +1,179 @@ +package main + +import ( + "context" + "log/slog" + "os" + "strconv" + + "github.com/orchard9/rdev/internal/domain" + "github.com/orchard9/rdev/internal/port" +) + +// Config holds application configuration. +type Config struct { + Port int + DBHost string + DBPort int + DBUser string + DBPassword string + DBName string + DBSSLMode string + AdminKey string + + // Credential store encryption key (required for storing secrets in DB) + CredentialEncryptionKey string + + // OpenCode configuration (optional - enables OpenCode as alternative code agent) + OpenCodeURL string // e.g., "http://opencode:4096" + OpenCodeUsername string // Basic auth username (default: "opencode") + OpenCodePassword string // Basic auth password + + // Infrastructure adapters (threesix.ai) - fallback values if not in credential store + GiteaURL string + GiteaToken string + GiteaDefaultOrg string + CloudflareToken string + CloudflareZoneID string + DefaultDomain string + DeployNamespace string + DeployTLSIssuer string + ClusterIP string + RegistryURL string + WoodpeckerURL string + WoodpeckerAPIToken string + WoodpeckerWebhookSecret string +} + +// InfraConfig holds infrastructure adapter configuration. +// Loaded from credential store with env var fallback. +type InfraConfig struct { + GiteaURL string + GiteaToken string + GiteaDefaultOrg string + CloudflareToken string + CloudflareZoneID string + DefaultDomain string + DeployNamespace string + DeployTLSIssuer string + ClusterIP string + RegistryURL string + WoodpeckerURL string + WoodpeckerAPIToken string + WoodpeckerWebhookSecret string +} + +func loadConfig() Config { + port := 8080 + if v := os.Getenv("PORT"); v != "" { + if p, err := strconv.Atoi(v); err == nil { + port = p + } + } + + dbPort := 5432 + if v := os.Getenv("DB_PORT"); v != "" { + if p, err := strconv.Atoi(v); err == nil { + dbPort = p + } + } + + return Config{ + Port: port, + DBHost: getEnv("DB_HOST", "postgres.databases.svc"), + DBPort: dbPort, + DBUser: getEnv("DB_USER", "appuser"), + DBPassword: os.Getenv("DB_PASSWORD"), + DBName: getEnv("DB_NAME", "rdev"), + DBSSLMode: getEnv("DB_SSL_MODE", "disable"), + AdminKey: os.Getenv("RDEV_ADMIN_KEY"), + + // Encryption key for credential store (generate with: openssl rand -base64 32) + // REQUIRED in production - no default to prevent insecure deployments + CredentialEncryptionKey: os.Getenv("CREDENTIAL_ENCRYPTION_KEY"), + + // OpenCode (optional alternative code agent) + OpenCodeURL: os.Getenv("OPENCODE_URL"), // e.g., "http://opencode:4096" + OpenCodeUsername: getEnv("OPENCODE_USERNAME", "opencode"), + OpenCodePassword: os.Getenv("OPENCODE_PASSWORD"), + + // Infrastructure adapters (fallback if not in credential store) + GiteaURL: getEnv("GITEA_URL", "https://git.threesix.ai"), + GiteaToken: os.Getenv("GITEA_TOKEN"), + GiteaDefaultOrg: getEnv("GITEA_DEFAULT_ORG", "jordan"), + CloudflareToken: os.Getenv("CLOUDFLARE_API_TOKEN"), + CloudflareZoneID: os.Getenv("CLOUDFLARE_ZONE_ID"), + DefaultDomain: getEnv("DEFAULT_DOMAIN", "threesix.ai"), + DeployNamespace: getEnv("DEPLOY_NAMESPACE", "projects"), + DeployTLSIssuer: getEnv("DEPLOY_TLS_ISSUER", "letsencrypt-threesix"), + ClusterIP: getEnv("CLUSTER_IP", "208.122.204.172"), + RegistryURL: getEnv("REGISTRY_URL", "zot.threesix.svc.cluster.local:5000"), + WoodpeckerURL: getEnv("WOODPECKER_URL", "https://ci.threesix.ai"), + WoodpeckerAPIToken: os.Getenv("WOODPECKER_API_TOKEN"), + WoodpeckerWebhookSecret: os.Getenv("WOODPECKER_WEBHOOK_SECRET"), + } +} + +func getEnv(key, defaultVal string) string { + if v := os.Getenv(key); v != "" { + return v + } + return defaultVal +} + +// loadInfraConfig loads infrastructure configuration from credential store, +// falling back to environment variables if not found in the store. +func loadInfraConfig(ctx context.Context, store port.CredentialStore, cfg Config, logger *slog.Logger) InfraConfig { + // Try to load from credential store + creds, err := store.GetMultiple(ctx, []string{ + domain.CredKeyGiteaToken, + domain.CredKeyGiteaURL, + domain.CredKeyCloudflareAPIToken, + domain.CredKeyCloudflareZoneID, + domain.CredKeyWoodpeckerURL, + domain.CredKeyWoodpeckerAPIToken, + domain.CredKeyWoodpeckerWebhookSecret, + domain.CredKeyRegistryURL, + }) + if err != nil { + logger.Warn("failed to load credentials from store, using env vars", "error", err) + creds = make(map[string]string) + } + + // Helper to get from store or fall back to env var + getOrFallback := func(key, envFallback string) string { + if v, ok := creds[key]; ok && v != "" { + return v + } + return envFallback + } + + infraCfg := InfraConfig{ + GiteaURL: getOrFallback(domain.CredKeyGiteaURL, cfg.GiteaURL), + GiteaToken: getOrFallback(domain.CredKeyGiteaToken, cfg.GiteaToken), + GiteaDefaultOrg: cfg.GiteaDefaultOrg, // Not a secret, use env + CloudflareToken: getOrFallback(domain.CredKeyCloudflareAPIToken, cfg.CloudflareToken), + CloudflareZoneID: getOrFallback(domain.CredKeyCloudflareZoneID, cfg.CloudflareZoneID), + DefaultDomain: cfg.DefaultDomain, // Not a secret, use env + DeployNamespace: cfg.DeployNamespace, // Not a secret, use env + DeployTLSIssuer: cfg.DeployTLSIssuer, // Not a secret, use env + ClusterIP: cfg.ClusterIP, // Not a secret, use env + RegistryURL: getOrFallback(domain.CredKeyRegistryURL, cfg.RegistryURL), + WoodpeckerURL: getOrFallback(domain.CredKeyWoodpeckerURL, cfg.WoodpeckerURL), + WoodpeckerAPIToken: getOrFallback(domain.CredKeyWoodpeckerAPIToken, cfg.WoodpeckerAPIToken), + WoodpeckerWebhookSecret: getOrFallback(domain.CredKeyWoodpeckerWebhookSecret, cfg.WoodpeckerWebhookSecret), + } + + // Log which credentials were loaded from store vs env + fromStore := 0 + for k := range creds { + if creds[k] != "" { + fromStore++ + } + } + if fromStore > 0 { + logger.Info("loaded credentials from store", "count", fromStore) + } + + return infraCfg +} diff --git a/cmd/rdev-api/main.go b/cmd/rdev-api/main.go index 99e6eae..7607b6c 100644 --- a/cmd/rdev-api/main.go +++ b/cmd/rdev-api/main.go @@ -37,10 +37,12 @@ import ( "context" "log/slog" "os" - "strconv" "time" "github.com/orchard9/rdev/internal/adapter/cloudflare" + "github.com/orchard9/rdev/internal/adapter/codeagent" + "github.com/orchard9/rdev/internal/adapter/codeagent/claudecode" + "github.com/orchard9/rdev/internal/adapter/codeagent/opencode" "github.com/orchard9/rdev/internal/adapter/deployer" "github.com/orchard9/rdev/internal/adapter/gitea" "github.com/orchard9/rdev/internal/adapter/kubernetes" @@ -50,11 +52,9 @@ import ( "github.com/orchard9/rdev/internal/adapter/woodpecker" "github.com/orchard9/rdev/internal/auth" "github.com/orchard9/rdev/internal/db" - "github.com/orchard9/rdev/internal/domain" "github.com/orchard9/rdev/internal/handlers" "github.com/orchard9/rdev/internal/metrics" "github.com/orchard9/rdev/internal/middleware" - "github.com/orchard9/rdev/internal/port" "github.com/orchard9/rdev/internal/service" "github.com/orchard9/rdev/internal/telemetry" "github.com/orchard9/rdev/internal/webhook" @@ -102,8 +102,10 @@ func main() { } defer func() { _ = database.Close() }() - // Initialize auth service - authService := auth.NewService(database.DB, cfg.AdminKey) + // Initialize auth service (hexagonal: repo → service → auth wrapper) + apiKeyRepo := postgres.NewAPIKeyRepository(database.DB) + apiKeySvc := service.NewAPIKeyService(apiKeyRepo, cfg.AdminKey) + authService := auth.NewService(apiKeySvc, cfg.AdminKey) // Initialize credential store (for infrastructure secrets) credentialStore := postgres.NewCredentialStore(database.DB, cfg.CredentialEncryptionKey) @@ -218,17 +220,49 @@ func main() { logger.Info("template provider initialized") } + // Initialize CodeAgent registry (multi-provider support) + agentRegistry := codeagent.NewRegistry() + + // Register Claude Code adapter (default - always available) + claudeCodeAdapter := claudecode.NewAdapter(namespace) + agentRegistry.Register(claudeCodeAdapter) + logger.Info("registered Claude Code agent", "provider", claudeCodeAdapter.Provider()) + + // Register OpenCode adapter (optional - only if configured) + if cfg.OpenCodeURL != "" { + openCodeAdapter := opencode.NewAdapter(opencode.ClientConfig{ + BaseURL: cfg.OpenCodeURL, + Username: cfg.OpenCodeUsername, + Password: cfg.OpenCodePassword, + Timeout: 30 * time.Second, + }) + agentRegistry.Register(openCodeAdapter) + logger.Info("registered OpenCode agent", "provider", openCodeAdapter.Provider(), "url", cfg.OpenCodeURL) + } + // Create services projectService := service.NewProjectService(projectRepo, k8sExecutor, streamPub). WithAuditLogger(auditLogger). WithCommandQueue(commandQueue). - WithWebhookDispatcher(webhookDispatcher) + WithWebhookDispatcher(webhookDispatcher). + WithCodeAgentRegistry(agentRegistry) // Create work service (for worker pool task management) workService := service.NewWorkService(workQueueRepo, service.WorkServiceConfig{ Logger: logger, }).WithWebhookDispatcher(webhookDispatcher) + // Initialize worker pool infrastructure + workerRegistryRepo := postgres.NewWorkerRegistryRepository(database.DB) + buildAuditRepo := postgres.NewBuildAuditRepository(database.DB) + + // Create worker service (manages worker lifecycle and task assignment) + workerService := service.NewWorkerService(workerRegistryRepo, workQueueRepo, logger). + WithBuildAudit(buildAuditRepo) + + // Create build service (orchestrates build submission and tracking) + buildService := service.NewBuildService(workQueueRepo, buildAuditRepo, logger) + // Create app app := api.New("rdev-api", api.WithPort(cfg.Port), @@ -261,12 +295,13 @@ func main() { webhookHandler := handlers.NewWebhookHandler(webhookRepo, projectRepo) workHandler := handlers.NewWorkHandler(workService) - // Initialize infrastructure handler (for threesix.ai git/deploy/dns) + // Initialize infrastructure handler (for threesix.ai git/deploy/dns/ci) infraHandler := handlers.NewInfrastructureHandler( giteaClient, dnsClient, deployerAdapter, projectRepo, + woodpeckerClient, handlers.InfrastructureConfig{ DefaultGitOwner: infraCfg.GiteaDefaultOrg, DefaultDomain: infraCfg.DefaultDomain, @@ -290,7 +325,7 @@ func main() { ) // Initialize project management handler - projectMgmtHandler := handlers.NewProjectManagementHandler(projectInfraService) + projectMgmtHandler := handlers.NewProjectManagementHandler(projectInfraService, logger) // Initialize Woodpecker webhook handler (for CI/CD auto-deploy) woodpeckerHandler := handlers.NewWoodpeckerWebhookHandler( @@ -308,6 +343,21 @@ func main() { // Initialize credentials handler (superadmin only) credentialsHandler := handlers.NewCredentialsHandler(credentialStore) + // Initialize agents handler (for code agent management) + agentsHandler := handlers.NewAgentsHandler(agentRegistry) + + // Initialize worker pool handlers + workersHandler := handlers.NewWorkersHandler(workerService) + buildsHandler := handlers.NewBuildsHandler(buildService) + createAndBuildHandler := handlers.NewCreateAndBuildHandler(projectInfraService, buildService, logger) + + // Override default health/ready endpoints with full dependency checks + healthHandler := handlers.NewHealthHandler("rdev-api", database.DB, nil). + WithAgentRegistry(agentRegistry) + + app.Router().Get("/health", healthHandler.Health) + app.Router().Get("/ready", healthHandler.Ready) + // Register routes projectsHandler.Mount(app.Router()) keysHandler.Mount(app.Router()) @@ -320,8 +370,12 @@ func main() { projectMgmtHandler.Mount(app.Router()) woodpeckerHandler.Mount(app.Router()) credentialsHandler.Mount(app.Router()) + agentsHandler.Mount(app.Router()) + workersHandler.Mount(app.Router()) + buildsHandler.Mount(app.Router()) + createAndBuildHandler.Mount(app.Router()) - // Start queue processor worker + // Start queue processor worker (per-project command queue) queueProcessor := worker.NewQueueProcessor( commandQueue, k8sExecutor, @@ -337,11 +391,57 @@ func main() { os.Exit(1) } + // Start work executor (cross-project worker pool) + var gitOps *worker.GitOperations + if infraCfg.GiteaToken != "" { + gitOps = worker.NewGitOperations(worker.GitOperationsConfig{ + GiteaToken: infraCfg.GiteaToken, + Logger: logger, + }) + } + buildExecutor := worker.NewBuildExecutor(agentRegistry, gitOps, logger) + workExecutor := worker.NewWorkExecutor( + workerService, + workService, + buildExecutor, + &worker.WorkExecutorConfig{ + PollPeriod: 5 * time.Second, + HeartbeatPeriod: 30 * time.Second, + Logger: logger, + }, + ) + if err := workExecutor.Start(); err != nil { + logger.Error("failed to start work executor", "error", err) + os.Exit(1) + } + healthHandler.WithWorkExecutor(workExecutor) + + // Start queue maintenance worker (stale task recovery, worker health, cleanup, metrics) + queueMaintenance := worker.NewQueueMaintenance( + workQueueRepo, + workerRegistryRepo, + &worker.QueueMaintenanceConfig{ + StaleTaskTimeout: 30 * time.Minute, + StaleWorkerTimeout: 2 * time.Minute, + CleanupAge: 7 * 24 * time.Hour, + MaintenancePeriod: 1 * time.Minute, + MetricsPeriod: 15 * time.Second, + Logger: logger, + }, + ) + queueMaintenance.Start() + // Enable API documentation app.EnableDocs(buildOpenAPISpec()) // Cleanup on shutdown app.OnShutdown(func(ctx context.Context) error { + // Stop work executor (deregisters worker) + workExecutor.Stop() + + // Stop queue maintenance worker + queueMaintenance.Stop() + // Stop queue processor queueProcessor.Stop() @@ -371,160 +471,5 @@ func main() { app.Run() } -// Config holds application configuration. -type Config struct { - Port int - DBHost string - DBPort int - DBUser string - DBPassword string - DBName string - DBSSLMode string - AdminKey string - - // Credential store encryption key (required for storing secrets in DB) - CredentialEncryptionKey string - - // Infrastructure adapters (threesix.ai) - fallback values if not in credential store - GiteaURL string - GiteaToken string - GiteaDefaultOrg string - CloudflareToken string - CloudflareZoneID string - DefaultDomain string - DeployNamespace string - DeployTLSIssuer string - ClusterIP string - RegistryURL string - WoodpeckerURL string - WoodpeckerAPIToken string - WoodpeckerWebhookSecret string -} - -// InfraConfig holds infrastructure adapter configuration. -// Loaded from credential store with env var fallback. -type InfraConfig struct { - GiteaURL string - GiteaToken string - GiteaDefaultOrg string - CloudflareToken string - CloudflareZoneID string - DefaultDomain string - DeployNamespace string - DeployTLSIssuer string - ClusterIP string - RegistryURL string - WoodpeckerURL string - WoodpeckerAPIToken string - WoodpeckerWebhookSecret string -} - -func loadConfig() Config { - port := 8080 - if v := os.Getenv("PORT"); v != "" { - if p, err := strconv.Atoi(v); err == nil { - port = p - } - } - - dbPort := 5432 - if v := os.Getenv("DB_PORT"); v != "" { - if p, err := strconv.Atoi(v); err == nil { - dbPort = p - } - } - - return Config{ - Port: port, - DBHost: getEnv("DB_HOST", "postgres.databases.svc"), - DBPort: dbPort, - DBUser: getEnv("DB_USER", "appuser"), - DBPassword: os.Getenv("DB_PASSWORD"), - DBName: getEnv("DB_NAME", "rdev"), - DBSSLMode: getEnv("DB_SSL_MODE", "disable"), - AdminKey: os.Getenv("RDEV_ADMIN_KEY"), - - // Encryption key for credential store (generate with: openssl rand -base64 32) - // REQUIRED in production - no default to prevent insecure deployments - CredentialEncryptionKey: os.Getenv("CREDENTIAL_ENCRYPTION_KEY"), - - // Infrastructure adapters (fallback if not in credential store) - GiteaURL: getEnv("GITEA_URL", "https://git.threesix.ai"), - GiteaToken: os.Getenv("GITEA_TOKEN"), - GiteaDefaultOrg: getEnv("GITEA_DEFAULT_ORG", "jordan"), - CloudflareToken: os.Getenv("CLOUDFLARE_API_TOKEN"), - CloudflareZoneID: os.Getenv("CLOUDFLARE_ZONE_ID"), - DefaultDomain: getEnv("DEFAULT_DOMAIN", "threesix.ai"), - DeployNamespace: getEnv("DEPLOY_NAMESPACE", "projects"), - DeployTLSIssuer: getEnv("DEPLOY_TLS_ISSUER", "letsencrypt-threesix"), - ClusterIP: getEnv("CLUSTER_IP", "208.122.204.172"), - RegistryURL: getEnv("REGISTRY_URL", "zot.threesix.svc.cluster.local:5000"), - WoodpeckerURL: getEnv("WOODPECKER_URL", "https://ci.threesix.ai"), - WoodpeckerAPIToken: os.Getenv("WOODPECKER_API_TOKEN"), - WoodpeckerWebhookSecret: os.Getenv("WOODPECKER_WEBHOOK_SECRET"), - } -} - -func getEnv(key, defaultVal string) string { - if v := os.Getenv(key); v != "" { - return v - } - return defaultVal -} - -// loadInfraConfig loads infrastructure configuration from credential store, -// falling back to environment variables if not found in the store. -func loadInfraConfig(ctx context.Context, store port.CredentialStore, cfg Config, logger *slog.Logger) InfraConfig { - // Try to load from credential store - creds, err := store.GetMultiple(ctx, []string{ - domain.CredKeyGiteaToken, - domain.CredKeyGiteaURL, - domain.CredKeyCloudflareAPIToken, - domain.CredKeyCloudflareZoneID, - domain.CredKeyWoodpeckerURL, - domain.CredKeyWoodpeckerAPIToken, - domain.CredKeyWoodpeckerWebhookSecret, - domain.CredKeyRegistryURL, - }) - if err != nil { - logger.Warn("failed to load credentials from store, using env vars", "error", err) - creds = make(map[string]string) - } - - // Helper to get from store or fall back to env var - getOrFallback := func(key, envFallback string) string { - if v, ok := creds[key]; ok && v != "" { - return v - } - return envFallback - } - - infraCfg := InfraConfig{ - GiteaURL: getOrFallback(domain.CredKeyGiteaURL, cfg.GiteaURL), - GiteaToken: getOrFallback(domain.CredKeyGiteaToken, cfg.GiteaToken), - GiteaDefaultOrg: cfg.GiteaDefaultOrg, // Not a secret, use env - CloudflareToken: getOrFallback(domain.CredKeyCloudflareAPIToken, cfg.CloudflareToken), - CloudflareZoneID: getOrFallback(domain.CredKeyCloudflareZoneID, cfg.CloudflareZoneID), - DefaultDomain: cfg.DefaultDomain, // Not a secret, use env - DeployNamespace: cfg.DeployNamespace, // Not a secret, use env - DeployTLSIssuer: cfg.DeployTLSIssuer, // Not a secret, use env - ClusterIP: cfg.ClusterIP, // Not a secret, use env - RegistryURL: getOrFallback(domain.CredKeyRegistryURL, cfg.RegistryURL), - WoodpeckerURL: getOrFallback(domain.CredKeyWoodpeckerURL, cfg.WoodpeckerURL), - WoodpeckerAPIToken: getOrFallback(domain.CredKeyWoodpeckerAPIToken, cfg.WoodpeckerAPIToken), - WoodpeckerWebhookSecret: getOrFallback(domain.CredKeyWoodpeckerWebhookSecret, cfg.WoodpeckerWebhookSecret), - } - - // Log which credentials were loaded from store vs env - fromStore := 0 - for k := range creds { - if creds[k] != "" { - fromStore++ - } - } - if fromStore > 0 { - logger.Info("loaded credentials from store", "count", fromStore) - } - - return infraCfg -} +// Config, InfraConfig, loadConfig, loadInfraConfig, and getEnv +// are defined in config.go. diff --git a/cmd/rdev-api/openapi.go b/cmd/rdev-api/openapi.go index a39b4f3..ea81f11 100644 --- a/cmd/rdev-api/openapi.go +++ b/cmd/rdev-api/openapi.go @@ -55,6 +55,9 @@ Command output is streamed via Server-Sent Events (SSE) at /projects/{id}/events spec.WithTag("Events", "Server-Sent Events for real-time output") spec.WithTag("Claude Config", "Manage commands, skills, and agents in /workspace/.claude/") spec.WithTag("Audit", "Command execution audit logs") + spec.WithTag("Code Agents", "Multi-provider code agent management") + spec.WithTag("Workers", "Worker pool management") + spec.WithTag("Builds", "Build orchestration and history") spec.WithTag("System", "Health and readiness endpoints") // Register all path operations @@ -65,6 +68,9 @@ Command output is streamed via Server-Sent Events (SSE) at /projects/{id}/events registerEventPaths(spec) registerClaudeConfigPaths(spec) registerAuditPaths(spec) + registerAgentPaths(spec) + registerWorkerPaths(spec) + registerBuildPaths(spec) return spec } @@ -456,141 +462,3 @@ func registerAuditPaths(spec *api.OpenAPISpec) { []param{{Name: "command_id", In: "path", Description: "Command ID", Required: true}}, )) } - -// param represents an OpenAPI parameter. -type param struct { - Name string - In string - Description string - Required bool -} - -// withAuth creates an operation that requires authentication. -func withAuth(summary, description, tag, scope, example string) map[string]any { - return map[string]any{ - "summary": summary, - "description": description + "\n\n**Required scope**: `" + scope + "`", - "tags": []string{tag}, - "security": []map[string]any{ - {"ApiKeyAuth": []string{}}, - }, - "responses": map[string]any{ - "200": map[string]any{ - "description": "Success", - "content": map[string]any{ - "application/json": map[string]any{ - "example": example, - }, - }, - }, - "401": map[string]any{"description": "Unauthorized - Missing or invalid API key"}, - "403": map[string]any{"description": "Forbidden - Insufficient permissions"}, - }, - } -} - -// withAuthAndBody creates an operation with auth and request body. -func withAuthAndBody(summary, description, tag, scope, requestExample, responseExample string) map[string]any { - return map[string]any{ - "summary": summary, - "description": description + "\n\n**Required scope**: `" + scope + "`", - "tags": []string{tag}, - "security": []map[string]any{ - {"ApiKeyAuth": []string{}}, - }, - "requestBody": map[string]any{ - "required": true, - "content": map[string]any{ - "application/json": map[string]any{ - "example": requestExample, - }, - }, - }, - "responses": map[string]any{ - "201": map[string]any{ - "description": "Created", - "content": map[string]any{ - "application/json": map[string]any{ - "example": responseExample, - }, - }, - }, - "400": map[string]any{"description": "Bad Request - Invalid input"}, - "401": map[string]any{"description": "Unauthorized - Missing or invalid API key"}, - "403": map[string]any{"description": "Forbidden - Insufficient permissions"}, - }, - } -} - -// withAuthAndParams creates an operation with auth and path parameters. -func withAuthAndParams(summary, description, tag, scope string, params []param) map[string]any { - parameters := make([]map[string]any, len(params)) - for i, p := range params { - parameters[i] = map[string]any{ - "name": p.Name, - "in": p.In, - "description": p.Description, - "required": p.Required, - "schema": map[string]any{"type": "string"}, - } - } - return map[string]any{ - "summary": summary, - "description": description + "\n\n**Required scope**: `" + scope + "`", - "tags": []string{tag}, - "security": []map[string]any{ - {"ApiKeyAuth": []string{}}, - }, - "parameters": parameters, - "responses": map[string]any{ - "200": map[string]any{"description": "Success"}, - "401": map[string]any{"description": "Unauthorized - Missing or invalid API key"}, - "403": map[string]any{"description": "Forbidden - Insufficient permissions"}, - "404": map[string]any{"description": "Not Found"}, - }, - } -} - -// withAuthBodyAndParams creates an operation with auth, body, and params. -func withAuthBodyAndParams(summary, description, tag, scope string, params []param, requestExample, responseExample string) map[string]any { - parameters := make([]map[string]any, len(params)) - for i, p := range params { - parameters[i] = map[string]any{ - "name": p.Name, - "in": p.In, - "description": p.Description, - "required": p.Required, - "schema": map[string]any{"type": "string"}, - } - } - return map[string]any{ - "summary": summary, - "description": description + "\n\n**Required scope**: `" + scope + "`", - "tags": []string{tag}, - "security": []map[string]any{ - {"ApiKeyAuth": []string{}}, - }, - "parameters": parameters, - "requestBody": map[string]any{ - "required": true, - "content": map[string]any{ - "application/json": map[string]any{ - "example": requestExample, - }, - }, - }, - "responses": map[string]any{ - "201": map[string]any{ - "description": "Created", - "content": map[string]any{ - "application/json": map[string]any{ - "example": responseExample, - }, - }, - }, - "400": map[string]any{"description": "Bad Request - Invalid input"}, - "401": map[string]any{"description": "Unauthorized - Missing or invalid API key"}, - "403": map[string]any{"description": "Forbidden - Insufficient permissions"}, - }, - } -} diff --git a/cmd/rdev-api/openapi_ext.go b/cmd/rdev-api/openapi_ext.go new file mode 100644 index 0000000..d94f9fb --- /dev/null +++ b/cmd/rdev-api/openapi_ext.go @@ -0,0 +1,342 @@ +package main + +import "github.com/orchard9/rdev/pkg/api" + +func registerAgentPaths(spec *api.OpenAPISpec) { + spec.AddPath("/agents", "get", withAuth( + "List code agents", + `Returns all registered code agent providers and their status. + +Shows which agents are available, their supported models, and the current default.`, + "Code Agents", + "projects:read", + `{ + "agents": [ + { + "provider": "claudecode", + "name": "Claude Code", + "available": true, + "default": true, + "supported_models": ["claude-sonnet-4-20250514"], + "default_model": "claude-sonnet-4-20250514" + }, + { + "provider": "opencode", + "name": "OpenCode", + "available": false, + "default": false, + "supported_models": ["gpt-4o", "claude-sonnet-4-20250514"], + "default_model": "claude-sonnet-4-20250514" + } + ], + "default_agent": "claudecode", + "total_agents": 2, + "available_count": 1 +}`, + )) + + spec.AddPath("/agents/health", "get", withAuth( + "Get agent health status", + `Returns the health status of all registered code agents. + +Checks connectivity to each agent backend and reports availability.`, + "Code Agents", + "projects:read", + `{ + "agents": [ + { + "provider": "claudecode", + "name": "Claude Code", + "healthy": true, + "message": "available", + "latency": "125ms", + "checked_at": "2026-01-27T12:00:00Z" + }, + { + "provider": "opencode", + "name": "OpenCode", + "healthy": false, + "message": "unavailable", + "latency": "5.002s", + "checked_at": "2026-01-27T12:00:00Z" + } + ], + "healthy_count": 1, + "total_count": 2, + "default_agent": "claudecode", + "default_healthy": true +}`, + )) + + spec.AddPath("/agents/{provider}", "get", withAuthAndParams( + "Get agent capabilities", + `Returns detailed capabilities for a specific code agent provider. + +Includes supported features, models, and configuration options.`, + "Code Agents", + "projects:read", + []param{{Name: "provider", In: "path", Description: "Agent provider ID (e.g., 'claudecode', 'opencode')", Required: true}}, + )) + + spec.AddPath("/agents/default", "post", withAuthAndBody( + "Set default agent", + `Changes the default code agent used for command execution. + +The specified provider must be registered and ideally available.`, + "Code Agents", + "admin", + `{"provider": "opencode"}`, + `{ + "default_agent": "opencode", + "message": "default agent updated" +}`, + )) +} + +// param represents an OpenAPI parameter. +type param struct { + Name string + In string + Description string + Required bool +} + +// withAuth creates an operation that requires authentication. +func withAuth(summary, description, tag, scope, example string) map[string]any { + return map[string]any{ + "summary": summary, + "description": description + "\n\n**Required scope**: `" + scope + "`", + "tags": []string{tag}, + "security": []map[string]any{ + {"ApiKeyAuth": []string{}}, + }, + "responses": map[string]any{ + "200": map[string]any{ + "description": "Success", + "content": map[string]any{ + "application/json": map[string]any{ + "example": example, + }, + }, + }, + "401": map[string]any{"description": "Unauthorized - Missing or invalid API key"}, + "403": map[string]any{"description": "Forbidden - Insufficient permissions"}, + }, + } +} + +// withAuthAndBody creates an operation with auth and request body. +func withAuthAndBody(summary, description, tag, scope, requestExample, responseExample string) map[string]any { + return map[string]any{ + "summary": summary, + "description": description + "\n\n**Required scope**: `" + scope + "`", + "tags": []string{tag}, + "security": []map[string]any{ + {"ApiKeyAuth": []string{}}, + }, + "requestBody": map[string]any{ + "required": true, + "content": map[string]any{ + "application/json": map[string]any{ + "example": requestExample, + }, + }, + }, + "responses": map[string]any{ + "201": map[string]any{ + "description": "Created", + "content": map[string]any{ + "application/json": map[string]any{ + "example": responseExample, + }, + }, + }, + "400": map[string]any{"description": "Bad Request - Invalid input"}, + "401": map[string]any{"description": "Unauthorized - Missing or invalid API key"}, + "403": map[string]any{"description": "Forbidden - Insufficient permissions"}, + }, + } +} + +// withAuthAndParams creates an operation with auth and path parameters. +func withAuthAndParams(summary, description, tag, scope string, params []param) map[string]any { + parameters := make([]map[string]any, len(params)) + for i, p := range params { + parameters[i] = map[string]any{ + "name": p.Name, + "in": p.In, + "description": p.Description, + "required": p.Required, + "schema": map[string]any{"type": "string"}, + } + } + return map[string]any{ + "summary": summary, + "description": description + "\n\n**Required scope**: `" + scope + "`", + "tags": []string{tag}, + "security": []map[string]any{ + {"ApiKeyAuth": []string{}}, + }, + "parameters": parameters, + "responses": map[string]any{ + "200": map[string]any{"description": "Success"}, + "401": map[string]any{"description": "Unauthorized - Missing or invalid API key"}, + "403": map[string]any{"description": "Forbidden - Insufficient permissions"}, + "404": map[string]any{"description": "Not Found"}, + }, + } +} + +// withAuthBodyAndParams creates an operation with auth, body, and params. +func withAuthBodyAndParams(summary, description, tag, scope string, params []param, requestExample, responseExample string) map[string]any { + parameters := make([]map[string]any, len(params)) + for i, p := range params { + parameters[i] = map[string]any{ + "name": p.Name, + "in": p.In, + "description": p.Description, + "required": p.Required, + "schema": map[string]any{"type": "string"}, + } + } + return map[string]any{ + "summary": summary, + "description": description + "\n\n**Required scope**: `" + scope + "`", + "tags": []string{tag}, + "security": []map[string]any{ + {"ApiKeyAuth": []string{}}, + }, + "parameters": parameters, + "requestBody": map[string]any{ + "required": true, + "content": map[string]any{ + "application/json": map[string]any{ + "example": requestExample, + }, + }, + }, + "responses": map[string]any{ + "201": map[string]any{ + "description": "Created", + "content": map[string]any{ + "application/json": map[string]any{ + "example": responseExample, + }, + }, + }, + "400": map[string]any{"description": "Bad Request - Invalid input"}, + "401": map[string]any{"description": "Unauthorized - Missing or invalid API key"}, + "403": map[string]any{"description": "Forbidden - Insufficient permissions"}, + }, + } +} + +func registerWorkerPaths(spec *api.OpenAPISpec) { + spec.AddPath("/workers", "get", withAuth( + "List workers", + "Returns all registered workers in the pool with status summary.", + "Workers", + "admin", + `{ + "workers": [ + { + "id": "rdev-worker-0", + "hostname": "rdev-worker-0.rdev.svc", + "status": "idle", + "capabilities": ["build", "test", "deploy"], + "registered_at": "2026-01-27T12:00:00Z", + "last_heartbeat": "2026-01-27T12:05:00Z", + "version": "1.0.0" + } + ], + "total": 1, + "summary": {"idle": 1, "busy": 0, "draining": 0, "offline": 0} +}`, + )) + + spec.AddPath("/workers/{workerId}", "get", withAuthAndParams( + "Get worker", + "Returns details for a specific worker.", + "Workers", + "admin", + []param{{Name: "workerId", In: "path", Description: "Worker ID", Required: true}}, + )) + + spec.AddPath("/workers/{workerId}/drain", "post", withAuthAndParams( + "Drain worker", + "Sets a worker to draining status. It will finish its current task but stop accepting new work.", + "Workers", + "admin", + []param{{Name: "workerId", In: "path", Description: "Worker ID", Required: true}}, + )) +} + +func registerBuildPaths(spec *api.OpenAPISpec) { + spec.AddPath("/projects/{id}/builds", "post", withAuthBodyAndParams( + "Start build", + "Enqueues a build task for a project. The build will be picked up by an available worker.", + "Builds", + "projects:execute", + []param{{Name: "id", In: "path", Description: "Project ID", Required: true}}, + `{ + "prompt": "Build a landing page with Next.js and Tailwind CSS", + "template": "nextjs-landing", + "auto_commit": true, + "auto_push": true, + "callback_url": "https://example.com/webhook" +}`, + `{ + "task_id": "task-abc123", + "project_id": "my-project", + "status": "pending", + "status_url": "/builds/task-abc123" +}`, + )) + + spec.AddPath("/projects/{id}/builds", "get", withAuthAndParams( + "List builds", + "Returns build history for a project.", + "Builds", + "projects:read", + []param{{Name: "id", In: "path", Description: "Project ID", Required: true}}, + )) + + spec.AddPath("/builds/{taskId}", "get", withAuthAndParams( + "Get build status", + "Returns the status and result of a specific build.", + "Builds", + "projects:read", + []param{{Name: "taskId", In: "path", Description: "Build task ID", Required: true}}, + )) + + spec.AddPath("/project/create-and-build", "post", withAuthAndBody( + "Create project and build", + `Creates a new project and immediately enqueues a build task. + +Combines project creation (git repo, DNS, CI activation) with build submission in a single call.`, + "Builds", + "admin", + `{ + "name": "my-landing-page", + "description": "Landing page for product launch", + "template": "nextjs-landing", + "prompt": "Build a modern landing page with hero, features, and CTA sections", + "auto_commit": true, + "auto_push": true +}`, + `{ + "project_id": "my-landing-page", + "name": "my-landing-page", + "domain": "my-landing-page.threesix.ai", + "url": "https://my-landing-page.threesix.ai", + "git": { + "owner": "jordan", + "name": "my-landing-page", + "clone_http": "https://git.threesix.ai/jordan/my-landing-page.git" + }, + "task_id": "task-abc123", + "status": "pending", + "status_url": "/builds/task-abc123" +}`, + )) +} diff --git a/cookbooks/landing-page.md b/cookbooks/landing-page.md index a5d8539..84f2f2a 100644 --- a/cookbooks/landing-page.md +++ b/cookbooks/landing-page.md @@ -7,27 +7,38 @@ This cookbook creates and deploys a simple landing page using the full threesix.ai autonomous infrastructure: ``` -rdev-api → Gitea repo → Claude agent → push → Woodpecker CI → K8s deployment +POST /project → Gitea repo + DNS + Woodpecker CI + template seed → Claude agent → git push → CI build → K8s deployment ``` **Target:** `landing.threesix.ai` (with future DNS aliases for www/root) -**Stack:** Astro (static site generator) +**Stack:** Astro (static site generator) via `astro-landing` template **Status:** Coming Soon page --- -## Current Architecture Gap +## What's Automated Today -**Two separate systems that need bridging:** +`POST /project` orchestrates the full infrastructure setup in a single call: -| System | Endpoint | What it manages | -|--------|----------|-----------------| -| Project Management | `POST /project` | Gitea repos, DNS records, K8s deployments | -| Claudebox Execution | `POST /projects/{id}/claude` | Code generation in existing claudebox pods | +| Step | Status | How | +|------|--------|-----| +| Gitea repo creation | Automated | `port.GitRepository` adapter | +| DNS A record | Automated | `port.DNSProvider` (Cloudflare) adapter | +| Woodpecker CI activation | Automated | `port.CIProvider` adapter, called during project creation | +| Template seeding | Automated | `port.TemplateProvider` with `astro-landing` template | +| K8s deployment | Automated | `port.Deployer` adapter (triggered by CI webhook) | -**The problem:** Creating a project via `POST /project` creates a Gitea repo, but there's no claudebox to generate code for it. The claudebox system only knows about pre-existing pods (pantheon, aeries). +## Full Pipeline Status -**The solution:** Use an existing claudebox as a "worker" to clone, build, and push to any project repo. +All infrastructure gaps have been closed. The full pipeline from project creation through code generation, CI monitoring, and multi-domain DNS is operational: + +| Capability | Endpoint | Status | +|------------|----------|--------| +| Project creation | `POST /project` | Operational | +| Code generation (worker) | `POST /projects/{id}/builds` | Operational | +| Create + build combo | `POST /project/create-and-build` | Operational | +| CI pipeline monitoring | `GET /projects/{id}/pipelines` | Operational | +| DNS alias management | `POST /projects/{id}/domains` | Operational | --- @@ -39,12 +50,12 @@ rdev-api → Gitea repo → Claude agent → push → Woodpecker CI → K8s depl |--------|----------|---------| | RDEV_ADMIN_KEY | `rdev-credentials` secret | rdev-api authentication | | GITEA_TOKEN | `rdev-credentials` secret | Gitea API access | -| WOODPECKER_API_TOKEN | `.secrets` file | Woodpecker repo activation | +| WOODPECKER_API_TOKEN | `rdev-credentials` secret | Woodpecker repo activation | | CLOUDFLARE_API_TOKEN | `rdev-credentials` secret | DNS management | ### Infrastructure Required -- [x] rdev-api running with infrastructure handlers (v0.7.1+) +- [x] rdev-api running with infrastructure handlers - [x] Gitea at https://git.threesix.ai - [x] Woodpecker CI at https://ci.threesix.ai - [x] Zot registry at zot.threesix.svc.cluster.local:5000 @@ -56,38 +67,35 @@ rdev-api → Gitea repo → Claude agent → push → Woodpecker CI → K8s depl ## Architecture ``` -┌─────────────────────────────────────────────────────────────────────────────┐ -│ Landing Page Flow │ -│ │ -│ 1. Create Project │ -│ POST /project {"name": "www"} │ -│ │ │ -│ ├──▶ Creates Gitea repo: jordan/www │ -│ └──▶ Creates DNS: www.threesix.ai → 208.122.204.172 │ -│ │ -│ 2. Activate Woodpecker │ -│ POST /api/repos?forge_remote_id={id} │ -│ │ │ -│ └──▶ Creates webhook in Gitea │ -│ │ -│ 3. Generate Code (Claude Agent) │ -│ claudebox or local Claude Code │ -│ │ │ -│ ├──▶ Creates Astro project │ -│ ├──▶ Creates Dockerfile │ -│ ├──▶ Creates .woodpecker.yml │ -│ └──▶ Pushes to Gitea │ -│ │ -│ 4. CI/CD Pipeline (automatic) │ -│ Woodpecker triggered by push │ -│ │ │ -│ ├──▶ Kaniko builds Docker image │ -│ ├──▶ Pushes to Zot registry │ -│ └──▶ Webhook triggers rdev-api deploy │ -│ │ -│ 5. Live at https://www.threesix.ai │ -│ │ -└─────────────────────────────────────────────────────────────────────────────┘ +┌──────────────────────────────────────────────────────────────────────┐ +│ Landing Page Flow │ +│ │ +│ 1. Create Project (single API call) │ +│ POST /project {"name": "landing", "template": "astro-landing"} │ +│ │ │ +│ ├──▶ Creates Gitea repo: threesix/landing │ +│ ├──▶ Creates DNS: landing.threesix.ai → cluster IP │ +│ ├──▶ Activates Woodpecker CI (auto) │ +│ └──▶ Seeds repo with astro-landing template │ +│ │ +│ 2. Generate Code (3 options) │ +│ Via worker pool, claudebox, or local Claude Code │ +│ │ │ +│ ├──▶ Customizes Astro landing page │ +│ ├──▶ Commits and pushes to Gitea │ +│ └──▶ Worker executor polls queue, dispatches to agent │ +│ │ +│ 3. CI/CD Pipeline (automatic on push) │ +│ Woodpecker triggered by git push │ +│ │ │ +│ ├──▶ npm install + npm build │ +│ ├──▶ Docker build (nginx) │ +│ ├──▶ Push to Zot registry │ +│ └──▶ kubectl set image (deploy) │ +│ │ +│ 4. Live at https://landing.threesix.ai │ +│ │ +└──────────────────────────────────────────────────────────────────────┘ ``` --- @@ -96,213 +104,193 @@ rdev-api → Gitea repo → Claude agent → push → Woodpecker CI → K8s depl ### Step 1: Create Project via rdev-api -```bash -RDEV_KEY="rdev_sk_prod_7f3a9c2e1d8b4a6f0e5c9d2b7a1f8e4c" +This single call creates the Gitea repo, DNS record, activates Woodpecker CI, and seeds the repo with the `astro-landing` template. +```bash curl -X POST https://rdev.masq-ops.orchard9.ai/project \ -H "Authorization: Bearer $RDEV_KEY" \ -H "Content-Type: application/json" \ - -d '{"name": "landing", "description": "threesix.ai landing page"}' + -d '{ + "name": "landing", + "description": "threesix.ai landing page", + "template": "astro-landing" + }' ``` **Response:** ```json { - "data": { + "project_id": "landing", + "name": "landing", + "description": "threesix.ai landing page", + "git": { + "owner": "threesix", "name": "landing", - "domain": "landing.threesix.ai", - "git": { - "clone_ssh": "git@git.threesix.ai:jordan/landing.git", - "clone_http": "https://git.threesix.ai/jordan/landing.git" - } - } + "clone_ssh": "git@git.threesix.ai:threesix/landing.git", + "clone_http": "https://git.threesix.ai/threesix/landing.git", + "html_url": "https://git.threesix.ai/threesix/landing" + }, + "domain": "landing.threesix.ai", + "url": "https://landing.threesix.ai", + "next_steps": [] } ``` -### Step 2: Activate Woodpecker CI +If any infrastructure step fails, `next_steps` will contain manual instructions for that step. The remaining steps still execute. +### Step 2: Generate Code + +The `astro-landing` template seeds the repo with a working Astro project, Dockerfile, nginx.conf, and `.woodpecker.yml`. You can customize it via Claude. + +**Option A: Local Claude Code (recommended for now)** ```bash -GITEA_TOKEN="5508ff241943e84aad0ced3559f5fbd311a2fb81" -WOODPECKER_TOKEN="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ0eXBlIjoidXNlciIsInVzZXItaWQiOiIxIn0.LcyVHcZ_gSvVH1w3y6TUCp_Jg9ubfsebOAVo-MtiNP8" - -# Get Gitea repo ID -REPO_ID=$(curl -s https://git.threesix.ai/api/v1/repos/jordan/landing \ - -H "Authorization: token $GITEA_TOKEN" | jq '.id') - -# Activate in Woodpecker (creates webhook automatically) -curl -X POST "https://ci.threesix.ai/api/repos?forge_remote_id=$REPO_ID" \ - -H "Authorization: Bearer $WOODPECKER_TOKEN" +git clone https://git.threesix.ai/threesix/landing.git +cd landing +# Use Claude Code to customize the landing page +# Then commit and push ``` -### Step 3: Generate Code via Claudebox - -Use the `pantheon` claudebox as a worker to generate code for the landing project: - +**Option B: Via existing claudebox** ```bash -RDEV_KEY="rdev_sk_prod_7f3a9c2e1d8b4a6f0e5c9d2b7a1f8e4c" - -# Tell Claude to build the landing page in /tmp/landing curl -X POST "https://rdev.masq-ops.orchard9.ai/projects/pantheon/claude" \ -H "Authorization: Bearer $RDEV_KEY" \ -H "Content-Type: application/json" \ -d '{ - "prompt": "Clone https://git.threesix.ai/jordan/landing.git to /tmp/landing, then create a simple Astro landing page with: Coming Soon message, threesix.ai branding (dark theme), responsive layout, Dockerfile (nginx), and .woodpecker.yml for CI/CD. Commit and push when done." + "prompt": "Clone https://git.threesix.ai/threesix/landing.git to /tmp/landing, then customize the Astro landing page with: Coming Soon message, threesix.ai branding (dark theme, gradient background), responsive layout. Commit and push when done." }' ``` -**What happens:** -1. Claude receives the prompt in the claudebox -2. Claude clones the repo to `/tmp/landing` -3. Claude generates the Astro project files -4. Claude commits and pushes to Gitea - -### Step 4: Monitor Build - -Watch Woodpecker for the build: -- https://ci.threesix.ai/jordan/landing - -Or via API: +**Option C: Via work queue (build endpoint)** ```bash -curl -s "https://ci.threesix.ai/api/repos/jordan/landing/pipelines" \ - -H "Authorization: Bearer $WOODPECKER_TOKEN" | jq '.[0] | {number, status, started}' -``` - -### Step 5: Verify Deployment - -```bash -curl -s "https://rdev.masq-ops.orchard9.ai/project/landing" \ - -H "Authorization: Bearer $RDEV_KEY" | jq '.data.deployment' -``` - -Site live at: https://landing.threesix.ai - -### Step 6: Configure DNS Aliases (Optional) - -Point `www.threesix.ai` and `threesix.ai` to the landing page: - -```bash -CF_TOKEN="nGoDhG6Za66XsKMl6W7LNXuowc5EM00glHxkq1KK" -CF_ZONE="e0bc8d510f62807b360db0c5994964c5" - -# Update root A record to point to k3s cluster -curl -X PATCH "https://api.cloudflare.com/client/v4/zones/$CF_ZONE/dns_records/{record_id}" \ - -H "Authorization: Bearer $CF_TOKEN" \ +# Enqueue a build task for a worker to execute +curl -X POST "https://rdev.masq-ops.orchard9.ai/projects/landing/builds" \ + -H "Authorization: Bearer $RDEV_KEY" \ -H "Content-Type: application/json" \ -d '{ - "type": "A", - "name": "threesix.ai", - "content": "208.122.204.172", - "proxied": false + "prompt": "Customize the landing page with Coming Soon message and threesix.ai branding", + "auto_commit": true, + "auto_push": true }' ``` ---- - -## File Templates - -### .woodpecker.yml - -```yaml -steps: - build: - image: gcr.io/kaniko-project/executor:latest - settings: - registry: zot.threesix.svc.cluster.local:5000 - tags: - - ${CI_COMMIT_SHA:0:8} - - latest - repo: zot.threesix.svc.cluster.local:5000/${CI_REPO_NAME} - context: . - dockerfile: Dockerfile - insecure: true - when: - branch: main - - notify: - image: alpine/curl:latest - commands: - - echo "Build complete, webhook will trigger deployment" - when: - branch: main - status: success +**Option D: Create + build in one call** +```bash +curl -X POST "https://rdev.masq-ops.orchard9.ai/project/create-and-build" \ + -H "Authorization: Bearer $RDEV_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "landing", + "description": "threesix.ai landing page", + "template": "astro-landing", + "build": { + "prompt": "Customize with Coming Soon message and threesix.ai branding", + "auto_commit": true, + "auto_push": true + } + }' ``` -### Dockerfile (Astro + Nginx) +### Step 3: Monitor Build -```dockerfile -# Build stage -FROM node:20-alpine AS builder -WORKDIR /app -COPY package*.json ./ -RUN npm ci -COPY . . -RUN npm run build +The git push triggers Woodpecker CI automatically. -# Production stage -FROM nginx:alpine -COPY --from=builder /app/dist /usr/share/nginx/html -COPY nginx.conf /etc/nginx/conf.d/default.conf -EXPOSE 80 -CMD ["nginx", "-g", "daemon off;"] +**Via rdev-api (recommended):** +```bash +# List recent pipelines +curl -s "https://rdev.masq-ops.orchard9.ai/projects/landing/pipelines" \ + -H "Authorization: Bearer $RDEV_KEY" | jq '.data.pipelines' + +# Get specific pipeline +curl -s "https://rdev.masq-ops.orchard9.ai/projects/landing/pipelines/1" \ + -H "Authorization: Bearer $RDEV_KEY" | jq '.data' ``` -### nginx.conf +**Via Woodpecker UI:** +- https://ci.threesix.ai/threesix/landing -```nginx -server { - listen 80; - server_name _; - root /usr/share/nginx/html; - index index.html; +### Step 4: Verify Deployment - location / { - try_files $uri $uri/ /index.html; - } +```bash +# Check project status via rdev-api +curl -s "https://rdev.masq-ops.orchard9.ai/project/landing" \ + -H "Authorization: Bearer $RDEV_KEY" | jq '.data.deployment' - # Cache static assets - location ~* \.(js|css|png|jpg|jpeg|gif|ico|svg|woff|woff2)$ { - expires 1y; - add_header Cache-Control "public, immutable"; - } -} +# Check deployment via K8s +export KUBECONFIG=~/.kube/orchard9-k3sf.yaml +kubectl get deploy -n projects landing + +# Check the site +curl -I https://landing.threesix.ai +``` + +### Step 5: Configure DNS Aliases (Optional) + +Point `www.threesix.ai` and `threesix.ai` to the landing page via the rdev-api domain alias endpoints. + +```bash +# Add www.threesix.ai as a CNAME alias +curl -X POST "https://rdev.masq-ops.orchard9.ai/projects/landing/domains" \ + -H "Authorization: Bearer $RDEV_KEY" \ + -H "Content-Type: application/json" \ + -d '{"domain": "www.threesix.ai", "type": "CNAME"}' + +# Add root A record +curl -X POST "https://rdev.masq-ops.orchard9.ai/projects/landing/domains" \ + -H "Authorization: Bearer $RDEV_KEY" \ + -H "Content-Type: application/json" \ + -d '{"domain": "threesix.ai"}' + +# List all domains for the project +curl -s "https://rdev.masq-ops.orchard9.ai/projects/landing/domains" \ + -H "Authorization: Bearer $RDEV_KEY" | jq '.data.domains' + +# Remove an alias +curl -X DELETE "https://rdev.masq-ops.orchard9.ai/projects/landing/domains/www.threesix.ai" \ + -H "Authorization: Bearer $RDEV_KEY" ``` --- -## Current Gaps & Future Automation +## Template: astro-landing -### What's Manual Today +The `astro-landing` template (`deployments/k8s/base/templates/astro-landing/`) seeds the repo with: -| Step | Status | Automation Path | -|------|--------|-----------------| -| Create project | ✅ API | Already automated | -| Activate Woodpecker | 🔧 API call needed | Add to rdev-api | -| Generate code | ❌ Manual Claude | Claudebox integration | -| Push to Gitea | ❌ Manual git | Claudebox with SSH key | -| Deploy | ✅ Webhook | Already automated | +| File | Purpose | +|------|---------| +| `.woodpecker.yml` | CI pipeline: npm build, Docker build, push, deploy | +| `.claude/CLAUDE.md` | Project instructions for Claude Code | +| `Dockerfile` | Multi-stage build (Node 20 build, nginx serve) | +| `nginx.conf` | Production config with gzip, caching, SPA fallback | +| `package.json` | Astro 4.0+ with Tailwind CSS | +| `astro.config.mjs` | Astro configuration | +| `tailwind.config.mjs` | Tailwind configuration | +| `src/pages/index.astro` | Landing page (dark theme with gradient) | +| `src/layouts/Layout.astro` | Base HTML layout | +| `README.md` | Development and deployment docs | -### To Fully Automate (Future Work) +Variables substituted during seeding: `{{PROJECT_NAME}}`, `{{DOMAIN}}`, `{{GIT_URL}}` -1. **Add Woodpecker activation to rdev-api** - - Store WOODPECKER_API_TOKEN in secrets - - Call Woodpecker API after creating Gitea repo - - Create webhook automatically +--- -2. **Claudebox code generation** - - Spawn claudebox with project context - - Claudebox has Gitea SSH key - - Claude Code generates code based on prompt - - Auto-push to Gitea +## Implementation Status -3. **Single API call** - ``` - POST /project/create-and-build - { - "name": "www", - "prompt": "Create an Astro landing page with coming soon message", - "stack": "astro" - } - ``` +All components for the full landing page pipeline are implemented: + +| Component | Location | Status | +|-----------|----------|--------| +| Work queue (enqueue/dequeue) | `internal/adapter/postgres/work_queue.go` | Implemented | +| Worker registry | `internal/adapter/postgres/worker_registry.go` | Implemented | +| Build audit tracking | `internal/adapter/postgres/build_audit.go` | Implemented | +| Build service | `internal/service/build_service.go` | Implemented | +| Worker service | `internal/service/worker_service.go` | Implemented | +| Work handlers (REST) | `internal/handlers/work.go` | Implemented | +| Code agent interface | `internal/port/code_agent.go` | Implemented | +| Worker executor daemon | `internal/worker/work_executor.go` | Implemented | +| BuildSpec-to-agent bridge | `internal/worker/build_executor.go` | Implemented | +| Git credential resolution | `internal/service/credential_service.go` | Implemented | +| DNS alias endpoints | `internal/handlers/infrastructure_domains.go` | Implemented | +| CI pipeline proxy | `internal/handlers/infrastructure_pipelines.go` | Implemented | +| Create-and-build endpoint | `internal/handlers/create_and_build.go` | Implemented | --- @@ -312,13 +300,13 @@ After deployment, verify: ```bash # Check DNS -dig www.threesix.ai +dig landing.threesix.ai # Check site -curl -I https://www.threesix.ai +curl -I https://landing.threesix.ai # Check deployment status -curl https://rdev.masq-ops.orchard9.ai/project/www \ +curl https://rdev.masq-ops.orchard9.ai/project/landing \ -H "Authorization: Bearer $RDEV_KEY" ``` @@ -329,8 +317,8 @@ curl https://rdev.masq-ops.orchard9.ai/project/www \ To remove the landing page: ```bash -# Delete via rdev-api (removes Gitea repo, DNS, K8s deployment) -curl -X DELETE https://rdev.masq-ops.orchard9.ai/project/www \ +# Delete via rdev-api (removes DNS, K8s deployment; Gitea repo preserved for safety) +curl -X DELETE https://rdev.masq-ops.orchard9.ai/project/landing \ -H "Authorization: Bearer $RDEV_KEY" ``` @@ -338,5 +326,7 @@ curl -X DELETE https://rdev.masq-ops.orchard9.ai/project/www \ ## Related -- [THREESIX_INFRASTRUCTURE.md](/Users/jordanwashburn/Workspace/orchard9/rdev/docs/plans/THREESIX_INFRASTRUCTURE.md) - Infrastructure plan -- [woodpecker-pipeline-template.yaml](../deployments/k8s/base/threesix/woodpecker-pipeline-template.yaml) - CI template +- [Build Orchestration](../ai-lookup/features/build-orchestration.md) - Build system documentation +- [Worker Pool](../ai-lookup/services/worker-pool.md) - Worker pool management +- [Work Queue](../.claude/guides/services/work-queue.md) - Work queue guide +- [Templates](../.claude/guides/services/templates.md) - Project template guide diff --git a/docs/features/multi-provider.md b/docs/features/multi-provider.md index 56b139d..6d86937 100644 --- a/docs/features/multi-provider.md +++ b/docs/features/multi-provider.md @@ -1,6 +1,6 @@ # Multi-Provider Code Agent Interface -> **Status:** In Progress (Weeks 1-4 Complete) +> **Status:** Complete (Weeks 1-5) > **Feature:** Unified interface supporting Claude Code and OpenCode providers ## Overview @@ -15,7 +15,7 @@ This document describes the architecture for supporting multiple code agent prov | Week 2: Claude Code Adapter | ✅ Complete | kubectl exec wrapper, stream-json parser | | Week 3: OpenCode Adapter | ✅ Complete | HTTP/SSE client, session management | | Week 4: Service Integration | ✅ Complete | ProjectService integration, event streaming | -| Week 5: Polish | ⬜ Pending | Model selection API, health monitoring, metrics, docs | +| Week 5: Polish | ✅ Complete | Agent HTTP endpoints, health monitoring, metrics, DI wiring | ## Architecture @@ -448,7 +448,78 @@ func (s *ProjectService) GetDefaultAgent() domain.AgentProvider func (s *ProjectService) SetDefaultAgent(provider domain.AgentProvider) error ``` -## API Changes (⬜ Pending - Week 5) +## API Changes (✅ Complete - Week 5) + +### Agent Management Endpoints + +```http +# List all registered agents +GET /agents + +{ + "data": { + "agents": [ + { + "provider": "claudecode", + "name": "Claude Code", + "available": true, + "default": true, + "supported_models": ["claude-sonnet-4-20250514"], + "default_model": "claude-sonnet-4-20250514" + } + ], + "default_agent": "claudecode", + "total_agents": 1, + "available_count": 1 + } +} + +# Get agent capabilities +GET /agents/{provider} + +{ + "data": { + "provider": "claudecode", + "supports_session_continuation": true, + "supports_model_selection": false, + "supports_tool_control": true, + "supports_streaming": true, + "supported_models": ["claude-sonnet-4-20250514"], + "default_model": "claude-sonnet-4-20250514", + "max_prompt_length": 100000 + } +} + +# Set default agent +POST /agents/default +Content-Type: application/json + +{ + "provider": "opencode" +} + +# Agent health status +GET /agents/health + +{ + "data": { + "agents": [ + { + "provider": "claudecode", + "name": "Claude Code", + "healthy": true, + "message": "available", + "latency": "1.234ms", + "checked_at": "2025-01-27T10:00:00Z" + } + ], + "healthy_count": 1, + "total_count": 1, + "default_agent": "claudecode", + "default_healthy": true + } +} +``` ### Project Response @@ -531,35 +602,56 @@ internal/ │ ├── adapter.go ✅ CodeAgent implementation │ ├── client.go ✅ HTTP/SSE client │ └── adapter_test.go ✅ Mock server tests +├── handlers/ +│ ├── agents.go ✅ Week 5: Agent management endpoints +│ └── agents_test.go ✅ Week 5: Handler tests ├── service/ │ ├── project_service.go ✅ Week 4: Agent registry integration -│ ├── project_service_agent.go ✅ Week 4: Agent execution methods +│ ├── project_service_agent.go ✅ Week 4: Agent execution methods + metrics │ ├── project_service_commands.go ✅ Extracted shell/git commands │ └── project_service_queue.go ✅ Extracted queue operations +├── metrics/ +│ └── metrics.go ✅ Week 5: Agent metrics (requests, tool use, availability) └── worker/ - └── queue_processor.go ⬜ Week 5: Use CodeAgent for queue + └── queue_processor.go ⬜ Future: Use CodeAgent for queue +cmd/ +└── rdev-api/ + ├── main.go ✅ Week 5: Agent registry DI wiring + └── openapi.go ✅ Week 5: Agent API documentation ``` -## Observability (⬜ Pending - Week 5) +## Observability (✅ Complete - Week 5) -### Metrics +### Prometheus Metrics | Metric | Labels | Description | |--------|--------|-------------| -| `code_agent_requests_total` | provider, project, status | Total requests | -| `code_agent_duration_seconds` | provider, project | Execution duration | -| `code_agent_events_total` | provider, event_type | Streaming events | +| `rdev_agent_requests_total` | provider, status | Total code agent requests | +| `rdev_agent_request_duration_seconds` | provider | Execution duration histogram | +| `rdev_agent_tool_use_total` | provider, tool | Tool invocations by agents | +| `rdev_agent_available` | provider | Availability gauge (1=available, 0=unavailable) | ### Health Check ```http -GET /health +GET /agents/health { - "status": "healthy", - "agents": { - "claudecode": "available", - "opencode": "unavailable" + "data": { + "agents": [ + { + "provider": "claudecode", + "name": "Claude Code", + "healthy": true, + "message": "available", + "latency": "1.234ms", + "checked_at": "2025-01-27T10:00:00Z" + } + ], + "healthy_count": 1, + "total_count": 1, + "default_agent": "claudecode", + "default_healthy": true } } ``` diff --git a/docs/plans/worker-executor-breakdown.md b/docs/plans/worker-executor-breakdown.md new file mode 100644 index 0000000..50c6042 --- /dev/null +++ b/docs/plans/worker-executor-breakdown.md @@ -0,0 +1,320 @@ +# Worker Executor Implementation Plan + +> Close the last gap in the landing page cookbook: automated code generation via the worker pool. + +## Context + +The work queue, worker registry, build audit, and code agent systems are **all implemented**. The single missing piece is a **work executor** — a background loop that consumes queued tasks and executes them via a code agent. This is analogous to the existing `QueueProcessor` (which processes per-project command queue tasks), but for the generic `WorkQueue` (cross-project worker pool tasks). + +### What Already Exists + +| Component | File | Status | +|-----------|------|--------| +| Work queue (PostgreSQL) | `internal/adapter/postgres/work_queue.go` | Done | +| Worker registry (PostgreSQL) | `internal/adapter/postgres/worker_registry.go` | Done | +| Build audit (PostgreSQL) | `internal/adapter/postgres/build_audit.go` | Done | +| WorkService (enqueue/dequeue/complete/fail) | `internal/service/work_service.go` | Done | +| WorkerService (claim/complete/health) | `internal/service/worker_service.go` | Done | +| BuildService (start/status/complete) | `internal/service/build_service.go` | Done | +| WorkHandler (REST API) | `internal/handlers/work.go` | Done | +| AgentsHandler (REST API) | `internal/handlers/agents.go` | Done | +| CodeAgent interface | `internal/port/code_agent.go` | Done | +| Domain models (WorkTask, Worker, BuildSpec) | `internal/domain/` | Done | +| Command QueueProcessor (reference pattern) | `internal/worker/queue_processor.go` | Done | + +### What's Missing + +| Gap | Priority | +|-----|----------| +| Work executor daemon (poll loop) | Critical | +| BuildSpec → AgentRequest translation | Critical | +| Git clone/commit/push in executor | Critical | +| Git credential resolution for cross-project | High | +| Worker management REST endpoints | Medium | +| DNS alias endpoint | Medium | +| Create-and-build endpoint | Medium | +| Woodpecker build status proxy | Low | + +--- + +## Week 1: Work Executor Core + +**Goal:** A background loop that claims tasks from the work queue and executes them via a code agent. By end of week, `POST /work/enqueue` → task claimed → agent executes → result recorded. + +### Tasks + +1. **Create `internal/worker/work_executor.go`** + - Follow the `QueueProcessor` pattern from `queue_processor.go` + - Poll loop: calls `WorkerService.ClaimTask(workerID)` on a ticker + - On task claim: route to appropriate handler based on `task.Type` + - On completion: call `WorkerService.CompleteTask(workerID, taskID, result)` + - On failure: call `WorkService.FailTask(taskID, errMsg)` (handles retry logic) + - Graceful shutdown via context cancellation + - Self-registers as a worker via `WorkerService.Register()` on start + - Sends heartbeats via `WorkerService.Heartbeat()` on a 30s ticker + +2. **Create `internal/worker/build_executor.go`** + - Handles `WorkTaskTypeBuild` tasks specifically + - Extracts `BuildSpec` fields from `WorkTask.Spec` (map[string]any → typed fields) + - Translates `BuildSpec.Prompt` into `domain.AgentRequest` + - Calls `CodeAgent.Execute()` with event streaming + - Collects output, files changed, duration into `domain.BuildResult` + - Returns `BuildResult` to the work executor + +3. **Wire into `cmd/rdev-api/main.go`** + - Create `WorkExecutor` alongside existing `QueueProcessor` + - Inject: `WorkerService`, `BuildService`, `CodeAgentRegistry` + - Start on boot, stop on shutdown + - Worker ID: hostname or pod name (from `HOSTNAME` env var) + +4. **Create `internal/worker/work_executor_test.go`** + - Test: executor starts and registers as a worker + - Test: executor claims a task and routes to build handler + - Test: build handler translates spec and calls code agent + - Test: results are recorded via CompleteTask + - Test: failures trigger FailTask with retry + - Test: graceful shutdown stops the poll loop + - Use mock implementations of ports + +### Deliverables + +- `POST /work/enqueue` with a build task → executor picks it up → agent runs → result in `GET /work/{taskId}` +- Worker visible in registry during execution +- Build audit entry created with spec and result + +### Files Created/Modified + +| File | Action | +|------|--------| +| `internal/worker/work_executor.go` | Create | +| `internal/worker/build_executor.go` | Create | +| `internal/worker/work_executor_test.go` | Create | +| `cmd/rdev-api/main.go` | Modify (wire executor) | + +--- + +## Week 2: Git Operations & Cross-Project Execution + +**Goal:** The executor can clone any project's repo, run the agent in that directory, and push results back. By end of week, the full build cycle works: enqueue → clone → agent generates code → commit → push → CI triggers. + +### Tasks + +1. **Create `internal/worker/git_operations.go`** + - `CloneRepo(ctx, gitURL, dir, token) error` — clone via HTTPS with token auth + - `CommitAndPush(ctx, dir, message) (commitSHA string, filesChanged []string, err error)` + - `ConfigureGit(dir, name, email)` — set git user for commits + - Uses `os/exec` for git commands (same pattern as `kubernetes.Executor` uses for kubectl) + - Workspace management: creates temp dir per task, cleans up after + +2. **Add git credential resolution to `BuildExecutor`** + - Option A (simplest): Use the Gitea token already in `InfraConfig.GiteaToken` + - All project repos are in Gitea, so one token covers all repos + - Pass token via HTTPS clone URL: `https://token@git.threesix.ai/org/repo.git` + - Option B (per-project): Look up project's git URL from database, resolve credentials + - **Recommendation:** Option A — the Gitea token is already loaded and available + +3. **Integrate git ops into `BuildExecutor`** + - Before agent execution: clone the project's repo to a temp directory + - Look up project git URL from database (add `ProjectStore` port or query directly) + - After agent execution: if `auto_commit` is true, commit changes + - After commit: if `auto_push` is true, push to remote + - Capture `commit_sha` and `files_changed` in `BuildResult` + +4. **Add project git URL lookup** + - The `ProjectInfraService` stores git URLs in the database during `CreateProject` + - Add a method to retrieve git info by project ID + - Or: include `git_url` in the `WorkTask.Spec` at enqueue time (simpler, no extra lookup) + +5. **Create `internal/worker/git_operations_test.go`** + - Test: clone with token auth + - Test: commit and push + - Test: workspace cleanup on success and failure + - Test: git URL construction with token + +6. **Integration test** + - Enqueue a build task with a real prompt + - Verify agent executes in cloned repo + - Verify commit is created (if auto_commit) + - Verify push succeeds (if auto_push) + - Verify BuildResult has correct fields + +### Deliverables + +- Full build cycle: enqueue → clone → execute → commit → push +- Git credentials resolved from infrastructure config +- Temp workspace created and cleaned per task +- Build audit shows commit SHA and files changed + +### Files Created/Modified + +| File | Action | +|------|--------| +| `internal/worker/git_operations.go` | Create | +| `internal/worker/git_operations_test.go` | Create | +| `internal/worker/build_executor.go` | Modify (add git integration) | +| `internal/worker/work_executor.go` | Modify (pass git config) | +| `cmd/rdev-api/main.go` | Modify (pass gitea token to executor) | + +--- + +## Week 3: API Enhancements + +**Goal:** Add the REST endpoints that complete the platform experience. By end of week, users can create a project, enqueue a build, monitor CI status, and manage DNS — all through rdev-api. + +### Tasks + +1. **Worker management endpoints — `internal/handlers/workers.go`** + - `GET /workers` — list all workers with status + - `GET /workers/{id}` — get worker details + - `POST /workers/{id}/drain` — drain a worker + - Wire `WorkerService` into handler + - Register in `cmd/rdev-api/main.go` and `openapi.go` + +2. **Build management endpoints — `internal/handlers/builds.go`** + - `POST /projects/{id}/builds` — enqueue a build (wraps `BuildService.StartBuild()`) + - `GET /projects/{id}/builds` — list build history + - `GET /projects/{id}/builds/{taskId}` — get build status + - Simpler API than raw `/work/enqueue` — project-scoped, build-specific + - Register in `cmd/rdev-api/main.go` and `openapi.go` + +3. **DNS alias endpoint — `internal/handlers/infrastructure.go`** + - `POST /projects/{id}/domains` — add DNS alias (A or CNAME record) + - `GET /projects/{id}/domains` — list domains for project + - `DELETE /projects/{id}/domains/{domain}` — remove alias + - Uses existing Cloudflare adapter's `CreateRecord()` and `DeleteRecordByName()` + - The adapter already supports full CRUD — just needs a handler + +4. **Woodpecker build status proxy — `internal/handlers/ci.go`** + - `GET /projects/{id}/ci/pipelines` — list recent Woodpecker pipelines + - `GET /projects/{id}/ci/pipelines/{number}` — get pipeline details + - Add `ListPipelines()` and `GetPipeline()` to `port.CIProvider` + - Implement in `internal/adapter/woodpecker/client.go` using Woodpecker SDK + - Low priority — can defer if time is tight + +5. **Create-and-build endpoint — `internal/handlers/project_management.go`** + - `POST /project/create-and-build` + - Request: `{ name, description, template, prompt, auto_push }` + - Calls `ProjectInfraService.CreateProject()` then `BuildService.StartBuild()` + - Returns project info + task ID + - Trivial once executor is working + +6. **Tests for all new handlers** + - Follow existing patterns in `handlers/*_test.go` + - Test request validation, success paths, error handling + +### Deliverables + +- `POST /projects/{id}/builds` as the clean API for code generation +- `GET /workers` for monitoring the worker pool +- `POST /projects/{id}/domains` for DNS aliases +- `POST /project/create-and-build` for the single-call flow +- All endpoints documented in `openapi.go` + +### Files Created/Modified + +| File | Action | +|------|--------| +| `internal/handlers/workers.go` | Create | +| `internal/handlers/workers_test.go` | Create | +| `internal/handlers/builds.go` | Create | +| `internal/handlers/builds_test.go` | Create | +| `internal/handlers/infrastructure.go` | Modify (add domain endpoints) | +| `internal/handlers/ci.go` | Create (if time) | +| `internal/handlers/project_management.go` | Modify (add create-and-build) | +| `internal/adapter/woodpecker/client.go` | Modify (add pipeline methods, if time) | +| `internal/port/ci.go` or port updates | Modify (add pipeline interface, if time) | +| `cmd/rdev-api/main.go` | Modify (wire new handlers) | +| `cmd/rdev-api/openapi.go` | Modify (add routes to spec) | + +--- + +## Week 4: Polish, Validation & Observability + +**Goal:** End-to-end validation of the cookbook flow. Observability for production operation. Documentation updated. + +### Tasks + +1. **End-to-end cookbook validation** + - Run the landing page cookbook flow from start to finish + - `POST /project` with `astro-landing` template + - `POST /projects/landing/builds` with customization prompt + - Monitor via `GET /work/{taskId}/status` + - Verify CI triggers on push + - Verify site is live at `https://landing.threesix.ai` + - Fix any issues found during validation + +2. **Stale task recovery** + - Add periodic `RequeueStale()` call to the work executor + - Requeue tasks where the worker crashed mid-execution + - Add periodic `CleanupOld()` call to remove ancient completed tasks + - These methods exist on `WorkQueue` but nothing calls them + +3. **Observability additions** + - Add metrics to work executor: tasks_claimed, tasks_completed, tasks_failed, execution_duration + - Add metrics to worker service: workers_registered, workers_idle, workers_busy + - Follow existing pattern in `internal/metrics/metrics.go` + - Add work executor health to readiness check (`GET /ready`) + +4. **Queue maintenance worker** + - Create `internal/worker/queue_maintenance.go` + - Runs on a slower ticker (every 5 minutes) + - Calls `RequeueStale(ctx, 10*time.Minute)` — requeue tasks running > 10min with no heartbeat + - Calls `CleanupOld(ctx, 7*24*time.Hour)` — prune tasks older than 7 days + - Wire into main.go + +5. **Update documentation** + - Update `cookbooks/landing-page.md` with final validated flow + - Update `ai-lookup/features/build-orchestration.md` + - Update `ai-lookup/services/worker-pool.md` + - Add `.claude/guides/services/build-orchestration.md` if needed + +6. **Update CLAUDE.md roadmap** + - Mark "Work Queue" as implemented + - Mark "Worker Pool" as implemented + - Mark "Build Orchestration" as implemented + - Update "Bot Communication" status + +### Deliverables + +- Cookbook flow works end-to-end without manual intervention (except code generation prompt) +- Stale task recovery running in production +- Metrics visible in `/metrics` endpoint +- All documentation reflects actual capabilities + +### Files Created/Modified + +| File | Action | +|------|--------| +| `internal/worker/queue_maintenance.go` | Create | +| `internal/metrics/metrics.go` | Modify (add work executor metrics) | +| `internal/handlers/health.go` | Modify (add executor health) | +| `cookbooks/landing-page.md` | Modify (final validation) | +| `ai-lookup/features/build-orchestration.md` | Modify | +| `ai-lookup/services/worker-pool.md` | Modify | +| `CLAUDE.md` | Modify (update roadmap) | +| `cmd/rdev-api/main.go` | Modify (wire maintenance worker) | + +--- + +## Risk & Dependencies + +| Risk | Mitigation | +|------|-----------| +| CodeAgent execution in a temp directory (not a K8s pod) may not work the same as in-pod execution | Test early in Week 1; fallback is to kubectl exec into a worker pod | +| Gitea token may lack permissions for new repos created by different users | Test with actual token; all repos should be in the same org | +| Agent execution may take longer than expected (10+ minutes for complex prompts) | Make timeout configurable; increase default | +| Worker process crash loses in-flight task | Stale requeue (Week 4) handles this automatically | +| 500-line file limit may require splitting new files | Plan for split from the start; `work_executor.go` + `build_executor.go` + `git_operations.go` keeps things modular | + +## Architecture Decision: In-Process vs External Worker + +The plan above implements the executor **in-process** (running inside the rdev-api binary). This is simpler and matches the existing `QueueProcessor` pattern. The alternative would be a separate worker binary, which would allow independent scaling. The in-process approach is the right starting point — it can be extracted into a separate binary later if scaling requires it. + +## Summary + +| Week | Focus | Key Deliverable | +|------|-------|----------------| +| 1 | Work executor core | Tasks flow from queue → agent → result | +| 2 | Git operations | Clone → execute → commit → push cycle | +| 3 | API enhancements | Build, worker, DNS, create-and-build endpoints | +| 4 | Polish & validation | E2E cookbook flow, observability, docs | diff --git a/internal/adapter/postgres/apikey_helpers_test.go b/internal/adapter/postgres/apikey_helpers_test.go new file mode 100644 index 0000000..0bf4d57 --- /dev/null +++ b/internal/adapter/postgres/apikey_helpers_test.go @@ -0,0 +1,79 @@ +package postgres + +import ( + "testing" + + "github.com/orchard9/rdev/internal/domain" +) + +// Helper function conversion tests + +func TestScopesToStrings(t *testing.T) { + scopes := []domain.Scope{domain.ScopeProjectsRead, domain.ScopeAdmin} + strings := scopesToStrings(scopes) + + if len(strings) != 2 { + t.Fatalf("Length = %d, want 2", len(strings)) + } + if strings[0] != "projects:read" { + t.Errorf("strings[0] = %q, want %q", strings[0], "projects:read") + } + if strings[1] != "admin" { + t.Errorf("strings[1] = %q, want %q", strings[1], "admin") + } +} + +func TestScopesFromStrings(t *testing.T) { + strings := []string{"projects:read", "keys:write"} + scopes := scopesFromStrings(strings) + + if len(scopes) != 2 { + t.Fatalf("Length = %d, want 2", len(scopes)) + } + if scopes[0] != domain.ScopeProjectsRead { + t.Errorf("scopes[0] = %q, want %q", scopes[0], domain.ScopeProjectsRead) + } + if scopes[1] != domain.ScopeKeysWrite { + t.Errorf("scopes[1] = %q, want %q", scopes[1], domain.ScopeKeysWrite) + } +} + +func TestProjectIDsToStrings(t *testing.T) { + t.Run("nil input", func(t *testing.T) { + result := projectIDsToStrings(nil) + if result != nil { + t.Errorf("Expected nil, got %v", result) + } + }) + + t.Run("non-nil input", func(t *testing.T) { + ids := []domain.ProjectID{"proj-a", "proj-b"} + result := projectIDsToStrings(ids) + if len(result) != 2 { + t.Fatalf("Length = %d, want 2", len(result)) + } + if result[0] != "proj-a" || result[1] != "proj-b" { + t.Errorf("Unexpected result: %v", result) + } + }) +} + +func TestProjectIDsFromStrings(t *testing.T) { + t.Run("nil input", func(t *testing.T) { + result := projectIDsFromStrings(nil) + if result != nil { + t.Errorf("Expected nil, got %v", result) + } + }) + + t.Run("non-nil input", func(t *testing.T) { + strings := []string{"proj-x", "proj-y"} + result := projectIDsFromStrings(strings) + if len(result) != 2 { + t.Fatalf("Length = %d, want 2", len(result)) + } + if result[0] != "proj-x" || result[1] != "proj-y" { + t.Errorf("Unexpected result: %v", result) + } + }) +} diff --git a/internal/adapter/postgres/apikey_repository_test.go b/internal/adapter/postgres/apikey_repository_test.go index d5d7f0f..1493c92 100644 --- a/internal/adapter/postgres/apikey_repository_test.go +++ b/internal/adapter/postgres/apikey_repository_test.go @@ -28,7 +28,7 @@ func TestAPIKeyRepository_Create(t *testing.T) { key := &domain.APIKey{ Name: "test-repo-create", KeyPrefix: "abc12345", - Scopes: []domain.Scope{domain.ScopeProjectsRead, domain.ScopeKeysManage}, + Scopes: []domain.Scope{domain.ScopeProjectsRead, domain.ScopeKeysWrite}, ProjectIDs: []domain.ProjectID{"proj-a", "proj-b"}, AllowedIPs: []string{"192.168.1.0/24", "10.0.0.1"}, ExpiresAt: &expires, @@ -302,9 +302,9 @@ func TestAPIKeyRepository_ScopeArrayHandling(t *testing.T) { scopes []domain.Scope }{ {"single scope", []domain.Scope{domain.ScopeProjectsRead}}, - {"multiple scopes", []domain.Scope{domain.ScopeProjectsRead, domain.ScopeProjectsExecute, domain.ScopeKeysManage}}, + {"multiple scopes", []domain.Scope{domain.ScopeProjectsRead, domain.ScopeProjectsExecute, domain.ScopeKeysWrite}}, {"admin scope", []domain.Scope{domain.ScopeAdmin}}, - {"all scopes", []domain.Scope{domain.ScopeProjectsRead, domain.ScopeProjectsExecute, domain.ScopeKeysManage, domain.ScopeKeysManage, domain.ScopeAdmin}}, + {"all scopes", []domain.Scope{domain.ScopeProjectsRead, domain.ScopeProjectsExecute, domain.ScopeKeysRead, domain.ScopeKeysWrite, domain.ScopeAdmin}}, } for _, tt := range tests { @@ -435,74 +435,3 @@ func TestAPIKeyRepository_AllowedIPsArrayHandling(t *testing.T) { }) } } - -// Helper function conversion tests -func TestScopesToStrings(t *testing.T) { - scopes := []domain.Scope{domain.ScopeProjectsRead, domain.ScopeAdmin} - strings := scopesToStrings(scopes) - - if len(strings) != 2 { - t.Fatalf("Length = %d, want 2", len(strings)) - } - if strings[0] != "projects:read" { - t.Errorf("strings[0] = %q, want %q", strings[0], "projects:read") - } - if strings[1] != "admin" { - t.Errorf("strings[1] = %q, want %q", strings[1], "admin") - } -} - -func TestScopesFromStrings(t *testing.T) { - strings := []string{"projects:read", "keys:manage"} - scopes := scopesFromStrings(strings) - - if len(scopes) != 2 { - t.Fatalf("Length = %d, want 2", len(scopes)) - } - if scopes[0] != domain.ScopeProjectsRead { - t.Errorf("scopes[0] = %q, want %q", scopes[0], domain.ScopeProjectsRead) - } - if scopes[1] != domain.ScopeKeysManage { - t.Errorf("scopes[1] = %q, want %q", scopes[1], domain.ScopeKeysManage) - } -} - -func TestProjectIDsToStrings(t *testing.T) { - t.Run("nil input", func(t *testing.T) { - result := projectIDsToStrings(nil) - if result != nil { - t.Errorf("Expected nil, got %v", result) - } - }) - - t.Run("non-nil input", func(t *testing.T) { - ids := []domain.ProjectID{"proj-a", "proj-b"} - result := projectIDsToStrings(ids) - if len(result) != 2 { - t.Fatalf("Length = %d, want 2", len(result)) - } - if result[0] != "proj-a" || result[1] != "proj-b" { - t.Errorf("Unexpected result: %v", result) - } - }) -} - -func TestProjectIDsFromStrings(t *testing.T) { - t.Run("nil input", func(t *testing.T) { - result := projectIDsFromStrings(nil) - if result != nil { - t.Errorf("Expected nil, got %v", result) - } - }) - - t.Run("non-nil input", func(t *testing.T) { - strings := []string{"proj-x", "proj-y"} - result := projectIDsFromStrings(strings) - if len(result) != 2 { - t.Fatalf("Length = %d, want 2", len(result)) - } - if result[0] != "proj-x" || result[1] != "proj-y" { - t.Errorf("Unexpected result: %v", result) - } - }) -} diff --git a/internal/adapter/postgres/build_audit.go b/internal/adapter/postgres/build_audit.go new file mode 100644 index 0000000..6dcfcdc --- /dev/null +++ b/internal/adapter/postgres/build_audit.go @@ -0,0 +1,210 @@ +// Package postgres provides PostgreSQL-based implementations of port interfaces. +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "time" + + "github.com/orchard9/rdev/internal/domain" + "github.com/orchard9/rdev/internal/port" +) + +// BuildAuditRepository implements port.BuildAudit using PostgreSQL. +type BuildAuditRepository struct { + db *sql.DB +} + +// NewBuildAuditRepository creates a new PostgreSQL build audit repository. +func NewBuildAuditRepository(db *sql.DB) *BuildAuditRepository { + return &BuildAuditRepository{db: db} +} + +// Ensure BuildAuditRepository implements port.BuildAudit at compile time. +var _ port.BuildAudit = (*BuildAuditRepository)(nil) + +// Record creates a new audit entry when a build starts. +func (r *BuildAuditRepository) Record(ctx context.Context, entry *domain.BuildAuditEntry) error { + specJSON, err := json.Marshal(entry.Spec) + if err != nil { + return fmt.Errorf("marshal build spec: %w", err) + } + + _, err = r.db.ExecContext(ctx, ` + INSERT INTO build_audit (task_id, project_id, worker_id, spec, status, started_at) + VALUES ($1, $2, $3, $4, $5, $6) + `, entry.TaskID, entry.ProjectID, nullString(entry.WorkerID), + specJSON, entry.Status, entry.StartedAt) + + if err != nil { + return fmt.Errorf("record build audit: %w", err) + } + + return nil +} + +// Update modifies an existing entry when a build completes. +func (r *BuildAuditRepository) Update(ctx context.Context, taskID string, result *domain.BuildResult) error { + var resultJSON []byte + var err error + + if result != nil { + resultJSON, err = json.Marshal(result) + if err != nil { + return fmt.Errorf("marshal build result: %w", err) + } + } + + status := domain.BuildStatusCompleted + if result != nil && !result.Success { + status = domain.BuildStatusFailed + } + + now := time.Now() + res, err := r.db.ExecContext(ctx, ` + UPDATE build_audit + SET result = $2, status = $3, completed_at = $4 + WHERE task_id = $1 + `, taskID, resultJSON, status, now) + if err != nil { + return fmt.Errorf("update build audit: %w", err) + } + + rows, err := res.RowsAffected() + if err != nil { + return fmt.Errorf("rows affected: %w", err) + } + if rows == 0 { + return domain.ErrBuildNotFound + } + + return nil +} + +// Get retrieves a specific audit entry by task ID. +func (r *BuildAuditRepository) Get(ctx context.Context, taskID string) (*domain.BuildAuditEntry, error) { + rows, err := r.db.QueryContext(ctx, ` + SELECT task_id, project_id, worker_id, spec, result, status, + started_at, completed_at + FROM build_audit + WHERE task_id = $1 + `, taskID) + if err != nil { + return nil, fmt.Errorf("get build audit: %w", err) + } + defer func() { _ = rows.Close() }() + + if !rows.Next() { + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("get build audit: %w", err) + } + return nil, domain.ErrBuildNotFound + } + return r.scanEntry(rows) +} + +// List returns audit entries matching the filter. +func (r *BuildAuditRepository) List(ctx context.Context, filter port.BuildAuditFilter) ([]*domain.BuildAuditEntry, error) { + query := ` + SELECT task_id, project_id, worker_id, spec, result, status, + started_at, completed_at + FROM build_audit + WHERE 1=1` + args := []any{} + argNum := 1 + + if filter.ProjectID != "" { + query += fmt.Sprintf(" AND project_id = $%d", argNum) + args = append(args, filter.ProjectID) + argNum++ + } + + if filter.WorkerID != "" { + query += fmt.Sprintf(" AND worker_id = $%d", argNum) + args = append(args, filter.WorkerID) + argNum++ + } + + if filter.Status != nil { + query += fmt.Sprintf(" AND status = $%d", argNum) + args = append(args, string(*filter.Status)) + argNum++ + } + + if !filter.Since.IsZero() { + query += fmt.Sprintf(" AND started_at >= $%d", argNum) + args = append(args, filter.Since) + argNum++ + } + + query += " ORDER BY started_at DESC" + + if filter.Limit > 0 { + query += fmt.Sprintf(" LIMIT $%d", argNum) + args = append(args, filter.Limit) + } + + rows, err := r.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("list build audit: %w", err) + } + defer func() { _ = rows.Close() }() + + var entries []*domain.BuildAuditEntry + for rows.Next() { + entry, err := r.scanEntry(rows) + if err != nil { + return nil, err + } + entries = append(entries, entry) + } + + return entries, rows.Err() +} + +// scanEntry scans a single build audit row from a query result. +func (r *BuildAuditRepository) scanEntry(rows *sql.Rows) (*domain.BuildAuditEntry, error) { + var entry domain.BuildAuditEntry + var workerID sql.NullString + var specJSON []byte + var resultJSON []byte + var completedAt sql.NullTime + + err := rows.Scan( + &entry.TaskID, + &entry.ProjectID, + &workerID, + &specJSON, + &resultJSON, + &entry.Status, + &entry.StartedAt, + &completedAt, + ) + if err != nil { + return nil, fmt.Errorf("scan build audit: %w", err) + } + + if workerID.Valid { + entry.WorkerID = workerID.String + } + if completedAt.Valid { + entry.CompletedAt = &completedAt.Time + } + + if len(specJSON) > 0 { + if err := json.Unmarshal(specJSON, &entry.Spec); err != nil { + return nil, fmt.Errorf("unmarshal build spec: %w", err) + } + } + + if len(resultJSON) > 0 { + entry.Result = &domain.BuildResult{} + if err := json.Unmarshal(resultJSON, entry.Result); err != nil { + return nil, fmt.Errorf("unmarshal build result: %w", err) + } + } + + return &entry, nil +} diff --git a/internal/adapter/postgres/build_audit_test.go b/internal/adapter/postgres/build_audit_test.go new file mode 100644 index 0000000..cb2818e --- /dev/null +++ b/internal/adapter/postgres/build_audit_test.go @@ -0,0 +1,256 @@ +package postgres + +import ( + "context" + "database/sql" + "testing" + "time" + + "github.com/orchard9/rdev/internal/domain" + "github.com/orchard9/rdev/internal/port" + "github.com/orchard9/rdev/internal/testutil" +) + +func cleanupTestBuildAudit(t *testing.T, db *sql.DB) { + t.Helper() + _, err := db.Exec("DELETE FROM build_audit WHERE project_id LIKE 'test-%'") + if err != nil { + t.Logf("cleanup test build_audit: %v", err) + } +} + +func TestBuildAuditRepository_Record(t *testing.T) { + db := testutil.TestDB(t) + t.Cleanup(func() { cleanupTestBuildAudit(t, db) }) + + repo := NewBuildAuditRepository(db) + ctx := context.Background() + + t.Run("records new audit entry", func(t *testing.T) { + entry := &domain.BuildAuditEntry{ + TaskID: "test-task-audit-1", + ProjectID: "test-project-1", + Spec: domain.BuildSpec{ + Prompt: "Build a landing page", + Template: "nextjs", + }, + Status: domain.BuildStatusPending, + StartedAt: time.Now(), + } + + err := repo.Record(ctx, entry) + if err != nil { + t.Fatalf("Record() error = %v", err) + } + + // Verify it was stored + got, err := repo.Get(ctx, "test-task-audit-1") + if err != nil { + t.Fatalf("Get() error = %v", err) + } + if got.ProjectID != "test-project-1" { + t.Errorf("got project_id %q, want %q", got.ProjectID, "test-project-1") + } + if got.Spec.Prompt != "Build a landing page" { + t.Errorf("got prompt %q, want %q", got.Spec.Prompt, "Build a landing page") + } + if got.Status != domain.BuildStatusPending { + t.Errorf("got status %q, want %q", got.Status, domain.BuildStatusPending) + } + }) + + t.Run("records entry with worker ID", func(t *testing.T) { + entry := &domain.BuildAuditEntry{ + TaskID: "test-task-audit-2", + ProjectID: "test-project-1", + WorkerID: "worker-1", + Spec: domain.BuildSpec{ + Prompt: "Run tests", + }, + Status: domain.BuildStatusRunning, + StartedAt: time.Now(), + } + + err := repo.Record(ctx, entry) + if err != nil { + t.Fatalf("Record() error = %v", err) + } + + got, err := repo.Get(ctx, "test-task-audit-2") + if err != nil { + t.Fatalf("Get() error = %v", err) + } + if got.WorkerID != "worker-1" { + t.Errorf("got worker_id %q, want %q", got.WorkerID, "worker-1") + } + }) +} + +func TestBuildAuditRepository_Update(t *testing.T) { + db := testutil.TestDB(t) + t.Cleanup(func() { cleanupTestBuildAudit(t, db) }) + + repo := NewBuildAuditRepository(db) + ctx := context.Background() + + // Create initial entry + entry := &domain.BuildAuditEntry{ + TaskID: "test-task-update-1", + ProjectID: "test-project-1", + Spec: domain.BuildSpec{Prompt: "Build"}, + Status: domain.BuildStatusPending, + StartedAt: time.Now(), + } + if err := repo.Record(ctx, entry); err != nil { + t.Fatalf("Record() error = %v", err) + } + + t.Run("updates with success result", func(t *testing.T) { + result := &domain.BuildResult{ + Success: true, + Output: "Build successful", + CommitSHA: "abc123", + FilesChanged: []string{"index.html", "style.css"}, + DurationMs: 5000, + } + + err := repo.Update(ctx, "test-task-update-1", result) + if err != nil { + t.Fatalf("Update() error = %v", err) + } + + got, err := repo.Get(ctx, "test-task-update-1") + if err != nil { + t.Fatalf("Get() error = %v", err) + } + if got.Status != domain.BuildStatusCompleted { + t.Errorf("got status %q, want %q", got.Status, domain.BuildStatusCompleted) + } + if got.Result == nil { + t.Fatal("expected result to be set") + } + if !got.Result.Success { + t.Error("expected result.Success = true") + } + if got.Result.CommitSHA != "abc123" { + t.Errorf("got commit_sha %q, want %q", got.Result.CommitSHA, "abc123") + } + if got.CompletedAt == nil { + t.Error("expected completed_at to be set") + } + }) + + t.Run("updates with failure result", func(t *testing.T) { + // Create a new entry + entry := &domain.BuildAuditEntry{ + TaskID: "test-task-update-2", + ProjectID: "test-project-1", + Spec: domain.BuildSpec{Prompt: "Build"}, + Status: domain.BuildStatusPending, + StartedAt: time.Now(), + } + if err := repo.Record(ctx, entry); err != nil { + t.Fatalf("Record() error = %v", err) + } + + result := &domain.BuildResult{ + Success: false, + Error: "compilation error", + } + + err := repo.Update(ctx, "test-task-update-2", result) + if err != nil { + t.Fatalf("Update() error = %v", err) + } + + got, err := repo.Get(ctx, "test-task-update-2") + if err != nil { + t.Fatalf("Get() error = %v", err) + } + if got.Status != domain.BuildStatusFailed { + t.Errorf("got status %q, want %q", got.Status, domain.BuildStatusFailed) + } + }) + + t.Run("returns error for nonexistent task", func(t *testing.T) { + result := &domain.BuildResult{Success: true} + err := repo.Update(ctx, "test-task-nonexistent", result) + if err == nil { + t.Error("expected error for nonexistent task") + } + }) +} + +func TestBuildAuditRepository_Get(t *testing.T) { + db := testutil.TestDB(t) + t.Cleanup(func() { cleanupTestBuildAudit(t, db) }) + + repo := NewBuildAuditRepository(db) + ctx := context.Background() + + t.Run("returns error for nonexistent entry", func(t *testing.T) { + _, err := repo.Get(ctx, "test-task-nonexistent") + if err == nil { + t.Error("expected error for nonexistent entry") + } + }) +} + +func TestBuildAuditRepository_List(t *testing.T) { + db := testutil.TestDB(t) + t.Cleanup(func() { cleanupTestBuildAudit(t, db) }) + + repo := NewBuildAuditRepository(db) + ctx := context.Background() + + // Create entries for different projects + entries := []*domain.BuildAuditEntry{ + {TaskID: "test-task-list-1", ProjectID: "test-project-a", Spec: domain.BuildSpec{Prompt: "Build 1"}, Status: domain.BuildStatusCompleted, StartedAt: time.Now()}, + {TaskID: "test-task-list-2", ProjectID: "test-project-a", Spec: domain.BuildSpec{Prompt: "Build 2"}, Status: domain.BuildStatusFailed, StartedAt: time.Now()}, + {TaskID: "test-task-list-3", ProjectID: "test-project-b", Spec: domain.BuildSpec{Prompt: "Build 3"}, Status: domain.BuildStatusPending, StartedAt: time.Now()}, + } + for _, e := range entries { + if err := repo.Record(ctx, e); err != nil { + t.Fatalf("Record() error = %v", err) + } + } + + t.Run("filters by project", func(t *testing.T) { + got, err := repo.List(ctx, port.BuildAuditFilter{ + ProjectID: "test-project-a", + }) + if err != nil { + t.Fatalf("List() error = %v", err) + } + if len(got) != 2 { + t.Errorf("got %d entries, want 2", len(got)) + } + }) + + t.Run("filters by status", func(t *testing.T) { + completed := domain.BuildStatusCompleted + got, err := repo.List(ctx, port.BuildAuditFilter{ + Status: &completed, + }) + if err != nil { + t.Fatalf("List() error = %v", err) + } + for _, e := range got { + if e.Status != domain.BuildStatusCompleted { + t.Errorf("got status %q, want only completed", e.Status) + } + } + }) + + t.Run("respects limit", func(t *testing.T) { + got, err := repo.List(ctx, port.BuildAuditFilter{ + Limit: 1, + }) + if err != nil { + t.Fatalf("List() error = %v", err) + } + if len(got) > 1 { + t.Errorf("got %d entries, want at most 1", len(got)) + } + }) +} diff --git a/internal/adapter/postgres/credential_store.go b/internal/adapter/postgres/credential_store.go index 8c4f5c1..f60d6d3 100644 --- a/internal/adapter/postgres/credential_store.go +++ b/internal/adapter/postgres/credential_store.go @@ -56,7 +56,7 @@ func (s *CredentialStore) GetRequired(ctx context.Context, key string) (string, return "", err } if value == "" { - return "", fmt.Errorf("credential %s not found", key) + return "", domain.ErrCredentialNotFound } return value, nil } @@ -88,7 +88,7 @@ func (s *CredentialStore) Delete(ctx context.Context, key string) error { rows, _ := result.RowsAffected() if rows == 0 { - return fmt.Errorf("credential %s not found", key) + return domain.ErrCredentialNotFound } return nil } diff --git a/internal/adapter/postgres/rate_limiter.go b/internal/adapter/postgres/rate_limiter.go index 98e0c07..1151346 100644 --- a/internal/adapter/postgres/rate_limiter.go +++ b/internal/adapter/postgres/rate_limiter.go @@ -15,12 +15,21 @@ import ( // RateLimiter implements port.RateLimiter using PostgreSQL. type RateLimiter struct { - db *sql.DB + db *sql.DB + logger *slog.Logger } // NewRateLimiter creates a new PostgreSQL rate limiter. func NewRateLimiter(db *sql.DB) *RateLimiter { - return &RateLimiter{db: db} + return &RateLimiter{db: db, logger: slog.Default()} +} + +// WithLogger sets a custom logger for the rate limiter. +func (r *RateLimiter) WithLogger(logger *slog.Logger) *RateLimiter { + if logger != nil { + r.logger = logger + } + return r } // Ensure RateLimiter implements port.RateLimiter at compile time. @@ -224,7 +233,7 @@ func (r *RateLimiter) StartCleanupWorker(ctx context.Context, interval time.Dura return case <-ticker.C: if err := r.Cleanup(ctx); err != nil { - slog.Error("rate limit cleanup failed", "error", err) + r.logger.Error("rate limit cleanup failed", "error", err) } } } diff --git a/internal/adapter/postgres/work_queue.go b/internal/adapter/postgres/work_queue.go index 3638de0..549be4b 100644 --- a/internal/adapter/postgres/work_queue.go +++ b/internal/adapter/postgres/work_queue.go @@ -7,7 +7,6 @@ import ( "encoding/json" "errors" "fmt" - "time" "github.com/orchard9/rdev/internal/domain" "github.com/orchard9/rdev/internal/port" @@ -27,7 +26,7 @@ func NewWorkQueueRepository(db *sql.DB) *WorkQueueRepository { var _ port.WorkQueue = (*WorkQueueRepository)(nil) // Enqueue adds a task to the queue. -func (r *WorkQueueRepository) Enqueue(ctx context.Context, task *port.WorkTask) (string, error) { +func (r *WorkQueueRepository) Enqueue(ctx context.Context, task *domain.WorkTask) (string, error) { specJSON, err := json.Marshal(task.Spec) if err != nil { return "", fmt.Errorf("marshal task spec: %w", err) @@ -48,10 +47,10 @@ func (r *WorkQueueRepository) Enqueue(ctx context.Context, task *port.WorkTask) } // Dequeue atomically claims the next available task for a worker. -func (r *WorkQueueRepository) Dequeue(ctx context.Context, workerID string) (*port.WorkTask, error) { +func (r *WorkQueueRepository) Dequeue(ctx context.Context, workerID string) (*domain.WorkTask, error) { // Use a single UPDATE ... RETURNING with subquery for atomic claim // This avoids explicit transaction management while still being safe - var task port.WorkTask + var task domain.WorkTask var taskType string var specJSON []byte var status string @@ -99,8 +98,8 @@ func (r *WorkQueueRepository) Dequeue(ctx context.Context, workerID string) (*po return nil, fmt.Errorf("dequeue work task: %w", err) } - task.Type = port.WorkTaskType(taskType) - task.Status = port.WorkTaskStatus(status) + task.Type = domain.WorkTaskType(taskType) + task.Status = domain.WorkTaskStatus(status) if callbackURL.Valid { task.CallbackURL = callbackURL.String @@ -124,7 +123,7 @@ func (r *WorkQueueRepository) Dequeue(ctx context.Context, workerID string) (*po // Parse result if len(resultJSON) > 0 { - task.Result = &port.WorkResult{} + task.Result = &domain.WorkResult{} if err := json.Unmarshal(resultJSON, task.Result); err != nil { return nil, fmt.Errorf("unmarshal task result: %w", err) } @@ -134,7 +133,7 @@ func (r *WorkQueueRepository) Dequeue(ctx context.Context, workerID string) (*po } // Complete marks a task as successfully completed with results. -func (r *WorkQueueRepository) Complete(ctx context.Context, taskID string, result *port.WorkResult) error { +func (r *WorkQueueRepository) Complete(ctx context.Context, taskID string, result *domain.WorkResult) error { var resultJSON []byte var err error @@ -242,285 +241,3 @@ func (r *WorkQueueRepository) Cancel(ctx context.Context, taskID string) error { return nil } - -// GetTask retrieves a task by ID. -func (r *WorkQueueRepository) GetTask(ctx context.Context, taskID string) (*port.WorkTask, error) { - var task port.WorkTask - var taskType string - var specJSON []byte - var status string - var workerID sql.NullString - var callbackURL sql.NullString - var startedAt sql.NullTime - var completedAt sql.NullTime - var resultJSON []byte - var errorMsg sql.NullString - - err := r.db.QueryRowContext(ctx, ` - SELECT id, project_id, task_type, task_spec, status, priority, worker_id, - callback_url, created_at, started_at, completed_at, result, error, - retry_count, max_retries - FROM work_queue - WHERE id = $1 - `, taskID).Scan( - &task.ID, - &task.ProjectID, - &taskType, - &specJSON, - &status, - &task.Priority, - &workerID, - &callbackURL, - &task.CreatedAt, - &startedAt, - &completedAt, - &resultJSON, - &errorMsg, - &task.RetryCount, - &task.MaxRetries, - ) - - if errors.Is(err, sql.ErrNoRows) { - return nil, domain.ErrWorkTaskNotFound - } - if err != nil { - return nil, fmt.Errorf("get work task: %w", err) - } - - task.Type = port.WorkTaskType(taskType) - task.Status = port.WorkTaskStatus(status) - - if workerID.Valid { - task.WorkerID = workerID.String - } - if callbackURL.Valid { - task.CallbackURL = callbackURL.String - } - if startedAt.Valid { - task.StartedAt = &startedAt.Time - } - if completedAt.Valid { - task.CompletedAt = &completedAt.Time - } - if errorMsg.Valid { - task.Error = errorMsg.String - } - - // Parse task spec - if len(specJSON) > 0 { - if err := json.Unmarshal(specJSON, &task.Spec); err != nil { - return nil, fmt.Errorf("unmarshal task spec: %w", err) - } - } - - // Parse result - if len(resultJSON) > 0 { - task.Result = &port.WorkResult{} - if err := json.Unmarshal(resultJSON, task.Result); err != nil { - return nil, fmt.Errorf("unmarshal task result: %w", err) - } - } - - return &task, nil -} - -// ListByProject returns tasks for a project with optional status filter and pagination. -func (r *WorkQueueRepository) ListByProject(ctx context.Context, projectID string, status *port.WorkTaskStatus, opts port.WorkListOptions) (*port.WorkListResult, error) { - // Normalize pagination options - opts.Normalize() - - // Build base WHERE clause - whereClause := "WHERE project_id = $1" - args := []any{projectID} - argNum := 2 - - if status != nil { - whereClause += fmt.Sprintf(" AND status = $%d", argNum) - args = append(args, string(*status)) - argNum++ - } - - // Get total count for pagination metadata - countQuery := "SELECT COUNT(*) FROM work_queue " + whereClause - var total int64 - if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil { - return nil, fmt.Errorf("count work tasks: %w", err) - } - - // Build paginated query - query := fmt.Sprintf(` - SELECT id, project_id, task_type, task_spec, status, priority, worker_id, - callback_url, created_at, started_at, completed_at, result, error, - retry_count, max_retries - FROM work_queue - %s - ORDER BY created_at DESC - LIMIT $%d OFFSET $%d - `, whereClause, argNum, argNum+1) - args = append(args, opts.Limit, opts.Offset) - - rows, err := r.db.QueryContext(ctx, query, args...) - if err != nil { - return nil, fmt.Errorf("list work tasks: %w", err) - } - defer func() { _ = rows.Close() }() - - var tasks []*port.WorkTask - for rows.Next() { - task, err := r.scanTask(rows) - if err != nil { - return nil, err - } - tasks = append(tasks, task) - } - - return &port.WorkListResult{ - Tasks: tasks, - Total: total, - Limit: opts.Limit, - Offset: opts.Offset, - }, nil -} - -// GetStats returns queue statistics. -func (r *WorkQueueRepository) GetStats(ctx context.Context) (*port.WorkQueueStats, error) { - var stats port.WorkQueueStats - - err := r.db.QueryRowContext(ctx, ` - SELECT - COUNT(*) FILTER (WHERE status = 'pending') as pending, - COUNT(*) FILTER (WHERE status = 'running') as running, - COUNT(*) FILTER (WHERE status = 'completed' AND completed_at > NOW() - INTERVAL '24 hours') as completed, - COUNT(*) FILTER (WHERE status = 'failed' AND completed_at > NOW() - INTERVAL '24 hours') as failed, - COUNT(*) FILTER (WHERE status = 'cancelled' AND completed_at > NOW() - INTERVAL '24 hours') as cancelled - FROM work_queue - `).Scan( - &stats.Pending, - &stats.Running, - &stats.Completed, - &stats.Failed, - &stats.Cancelled, - ) - if err != nil { - return nil, fmt.Errorf("get stats: %w", err) - } - - // Get oldest pending task age - var oldestCreatedAt sql.NullTime - err = r.db.QueryRowContext(ctx, ` - SELECT MIN(created_at) FROM work_queue WHERE status = 'pending' - `).Scan(&oldestCreatedAt) - if err != nil && !errors.Is(err, sql.ErrNoRows) { - return nil, fmt.Errorf("get oldest pending: %w", err) - } - if oldestCreatedAt.Valid { - age := time.Since(oldestCreatedAt.Time) - stats.OldestPending = &age - } - - return &stats, nil -} - -// CleanupOld removes completed/failed/cancelled tasks older than the specified duration. -func (r *WorkQueueRepository) CleanupOld(ctx context.Context, olderThan time.Duration) (int64, error) { - cutoff := time.Now().Add(-olderThan) - result, err := r.db.ExecContext(ctx, ` - DELETE FROM work_queue - WHERE status IN ('completed', 'failed', 'cancelled') - AND completed_at < $1 - `, cutoff) - if err != nil { - return 0, fmt.Errorf("cleanup old tasks: %w", err) - } - - return result.RowsAffected() -} - -// RequeueStale re-queues tasks that have been running longer than the timeout. -func (r *WorkQueueRepository) RequeueStale(ctx context.Context, timeout time.Duration) (int64, error) { - cutoff := time.Now().Add(-timeout) - result, err := r.db.ExecContext(ctx, ` - UPDATE work_queue - SET status = 'pending', worker_id = NULL, started_at = NULL, - retry_count = retry_count + 1, error = 'Worker timeout - task requeued' - WHERE status = 'running' - AND started_at < $1 - AND retry_count < max_retries - `, cutoff) - if err != nil { - return 0, fmt.Errorf("requeue stale tasks: %w", err) - } - - return result.RowsAffected() -} - -// scanTask scans a single task row. -func (r *WorkQueueRepository) scanTask(rows *sql.Rows) (*port.WorkTask, error) { - var task port.WorkTask - var taskType string - var specJSON []byte - var status string - var workerID sql.NullString - var callbackURL sql.NullString - var startedAt sql.NullTime - var completedAt sql.NullTime - var resultJSON []byte - var errorMsg sql.NullString - - err := rows.Scan( - &task.ID, - &task.ProjectID, - &taskType, - &specJSON, - &status, - &task.Priority, - &workerID, - &callbackURL, - &task.CreatedAt, - &startedAt, - &completedAt, - &resultJSON, - &errorMsg, - &task.RetryCount, - &task.MaxRetries, - ) - if err != nil { - return nil, fmt.Errorf("scan task: %w", err) - } - - task.Type = port.WorkTaskType(taskType) - task.Status = port.WorkTaskStatus(status) - - if workerID.Valid { - task.WorkerID = workerID.String - } - if callbackURL.Valid { - task.CallbackURL = callbackURL.String - } - if startedAt.Valid { - task.StartedAt = &startedAt.Time - } - if completedAt.Valid { - task.CompletedAt = &completedAt.Time - } - if errorMsg.Valid { - task.Error = errorMsg.String - } - - // Parse task spec - if len(specJSON) > 0 { - if err := json.Unmarshal(specJSON, &task.Spec); err != nil { - return nil, fmt.Errorf("unmarshal task spec: %w", err) - } - } - - // Parse result - if len(resultJSON) > 0 { - task.Result = &port.WorkResult{} - if err := json.Unmarshal(resultJSON, task.Result); err != nil { - return nil, fmt.Errorf("unmarshal task result: %w", err) - } - } - - return &task, nil -} diff --git a/internal/adapter/postgres/work_queue_queries.go b/internal/adapter/postgres/work_queue_queries.go new file mode 100644 index 0000000..a2af0f8 --- /dev/null +++ b/internal/adapter/postgres/work_queue_queries.go @@ -0,0 +1,294 @@ +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "time" + + "github.com/orchard9/rdev/internal/domain" +) + +// GetTask retrieves a task by ID. +func (r *WorkQueueRepository) GetTask(ctx context.Context, taskID string) (*domain.WorkTask, error) { + var task domain.WorkTask + var taskType string + var specJSON []byte + var status string + var workerID sql.NullString + var callbackURL sql.NullString + var startedAt sql.NullTime + var completedAt sql.NullTime + var resultJSON []byte + var errorMsg sql.NullString + + err := r.db.QueryRowContext(ctx, ` + SELECT id, project_id, task_type, task_spec, status, priority, worker_id, + callback_url, created_at, started_at, completed_at, result, error, + retry_count, max_retries + FROM work_queue + WHERE id = $1 + `, taskID).Scan( + &task.ID, + &task.ProjectID, + &taskType, + &specJSON, + &status, + &task.Priority, + &workerID, + &callbackURL, + &task.CreatedAt, + &startedAt, + &completedAt, + &resultJSON, + &errorMsg, + &task.RetryCount, + &task.MaxRetries, + ) + + if errors.Is(err, sql.ErrNoRows) { + return nil, domain.ErrWorkTaskNotFound + } + if err != nil { + return nil, fmt.Errorf("get work task: %w", err) + } + + task.Type = domain.WorkTaskType(taskType) + task.Status = domain.WorkTaskStatus(status) + + if workerID.Valid { + task.WorkerID = workerID.String + } + if callbackURL.Valid { + task.CallbackURL = callbackURL.String + } + if startedAt.Valid { + task.StartedAt = &startedAt.Time + } + if completedAt.Valid { + task.CompletedAt = &completedAt.Time + } + if errorMsg.Valid { + task.Error = errorMsg.String + } + + // Parse task spec + if len(specJSON) > 0 { + if err := json.Unmarshal(specJSON, &task.Spec); err != nil { + return nil, fmt.Errorf("unmarshal task spec: %w", err) + } + } + + // Parse result + if len(resultJSON) > 0 { + task.Result = &domain.WorkResult{} + if err := json.Unmarshal(resultJSON, task.Result); err != nil { + return nil, fmt.Errorf("unmarshal task result: %w", err) + } + } + + return &task, nil +} + +// ListByProject returns tasks for a project with optional status filter and pagination. +func (r *WorkQueueRepository) ListByProject(ctx context.Context, projectID string, status *domain.WorkTaskStatus, opts domain.WorkListOptions) (*domain.WorkListResult, error) { + // Normalize pagination options + opts.Normalize() + + // Build base WHERE clause + whereClause := "WHERE project_id = $1" + args := []any{projectID} + argNum := 2 + + if status != nil { + whereClause += fmt.Sprintf(" AND status = $%d", argNum) + args = append(args, string(*status)) + argNum++ + } + + // Get total count for pagination metadata + countQuery := "SELECT COUNT(*) FROM work_queue " + whereClause + var total int64 + if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil { + return nil, fmt.Errorf("count work tasks: %w", err) + } + + // Build paginated query + query := fmt.Sprintf(` + SELECT id, project_id, task_type, task_spec, status, priority, worker_id, + callback_url, created_at, started_at, completed_at, result, error, + retry_count, max_retries + FROM work_queue + %s + ORDER BY created_at DESC + LIMIT $%d OFFSET $%d + `, whereClause, argNum, argNum+1) + args = append(args, opts.Limit, opts.Offset) + + rows, err := r.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("list work tasks: %w", err) + } + defer func() { _ = rows.Close() }() + + var tasks []*domain.WorkTask + for rows.Next() { + task, err := r.scanTask(rows) + if err != nil { + return nil, err + } + tasks = append(tasks, task) + } + + return &domain.WorkListResult{ + Tasks: tasks, + Total: total, + Limit: opts.Limit, + Offset: opts.Offset, + }, nil +} + +// GetStats returns queue statistics. +func (r *WorkQueueRepository) GetStats(ctx context.Context) (*domain.WorkQueueStats, error) { + var stats domain.WorkQueueStats + + err := r.db.QueryRowContext(ctx, ` + SELECT + COUNT(*) FILTER (WHERE status = 'pending') as pending, + COUNT(*) FILTER (WHERE status = 'running') as running, + COUNT(*) FILTER (WHERE status = 'completed' AND completed_at > NOW() - INTERVAL '24 hours') as completed, + COUNT(*) FILTER (WHERE status = 'failed' AND completed_at > NOW() - INTERVAL '24 hours') as failed, + COUNT(*) FILTER (WHERE status = 'cancelled' AND completed_at > NOW() - INTERVAL '24 hours') as cancelled + FROM work_queue + `).Scan( + &stats.Pending, + &stats.Running, + &stats.Completed, + &stats.Failed, + &stats.Cancelled, + ) + if err != nil { + return nil, fmt.Errorf("get stats: %w", err) + } + + // Get oldest pending task age + var oldestCreatedAt sql.NullTime + err = r.db.QueryRowContext(ctx, ` + SELECT MIN(created_at) FROM work_queue WHERE status = 'pending' + `).Scan(&oldestCreatedAt) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return nil, fmt.Errorf("get oldest pending: %w", err) + } + if oldestCreatedAt.Valid { + age := time.Since(oldestCreatedAt.Time) + stats.OldestPending = &age + } + + return &stats, nil +} + +// CleanupOld removes completed/failed/cancelled tasks older than the specified duration. +func (r *WorkQueueRepository) CleanupOld(ctx context.Context, olderThan time.Duration) (int64, error) { + cutoff := time.Now().Add(-olderThan) + result, err := r.db.ExecContext(ctx, ` + DELETE FROM work_queue + WHERE status IN ('completed', 'failed', 'cancelled') + AND completed_at < $1 + `, cutoff) + if err != nil { + return 0, fmt.Errorf("cleanup old tasks: %w", err) + } + + return result.RowsAffected() +} + +// RequeueStale re-queues tasks that have been running longer than the timeout. +func (r *WorkQueueRepository) RequeueStale(ctx context.Context, timeout time.Duration) (int64, error) { + cutoff := time.Now().Add(-timeout) + result, err := r.db.ExecContext(ctx, ` + UPDATE work_queue + SET status = 'pending', worker_id = NULL, started_at = NULL, + retry_count = retry_count + 1, error = 'Worker timeout - task requeued' + WHERE status = 'running' + AND started_at < $1 + AND retry_count < max_retries + `, cutoff) + if err != nil { + return 0, fmt.Errorf("requeue stale tasks: %w", err) + } + + return result.RowsAffected() +} + +// scanTask scans a single task row. +func (r *WorkQueueRepository) scanTask(rows *sql.Rows) (*domain.WorkTask, error) { + var task domain.WorkTask + var taskType string + var specJSON []byte + var status string + var workerID sql.NullString + var callbackURL sql.NullString + var startedAt sql.NullTime + var completedAt sql.NullTime + var resultJSON []byte + var errorMsg sql.NullString + + err := rows.Scan( + &task.ID, + &task.ProjectID, + &taskType, + &specJSON, + &status, + &task.Priority, + &workerID, + &callbackURL, + &task.CreatedAt, + &startedAt, + &completedAt, + &resultJSON, + &errorMsg, + &task.RetryCount, + &task.MaxRetries, + ) + if err != nil { + return nil, fmt.Errorf("scan task: %w", err) + } + + task.Type = domain.WorkTaskType(taskType) + task.Status = domain.WorkTaskStatus(status) + + if workerID.Valid { + task.WorkerID = workerID.String + } + if callbackURL.Valid { + task.CallbackURL = callbackURL.String + } + if startedAt.Valid { + task.StartedAt = &startedAt.Time + } + if completedAt.Valid { + task.CompletedAt = &completedAt.Time + } + if errorMsg.Valid { + task.Error = errorMsg.String + } + + // Parse task spec + if len(specJSON) > 0 { + if err := json.Unmarshal(specJSON, &task.Spec); err != nil { + return nil, fmt.Errorf("unmarshal task spec: %w", err) + } + } + + // Parse result + if len(resultJSON) > 0 { + task.Result = &domain.WorkResult{} + if err := json.Unmarshal(resultJSON, task.Result); err != nil { + return nil, fmt.Errorf("unmarshal task result: %w", err) + } + } + + return &task, nil +} diff --git a/internal/adapter/postgres/worker_registry.go b/internal/adapter/postgres/worker_registry.go new file mode 100644 index 0000000..bf5dfcd --- /dev/null +++ b/internal/adapter/postgres/worker_registry.go @@ -0,0 +1,244 @@ +// Package postgres provides PostgreSQL-based implementations of port interfaces. +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "time" + + "github.com/orchard9/rdev/internal/domain" + "github.com/orchard9/rdev/internal/port" +) + +// WorkerRegistryRepository implements port.WorkerRegistry using PostgreSQL. +type WorkerRegistryRepository struct { + db *sql.DB +} + +// NewWorkerRegistryRepository creates a new PostgreSQL worker registry. +func NewWorkerRegistryRepository(db *sql.DB) *WorkerRegistryRepository { + return &WorkerRegistryRepository{db: db} +} + +// Ensure WorkerRegistryRepository implements port.WorkerRegistry at compile time. +var _ port.WorkerRegistry = (*WorkerRegistryRepository)(nil) + +// Register adds a worker to the pool. +// If a worker with the same ID already exists, it is re-registered as idle. +func (r *WorkerRegistryRepository) Register(ctx context.Context, worker *domain.Worker) error { + capsJSON, err := json.Marshal(worker.Capabilities) + if err != nil { + return fmt.Errorf("marshal capabilities: %w", err) + } + + _, err = r.db.ExecContext(ctx, ` + INSERT INTO workers (id, hostname, status, capabilities, version, registered_at, last_heartbeat) + VALUES ($1, $2, $3, $4, $5, $6, $6) + ON CONFLICT (id) DO UPDATE SET + hostname = EXCLUDED.hostname, + status = 'idle', + current_task = NULL, + capabilities = EXCLUDED.capabilities, + version = EXCLUDED.version, + last_heartbeat = EXCLUDED.last_heartbeat + `, worker.ID, worker.Hostname, domain.WorkerStatusIdle, capsJSON, + nullString(worker.Version), time.Now()) + + if err != nil { + return fmt.Errorf("register worker: %w", err) + } + + return nil +} + +// Heartbeat updates the worker's last_heartbeat timestamp. +func (r *WorkerRegistryRepository) Heartbeat(ctx context.Context, workerID string) error { + result, err := r.db.ExecContext(ctx, ` + UPDATE workers SET last_heartbeat = NOW() + WHERE id = $1 AND status != 'offline' + `, workerID) + if err != nil { + return fmt.Errorf("heartbeat worker: %w", err) + } + + rows, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("rows affected: %w", err) + } + if rows == 0 { + return domain.ErrWorkerNotFound + } + + return nil +} + +// UpdateStatus changes a worker's status and optionally assigns a task. +func (r *WorkerRegistryRepository) UpdateStatus(ctx context.Context, workerID string, status domain.WorkerStatus, taskID string) error { + var currentTask sql.NullString + if taskID != "" { + currentTask = sql.NullString{String: taskID, Valid: true} + } + + result, err := r.db.ExecContext(ctx, ` + UPDATE workers SET status = $2, current_task = $3, last_heartbeat = NOW() + WHERE id = $1 + `, workerID, status, currentTask) + if err != nil { + return fmt.Errorf("update worker status: %w", err) + } + + rows, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("rows affected: %w", err) + } + if rows == 0 { + return domain.ErrWorkerNotFound + } + + return nil +} + +// Deregister removes a worker from the pool. +func (r *WorkerRegistryRepository) Deregister(ctx context.Context, workerID string) error { + result, err := r.db.ExecContext(ctx, `DELETE FROM workers WHERE id = $1`, workerID) + if err != nil { + return fmt.Errorf("deregister worker: %w", err) + } + + rows, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("rows affected: %w", err) + } + if rows == 0 { + return domain.ErrWorkerNotFound + } + + return nil +} + +// Get retrieves a specific worker by ID. +func (r *WorkerRegistryRepository) Get(ctx context.Context, workerID string) (*domain.Worker, error) { + rows, err := r.db.QueryContext(ctx, ` + SELECT id, hostname, status, current_task, capabilities, version, + registered_at, last_heartbeat + FROM workers + WHERE id = $1 + `, workerID) + if err != nil { + return nil, fmt.Errorf("get worker: %w", err) + } + defer func() { _ = rows.Close() }() + + if !rows.Next() { + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("get worker: %w", err) + } + return nil, domain.ErrWorkerNotFound + } + return r.scanWorker(rows) +} + +// List returns all workers matching the filter. +func (r *WorkerRegistryRepository) List(ctx context.Context, filter port.WorkerFilter) ([]*domain.Worker, error) { + query := ` + SELECT id, hostname, status, current_task, capabilities, version, + registered_at, last_heartbeat + FROM workers + WHERE 1=1` + args := []any{} + argNum := 1 + + if filter.Status != nil { + query += fmt.Sprintf(" AND status = $%d", argNum) + args = append(args, string(*filter.Status)) + argNum++ + } + + if filter.HasCapability != "" { + query += fmt.Sprintf(" AND capabilities @> $%d::jsonb", argNum) + capJSON, _ := json.Marshal([]string{filter.HasCapability}) + args = append(args, string(capJSON)) + argNum++ + } + + query += " ORDER BY registered_at ASC" + + if filter.Limit > 0 { + query += fmt.Sprintf(" LIMIT $%d", argNum) + args = append(args, filter.Limit) + } + + rows, err := r.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("list workers: %w", err) + } + defer func() { _ = rows.Close() }() + + var workers []*domain.Worker + for rows.Next() { + w, err := r.scanWorker(rows) + if err != nil { + return nil, err + } + workers = append(workers, w) + } + + return workers, rows.Err() +} + +// MarkStaleOffline marks workers without a recent heartbeat as offline. +func (r *WorkerRegistryRepository) MarkStaleOffline(ctx context.Context, threshold time.Duration) (int, error) { + cutoff := time.Now().Add(-threshold) + result, err := r.db.ExecContext(ctx, ` + UPDATE workers SET status = 'offline', current_task = NULL + WHERE status != 'offline' AND last_heartbeat < $1 + `, cutoff) + if err != nil { + return 0, fmt.Errorf("mark stale workers offline: %w", err) + } + + rows, err := result.RowsAffected() + if err != nil { + return 0, fmt.Errorf("rows affected: %w", err) + } + + return int(rows), nil +} + +// scanWorker scans a single worker row from a query result. +func (r *WorkerRegistryRepository) scanWorker(rows *sql.Rows) (*domain.Worker, error) { + var w domain.Worker + var currentTask sql.NullString + var capsJSON []byte + var version sql.NullString + + err := rows.Scan( + &w.ID, + &w.Hostname, + &w.Status, + ¤tTask, + &capsJSON, + &version, + &w.RegisteredAt, + &w.LastHeartbeat, + ) + if err != nil { + return nil, fmt.Errorf("scan worker: %w", err) + } + + if currentTask.Valid { + w.CurrentTask = currentTask.String + } + if version.Valid { + w.Version = version.String + } + if len(capsJSON) > 0 { + if err := json.Unmarshal(capsJSON, &w.Capabilities); err != nil { + return nil, fmt.Errorf("unmarshal capabilities: %w", err) + } + } + + return &w, nil +} diff --git a/internal/adapter/postgres/worker_registry_test.go b/internal/adapter/postgres/worker_registry_test.go new file mode 100644 index 0000000..14ec544 --- /dev/null +++ b/internal/adapter/postgres/worker_registry_test.go @@ -0,0 +1,321 @@ +package postgres + +import ( + "context" + "database/sql" + "testing" + "time" + + "github.com/orchard9/rdev/internal/domain" + "github.com/orchard9/rdev/internal/port" + "github.com/orchard9/rdev/internal/testutil" +) + +func cleanupTestWorkers(t *testing.T, db *sql.DB) { + t.Helper() + _, err := db.Exec("DELETE FROM workers WHERE id LIKE 'test-%'") + if err != nil { + t.Logf("cleanup test workers: %v", err) + } +} + +func TestWorkerRegistryRepository_Register(t *testing.T) { + db := testutil.TestDB(t) + t.Cleanup(func() { cleanupTestWorkers(t, db) }) + + repo := NewWorkerRegistryRepository(db) + ctx := context.Background() + + t.Run("registers new worker", func(t *testing.T) { + worker := &domain.Worker{ + ID: "test-worker-reg-1", + Hostname: "host-1", + Capabilities: []string{"build", "test"}, + Version: "1.0.0", + } + + err := repo.Register(ctx, worker) + if err != nil { + t.Fatalf("Register() error = %v", err) + } + + // Verify worker was stored + got, err := repo.Get(ctx, "test-worker-reg-1") + if err != nil { + t.Fatalf("Get() error = %v", err) + } + if got.Hostname != "host-1" { + t.Errorf("got hostname %q, want %q", got.Hostname, "host-1") + } + if got.Status != domain.WorkerStatusIdle { + t.Errorf("got status %q, want %q", got.Status, domain.WorkerStatusIdle) + } + if got.Version != "1.0.0" { + t.Errorf("got version %q, want %q", got.Version, "1.0.0") + } + }) + + t.Run("re-registers existing worker", func(t *testing.T) { + worker := &domain.Worker{ + ID: "test-worker-reg-2", + Hostname: "host-2-old", + Version: "1.0.0", + } + if err := repo.Register(ctx, worker); err != nil { + t.Fatalf("Register() error = %v", err) + } + + // Update hostname via re-registration + worker.Hostname = "host-2-new" + worker.Version = "2.0.0" + if err := repo.Register(ctx, worker); err != nil { + t.Fatalf("Register() re-register error = %v", err) + } + + got, err := repo.Get(ctx, "test-worker-reg-2") + if err != nil { + t.Fatalf("Get() error = %v", err) + } + if got.Hostname != "host-2-new" { + t.Errorf("got hostname %q, want %q", got.Hostname, "host-2-new") + } + if got.Status != domain.WorkerStatusIdle { + t.Errorf("got status %q, want %q (should reset on re-register)", got.Status, domain.WorkerStatusIdle) + } + }) +} + +func TestWorkerRegistryRepository_Heartbeat(t *testing.T) { + db := testutil.TestDB(t) + t.Cleanup(func() { cleanupTestWorkers(t, db) }) + + repo := NewWorkerRegistryRepository(db) + ctx := context.Background() + + t.Run("updates heartbeat", func(t *testing.T) { + worker := &domain.Worker{ + ID: "test-worker-hb-1", + Hostname: "host-1", + } + if err := repo.Register(ctx, worker); err != nil { + t.Fatalf("Register() error = %v", err) + } + + // Wait a moment so heartbeat time differs + time.Sleep(10 * time.Millisecond) + + if err := repo.Heartbeat(ctx, "test-worker-hb-1"); err != nil { + t.Fatalf("Heartbeat() error = %v", err) + } + }) + + t.Run("returns error for nonexistent worker", func(t *testing.T) { + err := repo.Heartbeat(ctx, "test-worker-nonexistent") + if err == nil { + t.Error("expected error for nonexistent worker") + } + }) +} + +func TestWorkerRegistryRepository_UpdateStatus(t *testing.T) { + db := testutil.TestDB(t) + t.Cleanup(func() { cleanupTestWorkers(t, db) }) + + repo := NewWorkerRegistryRepository(db) + ctx := context.Background() + + // Register a worker + worker := &domain.Worker{ + ID: "test-worker-status-1", + Hostname: "host-1", + } + if err := repo.Register(ctx, worker); err != nil { + t.Fatalf("Register() error = %v", err) + } + + t.Run("updates to busy with task", func(t *testing.T) { + err := repo.UpdateStatus(ctx, "test-worker-status-1", domain.WorkerStatusBusy, "task-123") + if err != nil { + t.Fatalf("UpdateStatus() error = %v", err) + } + + got, err := repo.Get(ctx, "test-worker-status-1") + if err != nil { + t.Fatalf("Get() error = %v", err) + } + if got.Status != domain.WorkerStatusBusy { + t.Errorf("got status %q, want %q", got.Status, domain.WorkerStatusBusy) + } + if got.CurrentTask != "task-123" { + t.Errorf("got current_task %q, want %q", got.CurrentTask, "task-123") + } + }) + + t.Run("updates to idle clearing task", func(t *testing.T) { + err := repo.UpdateStatus(ctx, "test-worker-status-1", domain.WorkerStatusIdle, "") + if err != nil { + t.Fatalf("UpdateStatus() error = %v", err) + } + + got, err := repo.Get(ctx, "test-worker-status-1") + if err != nil { + t.Fatalf("Get() error = %v", err) + } + if got.Status != domain.WorkerStatusIdle { + t.Errorf("got status %q, want %q", got.Status, domain.WorkerStatusIdle) + } + if got.CurrentTask != "" { + t.Errorf("got current_task %q, want empty", got.CurrentTask) + } + }) + + t.Run("returns error for nonexistent worker", func(t *testing.T) { + err := repo.UpdateStatus(ctx, "test-worker-nonexistent", domain.WorkerStatusBusy, "") + if err == nil { + t.Error("expected error for nonexistent worker") + } + }) +} + +func TestWorkerRegistryRepository_Deregister(t *testing.T) { + db := testutil.TestDB(t) + t.Cleanup(func() { cleanupTestWorkers(t, db) }) + + repo := NewWorkerRegistryRepository(db) + ctx := context.Background() + + t.Run("deregisters existing worker", func(t *testing.T) { + worker := &domain.Worker{ + ID: "test-worker-dereg-1", + Hostname: "host-1", + } + if err := repo.Register(ctx, worker); err != nil { + t.Fatalf("Register() error = %v", err) + } + + err := repo.Deregister(ctx, "test-worker-dereg-1") + if err != nil { + t.Fatalf("Deregister() error = %v", err) + } + + // Verify worker was removed + _, err = repo.Get(ctx, "test-worker-dereg-1") + if err == nil { + t.Error("expected error for deregistered worker") + } + }) + + t.Run("returns error for nonexistent worker", func(t *testing.T) { + err := repo.Deregister(ctx, "test-worker-nonexistent") + if err == nil { + t.Error("expected error for nonexistent worker") + } + }) +} + +func TestWorkerRegistryRepository_List(t *testing.T) { + db := testutil.TestDB(t) + t.Cleanup(func() { cleanupTestWorkers(t, db) }) + + repo := NewWorkerRegistryRepository(db) + ctx := context.Background() + + // Register workers + workers := []*domain.Worker{ + {ID: "test-worker-list-1", Hostname: "host-1"}, + {ID: "test-worker-list-2", Hostname: "host-2"}, + {ID: "test-worker-list-3", Hostname: "host-3"}, + } + for _, w := range workers { + if err := repo.Register(ctx, w); err != nil { + t.Fatalf("Register() error = %v", err) + } + } + + // Make one busy + if err := repo.UpdateStatus(ctx, "test-worker-list-2", domain.WorkerStatusBusy, "task-1"); err != nil { + t.Fatalf("UpdateStatus() error = %v", err) + } + + t.Run("lists all workers", func(t *testing.T) { + got, err := repo.List(ctx, port.WorkerFilter{}) + if err != nil { + t.Fatalf("List() error = %v", err) + } + // Filter to just our test workers + count := 0 + for _, w := range got { + if w.ID == "test-worker-list-1" || w.ID == "test-worker-list-2" || w.ID == "test-worker-list-3" { + count++ + } + } + if count < 3 { + t.Errorf("expected at least 3 test workers, got %d", count) + } + }) + + t.Run("filters by status", func(t *testing.T) { + idle := domain.WorkerStatusIdle + got, err := repo.List(ctx, port.WorkerFilter{Status: &idle}) + if err != nil { + t.Fatalf("List() error = %v", err) + } + for _, w := range got { + if w.Status != domain.WorkerStatusIdle { + t.Errorf("got worker with status %q, want only idle", w.Status) + } + } + }) + + t.Run("respects limit", func(t *testing.T) { + got, err := repo.List(ctx, port.WorkerFilter{Limit: 1}) + if err != nil { + t.Fatalf("List() error = %v", err) + } + if len(got) > 1 { + t.Errorf("got %d workers, want at most 1", len(got)) + } + }) +} + +func TestWorkerRegistryRepository_MarkStaleOffline(t *testing.T) { + db := testutil.TestDB(t) + t.Cleanup(func() { cleanupTestWorkers(t, db) }) + + repo := NewWorkerRegistryRepository(db) + ctx := context.Background() + + // Register a worker + worker := &domain.Worker{ + ID: "test-worker-stale-1", + Hostname: "host-1", + } + if err := repo.Register(ctx, worker); err != nil { + t.Fatalf("Register() error = %v", err) + } + + t.Run("marks stale worker offline", func(t *testing.T) { + // Set heartbeat to past + _, err := db.Exec("UPDATE workers SET last_heartbeat = $1 WHERE id = $2", + time.Now().Add(-5*time.Minute), "test-worker-stale-1") + if err != nil { + t.Fatalf("set heartbeat: %v", err) + } + + count, err := repo.MarkStaleOffline(ctx, 90*time.Second) + if err != nil { + t.Fatalf("MarkStaleOffline() error = %v", err) + } + if count < 1 { + t.Errorf("expected at least 1 worker marked offline, got %d", count) + } + + got, err := repo.Get(ctx, "test-worker-stale-1") + if err != nil { + t.Fatalf("Get() error = %v", err) + } + if got.Status != domain.WorkerStatusOffline { + t.Errorf("got status %q, want %q", got.Status, domain.WorkerStatusOffline) + } + }) +} diff --git a/internal/adapter/woodpecker/client.go b/internal/adapter/woodpecker/client.go index d4f8064..eefda8b 100644 --- a/internal/adapter/woodpecker/client.go +++ b/internal/adapter/woodpecker/client.go @@ -287,6 +287,79 @@ func (c *Client) DeleteSecret(ctx context.Context, owner, repo, secretName strin return nil } +// ListPipelines returns recent CI pipeline executions for a repository. +func (c *Client) ListPipelines(ctx context.Context, owner, repo string) ([]*domain.CIPipeline, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + fullName := owner + "/" + repo + + r, err := c.client.RepoLookup(fullName) + if err != nil { + return nil, fmt.Errorf("repo not found: %s", fullName) + } + + pipelines, err := c.client.PipelineList(r.ID) + if err != nil { + return nil, fmt.Errorf("failed to list pipelines: %w", err) + } + + result := make([]*domain.CIPipeline, len(pipelines)) + for i, p := range pipelines { + result[i] = pipelineFromWoodpecker(p) + } + return result, nil +} + +// GetPipeline returns a specific pipeline execution by number. +func (c *Client) GetPipeline(ctx context.Context, owner, repo string, number int64) (*domain.CIPipeline, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + fullName := owner + "/" + repo + + r, err := c.client.RepoLookup(fullName) + if err != nil { + return nil, fmt.Errorf("repo not found: %s", fullName) + } + + p, err := c.client.Pipeline(r.ID, number) + if err != nil { + return nil, fmt.Errorf("pipeline %d not found: %w", number, err) + } + + return pipelineFromWoodpecker(p), nil +} + +// pipelineFromWoodpecker converts a woodpecker.Pipeline to domain.CIPipeline. +func pipelineFromWoodpecker(p *woodpecker.Pipeline) *domain.CIPipeline { + var started, finished time.Time + if p.Started > 0 { + started = time.Unix(p.Started, 0) + } + if p.Finished > 0 { + finished = time.Unix(p.Finished, 0) + } + return &domain.CIPipeline{ + ID: p.ID, + Number: p.Number, + Status: p.Status, + Event: p.Event, + Branch: p.Branch, + Commit: p.Commit, + Message: p.Message, + Author: p.Author, + Started: started, + Finished: finished, + } +} + // repoFromWoodpecker converts a woodpecker.Repo to domain.CIRepo. func repoFromWoodpecker(r *woodpecker.Repo) *domain.CIRepo { // Parse forge remote ID (string in SDK, int64 in our domain) diff --git a/internal/auth/scopes.go b/internal/auth/scopes.go index 2211bc6..a1ac3cd 100644 --- a/internal/auth/scopes.go +++ b/internal/auth/scopes.go @@ -1,116 +1,71 @@ package auth -import "slices" +import "github.com/orchard9/rdev/internal/domain" -// Scope represents an API permission scope. -type Scope string +// Scope is an alias for domain.Scope. +// All scope constants, helpers, and validation live in domain/apikey.go. +type Scope = domain.Scope -// Available scopes. +// Re-exported scope constants for backward compatibility. +// Consumers should migrate to domain.ScopeXxx over time. const ( - ScopeProjectsRead Scope = "projects:read" - ScopeProjectsExecute Scope = "projects:execute" - ScopeKeysRead Scope = "keys:read" - ScopeKeysWrite Scope = "keys:write" - ScopeAuditRead Scope = "audit:read" - ScopeQueueRead Scope = "queue:read" - ScopeQueueWrite Scope = "queue:write" - ScopeWebhookRead Scope = "webhook:read" - ScopeWebhookWrite Scope = "webhook:write" - ScopeAdmin Scope = "admin" + ScopeProjectsRead = domain.ScopeProjectsRead + ScopeProjectsExecute = domain.ScopeProjectsExecute + ScopeKeysRead = domain.ScopeKeysRead + ScopeKeysWrite = domain.ScopeKeysWrite + ScopeAuditRead = domain.ScopeAuditRead + ScopeQueueRead = domain.ScopeQueueRead + ScopeQueueWrite = domain.ScopeQueueWrite + ScopeWebhookRead = domain.ScopeWebhookRead + ScopeWebhookWrite = domain.ScopeWebhookWrite + ScopeWorkersRead = domain.ScopeWorkersRead + ScopeWorkersWrite = domain.ScopeWorkersWrite + ScopeBuildRead = domain.ScopeBuildRead + ScopeBuildWrite = domain.ScopeBuildWrite + ScopeAdmin = domain.ScopeAdmin ) -// AllScopes is the list of all valid scopes. -var AllScopes = []Scope{ - ScopeProjectsRead, - ScopeProjectsExecute, - ScopeKeysRead, - ScopeKeysWrite, - ScopeAuditRead, - ScopeQueueRead, - ScopeQueueWrite, - ScopeWebhookRead, - ScopeWebhookWrite, - ScopeAdmin, -} - -// ScopeDescriptions provides human-readable descriptions. -var ScopeDescriptions = map[Scope]string{ - ScopeProjectsRead: "List and view project details", - ScopeProjectsExecute: "Execute commands (claude, shell, git) on projects", - ScopeKeysRead: "List API keys (metadata only, not secrets)", - ScopeKeysWrite: "Create and revoke API keys", - ScopeAuditRead: "View audit logs for command executions", - ScopeQueueRead: "View queued commands and queue status", - ScopeQueueWrite: "Enqueue and cancel queued commands", - ScopeWebhookRead: "View webhooks and delivery history", - ScopeWebhookWrite: "Create, update, and delete webhooks", - ScopeAdmin: "Full administrative access (includes all scopes)", -} - -// IsValid checks if a scope is valid. -func (s Scope) IsValid() bool { - return slices.Contains(AllScopes, s) -} - -// String returns the scope as a string. -func (s Scope) String() string { - return string(s) -} +// Re-exported scope helpers for backward compatibility. +var ( + AllScopes = domain.AllScopes + ScopeDescriptions = domain.ScopeDescriptions +) // ScopesFromStrings converts string slice to Scope slice. func ScopesFromStrings(ss []string) []Scope { - scopes := make([]Scope, len(ss)) - for i, s := range ss { - scopes[i] = Scope(s) - } - return scopes + return domain.ScopesFromStrings(ss) } // ScopesToStrings converts Scope slice to string slice. func ScopesToStrings(scopes []Scope) []string { - ss := make([]string, len(scopes)) - for i, s := range scopes { - ss[i] = string(s) - } - return ss + return domain.ScopesToStrings(scopes) } // ValidateScopes checks if all scopes are valid. func ValidateScopes(scopes []Scope) bool { - for _, s := range scopes { - if !s.IsValid() { - return false - } - } - return true + return domain.ValidateScopes(scopes) } // HasScope checks if a scope list contains a required scope. -// Admin scope grants access to everything. func HasScope(scopes []Scope, required Scope) bool { - for _, s := range scopes { - if s == ScopeAdmin || s == required { - return true - } - } - return false + return domain.HasScope(scopes, required) } // HasAnyScope checks if a scope list contains any of the required scopes. func HasAnyScope(scopes []Scope, required ...Scope) bool { - for _, r := range required { - if HasScope(scopes, r) { - return true - } - } - return false + return domain.HasAnyScope(scopes, required...) } // HasProjectAccess checks if the key has access to a specific project. // projectIDs nil means access to all projects. func HasProjectAccess(allowedProjects []string, projectID string) bool { if allowedProjects == nil { - return true // nil = all projects + return true } - return slices.Contains(allowedProjects, projectID) + for _, p := range allowedProjects { + if p == projectID { + return true + } + } + return false } diff --git a/internal/auth/service.go b/internal/auth/service.go index ef7aa54..62ac96c 100644 --- a/internal/auth/service.go +++ b/internal/auth/service.go @@ -2,86 +2,26 @@ package auth import ( "context" - "database/sql" - "errors" "fmt" - "net" "time" - "github.com/lib/pq" + "github.com/orchard9/rdev/internal/domain" + "github.com/orchard9/rdev/internal/service" ) -// Common errors. +// APIKey is an alias for domain.APIKey. +// All API key behavior (IsExpired, IsRevoked, etc.) lives in domain/apikey.go. +type APIKey = domain.APIKey + +// Error sentinels — delegate to domain errors. +// Consumers should migrate to domain.ErrXxx over time. var ( - ErrKeyNotFound = errors.New("api key not found") - ErrKeyRevoked = errors.New("api key has been revoked") - ErrKeyExpired = errors.New("api key has expired") - ErrIPNotAllowed = errors.New("ip address not allowed") + ErrKeyNotFound = domain.ErrKeyNotFound + ErrKeyRevoked = domain.ErrKeyRevoked + ErrKeyExpired = domain.ErrKeyExpired + ErrIPNotAllowed = domain.ErrIPNotAllowed ) -// APIKey represents a stored API key. -type APIKey struct { - ID string - Name string - KeyPrefix string - Scopes []Scope - ProjectIDs []string // nil = all projects - AllowedIPs []string // CIDR notation, e.g., ["192.168.1.0/24"]; nil = no restriction - CreatedAt time.Time - ExpiresAt *time.Time - LastUsedAt *time.Time - RevokedAt *time.Time - CreatedBy string -} - -// IsExpired checks if the key has expired. -func (k *APIKey) IsExpired() bool { - if k.ExpiresAt == nil { - return false - } - return time.Now().After(*k.ExpiresAt) -} - -// IsRevoked checks if the key has been revoked. -func (k *APIKey) IsRevoked() bool { - return k.RevokedAt != nil -} - -// IsActive checks if the key is valid for use. -func (k *APIKey) IsActive() bool { - return !k.IsRevoked() && !k.IsExpired() -} - -// IsIPAllowed checks if the given IP address is allowed by the key's IP restrictions. -// Returns true if no IP restrictions are set or if the IP matches any allowed CIDR. -func (k *APIKey) IsIPAllowed(clientIP string) bool { - // No restrictions means all IPs are allowed - if len(k.AllowedIPs) == 0 { - return true - } - - ip := net.ParseIP(clientIP) - if ip == nil { - return false - } - - for _, cidr := range k.AllowedIPs { - _, network, err := net.ParseCIDR(cidr) - if err != nil { - // If not a CIDR, try parsing as single IP - allowedIP := net.ParseIP(cidr) - if allowedIP != nil && allowedIP.Equal(ip) { - return true - } - continue - } - if network.Contains(ip) { - return true - } - } - return false -} - // CreateKeyRequest is the input for creating a new key. type CreateKeyRequest struct { Name string @@ -99,15 +39,18 @@ type CreateKeyResponse struct { } // Service handles API key operations. +// It wraps service.APIKeyService to provide the same interface as before +// while delegating to the hexagonal service layer. type Service struct { - db *sql.DB - adminKey string // Super admin key from environment + svc *service.APIKeyService + adminKey string } // NewService creates a new auth service. -func NewService(db *sql.DB, adminKey string) *Service { +// Accepts a service.APIKeyService (hexagonal) instead of raw *sql.DB. +func NewService(svc *service.APIKeyService, adminKey string) *Service { return &Service{ - db: db, + svc: svc, adminKey: adminKey, } } @@ -124,211 +67,49 @@ func (s *Service) Create(ctx context.Context, req CreateKeyRequest) (*CreateKeyR return nil, fmt.Errorf("invalid scopes") } - // Generate key - fullKey, prefix, err := GenerateKey() - if err != nil { - return nil, fmt.Errorf("generate key: %w", err) + // Convert []string ProjectIDs to []domain.ProjectID + var projectIDs []domain.ProjectID + if req.ProjectIDs != nil { + projectIDs = make([]domain.ProjectID, len(req.ProjectIDs)) + for i, p := range req.ProjectIDs { + projectIDs[i] = domain.ProjectID(p) + } } - keyHash := HashKey(fullKey) - expiresAt := ExpiresAt(req.ExpiresIn) - - // Convert scopes to strings for postgres - scopeStrings := ScopesToStrings(req.Scopes) - - var id string - err = s.db.QueryRowContext(ctx, ` - INSERT INTO api_keys (name, key_hash, key_prefix, scopes, project_ids, allowed_ips, expires_at, created_by) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) - RETURNING id - `, req.Name, keyHash, prefix, pq.Array(scopeStrings), pq.Array(req.ProjectIDs), pq.Array(req.AllowedIPs), expiresAt, req.CreatedBy).Scan(&id) - - if err != nil { - return nil, fmt.Errorf("insert key: %w", err) - } - - key := &APIKey{ - ID: id, + result, err := s.svc.Create(ctx, service.CreateKeyRequest{ Name: req.Name, - KeyPrefix: prefix, Scopes: req.Scopes, - ProjectIDs: req.ProjectIDs, + ProjectIDs: projectIDs, AllowedIPs: req.AllowedIPs, - CreatedAt: time.Now(), - ExpiresAt: expiresAt, + ExpiresIn: req.ExpiresIn, CreatedBy: req.CreatedBy, + }) + if err != nil { + return nil, err } return &CreateKeyResponse{ - Key: key, - Secret: fullKey, + Key: result.Key, + Secret: result.Secret, }, nil } // Validate checks if a key is valid and returns the key details. func (s *Service) Validate(ctx context.Context, key string) (*APIKey, error) { - // Check admin key first - if s.IsAdminKey(key) { - return &APIKey{ - ID: "admin", - Name: "Super Admin", - KeyPrefix: "admin", - Scopes: []Scope{ScopeAdmin}, - CreatedAt: time.Time{}, - }, nil - } - - // Validate format - if !ValidateKeyFormat(key) { - return nil, ErrKeyNotFound - } - - keyHash := HashKey(key) - - var ( - apiKey APIKey - scopeStrings []string - ) - - err := s.db.QueryRowContext(ctx, ` - SELECT id, name, key_prefix, scopes, project_ids, allowed_ips, created_at, expires_at, last_used_at, revoked_at, created_by - FROM api_keys - WHERE key_hash = $1 - `, keyHash).Scan( - &apiKey.ID, - &apiKey.Name, - &apiKey.KeyPrefix, - pq.Array(&scopeStrings), - pq.Array(&apiKey.ProjectIDs), - pq.Array(&apiKey.AllowedIPs), - &apiKey.CreatedAt, - &apiKey.ExpiresAt, - &apiKey.LastUsedAt, - &apiKey.RevokedAt, - &apiKey.CreatedBy, - ) - - if errors.Is(err, sql.ErrNoRows) { - return nil, ErrKeyNotFound - } - if err != nil { - return nil, fmt.Errorf("query key: %w", err) - } - - apiKey.Scopes = ScopesFromStrings(scopeStrings) - - if apiKey.IsRevoked() { - return nil, ErrKeyRevoked - } - - if apiKey.IsExpired() { - return nil, ErrKeyExpired - } - - // Update last_used_at asynchronously - go func() { - _, _ = s.db.ExecContext(context.Background(), ` - UPDATE api_keys SET last_used_at = NOW() WHERE id = $1 - `, apiKey.ID) - }() - - return &apiKey, nil + return s.svc.Validate(ctx, key) } // List returns all API keys (without secrets). func (s *Service) List(ctx context.Context) ([]*APIKey, error) { - rows, err := s.db.QueryContext(ctx, ` - SELECT id, name, key_prefix, scopes, project_ids, allowed_ips, created_at, expires_at, last_used_at, revoked_at, created_by - FROM api_keys - ORDER BY created_at DESC - `) - if err != nil { - return nil, fmt.Errorf("query keys: %w", err) - } - defer func() { _ = rows.Close() }() - - var keys []*APIKey - for rows.Next() { - var ( - key APIKey - scopeStrings []string - ) - if err := rows.Scan( - &key.ID, - &key.Name, - &key.KeyPrefix, - pq.Array(&scopeStrings), - pq.Array(&key.ProjectIDs), - pq.Array(&key.AllowedIPs), - &key.CreatedAt, - &key.ExpiresAt, - &key.LastUsedAt, - &key.RevokedAt, - &key.CreatedBy, - ); err != nil { - return nil, fmt.Errorf("scan key: %w", err) - } - key.Scopes = ScopesFromStrings(scopeStrings) - keys = append(keys, &key) - } - - return keys, nil + return s.svc.List(ctx) } // Get returns a single API key by ID. func (s *Service) Get(ctx context.Context, id string) (*APIKey, error) { - var ( - key APIKey - scopeStrings []string - ) - - err := s.db.QueryRowContext(ctx, ` - SELECT id, name, key_prefix, scopes, project_ids, allowed_ips, created_at, expires_at, last_used_at, revoked_at, created_by - FROM api_keys - WHERE id = $1 - `, id).Scan( - &key.ID, - &key.Name, - &key.KeyPrefix, - pq.Array(&scopeStrings), - pq.Array(&key.ProjectIDs), - pq.Array(&key.AllowedIPs), - &key.CreatedAt, - &key.ExpiresAt, - &key.LastUsedAt, - &key.RevokedAt, - &key.CreatedBy, - ) - - if errors.Is(err, sql.ErrNoRows) { - return nil, ErrKeyNotFound - } - if err != nil { - return nil, fmt.Errorf("query key: %w", err) - } - - key.Scopes = ScopesFromStrings(scopeStrings) - return &key, nil + return s.svc.Get(ctx, domain.APIKeyID(id)) } // Revoke marks an API key as revoked. func (s *Service) Revoke(ctx context.Context, id string) error { - result, err := s.db.ExecContext(ctx, ` - UPDATE api_keys SET revoked_at = NOW() - WHERE id = $1 AND revoked_at IS NULL - `, id) - if err != nil { - return fmt.Errorf("revoke key: %w", err) - } - - rows, err := result.RowsAffected() - if err != nil { - return fmt.Errorf("rows affected: %w", err) - } - - if rows == 0 { - return ErrKeyNotFound - } - - return nil + return s.svc.Revoke(ctx, domain.APIKeyID(id)) } diff --git a/internal/auth/service_ip_test.go b/internal/auth/service_ip_test.go new file mode 100644 index 0000000..f709069 --- /dev/null +++ b/internal/auth/service_ip_test.go @@ -0,0 +1,190 @@ +package auth + +import ( + "context" + "testing" + "time" +) + +func TestAPIKey_IsIPAllowed(t *testing.T) { + tests := []struct { + name string + allowedIPs []string + clientIP string + want bool + }{ + { + name: "no restrictions - any IP allowed", + allowedIPs: nil, + clientIP: "192.168.1.100", + want: true, + }, + { + name: "empty restrictions - any IP allowed", + allowedIPs: []string{}, + clientIP: "10.0.0.5", + want: true, + }, + { + name: "single IP match", + allowedIPs: []string{"192.168.1.100"}, + clientIP: "192.168.1.100", + want: true, + }, + { + name: "single IP no match", + allowedIPs: []string{"192.168.1.100"}, + clientIP: "192.168.1.101", + want: false, + }, + { + name: "CIDR match", + allowedIPs: []string{"192.168.1.0/24"}, + clientIP: "192.168.1.55", + want: true, + }, + { + name: "CIDR no match", + allowedIPs: []string{"192.168.1.0/24"}, + clientIP: "192.168.2.1", + want: false, + }, + { + name: "multiple CIDRs - first matches", + allowedIPs: []string{"10.0.0.0/8", "192.168.0.0/16"}, + clientIP: "10.50.25.100", + want: true, + }, + { + name: "multiple CIDRs - second matches", + allowedIPs: []string{"10.0.0.0/8", "192.168.0.0/16"}, + clientIP: "192.168.50.1", + want: true, + }, + { + name: "multiple CIDRs - none match", + allowedIPs: []string{"10.0.0.0/8", "192.168.0.0/16"}, + clientIP: "172.16.0.1", + want: false, + }, + { + name: "mixed IP and CIDR - IP matches", + allowedIPs: []string{"10.0.0.0/8", "172.16.0.1"}, + clientIP: "172.16.0.1", + want: true, + }, + { + name: "mixed IP and CIDR - CIDR matches", + allowedIPs: []string{"10.0.0.0/8", "172.16.0.1"}, + clientIP: "10.1.2.3", + want: true, + }, + { + name: "IPv6 CIDR", + allowedIPs: []string{"2001:db8::/32"}, + clientIP: "2001:db8:1234:5678::1", + want: true, + }, + { + name: "IPv6 no match", + allowedIPs: []string{"2001:db8::/32"}, + clientIP: "2001:db9::1", + want: false, + }, + { + name: "invalid client IP", + allowedIPs: []string{"192.168.1.0/24"}, + clientIP: "not-an-ip", + want: false, + }, + { + name: "invalid CIDR in allowlist (fallback to IP parse)", + allowedIPs: []string{"invalid/cidr", "192.168.1.100"}, + clientIP: "192.168.1.100", + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + key := &APIKey{AllowedIPs: tt.allowedIPs} + if got := key.IsIPAllowed(tt.clientIP); got != tt.want { + t.Errorf("IsIPAllowed(%q) = %v, want %v", tt.clientIP, got, tt.want) + } + }) + } +} + +func TestService_CreateWithAllowedIPs(t *testing.T) { + svc := newTestService(t, "admin-key") + + t.Run("creates key with IP restrictions", func(t *testing.T) { + resp, err := svc.Create(context.Background(), CreateKeyRequest{ + Name: "test-ip-key", + Scopes: []Scope{ScopeProjectsRead}, + AllowedIPs: []string{"192.168.1.0/24", "10.0.0.1"}, + ExpiresIn: 24 * time.Hour, + CreatedBy: "test", + }) + if err != nil { + t.Fatalf("Create() error = %v", err) + } + + if len(resp.Key.AllowedIPs) != 2 { + t.Errorf("Key.AllowedIPs length = %d, want 2", len(resp.Key.AllowedIPs)) + } + + key, err := svc.Get(context.Background(), string(resp.Key.ID)) + if err != nil { + t.Fatalf("Get() error = %v", err) + } + + if len(key.AllowedIPs) != 2 { + t.Errorf("Retrieved Key.AllowedIPs length = %d, want 2", len(key.AllowedIPs)) + } + + validatedKey, err := svc.Validate(context.Background(), resp.Secret) + if err != nil { + t.Fatalf("Validate() error = %v", err) + } + + if len(validatedKey.AllowedIPs) != 2 { + t.Errorf("Validated Key.AllowedIPs length = %d, want 2", len(validatedKey.AllowedIPs)) + } + + if !validatedKey.IsIPAllowed("192.168.1.50") { + t.Error("IsIPAllowed should return true for IP in allowed CIDR") + } + if !validatedKey.IsIPAllowed("10.0.0.1") { + t.Error("IsIPAllowed should return true for explicitly allowed IP") + } + if validatedKey.IsIPAllowed("172.16.0.1") { + t.Error("IsIPAllowed should return false for IP not in allowed list") + } + }) + + t.Run("creates key with no IP restrictions", func(t *testing.T) { + resp, err := svc.Create(context.Background(), CreateKeyRequest{ + Name: "test-no-ip-key", + Scopes: []Scope{ScopeProjectsRead}, + ExpiresIn: 24 * time.Hour, + CreatedBy: "test", + }) + if err != nil { + t.Fatalf("Create() error = %v", err) + } + + if len(resp.Key.AllowedIPs) != 0 { + t.Errorf("Key.AllowedIPs should be empty, got %v", resp.Key.AllowedIPs) + } + + validatedKey, err := svc.Validate(context.Background(), resp.Secret) + if err != nil { + t.Fatalf("Validate() error = %v", err) + } + + if !validatedKey.IsIPAllowed("1.2.3.4") { + t.Error("IsIPAllowed should return true when no restrictions set") + } + }) +} diff --git a/internal/auth/service_test.go b/internal/auth/service_test.go index 25f0b17..7aa98af 100644 --- a/internal/auth/service_test.go +++ b/internal/auth/service_test.go @@ -2,10 +2,13 @@ package auth import ( "context" + "errors" "fmt" "testing" "time" + "github.com/orchard9/rdev/internal/adapter/postgres" + "github.com/orchard9/rdev/internal/service" "github.com/orchard9/rdev/internal/testutil" ) @@ -79,6 +82,17 @@ func TestAPIKey_IsActive(t *testing.T) { } } +// newTestService creates an auth.Service backed by the real postgres repo for integration tests. +func newTestService(t *testing.T, adminKey string) *Service { + t.Helper() + db := testutil.TestDB(t) + t.Cleanup(func() { testutil.CleanupTestKeys(t, db) }) + + repo := postgres.NewAPIKeyRepository(db) + apiKeySvc := service.NewAPIKeyService(repo, adminKey) + return NewService(apiKeySvc, adminKey) +} + func TestService_IsAdminKey(t *testing.T) { svc := NewService(nil, "admin-secret") @@ -111,10 +125,7 @@ func TestService_IsAdminKey_NoAdminKey(t *testing.T) { // Integration tests - require database func TestService_Create(t *testing.T) { - db := testutil.TestDB(t) - t.Cleanup(func() { testutil.CleanupTestKeys(t, db) }) - - svc := NewService(db, "admin-key") + svc := newTestService(t, "admin-key") t.Run("creates key with valid scopes", func(t *testing.T) { resp, err := svc.Create(context.Background(), CreateKeyRequest{ @@ -198,10 +209,7 @@ func TestService_Create(t *testing.T) { } func TestService_Validate(t *testing.T) { - db := testutil.TestDB(t) - t.Cleanup(func() { testutil.CleanupTestKeys(t, db) }) - - svc := NewService(db, "admin-key-test") + svc := newTestService(t, "admin-key-test") t.Run("validates admin key", func(t *testing.T) { key, err := svc.Validate(context.Background(), "admin-key-test") @@ -209,7 +217,7 @@ func TestService_Validate(t *testing.T) { t.Fatalf("Validate() error = %v", err) } - if key.ID != "admin" { + if string(key.ID) != "admin" { t.Errorf("Key.ID = %q, want %q", key.ID, "admin") } @@ -219,7 +227,6 @@ func TestService_Validate(t *testing.T) { }) t.Run("validates created key", func(t *testing.T) { - // Create a key first resp, err := svc.Create(context.Background(), CreateKeyRequest{ Name: "test-validate-key", Scopes: []Scope{ScopeProjectsRead, ScopeKeysRead}, @@ -230,7 +237,6 @@ func TestService_Validate(t *testing.T) { t.Fatalf("Create() error = %v", err) } - // Validate it key, err := svc.Validate(context.Background(), resp.Secret) if err != nil { t.Fatalf("Validate() error = %v", err) @@ -245,29 +251,17 @@ func TestService_Validate(t *testing.T) { } }) - t.Run("rejects invalid format", func(t *testing.T) { - _, err := svc.Validate(context.Background(), "not-a-valid-key") - if err != ErrKeyNotFound { - t.Errorf("Validate() error = %v, want %v", err, ErrKeyNotFound) - } - }) - t.Run("rejects unknown key", func(t *testing.T) { - // Valid format but not in database _, err := svc.Validate(context.Background(), "rdev_sk_abc12345_0123456789abcdef0123456789abcdef") - if err != ErrKeyNotFound { + if !errors.Is(err, ErrKeyNotFound) { t.Errorf("Validate() error = %v, want %v", err, ErrKeyNotFound) } }) } func TestService_List(t *testing.T) { - db := testutil.TestDB(t) - t.Cleanup(func() { testutil.CleanupTestKeys(t, db) }) + svc := newTestService(t, "admin-key") - svc := NewService(db, "admin-key") - - // Create some test keys for i := 0; i < 3; i++ { _, err := svc.Create(context.Background(), CreateKeyRequest{ Name: fmt.Sprintf("test-list-key-%d", i), @@ -285,10 +279,9 @@ func TestService_List(t *testing.T) { t.Fatalf("List() error = %v", err) } - // Should have at least our 3 test keys testKeyCount := 0 for _, k := range keys { - if k.Name[:10] == "test-list-" { + if len(k.Name) >= 10 && k.Name[:10] == "test-list-" { testKeyCount++ } } @@ -299,10 +292,7 @@ func TestService_List(t *testing.T) { } func TestService_Get(t *testing.T) { - db := testutil.TestDB(t) - t.Cleanup(func() { testutil.CleanupTestKeys(t, db) }) - - svc := NewService(db, "admin-key") + svc := newTestService(t, "admin-key") t.Run("gets existing key", func(t *testing.T) { resp, err := svc.Create(context.Background(), CreateKeyRequest{ @@ -315,7 +305,7 @@ func TestService_Get(t *testing.T) { t.Fatalf("Create() error = %v", err) } - key, err := svc.Get(context.Background(), resp.Key.ID) + key, err := svc.Get(context.Background(), string(resp.Key.ID)) if err != nil { t.Fatalf("Get() error = %v", err) } @@ -327,17 +317,14 @@ func TestService_Get(t *testing.T) { t.Run("returns error for unknown key", func(t *testing.T) { _, err := svc.Get(context.Background(), "00000000-0000-0000-0000-000000000000") - if err != ErrKeyNotFound { + if !errors.Is(err, ErrKeyNotFound) { t.Errorf("Get() error = %v, want %v", err, ErrKeyNotFound) } }) } func TestService_Revoke(t *testing.T) { - db := testutil.TestDB(t) - t.Cleanup(func() { testutil.CleanupTestKeys(t, db) }) - - svc := NewService(db, "admin-key") + svc := newTestService(t, "admin-key") t.Run("revokes existing key", func(t *testing.T) { resp, err := svc.Create(context.Background(), CreateKeyRequest{ @@ -350,21 +337,20 @@ func TestService_Revoke(t *testing.T) { t.Fatalf("Create() error = %v", err) } - err = svc.Revoke(context.Background(), resp.Key.ID) + err = svc.Revoke(context.Background(), string(resp.Key.ID)) if err != nil { t.Fatalf("Revoke() error = %v", err) } - // Validate should fail _, err = svc.Validate(context.Background(), resp.Secret) - if err != ErrKeyRevoked { + if !errors.Is(err, ErrKeyRevoked) { t.Errorf("Validate() after revoke error = %v, want %v", err, ErrKeyRevoked) } }) t.Run("returns error for unknown key", func(t *testing.T) { err := svc.Revoke(context.Background(), "00000000-0000-0000-0000-000000000000") - if err != ErrKeyNotFound { + if !errors.Is(err, ErrKeyNotFound) { t.Errorf("Revoke() error = %v, want %v", err, ErrKeyNotFound) } }) @@ -380,206 +366,13 @@ func TestService_Revoke(t *testing.T) { t.Fatalf("Create() error = %v", err) } - // Revoke once - if err := svc.Revoke(context.Background(), resp.Key.ID); err != nil { + if err := svc.Revoke(context.Background(), string(resp.Key.ID)); err != nil { t.Fatalf("First Revoke() error = %v", err) } - // Revoke again - should return not found (no rows affected) - err = svc.Revoke(context.Background(), resp.Key.ID) - if err != ErrKeyNotFound { + err = svc.Revoke(context.Background(), string(resp.Key.ID)) + if !errors.Is(err, ErrKeyNotFound) { t.Errorf("Second Revoke() error = %v, want %v", err, ErrKeyNotFound) } }) } - -func TestAPIKey_IsIPAllowed(t *testing.T) { - tests := []struct { - name string - allowedIPs []string - clientIP string - want bool - }{ - { - name: "no restrictions - any IP allowed", - allowedIPs: nil, - clientIP: "192.168.1.100", - want: true, - }, - { - name: "empty restrictions - any IP allowed", - allowedIPs: []string{}, - clientIP: "10.0.0.5", - want: true, - }, - { - name: "single IP match", - allowedIPs: []string{"192.168.1.100"}, - clientIP: "192.168.1.100", - want: true, - }, - { - name: "single IP no match", - allowedIPs: []string{"192.168.1.100"}, - clientIP: "192.168.1.101", - want: false, - }, - { - name: "CIDR match", - allowedIPs: []string{"192.168.1.0/24"}, - clientIP: "192.168.1.55", - want: true, - }, - { - name: "CIDR no match", - allowedIPs: []string{"192.168.1.0/24"}, - clientIP: "192.168.2.1", - want: false, - }, - { - name: "multiple CIDRs - first matches", - allowedIPs: []string{"10.0.0.0/8", "192.168.0.0/16"}, - clientIP: "10.50.25.100", - want: true, - }, - { - name: "multiple CIDRs - second matches", - allowedIPs: []string{"10.0.0.0/8", "192.168.0.0/16"}, - clientIP: "192.168.50.1", - want: true, - }, - { - name: "multiple CIDRs - none match", - allowedIPs: []string{"10.0.0.0/8", "192.168.0.0/16"}, - clientIP: "172.16.0.1", - want: false, - }, - { - name: "mixed IP and CIDR - IP matches", - allowedIPs: []string{"10.0.0.0/8", "172.16.0.1"}, - clientIP: "172.16.0.1", - want: true, - }, - { - name: "mixed IP and CIDR - CIDR matches", - allowedIPs: []string{"10.0.0.0/8", "172.16.0.1"}, - clientIP: "10.1.2.3", - want: true, - }, - { - name: "IPv6 CIDR", - allowedIPs: []string{"2001:db8::/32"}, - clientIP: "2001:db8:1234:5678::1", - want: true, - }, - { - name: "IPv6 no match", - allowedIPs: []string{"2001:db8::/32"}, - clientIP: "2001:db9::1", - want: false, - }, - { - name: "invalid client IP", - allowedIPs: []string{"192.168.1.0/24"}, - clientIP: "not-an-ip", - want: false, - }, - { - name: "invalid CIDR in allowlist (fallback to IP parse)", - allowedIPs: []string{"invalid/cidr", "192.168.1.100"}, - clientIP: "192.168.1.100", - want: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - key := &APIKey{AllowedIPs: tt.allowedIPs} - if got := key.IsIPAllowed(tt.clientIP); got != tt.want { - t.Errorf("IsIPAllowed(%q) = %v, want %v", tt.clientIP, got, tt.want) - } - }) - } -} - -func TestService_CreateWithAllowedIPs(t *testing.T) { - db := testutil.TestDB(t) - t.Cleanup(func() { testutil.CleanupTestKeys(t, db) }) - - svc := NewService(db, "admin-key") - - t.Run("creates key with IP restrictions", func(t *testing.T) { - resp, err := svc.Create(context.Background(), CreateKeyRequest{ - Name: "test-ip-key", - Scopes: []Scope{ScopeProjectsRead}, - AllowedIPs: []string{"192.168.1.0/24", "10.0.0.1"}, - ExpiresIn: 24 * time.Hour, - CreatedBy: "test", - }) - if err != nil { - t.Fatalf("Create() error = %v", err) - } - - if len(resp.Key.AllowedIPs) != 2 { - t.Errorf("Key.AllowedIPs length = %d, want 2", len(resp.Key.AllowedIPs)) - } - - // Verify via Get - key, err := svc.Get(context.Background(), resp.Key.ID) - if err != nil { - t.Fatalf("Get() error = %v", err) - } - - if len(key.AllowedIPs) != 2 { - t.Errorf("Retrieved Key.AllowedIPs length = %d, want 2", len(key.AllowedIPs)) - } - - // Verify via Validate - validatedKey, err := svc.Validate(context.Background(), resp.Secret) - if err != nil { - t.Fatalf("Validate() error = %v", err) - } - - if len(validatedKey.AllowedIPs) != 2 { - t.Errorf("Validated Key.AllowedIPs length = %d, want 2", len(validatedKey.AllowedIPs)) - } - - // Verify IP checking works - if !validatedKey.IsIPAllowed("192.168.1.50") { - t.Error("IsIPAllowed should return true for IP in allowed CIDR") - } - if !validatedKey.IsIPAllowed("10.0.0.1") { - t.Error("IsIPAllowed should return true for explicitly allowed IP") - } - if validatedKey.IsIPAllowed("172.16.0.1") { - t.Error("IsIPAllowed should return false for IP not in allowed list") - } - }) - - t.Run("creates key with no IP restrictions", func(t *testing.T) { - resp, err := svc.Create(context.Background(), CreateKeyRequest{ - Name: "test-no-ip-key", - Scopes: []Scope{ScopeProjectsRead}, - ExpiresIn: 24 * time.Hour, - CreatedBy: "test", - }) - if err != nil { - t.Fatalf("Create() error = %v", err) - } - - if len(resp.Key.AllowedIPs) != 0 { - t.Errorf("Key.AllowedIPs should be empty, got %v", resp.Key.AllowedIPs) - } - - // Verify via Validate - validatedKey, err := svc.Validate(context.Background(), resp.Secret) - if err != nil { - t.Fatalf("Validate() error = %v", err) - } - - // Any IP should be allowed - if !validatedKey.IsIPAllowed("1.2.3.4") { - t.Error("IsIPAllowed should return true when no restrictions set") - } - }) -} diff --git a/internal/cmdlimit/cmdlimit.go b/internal/cmdlimit/cmdlimit.go index 628d692..79eda10 100644 --- a/internal/cmdlimit/cmdlimit.go +++ b/internal/cmdlimit/cmdlimit.go @@ -3,13 +3,15 @@ package cmdlimit import ( "context" - "errors" "sync" "time" + + "github.com/orchard9/rdev/internal/domain" ) -// ErrLimitExceeded is returned when the concurrent command limit is reached. -var ErrLimitExceeded = errors.New("concurrent command limit exceeded") +// ErrLimitExceeded aliases domain.ErrLimitExceeded for backward compatibility. +// Consumers should migrate to domain.ErrLimitExceeded over time. +var ErrLimitExceeded = domain.ErrLimitExceeded // Config defines the limiter configuration. type Config struct { @@ -143,11 +145,11 @@ func (l *Limiter) Stats() Stats { } return Stats{ - TotalActive: l.totalCount, - MaxTotal: l.cfg.MaxConcurrentTotal, - ProjectCounts: projectStats, - MaxPerProject: l.cfg.MaxConcurrentPerProject, - ActiveCommandIDs: l.getActiveCommandIDs(), + TotalActive: l.totalCount, + MaxTotal: l.cfg.MaxConcurrentTotal, + ProjectCounts: projectStats, + MaxPerProject: l.cfg.MaxConcurrentPerProject, + ActiveCommandIDs: l.getActiveCommandIDs(), } } diff --git a/internal/db/migrations/012_worker_registry.sql b/internal/db/migrations/012_worker_registry.sql new file mode 100644 index 0000000..0b60253 --- /dev/null +++ b/internal/db/migrations/012_worker_registry.sql @@ -0,0 +1,71 @@ +-- Workers table for worker pool registration and health tracking. +-- Workers register on startup, send heartbeats, and are marked offline +-- if they miss heartbeats beyond the configured threshold. + +CREATE TABLE IF NOT EXISTS workers ( + id VARCHAR(255) PRIMARY KEY, + hostname VARCHAR(255) NOT NULL, + status VARCHAR(50) NOT NULL DEFAULT 'idle', + current_task VARCHAR(255), + capabilities JSONB DEFAULT '[]', + version VARCHAR(50), + registered_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + last_heartbeat TIMESTAMPTZ NOT NULL DEFAULT NOW(), + + CONSTRAINT workers_valid_status CHECK (status IN ('idle', 'busy', 'draining', 'offline')) +); + +-- Index for finding idle workers (used during task assignment) +CREATE INDEX IF NOT EXISTS idx_workers_status + ON workers(status) + WHERE status = 'idle'; + +-- Index for health checker (finding stale heartbeats) +CREATE INDEX IF NOT EXISTS idx_workers_heartbeat + ON workers(last_heartbeat); + +COMMENT ON TABLE workers IS 'Worker pool registry for tracking worker lifecycle and health'; +COMMENT ON COLUMN workers.id IS 'Unique worker identifier (typically Kubernetes pod name)'; +COMMENT ON COLUMN workers.hostname IS 'Worker hostname for identification'; +COMMENT ON COLUMN workers.status IS 'Worker status: idle, busy, draining, offline'; +COMMENT ON COLUMN workers.current_task IS 'ID of the work queue task currently being executed'; +COMMENT ON COLUMN workers.capabilities IS 'JSON array of worker capabilities (e.g., build, deploy)'; +COMMENT ON COLUMN workers.version IS 'Worker binary version string'; +COMMENT ON COLUMN workers.last_heartbeat IS 'Most recent heartbeat timestamp for health monitoring'; + +-- Build audit table for tracking build execution history. +-- Every build request creates an audit entry that is updated +-- when the build completes or fails. + +CREATE TABLE IF NOT EXISTS build_audit ( + task_id VARCHAR(255) PRIMARY KEY, + project_id VARCHAR(255) NOT NULL, + worker_id VARCHAR(255), + spec JSONB NOT NULL, + result JSONB, + status VARCHAR(50) NOT NULL DEFAULT 'pending', + started_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + completed_at TIMESTAMPTZ, + + CONSTRAINT build_audit_valid_status CHECK (status IN ('pending', 'running', 'completed', 'failed', 'cancelled')) +); + +-- Index for project build history queries +CREATE INDEX IF NOT EXISTS idx_build_audit_project + ON build_audit(project_id); + +-- Index for status-based queries (e.g., "show all running builds") +CREATE INDEX IF NOT EXISTS idx_build_audit_status + ON build_audit(status); + +-- Index for time-ordered queries (most recent first) +CREATE INDEX IF NOT EXISTS idx_build_audit_started + ON build_audit(started_at DESC); + +COMMENT ON TABLE build_audit IS 'Audit trail for build executions with full spec and result history'; +COMMENT ON COLUMN build_audit.task_id IS 'Work queue task ID (links to work_queue.id)'; +COMMENT ON COLUMN build_audit.project_id IS 'Project this build belongs to'; +COMMENT ON COLUMN build_audit.worker_id IS 'Worker that executed the build'; +COMMENT ON COLUMN build_audit.spec IS 'JSON build specification (prompt, template, variables, etc.)'; +COMMENT ON COLUMN build_audit.result IS 'JSON build result (output, commit_sha, files_changed, etc.)'; +COMMENT ON COLUMN build_audit.status IS 'Build status: pending, running, completed, failed, cancelled'; diff --git a/internal/domain/apikey.go b/internal/domain/apikey.go index e1bf933..5c24ac7 100644 --- a/internal/domain/apikey.go +++ b/internal/domain/apikey.go @@ -12,12 +12,122 @@ type APIKeyID string type Scope string const ( - ScopeAdmin Scope = "admin" ScopeProjectsRead Scope = "projects:read" ScopeProjectsExecute Scope = "projects:execute" - ScopeKeysManage Scope = "keys:manage" + ScopeKeysRead Scope = "keys:read" + ScopeKeysWrite Scope = "keys:write" + ScopeAuditRead Scope = "audit:read" + ScopeQueueRead Scope = "queue:read" + ScopeQueueWrite Scope = "queue:write" + ScopeWebhookRead Scope = "webhook:read" + ScopeWebhookWrite Scope = "webhook:write" + ScopeWorkersRead Scope = "workers:read" + ScopeWorkersWrite Scope = "workers:write" + ScopeBuildRead Scope = "build:read" + ScopeBuildWrite Scope = "build:write" + ScopeAdmin Scope = "admin" ) +// AllScopes is the list of all valid scopes. +var AllScopes = []Scope{ + ScopeProjectsRead, + ScopeProjectsExecute, + ScopeKeysRead, + ScopeKeysWrite, + ScopeAuditRead, + ScopeQueueRead, + ScopeQueueWrite, + ScopeWebhookRead, + ScopeWebhookWrite, + ScopeWorkersRead, + ScopeWorkersWrite, + ScopeBuildRead, + ScopeBuildWrite, + ScopeAdmin, +} + +// ScopeDescriptions provides human-readable descriptions. +var ScopeDescriptions = map[Scope]string{ + ScopeProjectsRead: "List and view project details", + ScopeProjectsExecute: "Execute commands (claude, shell, git) on projects", + ScopeKeysRead: "List API keys (metadata only, not secrets)", + ScopeKeysWrite: "Create and revoke API keys", + ScopeAuditRead: "View audit logs for command executions", + ScopeQueueRead: "View queued commands and queue status", + ScopeQueueWrite: "Enqueue and cancel queued commands", + ScopeWebhookRead: "View webhooks and delivery history", + ScopeWebhookWrite: "Create, update, and delete webhooks", + ScopeWorkersRead: "View workers and worker status", + ScopeWorkersWrite: "Manage workers (drain, register)", + ScopeBuildRead: "View build status and history", + ScopeBuildWrite: "Start and manage builds", + ScopeAdmin: "Full administrative access (includes all scopes)", +} + +// IsValid checks if a scope is valid. +func (s Scope) IsValid() bool { + for _, scope := range AllScopes { + if scope == s { + return true + } + } + return false +} + +// String returns the scope as a string. +func (s Scope) String() string { + return string(s) +} + +// ScopesFromStrings converts string slice to Scope slice. +func ScopesFromStrings(ss []string) []Scope { + scopes := make([]Scope, len(ss)) + for i, s := range ss { + scopes[i] = Scope(s) + } + return scopes +} + +// ScopesToStrings converts Scope slice to string slice. +func ScopesToStrings(scopes []Scope) []string { + ss := make([]string, len(scopes)) + for i, s := range scopes { + ss[i] = string(s) + } + return ss +} + +// ValidateScopes checks if all scopes are valid. +func ValidateScopes(scopes []Scope) bool { + for _, s := range scopes { + if !s.IsValid() { + return false + } + } + return true +} + +// HasScope checks if a scope list contains a required scope. +// Admin scope grants access to everything. +func HasScope(scopes []Scope, required Scope) bool { + for _, s := range scopes { + if s == ScopeAdmin || s == required { + return true + } + } + return false +} + +// HasAnyScope checks if a scope list contains any of the required scopes. +func HasAnyScope(scopes []Scope, required ...Scope) bool { + for _, r := range required { + if HasScope(scopes, r) { + return true + } + } + return false +} + // APIKey represents an API key for authentication. type APIKey struct { ID APIKeyID diff --git a/internal/domain/build.go b/internal/domain/build.go new file mode 100644 index 0000000..0801e4f --- /dev/null +++ b/internal/domain/build.go @@ -0,0 +1,202 @@ +// Package domain contains pure domain models with no external dependencies. +package domain + +import ( + "fmt" + "maps" + "net/url" + "strconv" + "strings" + "time" +) + +// BuildSpec defines what a build task should accomplish. +// It captures the user's intent and parameters for code generation. +type BuildSpec struct { + // Template is the project template to use (e.g., "nextjs-landing"). + Template string `json:"template,omitempty"` + + // Prompt is the user's instruction for the build. + Prompt string `json:"prompt"` + + // Variables contains template-specific substitution values. + Variables map[string]string `json:"variables,omitempty"` + + // AutoCommit controls whether changes are automatically committed. + AutoCommit bool `json:"auto_commit"` + + // AutoPush controls whether commits are automatically pushed. + AutoPush bool `json:"auto_push"` + + // CallbackURL is the webhook URL for completion notification. + CallbackURL string `json:"callback_url,omitempty"` +} + +// Validate checks that the BuildSpec has required fields. +func (s *BuildSpec) Validate() error { + if s.Prompt == "" { + return ErrPromptRequired + } + if s.CallbackURL != "" { + if err := ValidateCallbackURL(s.CallbackURL); err != nil { + return err + } + } + return nil +} + +// ValidateCallbackURL checks that a callback URL is safe to use. +// It rejects non-HTTPS URLs and private/internal network addresses. +func ValidateCallbackURL(rawURL string) error { + u, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("invalid callback URL: %w", err) + } + if u.Scheme != "https" { + return fmt.Errorf("callback URL must use HTTPS scheme, got %q", u.Scheme) + } + if u.Host == "" { + return fmt.Errorf("callback URL must have a host") + } + + host := u.Hostname() + lower := strings.ToLower(host) + + // Block localhost and loopback + if lower == "localhost" || lower == "127.0.0.1" || lower == "::1" || lower == "[::1]" { + return fmt.Errorf("callback URL must not point to localhost") + } + // Block common metadata endpoints + if lower == "metadata.google.internal" || lower == "169.254.169.254" { + return fmt.Errorf("callback URL must not point to cloud metadata service") + } + // Block private network ranges + if strings.HasPrefix(lower, "10.") || strings.HasPrefix(lower, "192.168.") { + return fmt.Errorf("callback URL must not point to private network addresses") + } + if strings.HasPrefix(lower, "172.") { + // 172.16.0.0 - 172.31.255.255 + parts := strings.SplitN(lower, ".", 3) + if len(parts) >= 2 { + if octet, err := strconv.Atoi(parts[1]); err == nil && octet >= 16 && octet <= 31 { + return fmt.Errorf("callback URL must not point to private network addresses") + } + } + } + return nil +} + +// BuildResult captures the outcome of a build execution. +type BuildResult struct { + // Success indicates whether the build completed successfully. + Success bool `json:"success"` + + // Output is the agent's text output during the build. + Output string `json:"output,omitempty"` + + // Error contains the error message if the build failed. + Error string `json:"error,omitempty"` + + // CommitSHA is the git commit hash if auto-commit was enabled. + CommitSHA string `json:"commit_sha,omitempty"` + + // FilesChanged lists files modified during the build. + FilesChanged []string `json:"files_changed,omitempty"` + + // DurationMs is the total execution time in milliseconds. + DurationMs int64 `json:"duration_ms"` + + // Artifacts contains named outputs from the build (e.g., deploy URLs). + Artifacts map[string]string `json:"artifacts,omitempty"` +} + +// ToWorkResult converts a BuildResult to a WorkResult. +// Build-specific fields (commit_sha, files_changed, duration_ms) are +// promoted into the artifacts map, overwriting any existing keys with +// the same names. +func (r *BuildResult) ToWorkResult() *WorkResult { + if r == nil { + return &WorkResult{} + } + + artifacts := make(map[string]string) + maps.Copy(artifacts, r.Artifacts) + + // Promote build-specific fields into artifacts + if r.CommitSHA != "" { + artifacts["commit_sha"] = r.CommitSHA + } + if r.DurationMs > 0 { + artifacts["duration_ms"] = strconv.FormatInt(r.DurationMs, 10) + } + if len(r.FilesChanged) > 0 { + artifacts["files_changed_count"] = strconv.Itoa(len(r.FilesChanged)) + } + + output := r.Output + if !r.Success && r.Error != "" { + output = r.Error + } + + return &WorkResult{ + Output: output, + Artifacts: artifacts, + } +} + +// BuildStatus represents the lifecycle state of a build. +type BuildStatus string + +const ( + BuildStatusPending BuildStatus = "pending" + BuildStatusRunning BuildStatus = "running" + BuildStatusCompleted BuildStatus = "completed" + BuildStatusFailed BuildStatus = "failed" + BuildStatusCancelled BuildStatus = "cancelled" +) + +// IsValid returns true if the status is a known valid status. +func (s BuildStatus) IsValid() bool { + switch s { + case BuildStatusPending, BuildStatusRunning, BuildStatusCompleted, + BuildStatusFailed, BuildStatusCancelled: + return true + } + return false +} + +// IsTerminal returns true if the build is in a final state. +func (s BuildStatus) IsTerminal() bool { + switch s { + case BuildStatusCompleted, BuildStatusFailed, BuildStatusCancelled: + return true + } + return false +} + +// BuildAuditEntry represents a single build's audit record. +type BuildAuditEntry struct { + // TaskID is the work queue task ID. + TaskID string `json:"task_id"` + + // ProjectID is the project this build belongs to. + ProjectID string `json:"project_id"` + + // WorkerID is the worker that executed the build. + WorkerID string `json:"worker_id,omitempty"` + + // Spec is the original build specification. + Spec BuildSpec `json:"spec"` + + // Result is the build outcome (nil if not yet complete). + Result *BuildResult `json:"result,omitempty"` + + // Status is the current build status. + Status BuildStatus `json:"status"` + + // StartedAt is when the build was created/enqueued. + StartedAt time.Time `json:"started_at"` + + // CompletedAt is when the build finished (nil if still running). + CompletedAt *time.Time `json:"completed_at,omitempty"` +} diff --git a/internal/domain/build_test.go b/internal/domain/build_test.go new file mode 100644 index 0000000..3a106ff --- /dev/null +++ b/internal/domain/build_test.go @@ -0,0 +1,207 @@ +package domain + +import ( + "errors" + "testing" +) + +func TestBuildSpec_Validate(t *testing.T) { + tests := []struct { + name string + spec BuildSpec + wantErr error + }{ + { + name: "valid spec with prompt", + spec: BuildSpec{Prompt: "Build a landing page"}, + wantErr: nil, + }, + { + name: "valid spec with all fields", + spec: BuildSpec{Prompt: "Build it", Template: "nextjs", AutoCommit: true, AutoPush: true}, + wantErr: nil, + }, + { + name: "empty prompt", + spec: BuildSpec{}, + wantErr: ErrPromptRequired, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.spec.Validate() + if tt.wantErr != nil { + if !errors.Is(err, tt.wantErr) { + t.Errorf("Validate() error = %v, want %v", err, tt.wantErr) + } + } else if err != nil { + t.Errorf("Validate() unexpected error = %v", err) + } + }) + } +} + +func TestBuildStatus_IsValid(t *testing.T) { + tests := []struct { + status BuildStatus + want bool + }{ + {BuildStatusPending, true}, + {BuildStatusRunning, true}, + {BuildStatusCompleted, true}, + {BuildStatusFailed, true}, + {BuildStatusCancelled, true}, + {"unknown", false}, + {"", false}, + } + + for _, tt := range tests { + t.Run(string(tt.status), func(t *testing.T) { + if got := tt.status.IsValid(); got != tt.want { + t.Errorf("IsValid() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestBuildResult_ToWorkResult(t *testing.T) { + t.Run("success with all fields", func(t *testing.T) { + result := &BuildResult{ + Success: true, + Output: "Build completed", + CommitSHA: "abc123", + FilesChanged: []string{"main.go", "go.mod"}, + DurationMs: 1500, + Artifacts: map[string]string{"deploy_url": "https://example.com"}, + } + + wr := result.ToWorkResult() + + if wr.Output != "Build completed" { + t.Errorf("Output = %q, want %q", wr.Output, "Build completed") + } + if wr.Artifacts["commit_sha"] != "abc123" { + t.Errorf("commit_sha = %q, want %q", wr.Artifacts["commit_sha"], "abc123") + } + if wr.Artifacts["duration_ms"] != "1500" { + t.Errorf("duration_ms = %q, want %q", wr.Artifacts["duration_ms"], "1500") + } + if wr.Artifacts["files_changed_count"] != "2" { + t.Errorf("files_changed_count = %q, want %q", wr.Artifacts["files_changed_count"], "2") + } + if wr.Artifacts["deploy_url"] != "https://example.com" { + t.Errorf("deploy_url = %q, want %q", wr.Artifacts["deploy_url"], "https://example.com") + } + }) + + t.Run("failure uses error as output", func(t *testing.T) { + result := &BuildResult{ + Success: false, + Output: "partial output", + Error: "build failed: missing deps", + } + + wr := result.ToWorkResult() + + if wr.Output != "build failed: missing deps" { + t.Errorf("Output = %q, want error message", wr.Output) + } + }) + + t.Run("nil artifacts map is safe", func(t *testing.T) { + result := &BuildResult{ + Success: true, + Output: "done", + } + + wr := result.ToWorkResult() + + if wr.Output != "done" { + t.Errorf("Output = %q, want %q", wr.Output, "done") + } + if len(wr.Artifacts) != 0 { + t.Errorf("Artifacts = %v, want empty", wr.Artifacts) + } + }) + + t.Run("zero duration not included", func(t *testing.T) { + result := &BuildResult{ + Success: true, + DurationMs: 0, + } + + wr := result.ToWorkResult() + + if _, ok := wr.Artifacts["duration_ms"]; ok { + t.Error("duration_ms should not be in artifacts when zero") + } + }) + + t.Run("nil receiver returns empty result", func(t *testing.T) { + var result *BuildResult + wr := result.ToWorkResult() + + if wr.Output != "" { + t.Errorf("Output = %q, want empty", wr.Output) + } + if wr.Artifacts != nil { + t.Errorf("Artifacts = %v, want nil", wr.Artifacts) + } + }) + + t.Run("promoted fields overwrite existing artifacts", func(t *testing.T) { + result := &BuildResult{ + Success: true, + CommitSHA: "new-sha", + Artifacts: map[string]string{ + "commit_sha": "old-sha", + "custom_key": "kept", + }, + } + + wr := result.ToWorkResult() + + if wr.Artifacts["commit_sha"] != "new-sha" { + t.Errorf("commit_sha = %q, want %q (promoted field should overwrite)", wr.Artifacts["commit_sha"], "new-sha") + } + if wr.Artifacts["custom_key"] != "kept" { + t.Errorf("custom_key = %q, want %q", wr.Artifacts["custom_key"], "kept") + } + }) + + t.Run("failed with empty error keeps output", func(t *testing.T) { + result := &BuildResult{ + Success: false, + Output: "partial output before crash", + Error: "", + } + + wr := result.ToWorkResult() + + if wr.Output != "partial output before crash" { + t.Errorf("Output = %q, want original output when error is empty", wr.Output) + } + }) +} + +func TestBuildStatus_IsTerminal(t *testing.T) { + tests := []struct { + status BuildStatus + want bool + }{ + {BuildStatusPending, false}, + {BuildStatusRunning, false}, + {BuildStatusCompleted, true}, + {BuildStatusFailed, true}, + {BuildStatusCancelled, true}, + } + + for _, tt := range tests { + t.Run(string(tt.status), func(t *testing.T) { + if got := tt.status.IsTerminal(); got != tt.want { + t.Errorf("IsTerminal() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/domain/domain_test.go b/internal/domain/domain_test.go index 28dfc47..6ae34e2 100644 --- a/internal/domain/domain_test.go +++ b/internal/domain/domain_test.go @@ -1,7 +1,6 @@ package domain_test import ( - "errors" "testing" "time" @@ -202,7 +201,7 @@ func TestAPIKey_HasAnyScope(t *testing.T) { }, { name: "no match", - scopes: []domain.Scope{domain.ScopeKeysManage}, + scopes: []domain.Scope{domain.ScopeKeysWrite}, check: []domain.Scope{domain.ScopeProjectsRead, domain.ScopeProjectsExecute}, want: false, }, @@ -403,257 +402,6 @@ func TestAPIKey_IsIPAllowed(t *testing.T) { } } -// ============================================================================= -// CommandResult Tests -// ============================================================================= - -func TestCommandResult_Success(t *testing.T) { - tests := []struct { - name string - exitCode int - err error - want bool - }{ - { - name: "success with zero exit code and no error", - exitCode: 0, - err: nil, - want: true, - }, - { - name: "failure with non-zero exit code", - exitCode: 1, - err: nil, - want: false, - }, - { - name: "failure with error", - exitCode: 0, - err: errors.New("execution failed"), - want: false, - }, - { - name: "failure with both error and non-zero exit", - exitCode: 127, - err: errors.New("command not found"), - want: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := &domain.CommandResult{ - ExitCode: tt.exitCode, - Error: tt.err, - } - if got := result.Success(); got != tt.want { - t.Errorf("Success() = %v, want %v", got, tt.want) - } - }) - } -} - -// ============================================================================= -// ProjectStatus Tests -// ============================================================================= - -func TestProjectStatus_IsAvailable(t *testing.T) { - tests := []struct { - status domain.ProjectStatus - want bool - }{ - {domain.ProjectStatusRunning, true}, - {domain.ProjectStatusPending, false}, - {domain.ProjectStatusFailed, false}, - {domain.ProjectStatusNotFound, false}, - {domain.ProjectStatusUnknown, false}, - {domain.ProjectStatusError, false}, - } - - for _, tt := range tests { - t.Run(string(tt.status), func(t *testing.T) { - if got := tt.status.IsAvailable(); got != tt.want { - t.Errorf("ProjectStatus(%q).IsAvailable() = %v, want %v", tt.status, got, tt.want) - } - }) - } -} - -func TestProjectStatus_IsTerminal(t *testing.T) { - tests := []struct { - status domain.ProjectStatus - want bool - }{ - {domain.ProjectStatusRunning, false}, - {domain.ProjectStatusPending, false}, - {domain.ProjectStatusFailed, true}, - {domain.ProjectStatusNotFound, true}, - {domain.ProjectStatusUnknown, false}, - {domain.ProjectStatusError, false}, - } - - for _, tt := range tests { - t.Run(string(tt.status), func(t *testing.T) { - if got := tt.status.IsTerminal(); got != tt.want { - t.Errorf("ProjectStatus(%q).IsTerminal() = %v, want %v", tt.status, got, tt.want) - } - }) - } -} - -// ============================================================================= -// Error Variables Tests -// ============================================================================= - -func TestErrorVariables_AreDistinct(t *testing.T) { - // Verify all domain errors are distinct and can be matched with errors.Is - allErrors := []error{ - domain.ErrProjectNotFound, - domain.ErrProjectNotRunning, - domain.ErrCommandNotFound, - domain.ErrCommandTimeout, - domain.ErrCommandCancelled, - domain.ErrLimitExceeded, - domain.ErrInvalidCommand, - domain.ErrCommandSanitization, - domain.ErrKeyNotFound, - domain.ErrKeyRevoked, - domain.ErrKeyExpired, - domain.ErrKeyInvalid, - domain.ErrUnauthorized, - domain.ErrForbidden, - domain.ErrInsufficientScope, - domain.ErrRateLimited, - domain.ErrDatabaseConnection, - domain.ErrKubernetesError, - } - - // Each error should only match itself - for i, err1 := range allErrors { - for j, err2 := range allErrors { - if i == j { - if !errors.Is(err1, err2) { - t.Errorf("error %v should match itself", err1) - } - } else { - if errors.Is(err1, err2) { - t.Errorf("error %v should not match %v", err1, err2) - } - } - } - } -} - -func TestErrorVariables_CanBeWrapped(t *testing.T) { - // Domain errors should be usable as base errors for wrapping - wrapped := errors.Join(domain.ErrProjectNotFound, errors.New("additional context")) - - if !errors.Is(wrapped, domain.ErrProjectNotFound) { - t.Error("wrapped error should match base domain error") - } -} - -// ============================================================================= -// Type Constants Tests -// ============================================================================= - -func TestScopeConstants(t *testing.T) { - // Verify scope constants have expected values (for documentation) - expectedScopes := map[domain.Scope]string{ - domain.ScopeAdmin: "admin", - domain.ScopeProjectsRead: "projects:read", - domain.ScopeProjectsExecute: "projects:execute", - domain.ScopeKeysManage: "keys:manage", - } - - for scope, expected := range expectedScopes { - if string(scope) != expected { - t.Errorf("Scope %v = %q, want %q", scope, string(scope), expected) - } - } -} - -func TestCommandTypeConstants(t *testing.T) { - expectedTypes := map[domain.CommandType]string{ - domain.CommandTypeClaude: "claude", - domain.CommandTypeShell: "shell", - domain.CommandTypeGit: "git", - } - - for cmdType, expected := range expectedTypes { - if string(cmdType) != expected { - t.Errorf("CommandType %v = %q, want %q", cmdType, string(cmdType), expected) - } - } -} - -func TestProjectStatusConstants(t *testing.T) { - expectedStatuses := map[domain.ProjectStatus]string{ - domain.ProjectStatusRunning: "running", - domain.ProjectStatusPending: "pending", - domain.ProjectStatusFailed: "failed", - domain.ProjectStatusNotFound: "not_found", - domain.ProjectStatusUnknown: "unknown", - domain.ProjectStatusError: "error", - } - - for status, expected := range expectedStatuses { - if string(status) != expected { - t.Errorf("ProjectStatus %v = %q, want %q", status, string(status), expected) - } - } -} - -// ============================================================================= -// Type Instantiation Tests -// ============================================================================= - -func TestProject_CanBeInstantiated(t *testing.T) { - proj := domain.Project{ - ID: "test-project", - Name: "Test Project", - Description: "A test project", - PodName: "test-pod", - Status: domain.ProjectStatusRunning, - Workspace: "/workspace", - } - - if proj.ID != "test-project" { - t.Errorf("Project.ID = %q, want %q", proj.ID, "test-project") - } -} - -func TestCommand_CanBeInstantiated(t *testing.T) { - now := time.Now() - cmd := domain.Command{ - ID: "cmd-1", - ProjectID: "proj-1", - Type: domain.CommandTypeClaude, - Args: []string{"--version"}, - StartedAt: now, - } - - if cmd.ID != "cmd-1" { - t.Errorf("Command.ID = %q, want %q", cmd.ID, "cmd-1") - } - if len(cmd.Args) != 1 { - t.Errorf("Command.Args length = %d, want 1", len(cmd.Args)) - } -} - -func TestOutputLine_CanBeInstantiated(t *testing.T) { - now := time.Now() - line := domain.OutputLine{ - Stream: "stdout", - Line: "Hello, world!", - Timestamp: now, - } - - if line.Stream != "stdout" { - t.Errorf("OutputLine.Stream = %q, want %q", line.Stream, "stdout") - } -} - // ============================================================================= // Helpers // ============================================================================= diff --git a/internal/domain/domain_types_test.go b/internal/domain/domain_types_test.go new file mode 100644 index 0000000..5ab213b --- /dev/null +++ b/internal/domain/domain_types_test.go @@ -0,0 +1,266 @@ +package domain_test + +import ( + "errors" + "testing" + "time" + + "github.com/orchard9/rdev/internal/domain" +) + +// ============================================================================= +// CommandResult Tests +// ============================================================================= + +func TestCommandResult_Success(t *testing.T) { + tests := []struct { + name string + exitCode int + err error + want bool + }{ + { + name: "success with zero exit code and no error", + exitCode: 0, + err: nil, + want: true, + }, + { + name: "failure with non-zero exit code", + exitCode: 1, + err: nil, + want: false, + }, + { + name: "failure with error", + exitCode: 0, + err: errors.New("execution failed"), + want: false, + }, + { + name: "failure with both error and non-zero exit", + exitCode: 127, + err: errors.New("command not found"), + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := &domain.CommandResult{ + ExitCode: tt.exitCode, + Error: tt.err, + } + if got := result.Success(); got != tt.want { + t.Errorf("Success() = %v, want %v", got, tt.want) + } + }) + } +} + +// ============================================================================= +// ProjectStatus Tests +// ============================================================================= + +func TestProjectStatus_IsAvailable(t *testing.T) { + tests := []struct { + status domain.ProjectStatus + want bool + }{ + {domain.ProjectStatusRunning, true}, + {domain.ProjectStatusPending, false}, + {domain.ProjectStatusFailed, false}, + {domain.ProjectStatusNotFound, false}, + {domain.ProjectStatusUnknown, false}, + {domain.ProjectStatusError, false}, + } + + for _, tt := range tests { + t.Run(string(tt.status), func(t *testing.T) { + if got := tt.status.IsAvailable(); got != tt.want { + t.Errorf("ProjectStatus(%q).IsAvailable() = %v, want %v", tt.status, got, tt.want) + } + }) + } +} + +func TestProjectStatus_IsTerminal(t *testing.T) { + tests := []struct { + status domain.ProjectStatus + want bool + }{ + {domain.ProjectStatusRunning, false}, + {domain.ProjectStatusPending, false}, + {domain.ProjectStatusFailed, true}, + {domain.ProjectStatusNotFound, true}, + {domain.ProjectStatusUnknown, false}, + {domain.ProjectStatusError, false}, + } + + for _, tt := range tests { + t.Run(string(tt.status), func(t *testing.T) { + if got := tt.status.IsTerminal(); got != tt.want { + t.Errorf("ProjectStatus(%q).IsTerminal() = %v, want %v", tt.status, got, tt.want) + } + }) + } +} + +// ============================================================================= +// Error Variables Tests +// ============================================================================= + +func TestErrorVariables_AreDistinct(t *testing.T) { + // Verify all domain errors are distinct and can be matched with errors.Is + allErrors := []error{ + domain.ErrProjectNotFound, + domain.ErrProjectNotRunning, + domain.ErrCommandNotFound, + domain.ErrCommandTimeout, + domain.ErrCommandCancelled, + domain.ErrLimitExceeded, + domain.ErrInvalidCommand, + domain.ErrCommandSanitization, + domain.ErrKeyNotFound, + domain.ErrKeyRevoked, + domain.ErrKeyExpired, + domain.ErrKeyInvalid, + domain.ErrUnauthorized, + domain.ErrForbidden, + domain.ErrInsufficientScope, + domain.ErrRateLimited, + domain.ErrDatabaseConnection, + domain.ErrKubernetesError, + } + + // Each error should only match itself + for i, err1 := range allErrors { + for j, err2 := range allErrors { + if i == j { + if !errors.Is(err1, err2) { + t.Errorf("error %v should match itself", err1) + } + } else { + if errors.Is(err1, err2) { + t.Errorf("error %v should not match %v", err1, err2) + } + } + } + } +} + +func TestErrorVariables_CanBeWrapped(t *testing.T) { + // Domain errors should be usable as base errors for wrapping + wrapped := errors.Join(domain.ErrProjectNotFound, errors.New("additional context")) + + if !errors.Is(wrapped, domain.ErrProjectNotFound) { + t.Error("wrapped error should match base domain error") + } +} + +// ============================================================================= +// Type Constants Tests +// ============================================================================= + +func TestScopeConstants(t *testing.T) { + // Verify scope constants have expected values (for documentation) + expectedScopes := map[domain.Scope]string{ + domain.ScopeAdmin: "admin", + domain.ScopeProjectsRead: "projects:read", + domain.ScopeProjectsExecute: "projects:execute", + domain.ScopeKeysRead: "keys:read", + domain.ScopeKeysWrite: "keys:write", + domain.ScopeAuditRead: "audit:read", + domain.ScopeQueueRead: "queue:read", + domain.ScopeQueueWrite: "queue:write", + domain.ScopeWebhookRead: "webhook:read", + domain.ScopeWebhookWrite: "webhook:write", + } + + for scope, expected := range expectedScopes { + if string(scope) != expected { + t.Errorf("Scope %v = %q, want %q", scope, string(scope), expected) + } + } +} + +func TestCommandTypeConstants(t *testing.T) { + expectedTypes := map[domain.CommandType]string{ + domain.CommandTypeClaude: "claude", + domain.CommandTypeShell: "shell", + domain.CommandTypeGit: "git", + } + + for cmdType, expected := range expectedTypes { + if string(cmdType) != expected { + t.Errorf("CommandType %v = %q, want %q", cmdType, string(cmdType), expected) + } + } +} + +func TestProjectStatusConstants(t *testing.T) { + expectedStatuses := map[domain.ProjectStatus]string{ + domain.ProjectStatusRunning: "running", + domain.ProjectStatusPending: "pending", + domain.ProjectStatusFailed: "failed", + domain.ProjectStatusNotFound: "not_found", + domain.ProjectStatusUnknown: "unknown", + domain.ProjectStatusError: "error", + } + + for status, expected := range expectedStatuses { + if string(status) != expected { + t.Errorf("ProjectStatus %v = %q, want %q", status, string(status), expected) + } + } +} + +// ============================================================================= +// Type Instantiation Tests +// ============================================================================= + +func TestProject_CanBeInstantiated(t *testing.T) { + proj := domain.Project{ + ID: "test-project", + Name: "Test Project", + Description: "A test project", + PodName: "test-pod", + Status: domain.ProjectStatusRunning, + Workspace: "/workspace", + } + + if proj.ID != "test-project" { + t.Errorf("Project.ID = %q, want %q", proj.ID, "test-project") + } +} + +func TestCommand_CanBeInstantiated(t *testing.T) { + now := time.Now() + cmd := domain.Command{ + ID: "cmd-1", + ProjectID: "proj-1", + Type: domain.CommandTypeClaude, + Args: []string{"--version"}, + StartedAt: now, + } + + if cmd.ID != "cmd-1" { + t.Errorf("Command.ID = %q, want %q", cmd.ID, "cmd-1") + } + if len(cmd.Args) != 1 { + t.Errorf("Command.Args length = %d, want 1", len(cmd.Args)) + } +} + +func TestOutputLine_CanBeInstantiated(t *testing.T) { + now := time.Now() + line := domain.OutputLine{ + Stream: "stdout", + Line: "Hello, world!", + Timestamp: now, + } + + if line.Stream != "stdout" { + t.Errorf("OutputLine.Stream = %q, want %q", line.Stream, "stdout") + } +} diff --git a/internal/domain/errors.go b/internal/domain/errors.go index f87c16a..3ad2d94 100644 --- a/internal/domain/errors.go +++ b/internal/domain/errors.go @@ -5,6 +5,9 @@ import "errors" // Domain errors - these are business-level errors that should be translated // to appropriate HTTP status codes or gRPC error codes by the presentation layer. var ( + // Generic errors + ErrNotFound = errors.New("not found") + // Project errors ErrProjectNotFound = errors.New("project not found") ErrProjectNotRunning = errors.New("project is not running") @@ -26,9 +29,20 @@ var ( ErrPromptRequired = errors.New("prompt is required") ErrInvalidTimeout = errors.New("timeout cannot be negative") + // Credential errors + ErrCredentialNotFound = errors.New("credential not found") + // Work queue errors ErrWorkTaskNotFound = errors.New("work task not found") + // Worker pool errors + ErrWorkerNotFound = errors.New("worker not found") + ErrWorkerIDRequired = errors.New("worker ID is required") + ErrWorkerHostnameRequired = errors.New("worker hostname is required") + + // Build errors + ErrBuildNotFound = errors.New("build not found") + // API Key errors ErrKeyNotFound = errors.New("api key not found") ErrKeyRevoked = errors.New("api key has been revoked") @@ -39,10 +53,15 @@ var ( ErrUnauthorized = errors.New("unauthorized") ErrForbidden = errors.New("forbidden") ErrInsufficientScope = errors.New("insufficient scope") + ErrIPNotAllowed = errors.New("ip address not allowed") // Rate limiting errors ErrRateLimited = errors.New("rate limit exceeded") + // Webhook errors + ErrWebhookNotFound = errors.New("webhook not found") + ErrInvalidWebhook = errors.New("invalid webhook configuration") + // Audit errors ErrAuditNotFound = errors.New("audit log entry not found") diff --git a/internal/domain/project.go b/internal/domain/project.go index a25a332..126c9b9 100644 --- a/internal/domain/project.go +++ b/internal/domain/project.go @@ -2,9 +2,66 @@ // These types represent the core business concepts of the application. package domain +import "regexp" + // ProjectID is a strongly-typed identifier for projects. type ProjectID string +// Project name/ID constraints. +const ( + // MaxProjectNameLen is the maximum length for project names (K8s name limit). + MaxProjectNameLen = 63 +) + +// projectIDRegex validates project IDs used for referencing existing projects. +// Allows uppercase, lowercase, digits, dashes, and underscores. Must start with a letter. +var projectIDRegex = regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9_-]*$`) + +// projectNameRegex validates project names for DNS/K8s resource creation. +// Lowercase only, digits, dashes. Must start with a lowercase letter. +var projectNameRegex = regexp.MustCompile(`^[a-z][a-z0-9-]*$`) + +// reservedProjectNames are names that cannot be used for new projects. +var reservedProjectNames = map[string]bool{ + "www": true, "api": true, "git": true, "ci": true, + "registry": true, "admin": true, "root": true, + "rdev": true, "pantheon": true, +} + +// ValidateProjectID validates a project ID for referencing existing projects. +// Allows letters, digits, dashes, underscores. Must start with a letter. Max 63 chars. +func ValidateProjectID(id string) error { + if id == "" { + return ErrInvalidProjectName + } + if len(id) > MaxProjectNameLen { + return ErrInvalidProjectName + } + if !projectIDRegex.MatchString(id) { + return ErrInvalidProjectName + } + return nil +} + +// ValidateProjectName validates a project name for DNS/K8s resource creation. +// Lowercase letters, digits, dashes only. Must start with a lowercase letter. +// Max 63 chars. Rejects reserved names. +func ValidateProjectName(name string) error { + if name == "" { + return ErrInvalidProjectName + } + if len(name) > MaxProjectNameLen { + return ErrInvalidProjectName + } + if !projectNameRegex.MatchString(name) { + return ErrInvalidProjectName + } + if reservedProjectNames[name] { + return ErrInvalidProjectName + } + return nil +} + // Project represents a claudebox project that can execute commands. type Project struct { ID ProjectID diff --git a/internal/domain/webhook.go b/internal/domain/webhook.go index a6d1bf8..a449f36 100644 --- a/internal/domain/webhook.go +++ b/internal/domain/webhook.go @@ -1,7 +1,6 @@ package domain import ( - "errors" "time" ) @@ -153,8 +152,4 @@ func DefaultWebhookDeliveryFilters() *WebhookDeliveryFilters { } } -// Webhook-related errors. -var ( - ErrWebhookNotFound = errors.New("webhook not found") - ErrInvalidWebhook = errors.New("invalid webhook configuration") -) +// Webhook-related errors are defined in errors.go for centralized error definitions. diff --git a/internal/domain/work.go b/internal/domain/work.go new file mode 100644 index 0000000..f7157c6 --- /dev/null +++ b/internal/domain/work.go @@ -0,0 +1,170 @@ +package domain + +import "time" + +// WorkTaskStatus represents the status of a work task. +type WorkTaskStatus string + +const ( + WorkTaskStatusPending WorkTaskStatus = "pending" + WorkTaskStatusRunning WorkTaskStatus = "running" + WorkTaskStatusCompleted WorkTaskStatus = "completed" + WorkTaskStatusFailed WorkTaskStatus = "failed" + WorkTaskStatusCancelled WorkTaskStatus = "cancelled" +) + +// IsValid returns true if the status is a known valid status. +func (s WorkTaskStatus) IsValid() bool { + switch s { + case WorkTaskStatusPending, WorkTaskStatusRunning, WorkTaskStatusCompleted, + WorkTaskStatusFailed, WorkTaskStatusCancelled: + return true + } + return false +} + +// WorkTaskType represents the type of work task. +type WorkTaskType string + +const ( + WorkTaskTypeBuild WorkTaskType = "build" + WorkTaskTypeTest WorkTaskType = "test" + WorkTaskTypeDeploy WorkTaskType = "deploy" + WorkTaskTypeCustom WorkTaskType = "custom" +) + +// IsValid returns true if the task type is a known valid type. +func (t WorkTaskType) IsValid() bool { + switch t { + case WorkTaskTypeBuild, WorkTaskTypeTest, WorkTaskTypeDeploy, WorkTaskTypeCustom: + return true + } + return false +} + +// WorkTask represents a task in the work queue. +type WorkTask struct { + // ID is the unique task identifier. + ID string + + // ProjectID is the project this task belongs to. + ProjectID string + + // Type is the task type (build, test, deploy, custom). + Type WorkTaskType + + // Spec contains task-specific parameters. + // For build tasks: template, prompt, variables, auto_deploy, git_url + // For test tasks: test_command, git_url + // For deploy tasks: image, replicas, env + Spec map[string]any + + // Status is the current task status. + Status WorkTaskStatus + + // Priority determines execution order (higher = more urgent). + Priority int + + // WorkerID is the ID of the worker that claimed this task. + WorkerID string + + // CallbackURL is the webhook URL for completion notification. + CallbackURL string + + // CreatedAt is when the task was created. + CreatedAt time.Time + + // StartedAt is when a worker started executing the task. + StartedAt *time.Time + + // CompletedAt is when the task finished (success or failure). + CompletedAt *time.Time + + // Result contains the task output (if completed). + Result *WorkResult + + // Error contains the error message (if failed). + Error string + + // RetryCount is the number of retry attempts. + RetryCount int + + // MaxRetries is the maximum allowed retry attempts. + MaxRetries int +} + +// WorkResult contains the result of a completed task. +type WorkResult struct { + // Output is the main output from task execution. + Output string `json:"output,omitempty"` + + // Artifacts contains named artifacts from the task. + // For build tasks: commit_sha, deploy_url, etc. + Artifacts map[string]string `json:"artifacts,omitempty"` +} + +// WorkQueueStats contains queue statistics. +type WorkQueueStats struct { + // Pending is the count of pending tasks. + Pending int64 `json:"pending"` + + // Running is the count of running tasks. + Running int64 `json:"running"` + + // Completed is the count of completed tasks (last 24h). + Completed int64 `json:"completed"` + + // Failed is the count of failed tasks (last 24h). + Failed int64 `json:"failed"` + + // Cancelled is the count of cancelled tasks (last 24h). + Cancelled int64 `json:"cancelled"` + + // OldestPending is the age of the oldest pending task. + OldestPending *time.Duration `json:"oldest_pending,omitempty"` +} + +// WorkListOptions contains pagination options for listing tasks. +type WorkListOptions struct { + // Limit is the maximum number of tasks to return (default: 50, max: 100). + Limit int + + // Offset is the number of tasks to skip (for pagination). + Offset int +} + +// DefaultWorkListOptions returns options with default values. +func DefaultWorkListOptions() WorkListOptions { + return WorkListOptions{ + Limit: 50, + Offset: 0, + } +} + +// Normalize applies defaults and limits to the options. +func (o *WorkListOptions) Normalize() { + if o.Limit <= 0 { + o.Limit = 50 + } + if o.Limit > 100 { + o.Limit = 100 + } + if o.Offset < 0 { + o.Offset = 0 + } +} + +// WorkListResult contains paginated task results. +type WorkListResult struct { + // Tasks is the list of tasks. + Tasks []*WorkTask + + // Total is the total count of matching tasks (for pagination metadata). + Total int64 + + // Limit is the limit that was applied. + Limit int + + // Offset is the offset that was applied. + Offset int +} diff --git a/internal/domain/worker.go b/internal/domain/worker.go new file mode 100644 index 0000000..4c1e78c --- /dev/null +++ b/internal/domain/worker.go @@ -0,0 +1,76 @@ +// Package domain contains pure domain models with no external dependencies. +package domain + +import "time" + +// Worker represents a registered worker in the pool. +type Worker struct { + // ID is the unique worker identifier (typically pod name). + ID string `json:"id"` + + // Hostname is the worker's hostname. + Hostname string `json:"hostname"` + + // Status is the current worker state. + Status WorkerStatus `json:"status"` + + // CurrentTask is the ID of the task currently being executed. + CurrentTask string `json:"current_task,omitempty"` + + // Capabilities lists what this worker can do (e.g., "build", "deploy"). + Capabilities []string `json:"capabilities,omitempty"` + + // RegisteredAt is when the worker first joined the pool. + RegisteredAt time.Time `json:"registered_at"` + + // LastHeartbeat is the most recent heartbeat timestamp. + LastHeartbeat time.Time `json:"last_heartbeat"` + + // Version is the worker binary version. + Version string `json:"version,omitempty"` +} + +// IsAvailable returns true if the worker can accept new tasks. +func (w *Worker) IsAvailable() bool { + return w.Status == WorkerStatusIdle +} + +// IsHealthy returns true if the worker has a recent heartbeat. +func (w *Worker) IsHealthy(threshold time.Duration) bool { + return time.Since(w.LastHeartbeat) < threshold +} + +// WorkerStatus represents the current state of a worker. +type WorkerStatus string + +const ( + WorkerStatusIdle WorkerStatus = "idle" + WorkerStatusBusy WorkerStatus = "busy" + WorkerStatusDraining WorkerStatus = "draining" + WorkerStatusOffline WorkerStatus = "offline" +) + +// IsValid returns true if the status is a known valid status. +func (s WorkerStatus) IsValid() bool { + switch s { + case WorkerStatusIdle, WorkerStatusBusy, WorkerStatusDraining, WorkerStatusOffline: + return true + } + return false +} + +// ValidWorkerStatuses returns all valid worker status values. +func ValidWorkerStatuses() []WorkerStatus { + return []WorkerStatus{WorkerStatusIdle, WorkerStatusBusy, WorkerStatusDraining, WorkerStatusOffline} +} + +// Validate checks that the Worker has required fields for registration. +func (w *Worker) Validate() error { + if w.ID == "" { + return ErrWorkerIDRequired + } + if w.Hostname == "" { + return ErrWorkerHostnameRequired + } + return nil +} diff --git a/internal/domain/worker_test.go b/internal/domain/worker_test.go new file mode 100644 index 0000000..e799c3a --- /dev/null +++ b/internal/domain/worker_test.go @@ -0,0 +1,146 @@ +package domain + +import ( + "errors" + "testing" + "time" +) + +func TestWorker_IsAvailable(t *testing.T) { + tests := []struct { + status WorkerStatus + want bool + }{ + {WorkerStatusIdle, true}, + {WorkerStatusBusy, false}, + {WorkerStatusDraining, false}, + {WorkerStatusOffline, false}, + } + + for _, tt := range tests { + t.Run(string(tt.status), func(t *testing.T) { + w := &Worker{Status: tt.status} + if got := w.IsAvailable(); got != tt.want { + t.Errorf("IsAvailable() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestWorker_IsHealthy(t *testing.T) { + threshold := 90 * time.Second + + tests := []struct { + name string + lastHeartbeat time.Time + want bool + }{ + { + name: "recent heartbeat", + lastHeartbeat: time.Now().Add(-30 * time.Second), + want: true, + }, + { + name: "stale heartbeat", + lastHeartbeat: time.Now().Add(-2 * time.Minute), + want: false, + }, + { + name: "exactly at threshold", + lastHeartbeat: time.Now().Add(-90 * time.Second), + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := &Worker{LastHeartbeat: tt.lastHeartbeat} + if got := w.IsHealthy(threshold); got != tt.want { + t.Errorf("IsHealthy() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestWorkerStatus_IsValid(t *testing.T) { + tests := []struct { + status WorkerStatus + want bool + }{ + {WorkerStatusIdle, true}, + {WorkerStatusBusy, true}, + {WorkerStatusDraining, true}, + {WorkerStatusOffline, true}, + {"unknown", false}, + {"", false}, + } + + for _, tt := range tests { + t.Run(string(tt.status), func(t *testing.T) { + if got := tt.status.IsValid(); got != tt.want { + t.Errorf("IsValid() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestWorker_Validate(t *testing.T) { + tests := []struct { + name string + worker Worker + wantErr error + }{ + { + name: "valid worker", + worker: Worker{ID: "worker-1", Hostname: "host-1"}, + wantErr: nil, + }, + { + name: "missing ID", + worker: Worker{Hostname: "host-1"}, + wantErr: ErrWorkerIDRequired, + }, + { + name: "missing hostname", + worker: Worker{ID: "worker-1"}, + wantErr: ErrWorkerHostnameRequired, + }, + { + name: "both missing", + worker: Worker{}, + wantErr: ErrWorkerIDRequired, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.worker.Validate() + if tt.wantErr != nil { + if !errors.Is(err, tt.wantErr) { + t.Errorf("Validate() error = %v, want %v", err, tt.wantErr) + } + } else if err != nil { + t.Errorf("Validate() unexpected error = %v", err) + } + }) + } +} + +func TestValidWorkerStatuses(t *testing.T) { + statuses := ValidWorkerStatuses() + if len(statuses) != 4 { + t.Errorf("expected 4 statuses, got %d", len(statuses)) + } + + found := make(map[WorkerStatus]bool) + for _, s := range statuses { + found[s] = true + } + + expected := []WorkerStatus{WorkerStatusIdle, WorkerStatusBusy, WorkerStatusDraining, WorkerStatusOffline} + for _, s := range expected { + if !found[s] { + t.Errorf("expected %s in valid statuses", s) + } + } +} diff --git a/internal/handlers/agents.go b/internal/handlers/agents.go new file mode 100644 index 0000000..882e473 --- /dev/null +++ b/internal/handlers/agents.go @@ -0,0 +1,269 @@ +// Package handlers provides HTTP handlers for the rdev API. +package handlers + +import ( + "encoding/json" + "net/http" + "time" + + "github.com/go-chi/chi/v5" + "github.com/orchard9/rdev/internal/domain" + "github.com/orchard9/rdev/internal/port" + "github.com/orchard9/rdev/pkg/api" +) + +// AgentsHandler handles code agent management endpoints. +type AgentsHandler struct { + registry port.CodeAgentRegistry +} + +// NewAgentsHandler creates a new agents handler. +func NewAgentsHandler(registry port.CodeAgentRegistry) *AgentsHandler { + return &AgentsHandler{ + registry: registry, + } +} + +// Mount registers the agent routes. +func (h *AgentsHandler) Mount(r api.Router) { + r.Route("/agents", func(r chi.Router) { + r.Get("/", h.List) + r.Get("/health", h.Health) + r.Get("/{provider}", h.GetCapabilities) + r.Post("/default", h.SetDefault) + }) +} + +// AgentDTO is the data transfer object for code agents. +type AgentDTO struct { + Provider string `json:"provider"` + Name string `json:"name"` + Available bool `json:"available"` + Default bool `json:"default"` + Models []string `json:"supported_models,omitempty"` + DefaultModel string `json:"default_model,omitempty"` +} + +// AgentCapabilitiesDTO is the DTO for agent capabilities. +type AgentCapabilitiesDTO struct { + Provider string `json:"provider"` + SupportsSessionContinuation bool `json:"supports_session_continuation"` + SupportsModelSelection bool `json:"supports_model_selection"` + SupportsToolControl bool `json:"supports_tool_control"` + SupportsStreaming bool `json:"supports_streaming"` + SupportedModels []string `json:"supported_models"` + DefaultModel string `json:"default_model"` + MaxPromptLength int `json:"max_prompt_length,omitempty"` +} + +// ListAgentsResponse is the response for GET /agents. +type ListAgentsResponse struct { + Agents []AgentDTO `json:"agents"` + DefaultAgent string `json:"default_agent"` + TotalAgents int `json:"total_agents"` + AvailableCount int `json:"available_count"` +} + +// List returns all registered code agents and their status. +// GET /agents +func (h *AgentsHandler) List(w http.ResponseWriter, r *http.Request) { + if h.registry == nil { + api.WriteSuccess(w, r, ListAgentsResponse{ + Agents: []AgentDTO{}, + DefaultAgent: "", + TotalAgents: 0, + AvailableCount: 0, + }) + return + } + + providers := h.registry.Available() + defaultProvider := h.registry.DefaultProvider() + + // Check availability for each agent + availableAgents := h.registry.AvailableAgents(r.Context()) + availableSet := make(map[domain.AgentProvider]bool) + for _, agent := range availableAgents { + availableSet[agent.Provider()] = true + } + + agents := make([]AgentDTO, 0, len(providers)) + availableCount := 0 + + for _, provider := range providers { + agent := h.registry.Get(provider) + if agent == nil { + continue + } + + caps := agent.Capabilities() + isAvailable := availableSet[provider] + isDefault := provider == defaultProvider + + if isAvailable { + availableCount++ + } + + agents = append(agents, AgentDTO{ + Provider: string(provider), + Name: agent.Name(), + Available: isAvailable, + Default: isDefault, + Models: caps.SupportedModels, + DefaultModel: caps.DefaultModel, + }) + } + + api.WriteSuccess(w, r, ListAgentsResponse{ + Agents: agents, + DefaultAgent: string(defaultProvider), + TotalAgents: len(agents), + AvailableCount: availableCount, + }) +} + +// GetCapabilities returns the capabilities of a specific agent. +// GET /agents/{provider} +func (h *AgentsHandler) GetCapabilities(w http.ResponseWriter, r *http.Request) { + providerStr := chi.URLParam(r, "provider") + provider := domain.AgentProvider(providerStr) + + if h.registry == nil { + api.WriteNotFound(w, r, "no agents registered") + return + } + + agent := h.registry.Get(provider) + if agent == nil { + api.WriteNotFound(w, r, "agent not found: "+providerStr) + return + } + + caps := agent.Capabilities() + + api.WriteSuccess(w, r, AgentCapabilitiesDTO{ + Provider: string(caps.Provider), + SupportsSessionContinuation: caps.SupportsSessionContinuation, + SupportsModelSelection: caps.SupportsModelSelection, + SupportsToolControl: caps.SupportsToolControl, + SupportsStreaming: caps.SupportsStreaming, + SupportedModels: caps.SupportedModels, + DefaultModel: caps.DefaultModel, + MaxPromptLength: caps.MaxPromptLength, + }) +} + +// AgentHealthDTO represents the health status of a single agent. +type AgentHealthDTO struct { + Provider string `json:"provider"` + Name string `json:"name"` + Healthy bool `json:"healthy"` + Message string `json:"message"` + Latency string `json:"latency"` + CheckedAt string `json:"checked_at"` +} + +// AgentHealthResponse is the response for GET /agents/health. +type AgentHealthResponse struct { + Agents []AgentHealthDTO `json:"agents"` + HealthyCount int `json:"healthy_count"` + TotalCount int `json:"total_count"` + DefaultAgent string `json:"default_agent"` + DefaultHealth bool `json:"default_healthy"` +} + +// Health returns the health status of all registered code agents. +// GET /agents/health +func (h *AgentsHandler) Health(w http.ResponseWriter, r *http.Request) { + if h.registry == nil { + api.WriteSuccess(w, r, AgentHealthResponse{ + Agents: []AgentHealthDTO{}, + HealthyCount: 0, + TotalCount: 0, + }) + return + } + + providers := h.registry.Available() + defaultProvider := h.registry.DefaultProvider() + + agents := make([]AgentHealthDTO, 0, len(providers)) + healthyCount := 0 + defaultHealthy := false + + for _, provider := range providers { + agent := h.registry.Get(provider) + if agent == nil { + continue + } + + start := time.Now() + healthy := agent.Available(r.Context()) + latency := time.Since(start) + + msg := "available" + if !healthy { + msg = "unavailable" + } + + if healthy { + healthyCount++ + } + if provider == defaultProvider { + defaultHealthy = healthy + } + + agents = append(agents, AgentHealthDTO{ + Provider: string(provider), + Name: agent.Name(), + Healthy: healthy, + Message: msg, + Latency: latency.String(), + CheckedAt: time.Now().UTC().Format(time.RFC3339), + }) + } + + api.WriteSuccess(w, r, AgentHealthResponse{ + Agents: agents, + HealthyCount: healthyCount, + TotalCount: len(agents), + DefaultAgent: string(defaultProvider), + DefaultHealth: defaultHealthy, + }) +} + +// SetDefaultRequest is the request body for POST /agents/default. +type SetDefaultRequest struct { + Provider string `json:"provider"` +} + +// SetDefault changes the default code agent. +// POST /agents/default +func (h *AgentsHandler) SetDefault(w http.ResponseWriter, r *http.Request) { + if h.registry == nil { + api.WriteBadRequest(w, r, "no agents registered") + return + } + + var req SetDefaultRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + api.WriteBadRequest(w, r, "invalid request body") + return + } + + if req.Provider == "" { + api.WriteBadRequest(w, r, "provider is required") + return + } + + provider := domain.AgentProvider(req.Provider) + if err := h.registry.SetDefault(provider); err != nil { + api.WriteBadRequest(w, r, err.Error()) + return + } + + api.WriteSuccess(w, r, map[string]any{ + "default_agent": req.Provider, + "message": "default agent updated", + }) +} diff --git a/internal/handlers/agents_test.go b/internal/handlers/agents_test.go new file mode 100644 index 0000000..7a69dee --- /dev/null +++ b/internal/handlers/agents_test.go @@ -0,0 +1,372 @@ +package handlers + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/orchard9/rdev/internal/domain" + "github.com/orchard9/rdev/internal/port" +) + +// mockCodeAgent implements port.CodeAgent for testing. +type mockCodeAgent struct { + name string + provider domain.AgentProvider + available bool + capabilities domain.AgentCapabilities +} + +func (m *mockCodeAgent) Name() string { return m.name } +func (m *mockCodeAgent) Provider() domain.AgentProvider { return m.provider } +func (m *mockCodeAgent) Execute(ctx context.Context, req *domain.AgentRequest, handler domain.AgentEventHandler) (*domain.AgentResult, error) { + return &domain.AgentResult{}, nil +} +func (m *mockCodeAgent) Cancel(ctx context.Context, sessionID string) error { return nil } +func (m *mockCodeAgent) Capabilities() domain.AgentCapabilities { return m.capabilities } +func (m *mockCodeAgent) Available(ctx context.Context) bool { return m.available } + +// mockAgentRegistry implements port.CodeAgentRegistry for testing. +type mockAgentRegistry struct { + agents map[domain.AgentProvider]*mockCodeAgent + defaultAgent domain.AgentProvider + setDefaultErr error +} + +func newMockAgentRegistry() *mockAgentRegistry { + return &mockAgentRegistry{ + agents: make(map[domain.AgentProvider]*mockCodeAgent), + } +} + +func (m *mockAgentRegistry) Register(agent port.CodeAgent) { + m.agents[agent.Provider()] = agent.(*mockCodeAgent) + if m.defaultAgent == "" { + m.defaultAgent = agent.Provider() + } +} + +// registerAgent is a helper for tests to directly add mockCodeAgents +func (m *mockAgentRegistry) registerAgent(agent *mockCodeAgent) { + m.agents[agent.provider] = agent + if m.defaultAgent == "" { + m.defaultAgent = agent.provider + } +} + +func (m *mockAgentRegistry) Get(provider domain.AgentProvider) port.CodeAgent { + agent, ok := m.agents[provider] + if !ok { + return nil + } + return agent +} + +func (m *mockAgentRegistry) Default() port.CodeAgent { + return m.Get(m.defaultAgent) +} + +func (m *mockAgentRegistry) DefaultProvider() domain.AgentProvider { + return m.defaultAgent +} + +func (m *mockAgentRegistry) SetDefault(provider domain.AgentProvider) error { + if m.setDefaultErr != nil { + return m.setDefaultErr + } + if _, ok := m.agents[provider]; !ok { + return fmt.Errorf("agent provider %q is not registered", provider) + } + m.defaultAgent = provider + return nil +} + +func (m *mockAgentRegistry) Available() []domain.AgentProvider { + providers := make([]domain.AgentProvider, 0, len(m.agents)) + for p := range m.agents { + providers = append(providers, p) + } + return providers +} + +func (m *mockAgentRegistry) AvailableAgents(ctx context.Context) []port.CodeAgent { + var available []port.CodeAgent + for _, agent := range m.agents { + if agent.available { + available = append(available, agent) + } + } + return available +} + +func (m *mockAgentRegistry) Count() int { + return len(m.agents) +} + +func TestAgentsHandler_List(t *testing.T) { + registry := newMockAgentRegistry() + registry.registerAgent(&mockCodeAgent{ + name: "Claude Code", + provider: domain.AgentProviderClaudeCode, + available: true, + capabilities: domain.AgentCapabilities{ + Provider: domain.AgentProviderClaudeCode, + SupportsSessionContinuation: true, + SupportedModels: []string{"claude-sonnet-4-20250514"}, + DefaultModel: "claude-sonnet-4-20250514", + }, + }) + registry.registerAgent(&mockCodeAgent{ + name: "OpenCode", + provider: domain.AgentProviderOpenCode, + available: false, + capabilities: domain.AgentCapabilities{ + Provider: domain.AgentProviderOpenCode, + SupportsSessionContinuation: true, + SupportsModelSelection: true, + SupportedModels: []string{"gpt-4o", "claude-sonnet-4-20250514"}, + DefaultModel: "claude-sonnet-4-20250514", + }, + }) + + handler := NewAgentsHandler(registry) + + req := httptest.NewRequest(http.MethodGet, "/agents", nil) + w := httptest.NewRecorder() + + handler.List(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + // Unwrap from api.Response + respBody := w.Body.Bytes() + var apiResp struct { + Data ListAgentsResponse `json:"data"` + } + if err := json.Unmarshal(respBody, &apiResp); err != nil { + t.Fatalf("failed to decode api response: %v", err) + } + + if apiResp.Data.TotalAgents != 2 { + t.Errorf("expected 2 agents, got %d", apiResp.Data.TotalAgents) + } + if apiResp.Data.AvailableCount != 1 { + t.Errorf("expected 1 available agent, got %d", apiResp.Data.AvailableCount) + } + if apiResp.Data.DefaultAgent != string(domain.AgentProviderClaudeCode) { + t.Errorf("expected default agent to be claude-code, got %s", apiResp.Data.DefaultAgent) + } +} + +func TestAgentsHandler_List_NoRegistry(t *testing.T) { + handler := NewAgentsHandler(nil) + + req := httptest.NewRequest(http.MethodGet, "/agents", nil) + w := httptest.NewRecorder() + + handler.List(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } +} + +func TestAgentsHandler_GetCapabilities(t *testing.T) { + registry := newMockAgentRegistry() + registry.registerAgent(&mockCodeAgent{ + name: "Claude Code", + provider: domain.AgentProviderClaudeCode, + capabilities: domain.AgentCapabilities{ + Provider: domain.AgentProviderClaudeCode, + SupportsSessionContinuation: true, + SupportsStreaming: true, + SupportedModels: []string{"claude-sonnet-4-20250514"}, + DefaultModel: "claude-sonnet-4-20250514", + MaxPromptLength: 100000, + }, + }) + + handler := NewAgentsHandler(registry) + + // Set up chi router for URL params + r := chi.NewRouter() + r.Get("/agents/{provider}", handler.GetCapabilities) + + req := httptest.NewRequest(http.MethodGet, "/agents/claudecode", nil) + w := httptest.NewRecorder() + + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + var apiResp struct { + Data AgentCapabilitiesDTO `json:"data"` + } + if err := json.Unmarshal(w.Body.Bytes(), &apiResp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + + if apiResp.Data.Provider != string(domain.AgentProviderClaudeCode) { + t.Errorf("expected provider claudecode, got %s", apiResp.Data.Provider) + } + if !apiResp.Data.SupportsSessionContinuation { + t.Error("expected session continuation support") + } + if !apiResp.Data.SupportsStreaming { + t.Error("expected streaming support") + } +} + +func TestAgentsHandler_GetCapabilities_NotFound(t *testing.T) { + registry := newMockAgentRegistry() + handler := NewAgentsHandler(registry) + + r := chi.NewRouter() + r.Get("/agents/{provider}", handler.GetCapabilities) + + req := httptest.NewRequest(http.MethodGet, "/agents/unknown", nil) + w := httptest.NewRecorder() + + r.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("expected status 404, got %d", w.Code) + } +} + +func TestAgentsHandler_SetDefault(t *testing.T) { + registry := newMockAgentRegistry() + registry.registerAgent(&mockCodeAgent{ + name: "Claude Code", + provider: domain.AgentProviderClaudeCode, + }) + registry.registerAgent(&mockCodeAgent{ + name: "OpenCode", + provider: domain.AgentProviderOpenCode, + }) + + handler := NewAgentsHandler(registry) + + body := SetDefaultRequest{Provider: string(domain.AgentProviderOpenCode)} + bodyBytes, _ := json.Marshal(body) + + req := httptest.NewRequest(http.MethodPost, "/agents/default", bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + handler.SetDefault(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + // Verify default was changed + if registry.defaultAgent != domain.AgentProviderOpenCode { + t.Errorf("expected default to be opencode, got %s", registry.defaultAgent) + } +} + +func TestAgentsHandler_SetDefault_InvalidProvider(t *testing.T) { + registry := newMockAgentRegistry() + registry.registerAgent(&mockCodeAgent{ + name: "Claude Code", + provider: domain.AgentProviderClaudeCode, + }) + + handler := NewAgentsHandler(registry) + + body := SetDefaultRequest{Provider: "unknown"} + bodyBytes, _ := json.Marshal(body) + + req := httptest.NewRequest(http.MethodPost, "/agents/default", bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + handler.SetDefault(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", w.Code) + } +} + +func TestAgentsHandler_SetDefault_EmptyProvider(t *testing.T) { + registry := newMockAgentRegistry() + handler := NewAgentsHandler(registry) + + body := SetDefaultRequest{Provider: ""} + bodyBytes, _ := json.Marshal(body) + + req := httptest.NewRequest(http.MethodPost, "/agents/default", bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + handler.SetDefault(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", w.Code) + } +} + +func TestAgentsHandler_Health(t *testing.T) { + registry := newMockAgentRegistry() + registry.registerAgent(&mockCodeAgent{ + name: "Claude Code", + provider: domain.AgentProviderClaudeCode, + available: true, + }) + registry.registerAgent(&mockCodeAgent{ + name: "OpenCode", + provider: domain.AgentProviderOpenCode, + available: false, + }) + + handler := NewAgentsHandler(registry) + + req := httptest.NewRequest(http.MethodGet, "/agents/health", nil) + w := httptest.NewRecorder() + + handler.Health(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + var apiResp struct { + Data AgentHealthResponse `json:"data"` + } + if err := json.Unmarshal(w.Body.Bytes(), &apiResp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + + if apiResp.Data.TotalCount != 2 { + t.Errorf("expected 2 agents, got %d", apiResp.Data.TotalCount) + } + if apiResp.Data.HealthyCount != 1 { + t.Errorf("expected 1 healthy agent, got %d", apiResp.Data.HealthyCount) + } + if !apiResp.Data.DefaultHealth { + t.Error("expected default agent to be healthy") + } +} + +func TestAgentsHandler_Health_NoRegistry(t *testing.T) { + handler := NewAgentsHandler(nil) + + req := httptest.NewRequest(http.MethodGet, "/agents/health", nil) + w := httptest.NewRecorder() + + handler.Health(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } +} diff --git a/internal/handlers/builds.go b/internal/handlers/builds.go new file mode 100644 index 0000000..745c985 --- /dev/null +++ b/internal/handlers/builds.go @@ -0,0 +1,238 @@ +// Package handlers provides HTTP handlers for the rdev API. +package handlers + +import ( + "encoding/json" + "errors" + "net/http" + "strconv" + + "github.com/go-chi/chi/v5" + "github.com/orchard9/rdev/internal/auth" + "github.com/orchard9/rdev/internal/domain" + "github.com/orchard9/rdev/internal/service" + "github.com/orchard9/rdev/pkg/api" +) + +// maxRequestBodySize is the maximum allowed size for request bodies (1MB). +const maxRequestBodySize = 1 << 20 + +// BuildsHandler handles project-scoped build endpoints. +type BuildsHandler struct { + buildService *service.BuildService +} + +// NewBuildsHandler creates a new builds handler. +func NewBuildsHandler(buildService *service.BuildService) *BuildsHandler { + return &BuildsHandler{ + buildService: buildService, + } +} + +// Mount registers the build routes. +func (h *BuildsHandler) Mount(r api.Router) { + // Project-scoped build endpoints + r.With(auth.RequireScope(auth.ScopeBuildWrite, auth.ScopeAdmin)). + Post("/projects/{id}/builds", h.StartBuild) + r.With(auth.RequireScope(auth.ScopeBuildRead, auth.ScopeAdmin)). + Get("/projects/{id}/builds", h.ListBuilds) + + // Build detail by task ID + r.With(auth.RequireScope(auth.ScopeBuildRead, auth.ScopeAdmin)). + Get("/builds/{taskId}", h.GetBuild) +} + +// StartBuildRequest is the request body for POST /projects/{id}/builds. +type StartBuildRequest struct { + Prompt string `json:"prompt"` + Template string `json:"template,omitempty"` + Variables map[string]string `json:"variables,omitempty"` + AutoCommit bool `json:"auto_commit"` + AutoPush bool `json:"auto_push"` + CallbackURL string `json:"callback_url,omitempty"` +} + +// StartBuildResponse is the response for POST /projects/{id}/builds. +type StartBuildResponse struct { + TaskID string `json:"task_id"` + ProjectID string `json:"project_id"` + Status string `json:"status"` + StatusURL string `json:"status_url"` +} + +// BuildAuditDTO is the data transfer object for build audit entries. +type BuildAuditDTO struct { + TaskID string `json:"task_id"` + ProjectID string `json:"project_id"` + WorkerID string `json:"worker_id,omitempty"` + Status string `json:"status"` + Prompt string `json:"prompt"` + Template string `json:"template,omitempty"` + AutoCommit bool `json:"auto_commit"` + AutoPush bool `json:"auto_push"` + Result *BuildResultDTO `json:"result,omitempty"` + StartedAt string `json:"started_at"` + CompletedAt string `json:"completed_at,omitempty"` +} + +// BuildResultDTO is the data transfer object for build results. +type BuildResultDTO struct { + Success bool `json:"success"` + Output string `json:"output,omitempty"` + Error string `json:"error,omitempty"` + CommitSHA string `json:"commit_sha,omitempty"` + FilesChanged []string `json:"files_changed,omitempty"` + DurationMs int64 `json:"duration_ms"` + Artifacts map[string]string `json:"artifacts,omitempty"` +} + +func toBuildAuditDTO(e *domain.BuildAuditEntry) *BuildAuditDTO { + if e == nil { + return nil + } + dto := &BuildAuditDTO{ + TaskID: e.TaskID, + ProjectID: e.ProjectID, + WorkerID: e.WorkerID, + Status: string(e.Status), + Prompt: e.Spec.Prompt, + Template: e.Spec.Template, + AutoCommit: e.Spec.AutoCommit, + AutoPush: e.Spec.AutoPush, + StartedAt: e.StartedAt.Format("2006-01-02T15:04:05Z07:00"), + } + if e.CompletedAt != nil { + dto.CompletedAt = e.CompletedAt.Format("2006-01-02T15:04:05Z07:00") + } + if e.Result != nil { + dto.Result = &BuildResultDTO{ + Success: e.Result.Success, + Output: e.Result.Output, + Error: e.Result.Error, + CommitSHA: e.Result.CommitSHA, + FilesChanged: e.Result.FilesChanged, + DurationMs: e.Result.DurationMs, + Artifacts: e.Result.Artifacts, + } + } + return dto +} + +// StartBuild enqueues a build task for a project. +// POST /projects/{id}/builds +func (h *BuildsHandler) StartBuild(w http.ResponseWriter, r *http.Request) { + projectID := chi.URLParam(r, "id") + if projectID == "" { + api.WriteBadRequest(w, r, "project ID is required") + return + } + + r.Body = http.MaxBytesReader(w, r.Body, maxRequestBodySize) + var req StartBuildRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + api.WriteBadRequest(w, r, "invalid request body") + return + } + + if req.Prompt == "" { + api.WriteBadRequest(w, r, "prompt is required") + return + } + + // Validate callback URL to prevent SSRF + if req.CallbackURL != "" { + if err := domain.ValidateCallbackURL(req.CallbackURL); err != nil { + api.WriteBadRequest(w, r, err.Error()) + return + } + } + + spec := domain.BuildSpec{ + Prompt: req.Prompt, + Template: req.Template, + Variables: req.Variables, + AutoCommit: req.AutoCommit, + AutoPush: req.AutoPush, + CallbackURL: req.CallbackURL, + } + + taskID, err := h.buildService.StartBuild(r.Context(), projectID, spec) + if err != nil { + if errors.Is(err, domain.ErrPromptRequired) { + api.WriteBadRequest(w, r, err.Error()) + return + } + api.WriteInternalError(w, r, "failed to start build") + return + } + + api.WriteCreated(w, r, StartBuildResponse{ + TaskID: taskID, + ProjectID: projectID, + Status: "pending", + StatusURL: "/builds/" + taskID, + }) +} + +// ListBuilds returns build history for a project. +// GET /projects/{id}/builds?limit=50 +func (h *BuildsHandler) ListBuilds(w http.ResponseWriter, r *http.Request) { + projectID := chi.URLParam(r, "id") + if projectID == "" { + api.WriteBadRequest(w, r, "project ID is required") + return + } + + limit := 50 + if limitStr := r.URL.Query().Get("limit"); limitStr != "" { + l, err := strconv.Atoi(limitStr) + if err != nil { + api.WriteBadRequest(w, r, "limit must be a valid integer") + return + } + if l < 1 || l > 200 { + api.WriteBadRequest(w, r, "limit must be between 1 and 200") + return + } + limit = l + } + + builds, err := h.buildService.ListBuilds(r.Context(), projectID, limit) + if err != nil { + api.WriteInternalError(w, r, "failed to list builds") + return + } + + dtos := make([]*BuildAuditDTO, len(builds)) + for i, b := range builds { + dtos[i] = toBuildAuditDTO(b) + } + + api.WriteSuccess(w, r, map[string]any{ + "builds": dtos, + "project_id": projectID, + "total": len(dtos), + }) +} + +// GetBuild returns the status of a specific build. +// GET /builds/{taskId} +func (h *BuildsHandler) GetBuild(w http.ResponseWriter, r *http.Request) { + taskID := chi.URLParam(r, "taskId") + if taskID == "" { + api.WriteBadRequest(w, r, "task ID is required") + return + } + + entry, err := h.buildService.GetBuildStatus(r.Context(), taskID) + if err != nil { + if errors.Is(err, domain.ErrBuildNotFound) { + api.WriteNotFound(w, r, "build not found: "+taskID) + return + } + api.WriteInternalError(w, r, "failed to get build status") + return + } + + api.WriteSuccess(w, r, toBuildAuditDTO(entry)) +} diff --git a/internal/handlers/builds_test.go b/internal/handlers/builds_test.go new file mode 100644 index 0000000..8a2073c --- /dev/null +++ b/internal/handlers/builds_test.go @@ -0,0 +1,351 @@ +package handlers + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-chi/chi/v5" + "github.com/orchard9/rdev/internal/auth" + "github.com/orchard9/rdev/internal/domain" + "github.com/orchard9/rdev/internal/port" + "github.com/orchard9/rdev/internal/service" +) + +// testAdminAuth is a chi middleware that injects an admin API key into the +// request context so auth.RequireScope passes in tests. +func testAdminAuth(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := auth.WithAPIKey(r.Context(), &domain.APIKey{ + Scopes: []domain.Scope{domain.ScopeAdmin}, + }) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// mockBuildAudit implements port.BuildAudit for testing. +type mockBuildAudit struct { + entries map[string]*domain.BuildAuditEntry + err error +} + +func newMockBuildAudit() *mockBuildAudit { + return &mockBuildAudit{ + entries: make(map[string]*domain.BuildAuditEntry), + } +} + +func (m *mockBuildAudit) Record(_ context.Context, entry *domain.BuildAuditEntry) error { + if m.err != nil { + return m.err + } + m.entries[entry.TaskID] = entry + return nil +} + +func (m *mockBuildAudit) Update(_ context.Context, taskID string, result *domain.BuildResult) error { + if m.err != nil { + return m.err + } + entry, ok := m.entries[taskID] + if !ok { + return domain.ErrBuildNotFound + } + entry.Result = result + if result.Success { + entry.Status = domain.BuildStatusCompleted + } else { + entry.Status = domain.BuildStatusFailed + } + now := time.Now() + entry.CompletedAt = &now + return nil +} + +func (m *mockBuildAudit) Get(_ context.Context, taskID string) (*domain.BuildAuditEntry, error) { + if m.err != nil { + return nil, m.err + } + entry, ok := m.entries[taskID] + if !ok { + return nil, domain.ErrBuildNotFound + } + return entry, nil +} + +func (m *mockBuildAudit) List(_ context.Context, filter port.BuildAuditFilter) ([]*domain.BuildAuditEntry, error) { + if m.err != nil { + return nil, m.err + } + var result []*domain.BuildAuditEntry + for _, entry := range m.entries { + if filter.ProjectID != "" && entry.ProjectID != filter.ProjectID { + continue + } + result = append(result, entry) + if filter.Limit > 0 && len(result) >= filter.Limit { + break + } + } + return result, nil +} + +func TestBuildsHandler_StartBuild(t *testing.T) { + queue := newMockWorkQueue() + audit := newMockBuildAudit() + buildService := service.NewBuildService(queue, audit, nil) + handler := NewBuildsHandler(buildService) + + router := chi.NewRouter() + router.Use(testAdminAuth) + handler.Mount(router) + + tests := []struct { + name string + projectID string + body StartBuildRequest + wantStatus int + }{ + { + name: "valid_build", + projectID: "my-project", + body: StartBuildRequest{ + Prompt: "Build a landing page with Next.js", + Template: "nextjs-landing", + AutoCommit: true, + AutoPush: true, + }, + wantStatus: http.StatusCreated, + }, + { + name: "missing_prompt", + projectID: "my-project", + body: StartBuildRequest{ + Template: "nextjs-landing", + }, + wantStatus: http.StatusBadRequest, + }, + { + name: "minimal_build", + projectID: "test-project", + body: StartBuildRequest{ + Prompt: "Add a footer component", + }, + wantStatus: http.StatusCreated, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + body, _ := json.Marshal(tt.body) + req := httptest.NewRequest(http.MethodPost, "/projects/"+tt.projectID+"/builds", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != tt.wantStatus { + t.Errorf("got status %d, want %d; body: %s", rec.Code, tt.wantStatus, rec.Body.String()) + } + + if tt.wantStatus == http.StatusCreated { + var resp map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + data, ok := resp["data"].(map[string]any) + if !ok { + t.Fatalf("expected data to be map, got %T", resp["data"]) + } + if data["task_id"] == nil || data["task_id"] == "" { + t.Error("expected task_id in response") + } + if data["project_id"] != tt.projectID { + t.Errorf("got project_id=%v, want %s", data["project_id"], tt.projectID) + } + if data["status"] != "pending" { + t.Errorf("got status=%v, want pending", data["status"]) + } + } + }) + } +} + +func TestBuildsHandler_GetBuild(t *testing.T) { + queue := newMockWorkQueue() + audit := newMockBuildAudit() + buildService := service.NewBuildService(queue, audit, nil) + handler := NewBuildsHandler(buildService) + + // Pre-populate an audit entry + audit.entries["task-1"] = &domain.BuildAuditEntry{ + TaskID: "task-1", + ProjectID: "my-project", + WorkerID: "worker-1", + Spec: domain.BuildSpec{ + Prompt: "Build landing page", + Template: "nextjs-landing", + }, + Status: domain.BuildStatusCompleted, + StartedAt: time.Now().Add(-5 * time.Minute), + Result: &domain.BuildResult{ + Success: true, + CommitSHA: "abc123", + DurationMs: 30000, + }, + } + + router := chi.NewRouter() + router.Use(testAdminAuth) + handler.Mount(router) + + tests := []struct { + name string + taskID string + wantStatus int + }{ + { + name: "existing_build", + taskID: "task-1", + wantStatus: http.StatusOK, + }, + { + name: "not_found", + taskID: "nonexistent", + wantStatus: http.StatusNotFound, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/builds/"+tt.taskID, nil) + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != tt.wantStatus { + t.Errorf("got status %d, want %d; body: %s", rec.Code, tt.wantStatus, rec.Body.String()) + } + + if tt.wantStatus == http.StatusOK { + var resp map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + data, ok := resp["data"].(map[string]any) + if !ok { + t.Fatalf("expected data to be map, got %T", resp["data"]) + } + if data["task_id"] != "task-1" { + t.Errorf("got task_id=%v, want task-1", data["task_id"]) + } + if data["status"] != "completed" { + t.Errorf("got status=%v, want completed", data["status"]) + } + } + }) + } +} + +func TestBuildsHandler_ListBuilds(t *testing.T) { + queue := newMockWorkQueue() + audit := newMockBuildAudit() + buildService := service.NewBuildService(queue, audit, nil) + handler := NewBuildsHandler(buildService) + + // Pre-populate audit entries + audit.entries["task-1"] = &domain.BuildAuditEntry{ + TaskID: "task-1", + ProjectID: "project-a", + Status: domain.BuildStatusCompleted, + Spec: domain.BuildSpec{Prompt: "Build page"}, + StartedAt: time.Now().Add(-10 * time.Minute), + } + audit.entries["task-2"] = &domain.BuildAuditEntry{ + TaskID: "task-2", + ProjectID: "project-a", + Status: domain.BuildStatusRunning, + Spec: domain.BuildSpec{Prompt: "Add footer"}, + StartedAt: time.Now().Add(-5 * time.Minute), + } + audit.entries["task-3"] = &domain.BuildAuditEntry{ + TaskID: "task-3", + ProjectID: "project-b", + Status: domain.BuildStatusPending, + Spec: domain.BuildSpec{Prompt: "Other project"}, + StartedAt: time.Now(), + } + + router := chi.NewRouter() + router.Use(testAdminAuth) + handler.Mount(router) + + t.Run("list_builds_for_project", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/projects/project-a/builds", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("got status %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String()) + } + + var resp map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + data, ok := resp["data"].(map[string]any) + if !ok { + t.Fatalf("expected data to be map, got %T", resp["data"]) + } + totalF, ok := data["total"].(float64) + if !ok { + t.Fatalf("expected total to be float64, got %T", data["total"]) + } + if int(totalF) != 2 { + t.Errorf("got total=%d, want 2", int(totalF)) + } + if data["project_id"] != "project-a" { + t.Errorf("got project_id=%v, want project-a", data["project_id"]) + } + }) + + t.Run("list_with_limit", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/projects/project-a/builds?limit=1", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("got status %d, want %d", rec.Code, http.StatusOK) + } + + var resp map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + data, ok := resp["data"].(map[string]any) + if !ok { + t.Fatalf("expected data to be map, got %T", resp["data"]) + } + totalF, ok := data["total"].(float64) + if !ok { + t.Fatalf("expected total to be float64, got %T", data["total"]) + } + if int(totalF) != 1 { + t.Errorf("got total=%d, want 1 (limited)", int(totalF)) + } + }) + + t.Run("invalid_limit", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/projects/project-a/builds?limit=abc", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Errorf("got status %d, want %d", rec.Code, http.StatusBadRequest) + } + }) +} diff --git a/internal/handlers/create_and_build.go b/internal/handlers/create_and_build.go new file mode 100644 index 0000000..10e3da4 --- /dev/null +++ b/internal/handlers/create_and_build.go @@ -0,0 +1,184 @@ +// Package handlers provides HTTP handlers for the rdev API. +package handlers + +import ( + "context" + "encoding/json" + "errors" + "log/slog" + "net/http" + "time" + + "github.com/orchard9/rdev/internal/domain" + "github.com/orchard9/rdev/internal/service" + "github.com/orchard9/rdev/pkg/api" +) + +// CreateAndBuildHandler handles the combined create-project-and-build endpoint. +type CreateAndBuildHandler struct { + infraService *service.ProjectInfraService + buildService *service.BuildService + logger *slog.Logger +} + +// NewCreateAndBuildHandler creates a new create-and-build handler. +func NewCreateAndBuildHandler( + infraService *service.ProjectInfraService, + buildService *service.BuildService, + logger *slog.Logger, +) *CreateAndBuildHandler { + if logger == nil { + logger = slog.Default() + } + return &CreateAndBuildHandler{ + infraService: infraService, + buildService: buildService, + logger: logger, + } +} + +// Mount registers the create-and-build route. +func (h *CreateAndBuildHandler) Mount(r api.Router) { + r.Post("/project/create-and-build", h.CreateAndBuild) +} + +// CreateAndBuildRequest is the request body for POST /project/create-and-build. +type CreateAndBuildRequest struct { + // Project creation fields + Name string `json:"name"` + Description string `json:"description,omitempty"` + Private bool `json:"private,omitempty"` + Template string `json:"template,omitempty"` + + // Build fields + Prompt string `json:"prompt"` + Variables map[string]string `json:"variables,omitempty"` + AutoCommit bool `json:"auto_commit"` + AutoPush bool `json:"auto_push"` + CallbackURL string `json:"callback_url,omitempty"` +} + +// CreateAndBuildResponse is the response for POST /project/create-and-build. +type CreateAndBuildResponse struct { + // Project info + ProjectID string `json:"project_id"` + Name string `json:"name"` + Domain string `json:"domain"` + URL string `json:"url"` + + // Git info + Git map[string]string `json:"git,omitempty"` + + // Build info + TaskID string `json:"task_id"` + Status string `json:"status"` + StatusURL string `json:"status_url"` +} + +// CreateAndBuild creates a project and immediately enqueues a build task. +// POST /project/create-and-build +func (h *CreateAndBuildHandler) CreateAndBuild(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), 60*time.Second) + defer cancel() + + if h.infraService == nil { + api.WriteInternalError(w, r, "project infrastructure service not configured") + return + } + if h.buildService == nil { + api.WriteInternalError(w, r, "build service not configured") + return + } + + r.Body = http.MaxBytesReader(w, r.Body, maxRequestBodySize) + var req CreateAndBuildRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + api.WriteBadRequest(w, r, "invalid request body") + return + } + + if req.Name == "" { + api.WriteBadRequest(w, r, "name is required") + return + } + if req.Prompt == "" { + api.WriteBadRequest(w, r, "prompt is required") + return + } + + // Validate callback URL to prevent SSRF + if req.CallbackURL != "" { + if err := domain.ValidateCallbackURL(req.CallbackURL); err != nil { + api.WriteBadRequest(w, r, err.Error()) + return + } + } + + // Step 1: Create the project + projectResult, err := h.infraService.CreateProject(ctx, service.CreateProjectRequest{ + Name: req.Name, + Description: req.Description, + Private: req.Private, + Template: req.Template, + }) + if err != nil { + if errors.Is(err, domain.ErrInvalidProjectName) { + api.WriteBadRequest(w, r, err.Error()) + return + } + h.logger.Error("project creation failed", "error", err, "name", req.Name) + api.WriteInternalError(w, r, "failed to create project") + return + } + + // Step 2: Enqueue the build task + spec := domain.BuildSpec{ + Prompt: req.Prompt, + Template: req.Template, + Variables: req.Variables, + AutoCommit: req.AutoCommit, + AutoPush: req.AutoPush, + CallbackURL: req.CallbackURL, + } + + taskID, err := h.buildService.StartBuild(ctx, projectResult.ProjectID, spec) + if err != nil { + h.logger.Error("build enqueue failed after project creation", + "error", err, + "project", projectResult.ProjectID, + ) + // Project was created but build failed to enqueue. + // Return the project info with a generic error and retry URL. + api.WriteJSON(w, r, http.StatusCreated, map[string]any{ + "project_id": projectResult.ProjectID, + "name": projectResult.Name, + "domain": projectResult.Domain, + "url": projectResult.URL, + "build_error": "project created but build failed to enqueue", + "retry_url": "/projects/" + projectResult.ProjectID + "/builds", + }) + return + } + + resp := CreateAndBuildResponse{ + ProjectID: projectResult.ProjectID, + Name: projectResult.Name, + Domain: projectResult.Domain, + URL: projectResult.URL, + TaskID: taskID, + Status: "pending", + StatusURL: "/builds/" + taskID, + } + + if projectResult.CloneHTTP != "" { + resp.Git = map[string]string{ + "owner": projectResult.GitRepoOwner, + "name": projectResult.GitRepoName, + "clone_ssh": projectResult.CloneSSH, + "clone_http": projectResult.CloneHTTP, + "html_url": projectResult.HTMLURL, + } + } + + api.WriteCreated(w, r, resp) +} diff --git a/internal/handlers/credentials.go b/internal/handlers/credentials.go index 7e7c01f..d6f46aa 100644 --- a/internal/handlers/credentials.go +++ b/internal/handlers/credentials.go @@ -2,10 +2,13 @@ package handlers import ( + "context" "encoding/json" + "errors" "net/http" "github.com/go-chi/chi/v5" + "github.com/orchard9/rdev/internal/auth" "github.com/orchard9/rdev/internal/domain" "github.com/orchard9/rdev/internal/port" "github.com/orchard9/rdev/pkg/api" @@ -22,6 +25,16 @@ func NewCredentialsHandler(store port.CredentialStore) *CredentialsHandler { return &CredentialsHandler{store: store} } +// updatedBy extracts the authenticated identity from the request context. +// Returns the API key name for regular keys, or "superadmin" for admin key auth +// (which has ID "admin") to preserve consistency with existing database records. +func updatedBy(ctx context.Context) string { + if key := auth.GetAPIKey(ctx); key != nil && key.ID != "admin" { + return key.Name + } + return "superadmin" +} + // Mount registers the credentials routes. // All routes require superadmin authentication (handled by middleware). func (h *CredentialsHandler) Mount(r api.Router) { @@ -146,7 +159,7 @@ func (h *CredentialsHandler) Set(w http.ResponseWriter, r *http.Request) { Value: req.Value, Description: req.Description, Category: req.Category, - UpdatedBy: "superadmin", // Could extract from auth context + UpdatedBy: updatedBy(ctx), } if err := h.store.Set(ctx, cred); err != nil { @@ -191,7 +204,7 @@ func (h *CredentialsHandler) SetBatch(w http.ResponseWriter, r *http.Request) { Value: c.Value, Description: c.Description, Category: c.Category, - UpdatedBy: "superadmin", + UpdatedBy: updatedBy(ctx), } } @@ -224,7 +237,11 @@ func (h *CredentialsHandler) Delete(w http.ResponseWriter, r *http.Request) { } if err := h.store.Delete(ctx, key); err != nil { - api.WriteNotFound(w, r, "credential not found") + if errors.Is(err, domain.ErrCredentialNotFound) { + api.WriteNotFound(w, r, "credential not found") + return + } + api.WriteInternalError(w, r, "failed to delete credential") return } diff --git a/internal/handlers/credentials_test.go b/internal/handlers/credentials_test.go new file mode 100644 index 0000000..42215a0 --- /dev/null +++ b/internal/handlers/credentials_test.go @@ -0,0 +1,440 @@ +package handlers + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-chi/chi/v5" + "github.com/orchard9/rdev/internal/domain" +) + +// mockCredentialStore implements port.CredentialStore for testing. +type mockCredentialStore struct { + creds map[string]domain.Credential + err error +} + +func newMockCredentialStore() *mockCredentialStore { + return &mockCredentialStore{ + creds: make(map[string]domain.Credential), + } +} + +func (m *mockCredentialStore) Get(_ context.Context, key string) (string, error) { + if m.err != nil { + return "", m.err + } + c, ok := m.creds[key] + if !ok { + return "", nil + } + return c.Value, nil +} + +func (m *mockCredentialStore) GetRequired(_ context.Context, key string) (string, error) { + if m.err != nil { + return "", m.err + } + c, ok := m.creds[key] + if !ok { + return "", fmt.Errorf("credential not found: %s", key) + } + return c.Value, nil +} + +func (m *mockCredentialStore) Set(_ context.Context, cred domain.Credential) error { + if m.err != nil { + return m.err + } + cred.CreatedAt = time.Now() + cred.UpdatedAt = time.Now() + m.creds[cred.Key] = cred + return nil +} + +func (m *mockCredentialStore) Delete(_ context.Context, key string) error { + if m.err != nil { + return m.err + } + if _, ok := m.creds[key]; !ok { + return domain.ErrCredentialNotFound + } + delete(m.creds, key) + return nil +} + +func (m *mockCredentialStore) List(_ context.Context) ([]domain.Credential, error) { + if m.err != nil { + return nil, m.err + } + var result []domain.Credential + for _, c := range m.creds { + result = append(result, c) + } + return result, nil +} + +func (m *mockCredentialStore) ListByCategory(_ context.Context, category string) ([]domain.Credential, error) { + if m.err != nil { + return nil, m.err + } + var result []domain.Credential + for _, c := range m.creds { + if c.Category == category { + result = append(result, c) + } + } + return result, nil +} + +func (m *mockCredentialStore) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + if m.err != nil { + return nil, m.err + } + result := make(map[string]string) + for _, k := range keys { + if c, ok := m.creds[k]; ok { + result[k] = c.Value + } + } + return result, nil +} + +func (m *mockCredentialStore) SetMultiple(_ context.Context, creds []domain.Credential) error { + if m.err != nil { + return m.err + } + for _, c := range creds { + c.CreatedAt = time.Now() + c.UpdatedAt = time.Now() + m.creds[c.Key] = c + } + return nil +} + +func setupCredentialsHandler() (*CredentialsHandler, *mockCredentialStore, chi.Router) { + store := newMockCredentialStore() + h := NewCredentialsHandler(store) + r := chi.NewRouter() + h.Mount(r) + return h, store, r +} + +func TestCredentialsHandler_List(t *testing.T) { + t.Run("empty list", func(t *testing.T) { + _, _, router := setupCredentialsHandler() + req := httptest.NewRequest("GET", "/credentials", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("status = %d, want %d", rec.Code, http.StatusOK) + } + + var resp map[string]any + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + data, ok := resp["data"].([]any) + if !ok { + t.Fatalf("response missing data array") + } + if len(data) != 0 { + t.Errorf("data length = %d, want 0", len(data)) + } + }) + + t.Run("with credentials", func(t *testing.T) { + _, store, router := setupCredentialsHandler() + store.creds["MY_TOKEN"] = domain.Credential{ + Key: "MY_TOKEN", + Value: "****", + Category: "gitea", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + req := httptest.NewRequest("GET", "/credentials", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("status = %d, want %d", rec.Code, http.StatusOK) + } + + var resp map[string]any + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + data, ok := resp["data"].([]any) + if !ok { + t.Fatalf("response missing data array") + } + if len(data) != 1 { + t.Errorf("data length = %d, want 1", len(data)) + } + }) + + t.Run("filter by category", func(t *testing.T) { + _, store, router := setupCredentialsHandler() + store.creds["GITEA_TOKEN"] = domain.Credential{ + Key: "GITEA_TOKEN", Value: "****", Category: "gitea", + CreatedAt: time.Now(), UpdatedAt: time.Now(), + } + store.creds["CF_TOKEN"] = domain.Credential{ + Key: "CF_TOKEN", Value: "****", Category: "cloudflare", + CreatedAt: time.Now(), UpdatedAt: time.Now(), + } + + req := httptest.NewRequest("GET", "/credentials?category=gitea", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("status = %d, want %d", rec.Code, http.StatusOK) + } + + var resp map[string]any + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + data := resp["data"].([]any) + if len(data) != 1 { + t.Errorf("data length = %d, want 1", len(data)) + } + }) +} + +func TestCredentialsHandler_Get(t *testing.T) { + t.Run("existing credential", func(t *testing.T) { + _, store, router := setupCredentialsHandler() + store.creds["MY_TOKEN"] = domain.Credential{ + Key: "MY_TOKEN", Value: "secret123", + CreatedAt: time.Now(), UpdatedAt: time.Now(), + } + + req := httptest.NewRequest("GET", "/credentials/MY_TOKEN", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("status = %d, want %d", rec.Code, http.StatusOK) + } + + var resp map[string]any + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + data := resp["data"].(map[string]any) + if data["key"] != "MY_TOKEN" { + t.Errorf("key = %q, want %q", data["key"], "MY_TOKEN") + } + if data["value"] != "secret123" { + t.Errorf("value = %q, want %q", data["value"], "secret123") + } + }) + + t.Run("not found", func(t *testing.T) { + _, _, router := setupCredentialsHandler() + + req := httptest.NewRequest("GET", "/credentials/MISSING", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusNotFound { + t.Errorf("status = %d, want %d", rec.Code, http.StatusNotFound) + } + }) +} + +func TestCredentialsHandler_Set(t *testing.T) { + t.Run("valid credential", func(t *testing.T) { + _, store, router := setupCredentialsHandler() + + body, _ := json.Marshal(SetCredentialRequest{ + Key: "NEW_TOKEN", + Value: "secret", + Description: "A test token", + Category: "gitea", + }) + req := httptest.NewRequest("POST", "/credentials", bytes.NewReader(body)) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusCreated { + t.Errorf("status = %d, want %d", rec.Code, http.StatusCreated) + } + + if _, ok := store.creds["NEW_TOKEN"]; !ok { + t.Error("credential not stored") + } + }) + + t.Run("missing key", func(t *testing.T) { + _, _, router := setupCredentialsHandler() + + body, _ := json.Marshal(SetCredentialRequest{Value: "secret"}) + req := httptest.NewRequest("POST", "/credentials", bytes.NewReader(body)) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } + }) + + t.Run("missing value", func(t *testing.T) { + _, _, router := setupCredentialsHandler() + + body, _ := json.Marshal(SetCredentialRequest{Key: "TOKEN"}) + req := httptest.NewRequest("POST", "/credentials", bytes.NewReader(body)) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } + }) + + t.Run("invalid json", func(t *testing.T) { + _, _, router := setupCredentialsHandler() + + req := httptest.NewRequest("POST", "/credentials", bytes.NewReader([]byte("not json"))) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } + }) +} + +func TestCredentialsHandler_SetBatch(t *testing.T) { + t.Run("valid batch", func(t *testing.T) { + _, store, router := setupCredentialsHandler() + + body, _ := json.Marshal(SetBatchRequest{ + Credentials: []SetCredentialRequest{ + {Key: "TOKEN1", Value: "val1"}, + {Key: "TOKEN2", Value: "val2"}, + }, + }) + req := httptest.NewRequest("POST", "/credentials/batch", bytes.NewReader(body)) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusCreated { + t.Errorf("status = %d, want %d", rec.Code, http.StatusCreated) + } + + if len(store.creds) != 2 { + t.Errorf("stored credentials = %d, want 2", len(store.creds)) + } + }) + + t.Run("empty array", func(t *testing.T) { + _, _, router := setupCredentialsHandler() + + body, _ := json.Marshal(SetBatchRequest{Credentials: []SetCredentialRequest{}}) + req := httptest.NewRequest("POST", "/credentials/batch", bytes.NewReader(body)) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } + }) + + t.Run("missing key in batch", func(t *testing.T) { + _, _, router := setupCredentialsHandler() + + body, _ := json.Marshal(SetBatchRequest{ + Credentials: []SetCredentialRequest{ + {Key: "TOKEN1", Value: "val1"}, + {Key: "", Value: "val2"}, + }, + }) + req := httptest.NewRequest("POST", "/credentials/batch", bytes.NewReader(body)) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } + }) + + t.Run("missing value in batch", func(t *testing.T) { + _, _, router := setupCredentialsHandler() + + body, _ := json.Marshal(SetBatchRequest{ + Credentials: []SetCredentialRequest{ + {Key: "TOKEN1", Value: ""}, + }, + }) + req := httptest.NewRequest("POST", "/credentials/batch", bytes.NewReader(body)) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } + }) +} + +func TestCredentialsHandler_Delete(t *testing.T) { + t.Run("existing credential", func(t *testing.T) { + _, store, router := setupCredentialsHandler() + store.creds["TO_DELETE"] = domain.Credential{ + Key: "TO_DELETE", Value: "val", + CreatedAt: time.Now(), UpdatedAt: time.Now(), + } + + req := httptest.NewRequest("DELETE", "/credentials/TO_DELETE", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("status = %d, want %d", rec.Code, http.StatusOK) + } + + var resp map[string]any + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + data := resp["data"].(map[string]any) + if data["status"] != "deleted" { + t.Errorf("status = %q, want %q", data["status"], "deleted") + } + }) + + t.Run("not found returns 404", func(t *testing.T) { + _, _, router := setupCredentialsHandler() + + req := httptest.NewRequest("DELETE", "/credentials/NONEXISTENT", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusNotFound { + t.Errorf("status = %d, want %d", rec.Code, http.StatusNotFound) + } + }) + + t.Run("store error returns 500", func(t *testing.T) { + _, store, router := setupCredentialsHandler() + store.err = fmt.Errorf("database connection lost") + + req := httptest.NewRequest("DELETE", "/credentials/ANY_KEY", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusInternalServerError { + t.Errorf("status = %d, want %d", rec.Code, http.StatusInternalServerError) + } + }) +} diff --git a/internal/handlers/health.go b/internal/handlers/health.go index def81ff..4160462 100644 --- a/internal/handlers/health.go +++ b/internal/handlers/health.go @@ -3,31 +3,51 @@ package handlers import ( "context" - "database/sql" + "fmt" "net/http" "strings" "time" + "github.com/orchard9/rdev/internal/port" "github.com/orchard9/rdev/pkg/api" - k8sclient "k8s.io/client-go/kubernetes" ) +// ExecutorHealthChecker reports whether a background executor is running. +type ExecutorHealthChecker interface { + Running() bool + WorkerID() string +} + // HealthHandler handles health and readiness checks. type HealthHandler struct { - serviceName string - db *sql.DB - k8sClient *k8sclient.Clientset + serviceName string + db port.DatabasePinger + k8sChecker port.KubernetesChecker + agentRegistry port.CodeAgentRegistry + workExecutor ExecutorHealthChecker } // NewHealthHandler creates a new health handler with dependencies. -func NewHealthHandler(serviceName string, db *sql.DB, k8sClient *k8sclient.Clientset) *HealthHandler { +func NewHealthHandler(serviceName string, db port.DatabasePinger, k8sChecker port.KubernetesChecker) *HealthHandler { return &HealthHandler{ serviceName: serviceName, db: db, - k8sClient: k8sClient, + k8sChecker: k8sChecker, } } +// WithAgentRegistry adds a code agent registry for health monitoring. +func (h *HealthHandler) WithAgentRegistry(registry port.CodeAgentRegistry) *HealthHandler { + h.agentRegistry = registry + return h +} + +// WithWorkExecutor adds a work executor for health monitoring. +func (h *HealthHandler) WithWorkExecutor(executor ExecutorHealthChecker) *HealthHandler { + h.workExecutor = executor + return h +} + // Health returns a simple liveness check. // This should be lightweight and only fail if the process is unhealthy. // GET /health @@ -59,7 +79,7 @@ func (h *HealthHandler) Ready(w http.ResponseWriter, r *http.Request) { } // Kubernetes check - if h.k8sClient != nil { + if h.k8sChecker != nil { k8sCheck := h.checkKubernetes(ctx) checks["kubernetes"] = k8sCheck if !k8sCheck.Healthy { @@ -67,6 +87,19 @@ func (h *HealthHandler) Ready(w http.ResponseWriter, r *http.Request) { } } + // Code agent checks (informational - don't affect overall readiness) + if h.agentRegistry != nil { + agentChecks := h.checkCodeAgents(ctx) + for name, check := range agentChecks { + checks["agent:"+name] = check + } + } + + // Work executor check (informational) + if h.workExecutor != nil { + checks["work_executor"] = h.checkWorkExecutor() + } + response := ReadinessResponse{ Status: "ready", Service: h.serviceName, @@ -107,11 +140,11 @@ func (h *HealthHandler) checkDatabase(ctx context.Context) CheckResult { } // checkKubernetes performs a Kubernetes API health check. -func (h *HealthHandler) checkKubernetes(ctx context.Context) CheckResult { +func (h *HealthHandler) checkKubernetes(_ context.Context) CheckResult { start := time.Now() // Try to get server version - lightweight API call - _, err := h.k8sClient.Discovery().ServerVersion() + _, err := h.k8sChecker.ServerVersion() latency := time.Since(start) if err != nil { @@ -139,6 +172,51 @@ func (h *HealthHandler) checkKubernetes(ctx context.Context) CheckResult { } } +// checkCodeAgents performs health checks on all registered code agents. +func (h *HealthHandler) checkCodeAgents(ctx context.Context) map[string]CheckResult { + results := make(map[string]CheckResult) + + providers := h.agentRegistry.Available() + for _, provider := range providers { + agent := h.agentRegistry.Get(provider) + if agent == nil { + continue + } + + start := time.Now() + available := agent.Available(ctx) + latency := time.Since(start) + + msg := "available" + if !available { + msg = "unavailable" + } + + results[string(provider)] = CheckResult{ + Healthy: available, + Message: fmt.Sprintf("%s (%s)", msg, agent.Name()), + Latency: latency.String(), + LastCheck: time.Now().UTC(), + } + } + + return results +} + +// checkWorkExecutor checks whether the work executor is running. +func (h *HealthHandler) checkWorkExecutor() CheckResult { + running := h.workExecutor.Running() + msg := fmt.Sprintf("worker %s: running", h.workExecutor.WorkerID()) + if !running { + msg = fmt.Sprintf("worker %s: stopped", h.workExecutor.WorkerID()) + } + return CheckResult{ + Healthy: running, + Message: msg, + LastCheck: time.Now().UTC(), + } +} + // CheckResult represents the result of a health check. type CheckResult struct { Healthy bool `json:"healthy"` diff --git a/internal/handlers/infrastructure.go b/internal/handlers/infrastructure.go index 5370aa2..d80f2d2 100644 --- a/internal/handlers/infrastructure.go +++ b/internal/handlers/infrastructure.go @@ -4,10 +4,8 @@ package handlers import ( "context" "encoding/json" - "errors" "fmt" "net/http" - "regexp" "time" "github.com/go-chi/chi/v5" @@ -16,12 +14,13 @@ import ( "github.com/orchard9/rdev/pkg/api" ) -// InfrastructureHandler handles git, deployment, and DNS endpoints. +// InfrastructureHandler handles git, deployment, DNS, and CI pipeline endpoints. type InfrastructureHandler struct { - gitRepo port.GitRepository - dns port.DNSProvider - deployer port.Deployer - projects port.ProjectRepository + gitRepo port.GitRepository + dns port.DNSProvider + deployer port.Deployer + projects port.ProjectRepository + ciProvider port.CIProvider // Config defaultGitOwner string @@ -29,21 +28,9 @@ type InfrastructureHandler struct { clusterIP string } -// projectIDRegex validates project IDs (alphanumeric, dash, underscore only). -var projectIDRegex = regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9_-]*$`) - // validateProjectID validates that a project ID is safe for use as repo/deployment name. func validateProjectID(id string) error { - if id == "" { - return errors.New("project ID cannot be empty") - } - if len(id) > 63 { // K8s name limit - return errors.New("project ID too long (max 63 characters)") - } - if !projectIDRegex.MatchString(id) { - return errors.New("project ID must start with a letter and contain only alphanumeric characters, dashes, or underscores") - } - return nil + return domain.ValidateProjectID(id) } // InfrastructureConfig configures the infrastructure handler. @@ -62,6 +49,7 @@ func NewInfrastructureHandler( dns port.DNSProvider, deployer port.Deployer, projects port.ProjectRepository, + ciProvider port.CIProvider, cfg InfrastructureConfig, ) *InfrastructureHandler { return &InfrastructureHandler{ @@ -69,6 +57,7 @@ func NewInfrastructureHandler( dns: dns, deployer: deployer, projects: projects, + ciProvider: ciProvider, defaultGitOwner: cfg.DefaultGitOwner, defaultDomain: cfg.DefaultDomain, clusterIP: cfg.ClusterIP, @@ -90,9 +79,18 @@ func (h *InfrastructureHandler) Mount(r api.Router) { r.Post("/projects/{id}/deploy/scale", h.ScaleDeploy) r.Get("/projects/{id}/deploy/logs", h.GetDeployLogs) - // Domain endpoints + // Domain endpoints (single) r.Post("/projects/{id}/domain", h.AddDomain) r.Delete("/projects/{id}/domain", h.RemoveDomain) + + // Domain alias management (multi-domain) + r.Get("/projects/{id}/domains", h.ListDomains) + r.Post("/projects/{id}/domains", h.AddDomainAlias) + r.Delete("/projects/{id}/domains/{domain}", h.RemoveDomainAlias) + + // CI pipeline endpoints + r.Get("/projects/{id}/pipelines", h.ListPipelines) + r.Get("/projects/{id}/pipelines/{number}", h.GetPipeline) } // CreateRepoRequest is the request body for POST /projects/{id}/repo. diff --git a/internal/handlers/infrastructure_deploy.go b/internal/handlers/infrastructure_deploy.go index e703a6b..66c2c4d 100644 --- a/internal/handlers/infrastructure_deploy.go +++ b/internal/handlers/infrastructure_deploy.go @@ -12,6 +12,9 @@ import ( "github.com/orchard9/rdev/pkg/api" ) +// maxReplicas is the maximum number of deployment replicas allowed. +const maxReplicas = 10 + // DeployRequest is the request body for POST /projects/{id}/deploy. type DeployRequest struct { Image string `json:"image"` // Container image @@ -218,8 +221,8 @@ func (h *InfrastructureHandler) ScaleDeploy(w http.ResponseWriter, r *http.Reque return } - if req.Replicas < 0 || req.Replicas > 10 { - api.WriteBadRequest(w, r, "replicas must be between 0 and 10") + if req.Replicas < 0 || req.Replicas > maxReplicas { + api.WriteBadRequest(w, r, fmt.Sprintf("replicas must be between 0 and %d", maxReplicas)) return } diff --git a/internal/handlers/infrastructure_domain_test.go b/internal/handlers/infrastructure_domain_test.go new file mode 100644 index 0000000..3761342 --- /dev/null +++ b/internal/handlers/infrastructure_domain_test.go @@ -0,0 +1,195 @@ +package handlers + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/orchard9/rdev/internal/domain" +) + +func TestInfrastructureHandler_RestartDeploy(t *testing.T) { + t.Run("success", func(t *testing.T) { + _, _, _, _, router := setupInfraHandler() + + req := httptest.NewRequest("POST", "/projects/myapp/deploy/restart", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("status = %d, want %d", rec.Code, http.StatusOK) + } + }) +} + +func TestInfrastructureHandler_ScaleDeploy(t *testing.T) { + t.Run("valid scale", func(t *testing.T) { + _, _, _, _, router := setupInfraHandler() + + body, _ := json.Marshal(ScaleRequest{Replicas: 3}) + req := httptest.NewRequest("POST", "/projects/myapp/deploy/scale", bytes.NewReader(body)) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("status = %d, want %d", rec.Code, http.StatusOK) + } + }) + + t.Run("invalid replicas too high", func(t *testing.T) { + _, _, _, _, router := setupInfraHandler() + + body, _ := json.Marshal(ScaleRequest{Replicas: 11}) + req := httptest.NewRequest("POST", "/projects/myapp/deploy/scale", bytes.NewReader(body)) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } + }) + + t.Run("invalid replicas negative", func(t *testing.T) { + _, _, _, _, router := setupInfraHandler() + + body, _ := json.Marshal(ScaleRequest{Replicas: -1}) + req := httptest.NewRequest("POST", "/projects/myapp/deploy/scale", bytes.NewReader(body)) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } + }) +} + +func TestInfrastructureHandler_GetDeployLogs(t *testing.T) { + t.Run("success", func(t *testing.T) { + _, _, _, deployer, router := setupInfraHandler() + deployer.logs = "line1\nline2\nline3\n" + + req := httptest.NewRequest("GET", "/projects/myapp/deploy/logs", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("status = %d, want %d", rec.Code, http.StatusOK) + } + }) +} + +func TestInfrastructureHandler_AddDomain(t *testing.T) { + t.Run("subdomain", func(t *testing.T) { + _, _, dns, _, router := setupInfraHandler() + + body, _ := json.Marshal(AddDomainRequest{Domain: "myapp.threesix.ai"}) + req := httptest.NewRequest("POST", "/projects/myapp/domain", bytes.NewReader(body)) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusCreated { + t.Errorf("status = %d, want %d", rec.Code, http.StatusCreated) + } + if len(dns.records) != 1 { + t.Errorf("DNS records = %d, want 1", len(dns.records)) + } + }) + + t.Run("external domain", func(t *testing.T) { + _, _, dns, _, router := setupInfraHandler() + + body, _ := json.Marshal(AddDomainRequest{Domain: "myapp.example.com"}) + req := httptest.NewRequest("POST", "/projects/myapp/domain", bytes.NewReader(body)) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusCreated { + t.Errorf("status = %d, want %d", rec.Code, http.StatusCreated) + } + // External domain should NOT create DNS record + if len(dns.records) != 0 { + t.Errorf("DNS records = %d, want 0 (external domain)", len(dns.records)) + } + }) + + t.Run("missing domain", func(t *testing.T) { + _, _, _, _, router := setupInfraHandler() + + body, _ := json.Marshal(AddDomainRequest{}) + req := httptest.NewRequest("POST", "/projects/myapp/domain", bytes.NewReader(body)) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } + }) +} + +func TestInfrastructureHandler_RemoveDomain(t *testing.T) { + t.Run("subdomain", func(t *testing.T) { + _, _, dns, _, router := setupInfraHandler() + dns.records["myapp"] = &domain.DNSRecord{ID: "rec-myapp", Name: "myapp"} + + req := httptest.NewRequest("DELETE", "/projects/myapp/domain?domain=myapp.threesix.ai", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("status = %d, want %d", rec.Code, http.StatusOK) + } + }) + + t.Run("missing domain param", func(t *testing.T) { + _, _, _, _, router := setupInfraHandler() + + req := httptest.NewRequest("DELETE", "/projects/myapp/domain", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } + }) +} + +func TestIsSubdomain(t *testing.T) { + tests := []struct { + domain, base string + want bool + }{ + {"myapp.threesix.ai", "threesix.ai", true}, + {"deep.sub.threesix.ai", "threesix.ai", true}, + {"threesix.ai", "threesix.ai", false}, + {"myapp.example.com", "threesix.ai", false}, + {"", "threesix.ai", false}, + } + for _, tt := range tests { + t.Run(tt.domain, func(t *testing.T) { + got := isSubdomain(tt.domain, tt.base) + if got != tt.want { + t.Errorf("isSubdomain(%q, %q) = %v, want %v", tt.domain, tt.base, got, tt.want) + } + }) + } +} + +func TestGetSubdomain(t *testing.T) { + tests := []struct { + domain, base, want string + }{ + {"myapp.threesix.ai", "threesix.ai", "myapp"}, + {"deep.sub.threesix.ai", "threesix.ai", "deep.sub"}, + {"threesix.ai", "threesix.ai", "threesix.ai"}, + } + for _, tt := range tests { + t.Run(tt.domain, func(t *testing.T) { + got := getSubdomain(tt.domain, tt.base) + if got != tt.want { + t.Errorf("getSubdomain(%q, %q) = %q, want %q", tt.domain, tt.base, got, tt.want) + } + }) + } +} diff --git a/internal/handlers/infrastructure_domains.go b/internal/handlers/infrastructure_domains.go new file mode 100644 index 0000000..9dafab5 --- /dev/null +++ b/internal/handlers/infrastructure_domains.go @@ -0,0 +1,260 @@ +package handlers + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "time" + + "github.com/go-chi/chi/v5" + "github.com/orchard9/rdev/internal/domain" + "github.com/orchard9/rdev/pkg/api" +) + +// DomainAliasRequest is the request body for POST /projects/{id}/domains. +type DomainAliasRequest struct { + Domain string `json:"domain"` // The domain to add (e.g., "www.threesix.ai") + Type string `json:"type,omitempty"` // "A" or "CNAME" (default: "A") + Proxied bool `json:"proxied,omitempty"` // Cloudflare proxy (default: false) + Content string `json:"content,omitempty"` // Target (default: cluster IP for A records) +} + +// DomainAliasResponse is the response for domain alias operations. +type DomainAliasResponse struct { + Domain string `json:"domain"` + Type string `json:"type"` + Content string `json:"content"` + TTL int `json:"ttl"` + Proxied bool `json:"proxied"` +} + +// ListDomains returns all DNS records associated with a project. +// GET /projects/{id}/domains +func (h *InfrastructureHandler) ListDomains(w http.ResponseWriter, r *http.Request) { + projectID := chi.URLParam(r, "id") + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + + if err := validateProjectID(projectID); err != nil { + api.WriteBadRequest(w, r, err.Error()) + return + } + + if h.dns == nil { + api.WriteInternalError(w, r, "DNS provider not configured") + return + } + + // List all A records and find ones matching this project + aRecords, err := h.dns.ListRecords(ctx, "A") + if err != nil { + api.WriteInternalError(w, r, "failed to list DNS records") + return + } + + // Also list CNAME records + cnameRecords, err := h.dns.ListRecords(ctx, "CNAME") + if err != nil { + api.WriteInternalError(w, r, "failed to list DNS records") + return + } + + // Filter records that belong to this project: + // - Primary: {projectID}.{defaultDomain} + // - Aliases: any record pointing to the cluster IP or the project's primary domain + primaryDomain := projectID + "." + h.defaultDomain + var domains []DomainAliasResponse + + for _, rec := range aRecords { + name := rec.Name + // Normalize: if name matches the project's subdomain or points to our cluster IP + if name == primaryDomain || (rec.Content == h.clusterIP && isProjectDomain(name, projectID, h.defaultDomain)) { + domains = append(domains, DomainAliasResponse{ + Domain: name, + Type: rec.Type, + Content: rec.Content, + TTL: rec.TTL, + Proxied: rec.Proxied, + }) + } + } + + for _, rec := range cnameRecords { + // CNAME records pointing to the project's primary domain + if rec.Content == primaryDomain { + domains = append(domains, DomainAliasResponse{ + Domain: rec.Name, + Type: rec.Type, + Content: rec.Content, + TTL: rec.TTL, + Proxied: rec.Proxied, + }) + } + } + + api.WriteSuccess(w, r, map[string]any{ + "project_id": projectID, + "domains": domains, + "total": len(domains), + }) +} + +// AddDomainAlias adds a DNS alias for a project. +// POST /projects/{id}/domains +func (h *InfrastructureHandler) AddDomainAlias(w http.ResponseWriter, r *http.Request) { + projectID := chi.URLParam(r, "id") + ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second) + defer cancel() + + if err := validateProjectID(projectID); err != nil { + api.WriteBadRequest(w, r, err.Error()) + return + } + + if h.dns == nil { + api.WriteInternalError(w, r, "DNS provider not configured") + return + } + + var req DomainAliasRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + api.WriteBadRequest(w, r, "invalid request body") + return + } + + if req.Domain == "" { + api.WriteBadRequest(w, r, "domain is required") + return + } + + // Default record type is A + recordType := strings.ToUpper(req.Type) + if recordType == "" { + recordType = "A" + } + if recordType != "A" && recordType != "CNAME" { + api.WriteBadRequest(w, r, "type must be A or CNAME") + return + } + + // Determine content + content := req.Content + if content == "" { + switch recordType { + case "A": + if h.clusterIP == "" { + api.WriteBadRequest(w, r, "cluster IP not configured and no content provided") + return + } + content = h.clusterIP + case "CNAME": + // Default CNAME target is the project's primary domain + content = projectID + "." + h.defaultDomain + } + } + + // Determine the DNS name + // If domain is a full FQDN under our zone, extract the subdomain for the API call + dnsName := req.Domain + if isSubdomain(req.Domain, h.defaultDomain) { + dnsName = getSubdomain(req.Domain, h.defaultDomain) + } + + record, err := h.dns.CreateRecord(ctx, domain.DNSRecord{ + Type: recordType, + Name: dnsName, + Content: content, + TTL: 1, + Proxied: req.Proxied, + }) + if err != nil { + api.WriteInternalError(w, r, fmt.Sprintf("failed to create DNS record: %v", err)) + return + } + + note := "Domain alias configured" + if !isSubdomain(req.Domain, h.defaultDomain) && recordType == "A" { + note = fmt.Sprintf("External domain configured. Point your DNS to %s", h.clusterIP) + } + + api.WriteCreated(w, r, map[string]any{ + "project": projectID, + "domain": record.Name, + "type": record.Type, + "content": record.Content, + "status": "configured", + "note": note, + }) +} + +// RemoveDomainAlias removes a DNS alias from a project. +// DELETE /projects/{id}/domains/{domain} +func (h *InfrastructureHandler) RemoveDomainAlias(w http.ResponseWriter, r *http.Request) { + projectID := chi.URLParam(r, "id") + aliasDomain := chi.URLParam(r, "domain") + ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second) + defer cancel() + + if err := validateProjectID(projectID); err != nil { + api.WriteBadRequest(w, r, err.Error()) + return + } + + if aliasDomain == "" { + api.WriteBadRequest(w, r, "domain path parameter is required") + return + } + + if h.dns == nil { + api.WriteInternalError(w, r, "DNS provider not configured") + return + } + + // Prevent deleting the project's primary domain through this endpoint + primaryDomain := projectID + "." + h.defaultDomain + if aliasDomain == primaryDomain { + api.WriteBadRequest(w, r, "cannot remove primary project domain through alias endpoint; use DELETE /project/{name} instead") + return + } + + // Check if the record exists before attempting deletion + dnsName := aliasDomain + if isSubdomain(aliasDomain, h.defaultDomain) { + dnsName = getSubdomain(aliasDomain, h.defaultDomain) + } + + // Check both A and CNAME records + aRecord, _ := h.dns.FindRecord(ctx, "A", dnsName) + cnameRecord, _ := h.dns.FindRecord(ctx, "CNAME", dnsName) + + if aRecord == nil && cnameRecord == nil { + api.WriteNotFound(w, r, fmt.Sprintf("no DNS record found for %s", aliasDomain)) + return + } + + // Delete whichever record(s) exist + if aRecord != nil { + _ = h.dns.DeleteRecordByName(ctx, "A", dnsName) + } + if cnameRecord != nil { + _ = h.dns.DeleteRecordByName(ctx, "CNAME", dnsName) + } + + api.WriteSuccess(w, r, map[string]string{ + "project": projectID, + "domain": aliasDomain, + "status": "removed", + }) +} + +// isProjectDomain checks if a DNS name is associated with a project. +// It matches: {projectID}.{baseDomain} or any subdomain pattern containing the project ID. +func isProjectDomain(name, projectID, baseDomain string) bool { + // Exact match: landing.threesix.ai + if name == projectID+"."+baseDomain { + return true + } + return false +} diff --git a/internal/handlers/infrastructure_domains_test.go b/internal/handlers/infrastructure_domains_test.go new file mode 100644 index 0000000..cf856a8 --- /dev/null +++ b/internal/handlers/infrastructure_domains_test.go @@ -0,0 +1,246 @@ +package handlers + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/orchard9/rdev/internal/domain" +) + +func TestInfrastructureHandler_ListDomains(t *testing.T) { + t.Run("returns matching A records", func(t *testing.T) { + _, _, dns, _, router := setupInfraHandler() + + // Add records — one matching the project, one unrelated + dns.records["landing.threesix.ai"] = &domain.DNSRecord{ + ID: "rec-1", Type: "A", Name: "landing.threesix.ai", + Content: "208.122.204.172", TTL: 1, + } + dns.records["other.threesix.ai"] = &domain.DNSRecord{ + ID: "rec-2", Type: "A", Name: "other.threesix.ai", + Content: "208.122.204.172", TTL: 1, + } + + req := httptest.NewRequest("GET", "/projects/landing/domains", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String()) + } + + var resp map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal: %v", err) + } + data := resp["data"].(map[string]any) + total := int(data["total"].(float64)) + if total != 1 { + t.Errorf("total = %d, want 1 (only landing.threesix.ai)", total) + } + }) + + t.Run("returns CNAME aliases", func(t *testing.T) { + _, _, dns, _, router := setupInfraHandler() + + dns.records["landing.threesix.ai"] = &domain.DNSRecord{ + ID: "rec-1", Type: "A", Name: "landing.threesix.ai", + Content: "208.122.204.172", TTL: 1, + } + dns.records["www.threesix.ai"] = &domain.DNSRecord{ + ID: "rec-2", Type: "CNAME", Name: "www.threesix.ai", + Content: "landing.threesix.ai", TTL: 1, + } + + req := httptest.NewRequest("GET", "/projects/landing/domains", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + var resp map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal: %v", err) + } + data := resp["data"].(map[string]any) + total := int(data["total"].(float64)) + if total != 2 { + t.Errorf("total = %d, want 2 (A + CNAME)", total) + } + }) + + t.Run("DNS not configured", func(t *testing.T) { + h := NewInfrastructureHandler(nil, nil, nil, nil, nil, InfrastructureConfig{ + DefaultGitOwner: "threesix", + DefaultDomain: "threesix.ai", + ClusterIP: "208.122.204.172", + }) + r := chi.NewRouter() + h.Mount(r) + + req := httptest.NewRequest("GET", "/projects/myapp/domains", nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusInternalServerError { + t.Errorf("status = %d, want %d", rec.Code, http.StatusInternalServerError) + } + }) +} + +func TestInfrastructureHandler_AddDomainAlias(t *testing.T) { + t.Run("add A record alias", func(t *testing.T) { + _, _, dns, _, router := setupInfraHandler() + + body, _ := json.Marshal(DomainAliasRequest{Domain: "www.threesix.ai"}) + req := httptest.NewRequest("POST", "/projects/landing/domains", bytes.NewReader(body)) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusCreated { + t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusCreated, rec.Body.String()) + } + if len(dns.records) != 1 { + t.Errorf("DNS records = %d, want 1", len(dns.records)) + } + }) + + t.Run("add CNAME alias", func(t *testing.T) { + _, _, dns, _, router := setupInfraHandler() + + body, _ := json.Marshal(DomainAliasRequest{ + Domain: "www.threesix.ai", + Type: "CNAME", + }) + req := httptest.NewRequest("POST", "/projects/landing/domains", bytes.NewReader(body)) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusCreated { + t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusCreated, rec.Body.String()) + } + // CNAME should target landing.threesix.ai + for _, r := range dns.records { + if r.Type != "CNAME" { + t.Errorf("type = %s, want CNAME", r.Type) + } + if r.Content != "landing.threesix.ai" { + t.Errorf("content = %s, want landing.threesix.ai", r.Content) + } + } + }) + + t.Run("invalid type", func(t *testing.T) { + _, _, _, _, router := setupInfraHandler() + + body, _ := json.Marshal(DomainAliasRequest{Domain: "www.threesix.ai", Type: "MX"}) + req := httptest.NewRequest("POST", "/projects/landing/domains", bytes.NewReader(body)) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } + }) + + t.Run("missing domain", func(t *testing.T) { + _, _, _, _, router := setupInfraHandler() + + body, _ := json.Marshal(DomainAliasRequest{}) + req := httptest.NewRequest("POST", "/projects/landing/domains", bytes.NewReader(body)) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } + }) + + t.Run("DNS not configured", func(t *testing.T) { + h := NewInfrastructureHandler(nil, nil, nil, nil, nil, InfrastructureConfig{ + DefaultGitOwner: "threesix", + DefaultDomain: "threesix.ai", + ClusterIP: "208.122.204.172", + }) + r := chi.NewRouter() + h.Mount(r) + + body, _ := json.Marshal(DomainAliasRequest{Domain: "www.threesix.ai"}) + req := httptest.NewRequest("POST", "/projects/landing/domains", bytes.NewReader(body)) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusInternalServerError { + t.Errorf("status = %d, want %d", rec.Code, http.StatusInternalServerError) + } + }) +} + +func TestInfrastructureHandler_RemoveDomainAlias(t *testing.T) { + t.Run("removes alias", func(t *testing.T) { + _, _, dns, _, router := setupInfraHandler() + dns.records["www"] = &domain.DNSRecord{ + ID: "rec-www", Type: "A", Name: "www", + Content: "208.122.204.172", + } + + req := httptest.NewRequest("DELETE", "/projects/landing/domains/www.threesix.ai", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String()) + } + }) + + t.Run("prevents removing primary domain", func(t *testing.T) { + _, _, _, _, router := setupInfraHandler() + + req := httptest.NewRequest("DELETE", "/projects/landing/domains/landing.threesix.ai", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } + }) + + t.Run("not found", func(t *testing.T) { + _, _, dns, _, router := setupInfraHandler() + dns.err = nil // No records stored + + req := httptest.NewRequest("DELETE", "/projects/landing/domains/nonexistent.threesix.ai", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusNotFound { + t.Errorf("status = %d, want %d; body: %s", rec.Code, http.StatusNotFound, rec.Body.String()) + } + }) +} + +func TestIsProjectDomain(t *testing.T) { + tests := []struct { + name, projectID, baseDomain string + want bool + }{ + {"landing.threesix.ai", "landing", "threesix.ai", true}, + {"other.threesix.ai", "landing", "threesix.ai", false}, + {"landing.example.com", "landing", "threesix.ai", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isProjectDomain(tt.name, tt.projectID, tt.baseDomain) + if got != tt.want { + t.Errorf("isProjectDomain(%q, %q, %q) = %v, want %v", + tt.name, tt.projectID, tt.baseDomain, got, tt.want) + } + }) + } +} diff --git a/internal/handlers/infrastructure_pipelines.go b/internal/handlers/infrastructure_pipelines.go new file mode 100644 index 0000000..9e88d42 --- /dev/null +++ b/internal/handlers/infrastructure_pipelines.go @@ -0,0 +1,120 @@ +package handlers + +import ( + "context" + "fmt" + "net/http" + "strconv" + "time" + + "github.com/go-chi/chi/v5" + "github.com/orchard9/rdev/pkg/api" +) + +// PipelineResponse is the JSON representation of a CI pipeline. +type PipelineResponse struct { + ID int64 `json:"id"` + Number int64 `json:"number"` + Status string `json:"status"` + Event string `json:"event"` + Branch string `json:"branch"` + Commit string `json:"commit"` + Message string `json:"message"` + Author string `json:"author"` + Started string `json:"started,omitempty"` + Finished string `json:"finished,omitempty"` +} + +// ListPipelines returns recent CI pipeline executions for a project. +// GET /projects/{id}/pipelines +func (h *InfrastructureHandler) ListPipelines(w http.ResponseWriter, r *http.Request) { + projectID := chi.URLParam(r, "id") + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + + if err := validateProjectID(projectID); err != nil { + api.WriteBadRequest(w, r, err.Error()) + return + } + + if h.ciProvider == nil { + api.WriteInternalError(w, r, "CI provider not configured") + return + } + + pipelines, err := h.ciProvider.ListPipelines(ctx, h.defaultGitOwner, projectID) + if err != nil { + api.WriteNotFound(w, r, fmt.Sprintf("pipelines not found: %v", err)) + return + } + + resp := make([]PipelineResponse, len(pipelines)) + for i, p := range pipelines { + resp[i] = PipelineResponse{ + ID: p.ID, + Number: p.Number, + Status: p.Status, + Event: p.Event, + Branch: p.Branch, + Commit: p.Commit, + Message: p.Message, + Author: p.Author, + Started: formatTime(p.Started), + Finished: formatTime(p.Finished), + } + } + + api.WriteSuccess(w, r, resp) +} + +// GetPipeline returns a specific CI pipeline execution for a project. +// GET /projects/{id}/pipelines/{number} +func (h *InfrastructureHandler) GetPipeline(w http.ResponseWriter, r *http.Request) { + projectID := chi.URLParam(r, "id") + numberStr := chi.URLParam(r, "number") + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + + if err := validateProjectID(projectID); err != nil { + api.WriteBadRequest(w, r, err.Error()) + return + } + + number, err := strconv.ParseInt(numberStr, 10, 64) + if err != nil { + api.WriteBadRequest(w, r, "invalid pipeline number") + return + } + + if h.ciProvider == nil { + api.WriteInternalError(w, r, "CI provider not configured") + return + } + + p, err := h.ciProvider.GetPipeline(ctx, h.defaultGitOwner, projectID, number) + if err != nil { + api.WriteNotFound(w, r, fmt.Sprintf("pipeline not found: %v", err)) + return + } + + api.WriteSuccess(w, r, PipelineResponse{ + ID: p.ID, + Number: p.Number, + Status: p.Status, + Event: p.Event, + Branch: p.Branch, + Commit: p.Commit, + Message: p.Message, + Author: p.Author, + Started: formatTime(p.Started), + Finished: formatTime(p.Finished), + }) +} + +// formatTime formats a time.Time as RFC3339, returning empty string for zero time. +func formatTime(t time.Time) string { + if t.IsZero() { + return "" + } + return t.Format(time.RFC3339) +} diff --git a/internal/handlers/infrastructure_pipelines_test.go b/internal/handlers/infrastructure_pipelines_test.go new file mode 100644 index 0000000..cc2e5cb --- /dev/null +++ b/internal/handlers/infrastructure_pipelines_test.go @@ -0,0 +1,250 @@ +package handlers + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-chi/chi/v5" + "github.com/orchard9/rdev/internal/domain" + "github.com/orchard9/rdev/internal/port" +) + +// mockCIProvider implements port.CIProvider for testing. +type mockCIProvider struct { + pipelines map[string][]*domain.CIPipeline + err error +} + +func newMockCIProvider() *mockCIProvider { + return &mockCIProvider{pipelines: make(map[string][]*domain.CIPipeline)} +} + +func (m *mockCIProvider) ActivateRepo(context.Context, string, string, string) (*domain.CIRepo, error) { + return nil, m.err +} +func (m *mockCIProvider) DeactivateRepo(context.Context, string, string) error { + return m.err +} +func (m *mockCIProvider) GetRepo(context.Context, string, string) (*domain.CIRepo, error) { + return nil, m.err +} +func (m *mockCIProvider) ListRepos(context.Context) ([]*domain.CIRepo, error) { + return nil, m.err +} +func (m *mockCIProvider) AddSecret(context.Context, string, string, domain.CISecret) error { + return m.err +} +func (m *mockCIProvider) DeleteSecret(context.Context, string, string, string) error { + return m.err +} + +func (m *mockCIProvider) ListPipelines(_ context.Context, owner, repo string) ([]*domain.CIPipeline, error) { + if m.err != nil { + return nil, m.err + } + key := owner + "/" + repo + p, ok := m.pipelines[key] + if !ok { + return nil, fmt.Errorf("repo not found: %s", key) + } + return p, nil +} + +func (m *mockCIProvider) GetPipeline(_ context.Context, owner, repo string, number int64) (*domain.CIPipeline, error) { + if m.err != nil { + return nil, m.err + } + key := owner + "/" + repo + for _, p := range m.pipelines[key] { + if p.Number == number { + return p, nil + } + } + return nil, fmt.Errorf("pipeline %d not found", number) +} + +func setupInfraHandlerWithCI(ci port.CIProvider) chi.Router { + h := NewInfrastructureHandler(nil, nil, nil, nil, ci, InfrastructureConfig{ + DefaultGitOwner: "threesix", + DefaultDomain: "threesix.ai", + }) + r := chi.NewRouter() + h.Mount(r) + return r +} + +func TestInfrastructureHandler_ListPipelines(t *testing.T) { + t.Run("success", func(t *testing.T) { + ci := newMockCIProvider() + ci.pipelines["threesix/myapp"] = []*domain.CIPipeline{ + { + ID: 100, + Number: 1, + Status: "success", + Event: "push", + Branch: "main", + Commit: "abc123", + Message: "initial commit", + Author: "dev", + Started: time.Date(2025, 1, 15, 10, 0, 0, 0, time.UTC), + }, + { + ID: 101, + Number: 2, + Status: "running", + Event: "push", + Branch: "feature", + Commit: "def456", + Author: "dev", + }, + } + + router := setupInfraHandlerWithCI(ci) + req := httptest.NewRequest("GET", "/projects/myapp/pipelines", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("status = %d, want %d", rec.Code, http.StatusOK) + } + }) + + t.Run("ci not configured", func(t *testing.T) { + router := setupInfraHandlerWithCI(nil) + req := httptest.NewRequest("GET", "/projects/myapp/pipelines", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusInternalServerError { + t.Errorf("status = %d, want %d", rec.Code, http.StatusInternalServerError) + } + }) + + t.Run("repo not found", func(t *testing.T) { + ci := newMockCIProvider() + router := setupInfraHandlerWithCI(ci) + + req := httptest.NewRequest("GET", "/projects/missing/pipelines", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusNotFound { + t.Errorf("status = %d, want %d", rec.Code, http.StatusNotFound) + } + }) + + t.Run("invalid project id", func(t *testing.T) { + ci := newMockCIProvider() + router := setupInfraHandlerWithCI(ci) + + req := httptest.NewRequest("GET", "/projects/INVALID!/pipelines", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } + }) +} + +func TestInfrastructureHandler_GetPipeline(t *testing.T) { + t.Run("success", func(t *testing.T) { + ci := newMockCIProvider() + ci.pipelines["threesix/myapp"] = []*domain.CIPipeline{ + { + ID: 100, + Number: 5, + Status: "success", + Event: "push", + Branch: "main", + Commit: "abc123", + Message: "fix bug", + Author: "dev", + Started: time.Date(2025, 1, 15, 10, 0, 0, 0, time.UTC), + Finished: time.Date(2025, 1, 15, 10, 5, 0, 0, time.UTC), + }, + } + + router := setupInfraHandlerWithCI(ci) + req := httptest.NewRequest("GET", "/projects/myapp/pipelines/5", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("status = %d, want %d", rec.Code, http.StatusOK) + } + }) + + t.Run("ci not configured", func(t *testing.T) { + router := setupInfraHandlerWithCI(nil) + req := httptest.NewRequest("GET", "/projects/myapp/pipelines/1", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusInternalServerError { + t.Errorf("status = %d, want %d", rec.Code, http.StatusInternalServerError) + } + }) + + t.Run("pipeline not found", func(t *testing.T) { + ci := newMockCIProvider() + ci.pipelines["threesix/myapp"] = []*domain.CIPipeline{} + router := setupInfraHandlerWithCI(ci) + + req := httptest.NewRequest("GET", "/projects/myapp/pipelines/999", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusNotFound { + t.Errorf("status = %d, want %d", rec.Code, http.StatusNotFound) + } + }) + + t.Run("invalid pipeline number", func(t *testing.T) { + ci := newMockCIProvider() + router := setupInfraHandlerWithCI(ci) + + req := httptest.NewRequest("GET", "/projects/myapp/pipelines/notanumber", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } + }) + + t.Run("invalid project id", func(t *testing.T) { + ci := newMockCIProvider() + router := setupInfraHandlerWithCI(ci) + + req := httptest.NewRequest("GET", "/projects/INVALID!/pipelines/1", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } + }) +} + +func TestFormatTime(t *testing.T) { + t.Run("non-zero time", func(t *testing.T) { + ts := time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC) + got := formatTime(ts) + want := "2025-01-15T10:30:00Z" + if got != want { + t.Errorf("formatTime() = %q, want %q", got, want) + } + }) + + t.Run("zero time", func(t *testing.T) { + got := formatTime(time.Time{}) + if got != "" { + t.Errorf("formatTime(zero) = %q, want empty string", got) + } + }) +} diff --git a/internal/handlers/infrastructure_test.go b/internal/handlers/infrastructure_test.go new file mode 100644 index 0000000..3bd7b19 --- /dev/null +++ b/internal/handlers/infrastructure_test.go @@ -0,0 +1,455 @@ +package handlers + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-chi/chi/v5" + "github.com/orchard9/rdev/internal/domain" +) + +// mockGitRepository implements port.GitRepository for testing. +type mockGitRepository struct { + repos map[string]*domain.Repo + err error +} + +func newMockGitRepository() *mockGitRepository { + return &mockGitRepository{repos: make(map[string]*domain.Repo)} +} + +func (m *mockGitRepository) CreateRepo(_ context.Context, name, description string, private bool) (*domain.Repo, error) { + if m.err != nil { + return nil, m.err + } + repo := &domain.Repo{ + ID: 1, + Owner: "threesix", + Name: name, + FullName: "threesix/" + name, + Description: description, + Private: private, + CloneSSH: fmt.Sprintf("git@git.threesix.ai:threesix/%s.git", name), + CloneHTTP: fmt.Sprintf("https://git.threesix.ai/threesix/%s.git", name), + HTMLURL: fmt.Sprintf("https://git.threesix.ai/threesix/%s", name), + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + m.repos[name] = repo + return repo, nil +} + +func (m *mockGitRepository) DeleteRepo(_ context.Context, _, name string) error { + if m.err != nil { + return m.err + } + delete(m.repos, name) + return nil +} + +func (m *mockGitRepository) ListRepos(_ context.Context, _ string) ([]*domain.Repo, error) { + if m.err != nil { + return nil, m.err + } + var repos []*domain.Repo + for _, r := range m.repos { + repos = append(repos, r) + } + return repos, nil +} + +func (m *mockGitRepository) GetRepo(_ context.Context, _, name string) (*domain.Repo, error) { + if m.err != nil { + return nil, m.err + } + r, ok := m.repos[name] + if !ok { + return nil, fmt.Errorf("repo not found: %s", name) + } + return r, nil +} + +func (m *mockGitRepository) AddCollaborator(context.Context, string, string, string, string) error { + return m.err +} +func (m *mockGitRepository) RemoveCollaborator(context.Context, string, string, string) error { + return m.err +} +func (m *mockGitRepository) AddDeployKey(context.Context, string, string, string, string, bool) (*domain.DeployKey, error) { + return nil, m.err +} +func (m *mockGitRepository) DeleteDeployKey(context.Context, string, string, int64) error { + return m.err +} +func (m *mockGitRepository) CreateWebhook(context.Context, string, string, string, string, []string) (*domain.RepoWebhook, error) { + return nil, m.err +} +func (m *mockGitRepository) DeleteWebhook(context.Context, string, string, int64) error { + return m.err +} + +// mockDNSProvider implements port.DNSProvider for testing. +type mockDNSProvider struct { + records map[string]*domain.DNSRecord + err error +} + +func newMockDNSProvider() *mockDNSProvider { + return &mockDNSProvider{records: make(map[string]*domain.DNSRecord)} +} + +func (m *mockDNSProvider) CreateRecord(_ context.Context, record domain.DNSRecord) (*domain.DNSRecord, error) { + if m.err != nil { + return nil, m.err + } + record.ID = "rec-" + record.Name + m.records[record.Name] = &record + return &record, nil +} + +func (m *mockDNSProvider) UpdateRecord(_ context.Context, recordID string, record domain.DNSRecord) (*domain.DNSRecord, error) { + if m.err != nil { + return nil, m.err + } + record.ID = recordID + m.records[recordID] = &record + return &record, nil +} + +func (m *mockDNSProvider) DeleteRecord(_ context.Context, recordID string) error { + if m.err != nil { + return m.err + } + delete(m.records, recordID) + return nil +} + +func (m *mockDNSProvider) DeleteRecordByName(_ context.Context, _, name string) error { + if m.err != nil { + return m.err + } + delete(m.records, name) + return nil +} + +func (m *mockDNSProvider) GetRecord(_ context.Context, recordID string) (*domain.DNSRecord, error) { + if m.err != nil { + return nil, m.err + } + r, ok := m.records[recordID] + if !ok { + return nil, fmt.Errorf("record not found") + } + return r, nil +} + +func (m *mockDNSProvider) ListRecords(_ context.Context, recordType string) ([]*domain.DNSRecord, error) { + if m.err != nil { + return nil, m.err + } + var result []*domain.DNSRecord + for _, r := range m.records { + if recordType == "" || r.Type == recordType { + result = append(result, r) + } + } + return result, nil +} + +func (m *mockDNSProvider) FindRecord(_ context.Context, _, name string) (*domain.DNSRecord, error) { + if m.err != nil { + return nil, m.err + } + r, ok := m.records[name] + if !ok { + return nil, nil + } + return r, nil +} + +// mockDeployer implements port.Deployer for testing. +type mockDeployer struct { + deployments map[string]*domain.DeployStatus + logs string + err error +} + +func newMockDeployer() *mockDeployer { + return &mockDeployer{deployments: make(map[string]*domain.DeployStatus)} +} + +func (m *mockDeployer) Deploy(_ context.Context, spec domain.DeploySpec) error { + if m.err != nil { + return m.err + } + m.deployments[spec.ProjectName] = &domain.DeployStatus{ + ProjectName: spec.ProjectName, + Image: spec.Image, + Replicas: spec.Replicas, + ReadyReplicas: 0, + URL: "https://" + spec.Domain, + Status: domain.DeploymentStatusPending, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + return nil +} + +func (m *mockDeployer) Undeploy(_ context.Context, projectName string) error { + if m.err != nil { + return m.err + } + delete(m.deployments, projectName) + return nil +} + +func (m *mockDeployer) GetStatus(_ context.Context, projectName string) (*domain.DeployStatus, error) { + if m.err != nil { + return nil, m.err + } + s, ok := m.deployments[projectName] + if !ok { + return nil, nil + } + return s, nil +} + +func (m *mockDeployer) Restart(_ context.Context, _ string) error { + return m.err +} + +func (m *mockDeployer) Scale(_ context.Context, projectName string, replicas int) error { + if m.err != nil { + return m.err + } + if s, ok := m.deployments[projectName]; ok { + s.Replicas = replicas + } + return nil +} + +func (m *mockDeployer) GetLogs(_ context.Context, _ string, _ int) (string, error) { + if m.err != nil { + return "", m.err + } + return m.logs, nil +} + +func setupInfraHandler() (*InfrastructureHandler, *mockGitRepository, *mockDNSProvider, *mockDeployer, chi.Router) { + git := newMockGitRepository() + dns := newMockDNSProvider() + deployer := newMockDeployer() + h := NewInfrastructureHandler(git, dns, deployer, nil, nil, InfrastructureConfig{ + DefaultGitOwner: "threesix", + DefaultDomain: "threesix.ai", + ClusterIP: "208.122.204.172", + }) + r := chi.NewRouter() + h.Mount(r) + return h, git, dns, deployer, r +} + +func TestInfrastructureHandler_CreateRepo(t *testing.T) { + t.Run("success", func(t *testing.T) { + _, git, _, _, router := setupInfraHandler() + + body, _ := json.Marshal(CreateRepoRequest{Description: "Test repo", Private: true}) + req := httptest.NewRequest("POST", "/projects/myapp/repo", bytes.NewReader(body)) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusCreated { + t.Errorf("status = %d, want %d", rec.Code, http.StatusCreated) + } + if _, ok := git.repos["myapp"]; !ok { + t.Error("repo not created") + } + }) + + t.Run("invalid project id", func(t *testing.T) { + _, _, _, _, router := setupInfraHandler() + + req := httptest.NewRequest("POST", "/projects/INVALID_NAME!/repo", bytes.NewReader([]byte("{}"))) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } + }) + + t.Run("empty body allowed", func(t *testing.T) { + _, _, _, _, router := setupInfraHandler() + + req := httptest.NewRequest("POST", "/projects/myapp/repo", bytes.NewReader([]byte(""))) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + // Should succeed with empty body (EOF is allowed) + if rec.Code != http.StatusCreated { + t.Errorf("status = %d, want %d", rec.Code, http.StatusCreated) + } + }) + + t.Run("git not configured", func(t *testing.T) { + h := NewInfrastructureHandler(nil, nil, nil, nil, nil, InfrastructureConfig{}) + r := chi.NewRouter() + h.Mount(r) + + req := httptest.NewRequest("POST", "/projects/myapp/repo", bytes.NewReader([]byte("{}"))) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusInternalServerError { + t.Errorf("status = %d, want %d", rec.Code, http.StatusInternalServerError) + } + }) +} + +func TestInfrastructureHandler_GetRepo(t *testing.T) { + t.Run("found", func(t *testing.T) { + _, git, _, _, router := setupInfraHandler() + git.repos["myapp"] = &domain.Repo{ + ID: 1, Owner: "threesix", Name: "myapp", FullName: "threesix/myapp", + CloneSSH: "git@git.threesix.ai:threesix/myapp.git", + } + + req := httptest.NewRequest("GET", "/projects/myapp/repo", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("status = %d, want %d", rec.Code, http.StatusOK) + } + }) + + t.Run("not found", func(t *testing.T) { + _, _, _, _, router := setupInfraHandler() + + req := httptest.NewRequest("GET", "/projects/missing/repo", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusNotFound { + t.Errorf("status = %d, want %d", rec.Code, http.StatusNotFound) + } + }) +} + +func TestInfrastructureHandler_DeleteRepo(t *testing.T) { + t.Run("success", func(t *testing.T) { + _, git, _, _, router := setupInfraHandler() + git.repos["myapp"] = &domain.Repo{ID: 1, Name: "myapp"} + + req := httptest.NewRequest("DELETE", "/projects/myapp/repo", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("status = %d, want %d", rec.Code, http.StatusOK) + } + }) +} + +func TestInfrastructureHandler_Deploy(t *testing.T) { + t.Run("success", func(t *testing.T) { + _, _, _, deployer, router := setupInfraHandler() + + body, _ := json.Marshal(DeployRequest{ + Image: "registry.threesix.ai/myapp:latest", + Port: 8080, + Replicas: 2, + }) + req := httptest.NewRequest("POST", "/projects/myapp/deploy", bytes.NewReader(body)) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusCreated { + t.Errorf("status = %d, want %d", rec.Code, http.StatusCreated) + } + if _, ok := deployer.deployments["myapp"]; !ok { + t.Error("deployment not created") + } + }) + + t.Run("missing image", func(t *testing.T) { + _, _, _, _, router := setupInfraHandler() + + body, _ := json.Marshal(DeployRequest{Port: 8080}) + req := httptest.NewRequest("POST", "/projects/myapp/deploy", bytes.NewReader(body)) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } + }) + + t.Run("deployer not configured", func(t *testing.T) { + h := NewInfrastructureHandler(nil, nil, nil, nil, nil, InfrastructureConfig{}) + r := chi.NewRouter() + h.Mount(r) + + body, _ := json.Marshal(DeployRequest{Image: "myimage:latest"}) + req := httptest.NewRequest("POST", "/projects/myapp/deploy", bytes.NewReader(body)) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusInternalServerError { + t.Errorf("status = %d, want %d", rec.Code, http.StatusInternalServerError) + } + }) +} + +func TestInfrastructureHandler_GetDeployStatus(t *testing.T) { + t.Run("found", func(t *testing.T) { + _, _, _, deployer, router := setupInfraHandler() + deployer.deployments["myapp"] = &domain.DeployStatus{ + ProjectName: "myapp", + Image: "myimage:latest", + Status: domain.DeploymentStatusRunning, + Replicas: 2, + } + + req := httptest.NewRequest("GET", "/projects/myapp/deploy/status", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("status = %d, want %d", rec.Code, http.StatusOK) + } + }) + + t.Run("not found", func(t *testing.T) { + _, _, _, _, router := setupInfraHandler() + + req := httptest.NewRequest("GET", "/projects/missing/deploy/status", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusNotFound { + t.Errorf("status = %d, want %d", rec.Code, http.StatusNotFound) + } + }) +} + +func TestInfrastructureHandler_Undeploy(t *testing.T) { + t.Run("success", func(t *testing.T) { + _, _, _, deployer, router := setupInfraHandler() + deployer.deployments["myapp"] = &domain.DeployStatus{ProjectName: "myapp"} + + req := httptest.NewRequest("DELETE", "/projects/myapp/deploy", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("status = %d, want %d", rec.Code, http.StatusOK) + } + }) +} diff --git a/internal/handlers/keys.go b/internal/handlers/keys.go index a7dc0df..ee07ac9 100644 --- a/internal/handlers/keys.go +++ b/internal/handlers/keys.go @@ -2,6 +2,7 @@ package handlers import ( "encoding/json" + "errors" "net" "net/http" @@ -66,7 +67,7 @@ type CreateKeyResponse struct { // apiKeyToResponse converts an APIKey to a JSON response. func apiKeyToResponse(k *auth.APIKey) KeyResponse { resp := KeyResponse{ - ID: k.ID, + ID: string(k.ID), Name: k.Name, KeyPrefix: k.KeyPrefix, Scopes: auth.ScopesToStrings(k.Scopes), @@ -76,7 +77,10 @@ func apiKeyToResponse(k *auth.APIKey) KeyResponse { } if k.ProjectIDs != nil { - resp.ProjectIDs = k.ProjectIDs + resp.ProjectIDs = make([]string, len(k.ProjectIDs)) + for i, pid := range k.ProjectIDs { + resp.ProjectIDs[i] = string(pid) + } } if k.AllowedIPs != nil { @@ -159,8 +163,8 @@ func (h *KeysHandler) Create(w http.ResponseWriter, r *http.Request) { // Get creator from authenticated key creator := "admin" - if apiKey := auth.GetAPIKey(r.Context()); apiKey != nil && apiKey.ID != "admin" { - creator = apiKey.ID + if apiKey := auth.GetAPIKey(r.Context()); apiKey != nil && string(apiKey.ID) != "admin" { + creator = string(apiKey.ID) } result, err := h.authService.Create(r.Context(), auth.CreateKeyRequest{ @@ -206,7 +210,7 @@ func (h *KeysHandler) Get(w http.ResponseWriter, r *http.Request) { key, err := h.authService.Get(r.Context(), id) if err != nil { - if err == auth.ErrKeyNotFound { + if errors.Is(err, auth.ErrKeyNotFound) { api.WriteNotFound(w, r, "Key not found") return } @@ -223,7 +227,7 @@ func (h *KeysHandler) Revoke(w http.ResponseWriter, r *http.Request) { id := chi.URLParam(r, "id") if err := h.authService.Revoke(r.Context(), id); err != nil { - if err == auth.ErrKeyNotFound { + if errors.Is(err, auth.ErrKeyNotFound) { api.WriteNotFound(w, r, "Key not found") return } diff --git a/internal/handlers/keys_test.go b/internal/handlers/keys_test.go index 32a915d..599ec45 100644 --- a/internal/handlers/keys_test.go +++ b/internal/handlers/keys_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "net/http" "net/http/httptest" "strings" @@ -11,7 +12,10 @@ import ( "time" "github.com/go-chi/chi/v5" + "github.com/orchard9/rdev/internal/adapter/postgres" "github.com/orchard9/rdev/internal/auth" + "github.com/orchard9/rdev/internal/domain" + "github.com/orchard9/rdev/internal/service" "github.com/orchard9/rdev/internal/testutil" ) @@ -22,7 +26,9 @@ func setupKeysHandler(t *testing.T) (*KeysHandler, chi.Router, *auth.Service) { db := testutil.TestDB(t) t.Cleanup(func() { testutil.CleanupTestKeys(t, db) }) - authService := auth.NewService(db, "test-admin-key") + apiKeyRepo := postgres.NewAPIKeyRepository(db) + apiKeySvc := service.NewAPIKeyService(apiKeyRepo, "test-admin-key") + authService := auth.NewService(apiKeySvc, "test-admin-key") handler := NewKeysHandler(authService) router := chi.NewRouter() @@ -204,7 +210,7 @@ func TestKeysHandler_Get(t *testing.T) { } t.Run("existing key", func(t *testing.T) { - req := httptest.NewRequest("GET", "/keys/"+result.Key.ID, nil) + req := httptest.NewRequest("GET", "/keys/"+string(result.Key.ID), nil) rec := httptest.NewRecorder() router.ServeHTTP(rec, req) @@ -249,7 +255,7 @@ func TestKeysHandler_Revoke(t *testing.T) { } t.Run("revoke existing key", func(t *testing.T) { - req := httptest.NewRequest("DELETE", "/keys/"+result.Key.ID, nil) + req := httptest.NewRequest("DELETE", "/keys/"+string(result.Key.ID), nil) rec := httptest.NewRecorder() router.ServeHTTP(rec, req) @@ -268,7 +274,7 @@ func TestKeysHandler_Revoke(t *testing.T) { // Verify the key is actually revoked _, err := authService.Validate(context.Background(), result.Secret) - if err != auth.ErrKeyRevoked { + if !errors.Is(err, auth.ErrKeyRevoked) { t.Errorf("Key should be revoked, got err = %v", err) } }) @@ -308,11 +314,11 @@ func TestApiKeyToResponse(t *testing.T) { future := now.Add(24 * time.Hour) key := &auth.APIKey{ - ID: "test-id", + ID: domain.APIKeyID("test-id"), Name: "test-name", KeyPrefix: "rdev_sk_abc", Scopes: []auth.Scope{auth.ScopeProjectsRead, auth.ScopeProjectsExecute}, - ProjectIDs: []string{"proj-a"}, + ProjectIDs: []domain.ProjectID{"proj-a"}, CreatedAt: now, ExpiresAt: &future, LastUsedAt: &now, diff --git a/internal/handlers/project_management.go b/internal/handlers/project_management.go index 6945172..4276529 100644 --- a/internal/handlers/project_management.go +++ b/internal/handlers/project_management.go @@ -18,12 +18,17 @@ import ( // ProjectManagementHandler handles project lifecycle operations. type ProjectManagementHandler struct { infraService *service.ProjectInfraService + logger *slog.Logger } // NewProjectManagementHandler creates a new project management handler. -func NewProjectManagementHandler(infraService *service.ProjectInfraService) *ProjectManagementHandler { +func NewProjectManagementHandler(infraService *service.ProjectInfraService, logger *slog.Logger) *ProjectManagementHandler { + if logger == nil { + logger = slog.Default() + } return &ProjectManagementHandler{ infraService: infraService, + logger: logger, } } @@ -84,7 +89,7 @@ func (h *ProjectManagementHandler) Create(w http.ResponseWriter, r *http.Request return } // Log internal errors but return generic message to client - slog.Error("project creation failed", "error", err, "name", req.Name) + h.logger.Error("project creation failed", "error", err, "name", req.Name) api.WriteInternalError(w, r, "failed to create project") return } @@ -119,7 +124,7 @@ func (h *ProjectManagementHandler) List(w http.ResponseWriter, r *http.Request) projects, err := h.infraService.ListProjects(ctx) if err != nil { - slog.Error("failed to list projects", "error", err) + h.logger.Error("failed to list projects", "error", err) api.WriteInternalError(w, r, "failed to list projects") return } @@ -169,7 +174,7 @@ func (h *ProjectManagementHandler) Status(w http.ResponseWriter, r *http.Request api.WriteNotFound(w, r, "project not found") return } - slog.Error("failed to get project status", "error", err, "name", name) + h.logger.Error("failed to get project status", "error", err, "name", name) api.WriteInternalError(w, r, "failed to get project status") return } @@ -216,7 +221,7 @@ func (h *ProjectManagementHandler) Delete(w http.ResponseWriter, r *http.Request api.WriteNotFound(w, r, "project not found") return } - slog.Error("failed to delete project", "error", err, "name", name) + h.logger.Error("failed to delete project", "error", err, "name", name) api.WriteInternalError(w, r, "failed to delete project") return } @@ -240,7 +245,7 @@ func (h *ProjectManagementHandler) ListTemplates(w http.ResponseWriter, r *http. templates, err := h.infraService.ListTemplates(ctx) if err != nil { - slog.Error("failed to list templates", "error", err) + h.logger.Error("failed to list templates", "error", err) api.WriteInternalError(w, r, "failed to list templates") return } @@ -277,7 +282,7 @@ func (h *ProjectManagementHandler) GetTemplate(w http.ResponseWriter, r *http.Re api.WriteNotFound(w, r, "template not found") return } - slog.Error("failed to get template", "error", err, "name", name) + h.logger.Error("failed to get template", "error", err, "name", name) api.WriteInternalError(w, r, "failed to get template") return } diff --git a/internal/handlers/project_management_test.go b/internal/handlers/project_management_test.go new file mode 100644 index 0000000..d639cec --- /dev/null +++ b/internal/handlers/project_management_test.go @@ -0,0 +1,85 @@ +package handlers + +import ( + "bytes" + "encoding/json" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" +) + +func TestProjectManagementHandler_NilService(t *testing.T) { + h := NewProjectManagementHandler(nil, slog.Default()) + r := chi.NewRouter() + h.Mount(r) + + tests := []struct { + name string + method string + path string + body string + }{ + {"create", "POST", "/project", `{"name":"test"}`}, + {"list", "GET", "/project", ""}, + {"status", "GET", "/project/test", ""}, + {"delete", "DELETE", "/project/test", ""}, + {"list templates", "GET", "/templates", ""}, + {"get template", "GET", "/templates/default", ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var req *http.Request + if tt.body != "" { + req = httptest.NewRequest(tt.method, tt.path, bytes.NewReader([]byte(tt.body))) + } else { + req = httptest.NewRequest(tt.method, tt.path, nil) + } + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusInternalServerError { + t.Errorf("%s: status = %d, want %d", tt.name, rec.Code, http.StatusInternalServerError) + } + }) + } +} + +func TestProjectManagementHandler_CreateValidation(t *testing.T) { + // With nil service, the handler returns 500 before reaching validation. + // This tests that the nil check takes precedence. + h := NewProjectManagementHandler(nil, slog.Default()) + r := chi.NewRouter() + h.Mount(r) + + t.Run("nil service returns 500 even with missing name", func(t *testing.T) { + body, _ := json.Marshal(CreateRequest{Name: ""}) + req := httptest.NewRequest("POST", "/project", bytes.NewReader(body)) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusInternalServerError { + t.Errorf("status = %d, want %d", rec.Code, http.StatusInternalServerError) + } + }) + + t.Run("nil service returns 500 even with invalid json", func(t *testing.T) { + req := httptest.NewRequest("POST", "/project", bytes.NewReader([]byte("not json"))) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusInternalServerError { + t.Errorf("status = %d, want %d", rec.Code, http.StatusInternalServerError) + } + }) +} + +func TestNewProjectManagementHandler_NilLogger(t *testing.T) { + h := NewProjectManagementHandler(nil, nil) + if h.logger == nil { + t.Error("logger should default to slog.Default() when nil") + } +} diff --git a/internal/handlers/projects.go b/internal/handlers/projects.go index b04bce9..4a1def5 100644 --- a/internal/handlers/projects.go +++ b/internal/handlers/projects.go @@ -67,7 +67,7 @@ func getAuditContext(r *http.Request) *service.AuditContext { } return &service.AuditContext{ - APIKeyID: apiKey.ID, + APIKeyID: string(apiKey.ID), ClientIP: getClientIP(r), UserAgent: r.UserAgent(), } diff --git a/internal/handlers/projects_commands_test.go b/internal/handlers/projects_commands_test.go new file mode 100644 index 0000000..d095725 --- /dev/null +++ b/internal/handlers/projects_commands_test.go @@ -0,0 +1,98 @@ +package handlers + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" +) + +func TestProjectsHandler_RunClaude_InvalidJSON(t *testing.T) { + h := &ProjectsHandler{streams: newStreamManager()} + r := chi.NewRouter() + h.Mount(r) + + req := httptest.NewRequest("POST", "/projects/myapp/claude", bytes.NewReader([]byte("not json"))) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } +} + +func TestProjectsHandler_RunClaude_NoServiceConfigured(t *testing.T) { + h := &ProjectsHandler{streams: newStreamManager()} + r := chi.NewRouter() + h.Mount(r) + + body, _ := json.Marshal(ClaudeRequest{Prompt: "hello"}) + req := httptest.NewRequest("POST", "/projects/myapp/claude", bytes.NewReader(body)) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusInternalServerError { + t.Errorf("status = %d, want %d", rec.Code, http.StatusInternalServerError) + } +} + +func TestProjectsHandler_RunShell_InvalidJSON(t *testing.T) { + h := &ProjectsHandler{streams: newStreamManager()} + r := chi.NewRouter() + h.Mount(r) + + req := httptest.NewRequest("POST", "/projects/myapp/shell", bytes.NewReader([]byte("not json"))) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } +} + +func TestProjectsHandler_RunShell_NoServiceConfigured(t *testing.T) { + h := &ProjectsHandler{streams: newStreamManager()} + r := chi.NewRouter() + h.Mount(r) + + body, _ := json.Marshal(ShellRequest{Command: "ls"}) + req := httptest.NewRequest("POST", "/projects/myapp/shell", bytes.NewReader(body)) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusInternalServerError { + t.Errorf("status = %d, want %d", rec.Code, http.StatusInternalServerError) + } +} + +func TestProjectsHandler_RunGit_InvalidJSON(t *testing.T) { + h := &ProjectsHandler{streams: newStreamManager()} + r := chi.NewRouter() + h.Mount(r) + + req := httptest.NewRequest("POST", "/projects/myapp/git", bytes.NewReader([]byte("not json"))) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } +} + +func TestProjectsHandler_RunGit_NoServiceConfigured(t *testing.T) { + h := &ProjectsHandler{streams: newStreamManager()} + r := chi.NewRouter() + h.Mount(r) + + body, _ := json.Marshal(GitRequest{Args: []string{"status"}}) + req := httptest.NewRequest("POST", "/projects/myapp/git", bytes.NewReader(body)) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusInternalServerError { + t.Errorf("status = %d, want %d", rec.Code, http.StatusInternalServerError) + } +} diff --git a/internal/handlers/projects_stream_test.go b/internal/handlers/projects_stream_test.go new file mode 100644 index 0000000..d67b60c --- /dev/null +++ b/internal/handlers/projects_stream_test.go @@ -0,0 +1,187 @@ +package handlers + +import ( + "encoding/json" + "net/http/httptest" + "strings" + "testing" + "time" +) + +// recorderFlusher wraps httptest.ResponseRecorder to satisfy http.Flusher. +type recorderFlusher struct { + *httptest.ResponseRecorder +} + +func (rf *recorderFlusher) Flush() {} + +func newRecorderFlusher() *recorderFlusher { + return &recorderFlusher{httptest.NewRecorder()} +} + +func TestStreamManager_SubscribeAndSend(t *testing.T) { + sm := newStreamManager() + + ch := sm.Subscribe("stream-1") + defer sm.Unsubscribe("stream-1", ch) + + // Send event + sm.Send("stream-1", "output", map[string]any{"line": "hello"}) + + select { + case evt := <-ch: + if evt.Type != "output" { + t.Errorf("event type = %q, want %q", evt.Type, "output") + } + if evt.Data["line"] != "hello" { + t.Errorf("event data = %v, want line=hello", evt.Data) + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for event") + } +} + +func TestStreamManager_MultipleSubscribers(t *testing.T) { + sm := newStreamManager() + + ch1 := sm.Subscribe("stream-1") + ch2 := sm.Subscribe("stream-1") + defer sm.Unsubscribe("stream-1", ch1) + defer sm.Unsubscribe("stream-1", ch2) + + sm.Send("stream-1", "test", map[string]any{"value": 1}) + + // Both should receive + for i, ch := range []chan streamEvent{ch1, ch2} { + select { + case evt := <-ch: + if evt.Type != "test" { + t.Errorf("subscriber %d: event type = %q, want %q", i, evt.Type, "test") + } + case <-time.After(time.Second): + t.Fatalf("subscriber %d: timed out", i) + } + } +} + +func TestStreamManager_Close(t *testing.T) { + sm := newStreamManager() + + ch := sm.Subscribe("stream-1") + + sm.Close("stream-1") + + // Channel should be closed + _, ok := <-ch + if ok { + t.Error("channel should be closed after Close()") + } +} + +func TestStreamManager_SendToNonexistentStream(t *testing.T) { + sm := newStreamManager() + // Should not panic + sm.Send("nonexistent", "test", map[string]any{}) +} + +func TestStreamManager_Unsubscribe(t *testing.T) { + sm := newStreamManager() + + ch1 := sm.Subscribe("stream-1") + ch2 := sm.Subscribe("stream-1") + + sm.Unsubscribe("stream-1", ch1) + + // ch1 should be closed + _, ok := <-ch1 + if ok { + t.Error("ch1 should be closed after Unsubscribe") + } + + // ch2 should still receive + sm.Send("stream-1", "test", map[string]any{}) + select { + case evt := <-ch2: + if evt.Type != "test" { + t.Errorf("event type = %q, want %q", evt.Type, "test") + } + case <-time.After(time.Second): + t.Fatal("ch2 timed out") + } + + sm.Unsubscribe("stream-1", ch2) +} + +func TestWriteSSE(t *testing.T) { + rf := newRecorderFlusher() + + writeSSE(rf.ResponseRecorder, rf, "output", map[string]any{"line": "hello"}) + + body := rf.Body.String() + if !strings.Contains(body, "event: output\n") { + t.Errorf("missing event line in SSE output: %s", body) + } + if !strings.Contains(body, "data: ") { + t.Errorf("missing data line in SSE output: %s", body) + } + // Should not have id line + if strings.Contains(body, "id: ") { + t.Errorf("should not have id line without ID: %s", body) + } + + // Verify data is valid JSON + lines := strings.Split(body, "\n") + for _, line := range lines { + if strings.HasPrefix(line, "data: ") { + jsonStr := strings.TrimPrefix(line, "data: ") + var parsed map[string]any + if err := json.Unmarshal([]byte(jsonStr), &parsed); err != nil { + t.Errorf("data is not valid JSON: %v", err) + } + if parsed["line"] != "hello" { + t.Errorf("data[line] = %v, want hello", parsed["line"]) + } + } + } +} + +func TestWriteSSEWithID(t *testing.T) { + rf := newRecorderFlusher() + + writeSSEWithID(rf.ResponseRecorder, rf, "evt-123", "complete", map[string]any{"exit_code": 0}) + + body := rf.Body.String() + if !strings.Contains(body, "id: evt-123\n") { + t.Errorf("missing id line in SSE output: %s", body) + } + if !strings.Contains(body, "event: complete\n") { + t.Errorf("missing event line in SSE output: %s", body) + } +} + +func TestStreamManager_FullChannel(t *testing.T) { + sm := newStreamManager() + + ch := sm.Subscribe("stream-1") + + // Fill the channel (buffer size is 100) + for i := 0; i < 100; i++ { + sm.Send("stream-1", "test", map[string]any{"i": i}) + } + + // Next send should not block (dropped) + done := make(chan struct{}) + go func() { + sm.Send("stream-1", "test", map[string]any{"i": 100}) + close(done) + }() + + select { + case <-done: + // Good - did not block + case <-time.After(time.Second): + t.Fatal("Send blocked on full channel") + } + + sm.Unsubscribe("stream-1", ch) +} diff --git a/internal/handlers/queue.go b/internal/handlers/queue.go index 9c910ba..440c2ce 100644 --- a/internal/handlers/queue.go +++ b/internal/handlers/queue.go @@ -138,7 +138,7 @@ func (h *QueueHandler) Enqueue(w http.ResponseWriter, r *http.Request) { // Get API key ID for audit trail var apiKeyID string if apiKey := auth.GetAPIKey(r.Context()); apiKey != nil { - apiKeyID = apiKey.ID + apiKeyID = string(apiKey.ID) } // Create queued command diff --git a/internal/handlers/webhooks.go b/internal/handlers/webhooks.go index 40fce83..5a0a4b5 100644 --- a/internal/handlers/webhooks.go +++ b/internal/handlers/webhooks.go @@ -8,13 +8,13 @@ import ( "errors" "fmt" "net/http" - "net/url" "strconv" "github.com/go-chi/chi/v5" "github.com/google/uuid" "github.com/orchard9/rdev/internal/domain" "github.com/orchard9/rdev/internal/port" + "github.com/orchard9/rdev/internal/validate" "github.com/orchard9/rdev/pkg/api" ) @@ -115,8 +115,7 @@ func (h *WebhookHandler) Create(w http.ResponseWriter, r *http.Request) { api.WriteBadRequest(w, r, "url is required") return } - parsedURL, err := url.Parse(req.URL) - if err != nil || (parsedURL.Scheme != "http" && parsedURL.Scheme != "https") { + if err := validate.HTTPURL(req.URL, "url"); err != nil { api.WriteBadRequest(w, r, "url must be a valid HTTP or HTTPS URL") return } @@ -288,8 +287,7 @@ func (h *WebhookHandler) Update(w http.ResponseWriter, r *http.Request) { // Update fields if req.URL != "" { - parsedURL, err := url.Parse(req.URL) - if err != nil || (parsedURL.Scheme != "http" && parsedURL.Scheme != "https") { + if err := validate.HTTPURL(req.URL, "url"); err != nil { api.WriteBadRequest(w, r, "url must be a valid HTTP or HTTPS URL") return } diff --git a/internal/handlers/work.go b/internal/handlers/work.go index 82976b3..b3b346c 100644 --- a/internal/handlers/work.go +++ b/internal/handlers/work.go @@ -6,13 +6,12 @@ import ( "errors" "fmt" "net/http" - "net/url" "strconv" "github.com/go-chi/chi/v5" "github.com/orchard9/rdev/internal/domain" - "github.com/orchard9/rdev/internal/port" "github.com/orchard9/rdev/internal/service" + "github.com/orchard9/rdev/internal/validate" "github.com/orchard9/rdev/pkg/api" ) @@ -88,19 +87,16 @@ func (h *WorkHandler) Enqueue(w http.ResponseWriter, r *http.Request) { } // Validate task type - taskType := port.WorkTaskType(req.TaskType) + taskType := domain.WorkTaskType(req.TaskType) if !taskType.IsValid() { api.WriteBadRequest(w, r, "task_type must be one of: build, test, deploy, custom") return } // Validate callback URL if provided - if req.CallbackURL != "" { - parsedURL, err := url.Parse(req.CallbackURL) - if err != nil || (parsedURL.Scheme != "http" && parsedURL.Scheme != "https") { - api.WriteBadRequest(w, r, "callback_url must be a valid HTTP/HTTPS URL") - return - } + if err := validate.HTTPURL(req.CallbackURL, "callback_url"); err != nil { + api.WriteBadRequest(w, r, "callback_url must be a valid HTTP/HTTPS URL") + return } result, err := h.workService.EnqueueTask(r.Context(), service.EnqueueTaskRequest{ @@ -157,8 +153,8 @@ type WorkResultDTO struct { Artifacts map[string]string `json:"artifacts,omitempty"` } -// toWorkTaskDTO converts a port.WorkTask to a WorkTaskDTO. -func toWorkTaskDTO(t *port.WorkTask) *WorkTaskDTO { +// toWorkTaskDTO converts a domain.WorkTask to a WorkTaskDTO. +func toWorkTaskDTO(t *domain.WorkTask) *WorkTaskDTO { if t == nil { return nil } @@ -273,7 +269,7 @@ func (h *WorkHandler) Complete(w http.ResponseWriter, r *http.Request) { return } - result := &port.WorkResult{ + result := &domain.WorkResult{ Output: req.Output, Artifacts: req.Artifacts, } @@ -358,9 +354,9 @@ func (h *WorkHandler) ListByProject(w http.ResponseWriter, r *http.Request) { projectID := chi.URLParam(r, "projectId") // Parse and validate optional status filter - var status *port.WorkTaskStatus + var status *domain.WorkTaskStatus if s := r.URL.Query().Get("status"); s != "" { - st := port.WorkTaskStatus(s) + st := domain.WorkTaskStatus(s) if !st.IsValid() { api.WriteBadRequest(w, r, "invalid status filter: must be pending, running, completed, failed, or cancelled") return @@ -369,7 +365,7 @@ func (h *WorkHandler) ListByProject(w http.ResponseWriter, r *http.Request) { } // Parse pagination options - opts := port.DefaultWorkListOptions() + opts := domain.DefaultWorkListOptions() if limitStr := r.URL.Query().Get("limit"); limitStr != "" { limit, err := strconv.Atoi(limitStr) if err != nil { diff --git a/internal/handlers/work_lifecycle_test.go b/internal/handlers/work_lifecycle_test.go new file mode 100644 index 0000000..221bbd8 --- /dev/null +++ b/internal/handlers/work_lifecycle_test.go @@ -0,0 +1,381 @@ +package handlers + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-chi/chi/v5" + "github.com/orchard9/rdev/internal/domain" + "github.com/orchard9/rdev/internal/service" +) + +func TestWorkHandler_Fail(t *testing.T) { + mockQueue := newMockWorkQueue() + workService := service.NewWorkService(mockQueue, service.WorkServiceConfig{}) + handler := NewWorkHandler(workService) + + // Pre-populate a running task + mockQueue.tasks["task-1"] = &domain.WorkTask{ + ID: "task-1", + ProjectID: "test-project", + Type: domain.WorkTaskTypeBuild, + Status: domain.WorkTaskStatusRunning, + WorkerID: "worker-1", + MaxRetries: 3, + CreatedAt: time.Now(), + } + + router := chi.NewRouter() + handler.Mount(router) + + tests := []struct { + name string + taskID string + body FailWorkRequest + wantStatus int + }{ + { + name: "valid_fail", + taskID: "task-1", + body: FailWorkRequest{Error: "Build failed: npm error"}, + wantStatus: http.StatusOK, + }, + { + name: "missing_error", + taskID: "task-1", + body: FailWorkRequest{}, + wantStatus: http.StatusBadRequest, + }, + { + name: "task_not_found", + taskID: "nonexistent", + body: FailWorkRequest{Error: "Failed"}, + wantStatus: http.StatusNotFound, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + body, _ := json.Marshal(tt.body) + req := httptest.NewRequest(http.MethodPost, "/work/"+tt.taskID+"/fail", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != tt.wantStatus { + t.Errorf("got status %d, want %d; body: %s", rec.Code, tt.wantStatus, rec.Body.String()) + } + }) + } +} + +func TestWorkHandler_Cancel(t *testing.T) { + mockQueue := newMockWorkQueue() + workService := service.NewWorkService(mockQueue, service.WorkServiceConfig{}) + handler := NewWorkHandler(workService) + + // Pre-populate tasks + mockQueue.tasks["pending-task"] = &domain.WorkTask{ + ID: "pending-task", + ProjectID: "test-project", + Type: domain.WorkTaskTypeBuild, + Status: domain.WorkTaskStatusPending, + CreatedAt: time.Now(), + } + mockQueue.tasks["running-task"] = &domain.WorkTask{ + ID: "running-task", + ProjectID: "test-project", + Type: domain.WorkTaskTypeBuild, + Status: domain.WorkTaskStatusRunning, + CreatedAt: time.Now(), + } + + router := chi.NewRouter() + handler.Mount(router) + + tests := []struct { + name string + taskID string + wantStatus int + }{ + { + name: "cancel_pending_task", + taskID: "pending-task", + wantStatus: http.StatusOK, + }, + { + name: "cancel_running_task_fails", + taskID: "running-task", + wantStatus: http.StatusNotFound, // Can only cancel pending tasks + }, + { + name: "task_not_found", + taskID: "nonexistent", + wantStatus: http.StatusNotFound, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/work/"+tt.taskID+"/cancel", nil) + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != tt.wantStatus { + t.Errorf("got status %d, want %d; body: %s", rec.Code, tt.wantStatus, rec.Body.String()) + } + }) + } +} + +func TestWorkHandler_GetTask(t *testing.T) { + mockQueue := newMockWorkQueue() + workService := service.NewWorkService(mockQueue, service.WorkServiceConfig{}) + handler := NewWorkHandler(workService) + + // Pre-populate a task + mockQueue.tasks["task-1"] = &domain.WorkTask{ + ID: "task-1", + ProjectID: "test-project", + Type: domain.WorkTaskTypeBuild, + Status: domain.WorkTaskStatusRunning, + Spec: map[string]any{ + "prompt": "Build it", + }, + CreatedAt: time.Now(), + } + + router := chi.NewRouter() + handler.Mount(router) + + tests := []struct { + name string + taskID string + wantStatus int + }{ + { + name: "get_existing_task", + taskID: "task-1", + wantStatus: http.StatusOK, + }, + { + name: "task_not_found", + taskID: "nonexistent", + wantStatus: http.StatusNotFound, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/work/"+tt.taskID, nil) + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != tt.wantStatus { + t.Errorf("got status %d, want %d; body: %s", rec.Code, tt.wantStatus, rec.Body.String()) + } + }) + } +} + +func TestWorkHandler_ListByProject(t *testing.T) { + mockQueue := newMockWorkQueue() + workService := service.NewWorkService(mockQueue, service.WorkServiceConfig{}) + handler := NewWorkHandler(workService) + + // Pre-populate tasks + mockQueue.tasks["task-1"] = &domain.WorkTask{ + ID: "task-1", + ProjectID: "project-a", + Type: domain.WorkTaskTypeBuild, + Status: domain.WorkTaskStatusPending, + CreatedAt: time.Now(), + } + mockQueue.tasks["task-2"] = &domain.WorkTask{ + ID: "task-2", + ProjectID: "project-a", + Type: domain.WorkTaskTypeTest, + Status: domain.WorkTaskStatusCompleted, + CreatedAt: time.Now(), + } + mockQueue.tasks["task-3"] = &domain.WorkTask{ + ID: "task-3", + ProjectID: "project-b", + Type: domain.WorkTaskTypeDeploy, + Status: domain.WorkTaskStatusRunning, + CreatedAt: time.Now(), + } + + router := chi.NewRouter() + handler.Mount(router) + + t.Run("list_all_for_project", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/work/projects/project-a", nil) + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("got status %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String()) + } + + var resp map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + data := resp["data"].(map[string]any) + total := int(data["total"].(float64)) + if total != 2 { + t.Errorf("got %d tasks, want 2", total) + } + // Verify pagination metadata is present + if _, ok := data["limit"]; !ok { + t.Error("expected limit in response") + } + if _, ok := data["offset"]; !ok { + t.Error("expected offset in response") + } + }) + + t.Run("list_with_status_filter", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/work/projects/project-a?status=pending", nil) + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("got status %d, want %d", rec.Code, http.StatusOK) + } + + var resp map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + data := resp["data"].(map[string]any) + total := int(data["total"].(float64)) + if total != 1 { + t.Errorf("got %d tasks, want 1", total) + } + }) + + t.Run("list_with_pagination", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/work/projects/project-a?limit=1&offset=0", nil) + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("got status %d, want %d", rec.Code, http.StatusOK) + } + + var resp map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + data := resp["data"].(map[string]any) + + // Total should reflect all matching tasks + total := int(data["total"].(float64)) + if total != 2 { + t.Errorf("got total=%d, want 2", total) + } + + // But tasks returned should be limited + tasks := data["tasks"].([]any) + if len(tasks) != 1 { + t.Errorf("got %d tasks returned, want 1", len(tasks)) + } + + // Verify limit/offset are reflected + if int(data["limit"].(float64)) != 1 { + t.Errorf("got limit=%v, want 1", data["limit"]) + } + if int(data["offset"].(float64)) != 0 { + t.Errorf("got offset=%v, want 0", data["offset"]) + } + }) + + t.Run("invalid_limit", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/work/projects/project-a?limit=invalid", nil) + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Errorf("got status %d, want %d", rec.Code, http.StatusBadRequest) + } + }) + + t.Run("invalid_offset", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/work/projects/project-a?offset=invalid", nil) + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Errorf("got status %d, want %d", rec.Code, http.StatusBadRequest) + } + }) + + t.Run("invalid_status_filter", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/work/projects/project-a?status=invalid", nil) + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Errorf("got status %d, want %d", rec.Code, http.StatusBadRequest) + } + }) +} + +func TestWorkHandler_Stats(t *testing.T) { + mockQueue := newMockWorkQueue() + workService := service.NewWorkService(mockQueue, service.WorkServiceConfig{}) + handler := NewWorkHandler(workService) + + // Pre-populate tasks with various statuses + mockQueue.tasks["task-1"] = &domain.WorkTask{ID: "task-1", Status: domain.WorkTaskStatusPending} + mockQueue.tasks["task-2"] = &domain.WorkTask{ID: "task-2", Status: domain.WorkTaskStatusPending} + mockQueue.tasks["task-3"] = &domain.WorkTask{ID: "task-3", Status: domain.WorkTaskStatusRunning} + mockQueue.tasks["task-4"] = &domain.WorkTask{ID: "task-4", Status: domain.WorkTaskStatusCompleted} + mockQueue.tasks["task-5"] = &domain.WorkTask{ID: "task-5", Status: domain.WorkTaskStatusFailed} + + router := chi.NewRouter() + handler.Mount(router) + + req := httptest.NewRequest(http.MethodGet, "/work/stats", nil) + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("got status %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String()) + } + + var resp map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + data := resp["data"].(map[string]any) + + if int(data["pending"].(float64)) != 2 { + t.Errorf("got pending=%v, want 2", data["pending"]) + } + if int(data["running"].(float64)) != 1 { + t.Errorf("got running=%v, want 1", data["running"]) + } + if int(data["completed"].(float64)) != 1 { + t.Errorf("got completed=%v, want 1", data["completed"]) + } + if int(data["failed"].(float64)) != 1 { + t.Errorf("got failed=%v, want 1", data["failed"]) + } +} diff --git a/internal/handlers/work_test.go b/internal/handlers/work_test.go index 927012e..77adf35 100644 --- a/internal/handlers/work_test.go +++ b/internal/handlers/work_test.go @@ -11,41 +11,40 @@ import ( "github.com/go-chi/chi/v5" "github.com/orchard9/rdev/internal/domain" - "github.com/orchard9/rdev/internal/port" "github.com/orchard9/rdev/internal/service" ) // mockWorkQueue implements port.WorkQueue for testing. type mockWorkQueue struct { - tasks map[string]*port.WorkTask + tasks map[string]*domain.WorkTask err error } func newMockWorkQueue() *mockWorkQueue { return &mockWorkQueue{ - tasks: make(map[string]*port.WorkTask), + tasks: make(map[string]*domain.WorkTask), } } -func (m *mockWorkQueue) Enqueue(ctx context.Context, task *port.WorkTask) (string, error) { +func (m *mockWorkQueue) Enqueue(ctx context.Context, task *domain.WorkTask) (string, error) { if m.err != nil { return "", m.err } id := "task-123" task.ID = id - task.Status = port.WorkTaskStatusPending + task.Status = domain.WorkTaskStatusPending task.CreatedAt = time.Now() m.tasks[id] = task return id, nil } -func (m *mockWorkQueue) Dequeue(ctx context.Context, workerID string) (*port.WorkTask, error) { +func (m *mockWorkQueue) Dequeue(ctx context.Context, workerID string) (*domain.WorkTask, error) { if m.err != nil { return nil, m.err } for _, task := range m.tasks { - if task.Status == port.WorkTaskStatusPending { - task.Status = port.WorkTaskStatusRunning + if task.Status == domain.WorkTaskStatusPending { + task.Status = domain.WorkTaskStatusRunning task.WorkerID = workerID now := time.Now() task.StartedAt = &now @@ -55,7 +54,7 @@ func (m *mockWorkQueue) Dequeue(ctx context.Context, workerID string) (*port.Wor return nil, nil } -func (m *mockWorkQueue) Complete(ctx context.Context, taskID string, result *port.WorkResult) error { +func (m *mockWorkQueue) Complete(ctx context.Context, taskID string, result *domain.WorkResult) error { if m.err != nil { return m.err } @@ -63,7 +62,7 @@ func (m *mockWorkQueue) Complete(ctx context.Context, taskID string, result *por if !ok { return domain.ErrWorkTaskNotFound } - task.Status = port.WorkTaskStatusCompleted + task.Status = domain.WorkTaskStatusCompleted task.Result = result now := time.Now() task.CompletedAt = &now @@ -79,11 +78,11 @@ func (m *mockWorkQueue) Fail(ctx context.Context, taskID string, errMsg string) return domain.ErrWorkTaskNotFound } if task.RetryCount < task.MaxRetries { - task.Status = port.WorkTaskStatusPending + task.Status = domain.WorkTaskStatusPending task.RetryCount++ task.Error = errMsg } else { - task.Status = port.WorkTaskStatusFailed + task.Status = domain.WorkTaskStatusFailed task.Error = errMsg now := time.Now() task.CompletedAt = &now @@ -99,16 +98,16 @@ func (m *mockWorkQueue) Cancel(ctx context.Context, taskID string) error { if !ok { return domain.ErrWorkTaskNotFound } - if task.Status != port.WorkTaskStatusPending { + if task.Status != domain.WorkTaskStatusPending { return domain.ErrWorkTaskNotFound } - task.Status = port.WorkTaskStatusCancelled + task.Status = domain.WorkTaskStatusCancelled now := time.Now() task.CompletedAt = &now return nil } -func (m *mockWorkQueue) GetTask(ctx context.Context, taskID string) (*port.WorkTask, error) { +func (m *mockWorkQueue) GetTask(ctx context.Context, taskID string) (*domain.WorkTask, error) { if m.err != nil { return nil, m.err } @@ -119,13 +118,13 @@ func (m *mockWorkQueue) GetTask(ctx context.Context, taskID string) (*port.WorkT return task, nil } -func (m *mockWorkQueue) ListByProject(ctx context.Context, projectID string, status *port.WorkTaskStatus, opts port.WorkListOptions) (*port.WorkListResult, error) { +func (m *mockWorkQueue) ListByProject(ctx context.Context, projectID string, status *domain.WorkTaskStatus, opts domain.WorkListOptions) (*domain.WorkListResult, error) { if m.err != nil { return nil, m.err } opts.Normalize() - var tasks []*port.WorkTask + var tasks []*domain.WorkTask for _, task := range m.tasks { if task.ProjectID == projectID { if status == nil || task.Status == *status { @@ -146,7 +145,7 @@ func (m *mockWorkQueue) ListByProject(ctx context.Context, projectID string, sta tasks = tasks[opts.Offset:end] } - return &port.WorkListResult{ + return &domain.WorkListResult{ Tasks: tasks, Total: total, Limit: opts.Limit, @@ -154,22 +153,22 @@ func (m *mockWorkQueue) ListByProject(ctx context.Context, projectID string, sta }, nil } -func (m *mockWorkQueue) GetStats(ctx context.Context) (*port.WorkQueueStats, error) { +func (m *mockWorkQueue) GetStats(ctx context.Context) (*domain.WorkQueueStats, error) { if m.err != nil { return nil, m.err } - stats := &port.WorkQueueStats{} + stats := &domain.WorkQueueStats{} for _, task := range m.tasks { switch task.Status { - case port.WorkTaskStatusPending: + case domain.WorkTaskStatusPending: stats.Pending++ - case port.WorkTaskStatusRunning: + case domain.WorkTaskStatusRunning: stats.Running++ - case port.WorkTaskStatusCompleted: + case domain.WorkTaskStatusCompleted: stats.Completed++ - case port.WorkTaskStatusFailed: + case domain.WorkTaskStatusFailed: stats.Failed++ - case port.WorkTaskStatusCancelled: + case domain.WorkTaskStatusCancelled: stats.Cancelled++ } } @@ -292,11 +291,11 @@ func TestWorkHandler_Dequeue(t *testing.T) { handler := NewWorkHandler(workService) // Pre-populate a pending task - mockQueue.tasks["task-1"] = &port.WorkTask{ + mockQueue.tasks["task-1"] = &domain.WorkTask{ ID: "task-1", ProjectID: "test-project", - Type: port.WorkTaskTypeBuild, - Status: port.WorkTaskStatusPending, + Type: domain.WorkTaskTypeBuild, + Status: domain.WorkTaskStatusPending, CreatedAt: time.Now(), } @@ -359,11 +358,11 @@ func TestWorkHandler_Complete(t *testing.T) { handler := NewWorkHandler(workService) // Pre-populate a running task - mockQueue.tasks["task-1"] = &port.WorkTask{ + mockQueue.tasks["task-1"] = &domain.WorkTask{ ID: "task-1", ProjectID: "test-project", - Type: port.WorkTaskTypeBuild, - Status: port.WorkTaskStatusRunning, + Type: domain.WorkTaskTypeBuild, + Status: domain.WorkTaskStatusRunning, WorkerID: "worker-1", CreatedAt: time.Now(), } @@ -411,370 +410,3 @@ func TestWorkHandler_Complete(t *testing.T) { }) } } - -func TestWorkHandler_Fail(t *testing.T) { - mockQueue := newMockWorkQueue() - workService := service.NewWorkService(mockQueue, service.WorkServiceConfig{}) - handler := NewWorkHandler(workService) - - // Pre-populate a running task - mockQueue.tasks["task-1"] = &port.WorkTask{ - ID: "task-1", - ProjectID: "test-project", - Type: port.WorkTaskTypeBuild, - Status: port.WorkTaskStatusRunning, - WorkerID: "worker-1", - MaxRetries: 3, - CreatedAt: time.Now(), - } - - router := chi.NewRouter() - handler.Mount(router) - - tests := []struct { - name string - taskID string - body FailWorkRequest - wantStatus int - }{ - { - name: "valid_fail", - taskID: "task-1", - body: FailWorkRequest{Error: "Build failed: npm error"}, - wantStatus: http.StatusOK, - }, - { - name: "missing_error", - taskID: "task-1", - body: FailWorkRequest{}, - wantStatus: http.StatusBadRequest, - }, - { - name: "task_not_found", - taskID: "nonexistent", - body: FailWorkRequest{Error: "Failed"}, - wantStatus: http.StatusNotFound, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - body, _ := json.Marshal(tt.body) - req := httptest.NewRequest(http.MethodPost, "/work/"+tt.taskID+"/fail", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - rec := httptest.NewRecorder() - - router.ServeHTTP(rec, req) - - if rec.Code != tt.wantStatus { - t.Errorf("got status %d, want %d; body: %s", rec.Code, tt.wantStatus, rec.Body.String()) - } - }) - } -} - -func TestWorkHandler_Cancel(t *testing.T) { - mockQueue := newMockWorkQueue() - workService := service.NewWorkService(mockQueue, service.WorkServiceConfig{}) - handler := NewWorkHandler(workService) - - // Pre-populate tasks - mockQueue.tasks["pending-task"] = &port.WorkTask{ - ID: "pending-task", - ProjectID: "test-project", - Type: port.WorkTaskTypeBuild, - Status: port.WorkTaskStatusPending, - CreatedAt: time.Now(), - } - mockQueue.tasks["running-task"] = &port.WorkTask{ - ID: "running-task", - ProjectID: "test-project", - Type: port.WorkTaskTypeBuild, - Status: port.WorkTaskStatusRunning, - CreatedAt: time.Now(), - } - - router := chi.NewRouter() - handler.Mount(router) - - tests := []struct { - name string - taskID string - wantStatus int - }{ - { - name: "cancel_pending_task", - taskID: "pending-task", - wantStatus: http.StatusOK, - }, - { - name: "cancel_running_task_fails", - taskID: "running-task", - wantStatus: http.StatusNotFound, // Can only cancel pending tasks - }, - { - name: "task_not_found", - taskID: "nonexistent", - wantStatus: http.StatusNotFound, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - req := httptest.NewRequest(http.MethodPost, "/work/"+tt.taskID+"/cancel", nil) - rec := httptest.NewRecorder() - - router.ServeHTTP(rec, req) - - if rec.Code != tt.wantStatus { - t.Errorf("got status %d, want %d; body: %s", rec.Code, tt.wantStatus, rec.Body.String()) - } - }) - } -} - -func TestWorkHandler_GetTask(t *testing.T) { - mockQueue := newMockWorkQueue() - workService := service.NewWorkService(mockQueue, service.WorkServiceConfig{}) - handler := NewWorkHandler(workService) - - // Pre-populate a task - mockQueue.tasks["task-1"] = &port.WorkTask{ - ID: "task-1", - ProjectID: "test-project", - Type: port.WorkTaskTypeBuild, - Status: port.WorkTaskStatusRunning, - Spec: map[string]any{ - "prompt": "Build it", - }, - CreatedAt: time.Now(), - } - - router := chi.NewRouter() - handler.Mount(router) - - tests := []struct { - name string - taskID string - wantStatus int - }{ - { - name: "get_existing_task", - taskID: "task-1", - wantStatus: http.StatusOK, - }, - { - name: "task_not_found", - taskID: "nonexistent", - wantStatus: http.StatusNotFound, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/work/"+tt.taskID, nil) - rec := httptest.NewRecorder() - - router.ServeHTTP(rec, req) - - if rec.Code != tt.wantStatus { - t.Errorf("got status %d, want %d; body: %s", rec.Code, tt.wantStatus, rec.Body.String()) - } - }) - } -} - -func TestWorkHandler_ListByProject(t *testing.T) { - mockQueue := newMockWorkQueue() - workService := service.NewWorkService(mockQueue, service.WorkServiceConfig{}) - handler := NewWorkHandler(workService) - - // Pre-populate tasks - mockQueue.tasks["task-1"] = &port.WorkTask{ - ID: "task-1", - ProjectID: "project-a", - Type: port.WorkTaskTypeBuild, - Status: port.WorkTaskStatusPending, - CreatedAt: time.Now(), - } - mockQueue.tasks["task-2"] = &port.WorkTask{ - ID: "task-2", - ProjectID: "project-a", - Type: port.WorkTaskTypeTest, - Status: port.WorkTaskStatusCompleted, - CreatedAt: time.Now(), - } - mockQueue.tasks["task-3"] = &port.WorkTask{ - ID: "task-3", - ProjectID: "project-b", - Type: port.WorkTaskTypeDeploy, - Status: port.WorkTaskStatusRunning, - CreatedAt: time.Now(), - } - - router := chi.NewRouter() - handler.Mount(router) - - t.Run("list_all_for_project", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/work/projects/project-a", nil) - rec := httptest.NewRecorder() - - router.ServeHTTP(rec, req) - - if rec.Code != http.StatusOK { - t.Errorf("got status %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String()) - } - - var resp map[string]any - if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { - t.Fatalf("failed to unmarshal: %v", err) - } - data := resp["data"].(map[string]any) - total := int(data["total"].(float64)) - if total != 2 { - t.Errorf("got %d tasks, want 2", total) - } - // Verify pagination metadata is present - if _, ok := data["limit"]; !ok { - t.Error("expected limit in response") - } - if _, ok := data["offset"]; !ok { - t.Error("expected offset in response") - } - }) - - t.Run("list_with_status_filter", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/work/projects/project-a?status=pending", nil) - rec := httptest.NewRecorder() - - router.ServeHTTP(rec, req) - - if rec.Code != http.StatusOK { - t.Errorf("got status %d, want %d", rec.Code, http.StatusOK) - } - - var resp map[string]any - if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { - t.Fatalf("failed to unmarshal: %v", err) - } - data := resp["data"].(map[string]any) - total := int(data["total"].(float64)) - if total != 1 { - t.Errorf("got %d tasks, want 1", total) - } - }) - - t.Run("list_with_pagination", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/work/projects/project-a?limit=1&offset=0", nil) - rec := httptest.NewRecorder() - - router.ServeHTTP(rec, req) - - if rec.Code != http.StatusOK { - t.Errorf("got status %d, want %d", rec.Code, http.StatusOK) - } - - var resp map[string]any - if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { - t.Fatalf("failed to unmarshal: %v", err) - } - data := resp["data"].(map[string]any) - - // Total should reflect all matching tasks - total := int(data["total"].(float64)) - if total != 2 { - t.Errorf("got total=%d, want 2", total) - } - - // But tasks returned should be limited - tasks := data["tasks"].([]any) - if len(tasks) != 1 { - t.Errorf("got %d tasks returned, want 1", len(tasks)) - } - - // Verify limit/offset are reflected - if int(data["limit"].(float64)) != 1 { - t.Errorf("got limit=%v, want 1", data["limit"]) - } - if int(data["offset"].(float64)) != 0 { - t.Errorf("got offset=%v, want 0", data["offset"]) - } - }) - - t.Run("invalid_limit", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/work/projects/project-a?limit=invalid", nil) - rec := httptest.NewRecorder() - - router.ServeHTTP(rec, req) - - if rec.Code != http.StatusBadRequest { - t.Errorf("got status %d, want %d", rec.Code, http.StatusBadRequest) - } - }) - - t.Run("invalid_offset", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/work/projects/project-a?offset=invalid", nil) - rec := httptest.NewRecorder() - - router.ServeHTTP(rec, req) - - if rec.Code != http.StatusBadRequest { - t.Errorf("got status %d, want %d", rec.Code, http.StatusBadRequest) - } - }) - - t.Run("invalid_status_filter", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/work/projects/project-a?status=invalid", nil) - rec := httptest.NewRecorder() - - router.ServeHTTP(rec, req) - - if rec.Code != http.StatusBadRequest { - t.Errorf("got status %d, want %d", rec.Code, http.StatusBadRequest) - } - }) -} - -func TestWorkHandler_Stats(t *testing.T) { - mockQueue := newMockWorkQueue() - workService := service.NewWorkService(mockQueue, service.WorkServiceConfig{}) - handler := NewWorkHandler(workService) - - // Pre-populate tasks with various statuses - mockQueue.tasks["task-1"] = &port.WorkTask{ID: "task-1", Status: port.WorkTaskStatusPending} - mockQueue.tasks["task-2"] = &port.WorkTask{ID: "task-2", Status: port.WorkTaskStatusPending} - mockQueue.tasks["task-3"] = &port.WorkTask{ID: "task-3", Status: port.WorkTaskStatusRunning} - mockQueue.tasks["task-4"] = &port.WorkTask{ID: "task-4", Status: port.WorkTaskStatusCompleted} - mockQueue.tasks["task-5"] = &port.WorkTask{ID: "task-5", Status: port.WorkTaskStatusFailed} - - router := chi.NewRouter() - handler.Mount(router) - - req := httptest.NewRequest(http.MethodGet, "/work/stats", nil) - rec := httptest.NewRecorder() - - router.ServeHTTP(rec, req) - - if rec.Code != http.StatusOK { - t.Errorf("got status %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String()) - } - - var resp map[string]any - if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { - t.Fatalf("failed to unmarshal: %v", err) - } - data := resp["data"].(map[string]any) - - if int(data["pending"].(float64)) != 2 { - t.Errorf("got pending=%v, want 2", data["pending"]) - } - if int(data["running"].(float64)) != 1 { - t.Errorf("got running=%v, want 1", data["running"]) - } - if int(data["completed"].(float64)) != 1 { - t.Errorf("got completed=%v, want 1", data["completed"]) - } - if int(data["failed"].(float64)) != 1 { - t.Errorf("got failed=%v, want 1", data["failed"]) - } -} diff --git a/internal/handlers/workers.go b/internal/handlers/workers.go new file mode 100644 index 0000000..5cc1309 --- /dev/null +++ b/internal/handlers/workers.go @@ -0,0 +1,162 @@ +// Package handlers provides HTTP handlers for the rdev API. +package handlers + +import ( + "errors" + "net/http" + + "github.com/go-chi/chi/v5" + "github.com/orchard9/rdev/internal/auth" + "github.com/orchard9/rdev/internal/domain" + "github.com/orchard9/rdev/internal/port" + "github.com/orchard9/rdev/internal/service" + "github.com/orchard9/rdev/pkg/api" +) + +// WorkersHandler handles worker pool management endpoints. +type WorkersHandler struct { + workerService *service.WorkerService +} + +// NewWorkersHandler creates a new workers handler. +func NewWorkersHandler(workerService *service.WorkerService) *WorkersHandler { + return &WorkersHandler{ + workerService: workerService, + } +} + +// Mount registers the worker pool routes. +func (h *WorkersHandler) Mount(r api.Router) { + r.Route("/workers", func(r chi.Router) { + r.With(auth.RequireScope(auth.ScopeWorkersRead, auth.ScopeAdmin)).Get("/", h.List) + r.With(auth.RequireScope(auth.ScopeWorkersRead, auth.ScopeAdmin)).Get("/{workerId}", h.Get) + r.With(auth.RequireScope(auth.ScopeWorkersWrite, auth.ScopeAdmin)).Post("/{workerId}/drain", h.Drain) + }) +} + +// WorkerDTO is the data transfer object for workers. +type WorkerDTO struct { + ID string `json:"id"` + Hostname string `json:"hostname"` + Status string `json:"status"` + CurrentTask string `json:"current_task,omitempty"` + Capabilities []string `json:"capabilities,omitempty"` + RegisteredAt string `json:"registered_at"` + LastHeartbeat string `json:"last_heartbeat"` + Version string `json:"version,omitempty"` +} + +func toWorkerDTO(w *domain.Worker) *WorkerDTO { + if w == nil { + return nil + } + return &WorkerDTO{ + ID: w.ID, + Hostname: w.Hostname, + Status: string(w.Status), + CurrentTask: w.CurrentTask, + Capabilities: w.Capabilities, + RegisteredAt: w.RegisteredAt.Format("2006-01-02T15:04:05Z07:00"), + LastHeartbeat: w.LastHeartbeat.Format("2006-01-02T15:04:05Z07:00"), + Version: w.Version, + } +} + +// List returns all workers with optional status filter. +// GET /workers?status=idle +func (h *WorkersHandler) List(w http.ResponseWriter, r *http.Request) { + filter := port.WorkerFilter{} + + if s := r.URL.Query().Get("status"); s != "" { + st := domain.WorkerStatus(s) + if !st.IsValid() { + api.WriteBadRequest(w, r, "invalid status: must be idle, busy, draining, or offline") + return + } + filter.Status = &st + } + + workers, err := h.workerService.ListWorkers(r.Context(), filter) + if err != nil { + api.WriteInternalError(w, r, "failed to list workers") + return + } + + dtos := make([]*WorkerDTO, len(workers)) + for i, wkr := range workers { + dtos[i] = toWorkerDTO(wkr) + } + + // Compute summary counts + idle, busy, draining, offline := 0, 0, 0, 0 + for _, wkr := range workers { + switch wkr.Status { + case domain.WorkerStatusIdle: + idle++ + case domain.WorkerStatusBusy: + busy++ + case domain.WorkerStatusDraining: + draining++ + case domain.WorkerStatusOffline: + offline++ + } + } + + api.WriteSuccess(w, r, map[string]any{ + "workers": dtos, + "total": len(dtos), + "summary": map[string]int{ + "idle": idle, + "busy": busy, + "draining": draining, + "offline": offline, + }, + }) +} + +// Get returns a specific worker by ID. +// GET /workers/{workerId} +func (h *WorkersHandler) Get(w http.ResponseWriter, r *http.Request) { + workerID := chi.URLParam(r, "workerId") + if workerID == "" { + api.WriteBadRequest(w, r, "worker ID is required") + return + } + + worker, err := h.workerService.GetWorker(r.Context(), workerID) + if err != nil { + if errors.Is(err, domain.ErrWorkerNotFound) { + api.WriteNotFound(w, r, "worker not found: "+workerID) + return + } + api.WriteInternalError(w, r, "failed to get worker") + return + } + + api.WriteSuccess(w, r, toWorkerDTO(worker)) +} + +// Drain sets a worker to draining status. +// POST /workers/{workerId}/drain +func (h *WorkersHandler) Drain(w http.ResponseWriter, r *http.Request) { + workerID := chi.URLParam(r, "workerId") + if workerID == "" { + api.WriteBadRequest(w, r, "worker ID is required") + return + } + + if err := h.workerService.DrainWorker(r.Context(), workerID); err != nil { + if errors.Is(err, domain.ErrWorkerNotFound) { + api.WriteNotFound(w, r, "worker not found: "+workerID) + return + } + api.WriteInternalError(w, r, "failed to drain worker") + return + } + + api.WriteSuccess(w, r, map[string]any{ + "worker_id": workerID, + "status": "draining", + "message": "worker will finish current task then stop accepting new work", + }) +} diff --git a/internal/handlers/workers_test.go b/internal/handlers/workers_test.go new file mode 100644 index 0000000..518c335 --- /dev/null +++ b/internal/handlers/workers_test.go @@ -0,0 +1,304 @@ +package handlers + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-chi/chi/v5" + "github.com/orchard9/rdev/internal/domain" + "github.com/orchard9/rdev/internal/port" + "github.com/orchard9/rdev/internal/service" +) + +// mockWorkerRegistry implements port.WorkerRegistry for testing. +type mockWorkerRegistry struct { + workers map[string]*domain.Worker + err error +} + +func newMockWorkerRegistry() *mockWorkerRegistry { + return &mockWorkerRegistry{ + workers: make(map[string]*domain.Worker), + } +} + +func (m *mockWorkerRegistry) Register(_ context.Context, w *domain.Worker) error { + if m.err != nil { + return m.err + } + m.workers[w.ID] = w + return nil +} + +func (m *mockWorkerRegistry) Heartbeat(_ context.Context, workerID string) error { + if m.err != nil { + return m.err + } + w, ok := m.workers[workerID] + if !ok { + return domain.ErrWorkerNotFound + } + w.LastHeartbeat = time.Now() + return nil +} + +func (m *mockWorkerRegistry) UpdateStatus(_ context.Context, workerID string, status domain.WorkerStatus, taskID string) error { + if m.err != nil { + return m.err + } + w, ok := m.workers[workerID] + if !ok { + return domain.ErrWorkerNotFound + } + w.Status = status + w.CurrentTask = taskID + return nil +} + +func (m *mockWorkerRegistry) Deregister(_ context.Context, workerID string) error { + if m.err != nil { + return m.err + } + delete(m.workers, workerID) + return nil +} + +func (m *mockWorkerRegistry) Get(_ context.Context, workerID string) (*domain.Worker, error) { + if m.err != nil { + return nil, m.err + } + w, ok := m.workers[workerID] + if !ok { + return nil, domain.ErrWorkerNotFound + } + return w, nil +} + +func (m *mockWorkerRegistry) List(_ context.Context, filter port.WorkerFilter) ([]*domain.Worker, error) { + if m.err != nil { + return nil, m.err + } + var result []*domain.Worker + for _, w := range m.workers { + if filter.Status != nil && w.Status != *filter.Status { + continue + } + result = append(result, w) + } + return result, nil +} + +func (m *mockWorkerRegistry) MarkStaleOffline(_ context.Context, _ time.Duration) (int, error) { + return 0, m.err +} + +func TestWorkersHandler_List(t *testing.T) { + registry := newMockWorkerRegistry() + queue := newMockWorkQueue() + workerService := service.NewWorkerService(registry, queue, nil) + handler := NewWorkersHandler(workerService) + + // Populate workers + registry.workers["worker-1"] = &domain.Worker{ + ID: "worker-1", + Hostname: "host-1", + Status: domain.WorkerStatusIdle, + Capabilities: []string{"build"}, + RegisteredAt: time.Now(), + LastHeartbeat: time.Now(), + Version: "1.0.0", + } + registry.workers["worker-2"] = &domain.Worker{ + ID: "worker-2", + Hostname: "host-2", + Status: domain.WorkerStatusBusy, + CurrentTask: "task-abc", + Capabilities: []string{"build", "deploy"}, + RegisteredAt: time.Now(), + LastHeartbeat: time.Now(), + Version: "1.0.0", + } + + router := chi.NewRouter() + router.Use(testAdminAuth) + handler.Mount(router) + + t.Run("list_all_workers", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/workers", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("got status %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String()) + } + + var resp map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + data, ok := resp["data"].(map[string]any) + if !ok { + t.Fatalf("expected data to be map, got %T", resp["data"]) + } + totalF, ok := data["total"].(float64) + if !ok { + t.Fatalf("expected total to be float64, got %T", data["total"]) + } + if int(totalF) != 2 { + t.Errorf("got total=%d, want 2", int(totalF)) + } + + summary, ok := data["summary"].(map[string]any) + if !ok { + t.Fatalf("expected summary to be map, got %T", data["summary"]) + } + if idleF, ok := summary["idle"].(float64); !ok || int(idleF) != 1 { + t.Errorf("got idle=%v, want 1", summary["idle"]) + } + if busyF, ok := summary["busy"].(float64); !ok || int(busyF) != 1 { + t.Errorf("got busy=%v, want 1", summary["busy"]) + } + }) + + t.Run("filter_by_status", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/workers?status=idle", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("got status %d, want %d", rec.Code, http.StatusOK) + } + + var resp map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + data, ok := resp["data"].(map[string]any) + if !ok { + t.Fatalf("expected data to be map, got %T", resp["data"]) + } + totalF, ok := data["total"].(float64) + if !ok { + t.Fatalf("expected total to be float64, got %T", data["total"]) + } + if int(totalF) != 1 { + t.Errorf("got total=%d, want 1", int(totalF)) + } + }) + + t.Run("invalid_status_filter", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/workers?status=invalid", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Errorf("got status %d, want %d", rec.Code, http.StatusBadRequest) + } + }) +} + +func TestWorkersHandler_Get(t *testing.T) { + registry := newMockWorkerRegistry() + queue := newMockWorkQueue() + workerService := service.NewWorkerService(registry, queue, nil) + handler := NewWorkersHandler(workerService) + + registry.workers["worker-1"] = &domain.Worker{ + ID: "worker-1", + Hostname: "host-1", + Status: domain.WorkerStatusIdle, + RegisteredAt: time.Now(), + LastHeartbeat: time.Now(), + } + + router := chi.NewRouter() + router.Use(testAdminAuth) + handler.Mount(router) + + tests := []struct { + name string + workerID string + wantStatus int + }{ + { + name: "existing_worker", + workerID: "worker-1", + wantStatus: http.StatusOK, + }, + { + name: "not_found", + workerID: "nonexistent", + wantStatus: http.StatusNotFound, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/workers/"+tt.workerID, nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != tt.wantStatus { + t.Errorf("got status %d, want %d; body: %s", rec.Code, tt.wantStatus, rec.Body.String()) + } + }) + } +} + +func TestWorkersHandler_Drain(t *testing.T) { + registry := newMockWorkerRegistry() + queue := newMockWorkQueue() + workerService := service.NewWorkerService(registry, queue, nil) + handler := NewWorkersHandler(workerService) + + registry.workers["worker-1"] = &domain.Worker{ + ID: "worker-1", + Hostname: "host-1", + Status: domain.WorkerStatusBusy, + CurrentTask: "task-abc", + RegisteredAt: time.Now(), + LastHeartbeat: time.Now(), + } + + router := chi.NewRouter() + router.Use(testAdminAuth) + handler.Mount(router) + + tests := []struct { + name string + workerID string + wantStatus int + }{ + { + name: "drain_existing_worker", + workerID: "worker-1", + wantStatus: http.StatusOK, + }, + { + name: "drain_nonexistent_worker", + workerID: "nonexistent", + wantStatus: http.StatusNotFound, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/workers/"+tt.workerID+"/drain", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != tt.wantStatus { + t.Errorf("got status %d, want %d; body: %s", rec.Code, tt.wantStatus, rec.Body.String()) + } + }) + } + + // Verify the worker was actually set to draining + if registry.workers["worker-1"].Status != domain.WorkerStatusDraining { + t.Errorf("expected worker status to be draining, got %s", registry.workers["worker-1"].Status) + } +} diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go index 2e01162..1cd24be 100644 --- a/internal/metrics/metrics.go +++ b/internal/metrics/metrics.go @@ -25,6 +25,57 @@ var ( Buckets: prometheus.ExponentialBuckets(0.1, 2, 15), // 0.1s to ~27min }, []string{"project", "type"}) + // Code Agents + agentRequestsTotal = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "rdev_agent_requests_total", + Help: "Total number of code agent requests", + }, []string{"provider", "status"}) + + agentRequestDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{ + Name: "rdev_agent_request_duration_seconds", + Help: "Duration of code agent requests in seconds", + Buckets: prometheus.ExponentialBuckets(0.1, 2, 15), // 0.1s to ~27min + }, []string{"provider"}) + + agentToolUse = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "rdev_agent_tool_use_total", + Help: "Total number of tool invocations by code agents", + }, []string{"provider", "tool"}) + + agentAvailability = promauto.NewGaugeVec(prometheus.GaugeOpts{ + Name: "rdev_agent_available", + Help: "Whether the code agent is available (1) or not (0)", + }, []string{"provider"}) + + // Worker Pool + workersTotal = promauto.NewGaugeVec(prometheus.GaugeOpts{ + Name: "rdev_workers_total", + Help: "Number of registered workers by status", + }, []string{"status"}) + + workerHeartbeatAge = promauto.NewGaugeVec(prometheus.GaugeOpts{ + Name: "rdev_worker_heartbeat_age_seconds", + Help: "Age of the most recent worker heartbeat in seconds", + }, []string{"worker_id"}) + + // Builds + buildsTotal = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "rdev_builds_total", + Help: "Total number of build tasks by status", + }, []string{"project", "status"}) + + buildDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{ + Name: "rdev_build_duration_seconds", + Help: "Duration of build executions in seconds", + Buckets: prometheus.ExponentialBuckets(1, 2, 12), // 1s to ~34min + }, []string{"project"}) + + // Work Queue + workQueueDepth = promauto.NewGaugeVec(prometheus.GaugeOpts{ + Name: "rdev_work_queue_depth", + Help: "Number of tasks in the work queue by status", + }, []string{"status"}) + // Streams activeStreams = promauto.NewGaugeVec(prometheus.GaugeOpts{ Name: "rdev_active_streams", @@ -81,6 +132,49 @@ func RecordAuthFailure(reason string) { authFailures.WithLabelValues(reason).Inc() } +// RecordAgentRequest records a code agent request execution. +func RecordAgentRequest(provider, status string, durationMs int64) { + agentRequestsTotal.WithLabelValues(provider, status).Inc() + agentRequestDuration.WithLabelValues(provider).Observe(float64(durationMs) / 1000.0) +} + +// RecordAgentToolUse records a tool invocation by a code agent. +func RecordAgentToolUse(provider, tool string) { + agentToolUse.WithLabelValues(provider, tool).Inc() +} + +// SetAgentAvailability sets the availability status of a code agent. +func SetAgentAvailability(provider string, available bool) { + val := 0.0 + if available { + val = 1.0 + } + agentAvailability.WithLabelValues(provider).Set(val) +} + +// SetWorkerCount sets the number of workers for a given status. +func SetWorkerCount(status string, count int) { + workersTotal.WithLabelValues(status).Set(float64(count)) +} + +// RecordWorkerHeartbeat sets the age of a worker's most recent heartbeat. +func RecordWorkerHeartbeat(workerID string, ageSeconds float64) { + workerHeartbeatAge.WithLabelValues(workerID).Set(ageSeconds) +} + +// RecordBuild records a build task completion. +func RecordBuild(project, status string, durationMs int64) { + buildsTotal.WithLabelValues(project, status).Inc() + if durationMs > 0 { + buildDuration.WithLabelValues(project).Observe(float64(durationMs) / 1000.0) + } +} + +// SetWorkQueueDepth sets the current depth of the work queue for a status. +func SetWorkQueueDepth(status string, count int64) { + workQueueDepth.WithLabelValues(status).Set(float64(count)) +} + // Handler returns the Prometheus HTTP handler. func Handler() http.Handler { return promhttp.Handler() @@ -124,6 +218,10 @@ var pathNormalizers = []struct { }{ // /keys/uuid -> /keys/{id} {regexp.MustCompile(`^/keys/[^/]+$`), "/keys/{id}"}, + // /workers/{id}/... -> /workers/{id}/... + {regexp.MustCompile(`^/workers/[^/]+(/.*)?$`), "/workers/{id}$1"}, + // /builds/{id} -> /builds/{id} + {regexp.MustCompile(`^/builds/[^/]+$`), "/builds/{id}"}, // /projects/{id}/claude-config/{type}/{name} -> /projects/{id}/claude-config/{type}/{name} {regexp.MustCompile(`^/projects/[^/]+/claude-config/(commands|skills|agents)/[^/]+$`), "/projects/{id}/claude-config/$1/{name}"}, // /projects/{id}/... (any sub-path) - must be last as it's most general diff --git a/internal/middleware/rate_limit.go b/internal/middleware/rate_limit.go index 195f903..39576a0 100644 --- a/internal/middleware/rate_limit.go +++ b/internal/middleware/rate_limit.go @@ -63,7 +63,7 @@ func RateLimitMiddleware(cfg RateLimitConfig) func(http.Handler) http.Handler { } // Skip rate limiting for admin keys - if apiKey.ID == "admin" { + if string(apiKey.ID) == "admin" { next.ServeHTTP(w, r) return } @@ -71,7 +71,7 @@ func RateLimitMiddleware(cfg RateLimitConfig) func(http.Handler) http.Handler { // Check rate limit and record atomically to prevent race conditions // RecordRequest is called first to ensure the count is incremented before // we check, preventing burst bypass under high concurrency - if err := cfg.Limiter.RecordRequest(r.Context(), apiKey.ID); err != nil { + if err := cfg.Limiter.RecordRequest(r.Context(), string(apiKey.ID)); err != nil { logger.Error("failed to record rate limit request", "error", err, "key_id", apiKey.ID) // On error, allow the request (fail open) next.ServeHTTP(w, r) @@ -79,7 +79,7 @@ func RateLimitMiddleware(cfg RateLimitConfig) func(http.Handler) http.Handler { } // Now check the limit (which includes the just-recorded request) - result, err := cfg.Limiter.CheckLimit(r.Context(), apiKey.ID) + result, err := cfg.Limiter.CheckLimit(r.Context(), string(apiKey.ID)) if err != nil { logger.Error("failed to check rate limit", "error", err, "key_id", apiKey.ID) // On error, allow the request (fail open) diff --git a/internal/port/build_audit.go b/internal/port/build_audit.go new file mode 100644 index 0000000..c248861 --- /dev/null +++ b/internal/port/build_audit.go @@ -0,0 +1,46 @@ +// Package port defines interface contracts for external adapters. +package port + +import ( + "context" + "time" + + "github.com/orchard9/rdev/internal/domain" +) + +// BuildAudit records build history for observability and debugging. +// Every build that passes through the system gets an audit entry, +// providing a complete history of what was requested, who executed it, +// and what the outcome was. +type BuildAudit interface { + // Record creates a new audit entry when a build starts. + Record(ctx context.Context, entry *domain.BuildAuditEntry) error + + // Update modifies an existing entry when a build completes. + Update(ctx context.Context, taskID string, result *domain.BuildResult) error + + // Get retrieves a specific audit entry by task ID. + // Returns ErrBuildNotFound if the entry does not exist. + Get(ctx context.Context, taskID string) (*domain.BuildAuditEntry, error) + + // List returns audit entries matching the filter. + List(ctx context.Context, filter BuildAuditFilter) ([]*domain.BuildAuditEntry, error) +} + +// BuildAuditFilter specifies criteria for listing audit entries. +type BuildAuditFilter struct { + // ProjectID filters entries by project. + ProjectID string + + // WorkerID filters entries by worker. + WorkerID string + + // Status filters entries by build status. Nil means all statuses. + Status *domain.BuildStatus + + // Since filters entries created after this time. + Since time.Time + + // Limit is the maximum number of entries to return. + Limit int +} diff --git a/internal/port/ci_provider.go b/internal/port/ci_provider.go index 51f8823..35f9e7b 100644 --- a/internal/port/ci_provider.go +++ b/internal/port/ci_provider.go @@ -29,4 +29,10 @@ type CIProvider interface { // DeleteSecret removes a secret from a repository. DeleteSecret(ctx context.Context, owner, repo, secretName string) error + + // ListPipelines returns recent CI pipeline executions for a repository. + ListPipelines(ctx context.Context, owner, repo string) ([]*domain.CIPipeline, error) + + // GetPipeline returns a specific pipeline execution by number. + GetPipeline(ctx context.Context, owner, repo string, number int64) (*domain.CIPipeline, error) } diff --git a/internal/port/code_agent.go b/internal/port/code_agent.go index b420b4a..2c1be31 100644 --- a/internal/port/code_agent.go +++ b/internal/port/code_agent.go @@ -49,6 +49,10 @@ type CodeAgentRegistry interface { // Returns nil if no agents are registered. Default() CodeAgent + // DefaultProvider returns the current default provider. + // Returns empty string if no agents are registered. + DefaultProvider() domain.AgentProvider + // SetDefault sets which provider should be used as the default. // Returns error if the provider is not registered. SetDefault(provider domain.AgentProvider) error @@ -58,4 +62,7 @@ type CodeAgentRegistry interface { // AvailableAgents returns all registered agents that are currently available. AvailableAgents(ctx context.Context) []CodeAgent + + // Count returns the number of registered agents. + Count() int } diff --git a/internal/port/credential_store.go b/internal/port/credential_store.go index 05b0d61..8ac2103 100644 --- a/internal/port/credential_store.go +++ b/internal/port/credential_store.go @@ -13,13 +13,15 @@ type CredentialStore interface { // Get retrieves a credential by key. Returns empty string if not found. Get(ctx context.Context, key string) (string, error) - // GetRequired retrieves a credential by key. Returns error if not found. + // GetRequired retrieves a credential by key. + // Returns domain.ErrCredentialNotFound if the key does not exist. GetRequired(ctx context.Context, key string) (string, error) // Set stores or updates a credential. Set(ctx context.Context, cred domain.Credential) error // Delete removes a credential by key. + // Returns domain.ErrCredentialNotFound if the key does not exist. Delete(ctx context.Context, key string) error // List returns all credentials (with values masked). diff --git a/internal/port/health.go b/internal/port/health.go new file mode 100644 index 0000000..de853dc --- /dev/null +++ b/internal/port/health.go @@ -0,0 +1,15 @@ +package port + +import "context" + +// DatabasePinger checks database connectivity. +// *sql.DB satisfies this interface. +type DatabasePinger interface { + PingContext(ctx context.Context) error +} + +// KubernetesChecker checks Kubernetes API connectivity. +type KubernetesChecker interface { + // ServerVersion returns the server version string, or an error if unreachable. + ServerVersion() (string, error) +} diff --git a/internal/port/work_queue.go b/internal/port/work_queue.go index 20c3cc3..67d9993 100644 --- a/internal/port/work_queue.go +++ b/internal/port/work_queue.go @@ -4,6 +4,8 @@ package port import ( "context" "time" + + "github.com/orchard9/rdev/internal/domain" ) // WorkQueue defines operations for the worker pool task queue. @@ -12,15 +14,15 @@ import ( type WorkQueue interface { // Enqueue adds a task to the queue. // Returns the task ID. - Enqueue(ctx context.Context, task *WorkTask) (string, error) + Enqueue(ctx context.Context, task *domain.WorkTask) (string, error) // Dequeue atomically claims the next available task for a worker. // Uses FOR UPDATE SKIP LOCKED for concurrent worker safety. // Returns nil if no tasks are available. - Dequeue(ctx context.Context, workerID string) (*WorkTask, error) + Dequeue(ctx context.Context, workerID string) (*domain.WorkTask, error) // Complete marks a task as successfully completed with results. - Complete(ctx context.Context, taskID string, result *WorkResult) error + Complete(ctx context.Context, taskID string, result *domain.WorkResult) error // Fail marks a task as failed with an error message. // If retry_count < max_retries, the task will be re-queued as pending. @@ -31,13 +33,13 @@ type WorkQueue interface { Cancel(ctx context.Context, taskID string) error // GetTask retrieves a task by ID. - GetTask(ctx context.Context, taskID string) (*WorkTask, error) + GetTask(ctx context.Context, taskID string) (*domain.WorkTask, error) // ListByProject returns tasks for a project with optional status filter and pagination. - ListByProject(ctx context.Context, projectID string, status *WorkTaskStatus, opts WorkListOptions) (*WorkListResult, error) + ListByProject(ctx context.Context, projectID string, status *domain.WorkTaskStatus, opts domain.WorkListOptions) (*domain.WorkListResult, error) // GetStats returns queue statistics. - GetStats(ctx context.Context) (*WorkQueueStats, error) + GetStats(ctx context.Context) (*domain.WorkQueueStats, error) // CleanupOld removes completed/failed/cancelled tasks older than the specified duration. CleanupOld(ctx context.Context, olderThan time.Duration) (int64, error) @@ -46,170 +48,3 @@ type WorkQueue interface { // This handles workers that crashed without reporting completion. RequeueStale(ctx context.Context, timeout time.Duration) (int64, error) } - -// WorkTaskStatus represents the status of a work task. -type WorkTaskStatus string - -const ( - WorkTaskStatusPending WorkTaskStatus = "pending" - WorkTaskStatusRunning WorkTaskStatus = "running" - WorkTaskStatusCompleted WorkTaskStatus = "completed" - WorkTaskStatusFailed WorkTaskStatus = "failed" - WorkTaskStatusCancelled WorkTaskStatus = "cancelled" -) - -// IsValid returns true if the status is a known valid status. -func (s WorkTaskStatus) IsValid() bool { - switch s { - case WorkTaskStatusPending, WorkTaskStatusRunning, WorkTaskStatusCompleted, - WorkTaskStatusFailed, WorkTaskStatusCancelled: - return true - } - return false -} - -// WorkTaskType represents the type of work task. -type WorkTaskType string - -const ( - WorkTaskTypeBuild WorkTaskType = "build" - WorkTaskTypeTest WorkTaskType = "test" - WorkTaskTypeDeploy WorkTaskType = "deploy" - WorkTaskTypeCustom WorkTaskType = "custom" -) - -// IsValid returns true if the task type is a known valid type. -func (t WorkTaskType) IsValid() bool { - switch t { - case WorkTaskTypeBuild, WorkTaskTypeTest, WorkTaskTypeDeploy, WorkTaskTypeCustom: - return true - } - return false -} - -// WorkTask represents a task in the work queue. -type WorkTask struct { - // ID is the unique task identifier. - ID string - - // ProjectID is the project this task belongs to. - ProjectID string - - // Type is the task type (build, test, deploy, custom). - Type WorkTaskType - - // Spec contains task-specific parameters. - // For build tasks: template, prompt, variables, auto_deploy, git_url - // For test tasks: test_command, git_url - // For deploy tasks: image, replicas, env - Spec map[string]any - - // Status is the current task status. - Status WorkTaskStatus - - // Priority determines execution order (higher = more urgent). - Priority int - - // WorkerID is the ID of the worker that claimed this task. - WorkerID string - - // CallbackURL is the webhook URL for completion notification. - CallbackURL string - - // CreatedAt is when the task was created. - CreatedAt time.Time - - // StartedAt is when a worker started executing the task. - StartedAt *time.Time - - // CompletedAt is when the task finished (success or failure). - CompletedAt *time.Time - - // Result contains the task output (if completed). - Result *WorkResult - - // Error contains the error message (if failed). - Error string - - // RetryCount is the number of retry attempts. - RetryCount int - - // MaxRetries is the maximum allowed retry attempts. - MaxRetries int -} - -// WorkResult contains the result of a completed task. -type WorkResult struct { - // Output is the main output from task execution. - Output string `json:"output,omitempty"` - - // Artifacts contains named artifacts from the task. - // For build tasks: commit_sha, deploy_url, etc. - Artifacts map[string]string `json:"artifacts,omitempty"` -} - -// WorkQueueStats contains queue statistics. -type WorkQueueStats struct { - // Pending is the count of pending tasks. - Pending int64 `json:"pending"` - - // Running is the count of running tasks. - Running int64 `json:"running"` - - // Completed is the count of completed tasks (last 24h). - Completed int64 `json:"completed"` - - // Failed is the count of failed tasks (last 24h). - Failed int64 `json:"failed"` - - // Cancelled is the count of cancelled tasks (last 24h). - Cancelled int64 `json:"cancelled"` - - // OldestPending is the age of the oldest pending task. - OldestPending *time.Duration `json:"oldest_pending,omitempty"` -} - -// WorkListOptions contains pagination options for listing tasks. -type WorkListOptions struct { - // Limit is the maximum number of tasks to return (default: 50, max: 100). - Limit int - - // Offset is the number of tasks to skip (for pagination). - Offset int -} - -// DefaultWorkListOptions returns options with default values. -func DefaultWorkListOptions() WorkListOptions { - return WorkListOptions{ - Limit: 50, - Offset: 0, - } -} - -// Normalize applies defaults and limits to the options. -func (o *WorkListOptions) Normalize() { - if o.Limit <= 0 { - o.Limit = 50 - } - if o.Limit > 100 { - o.Limit = 100 - } - if o.Offset < 0 { - o.Offset = 0 - } -} - -// WorkListResult contains paginated task results. -type WorkListResult struct { - // Tasks is the list of tasks. - Tasks []*WorkTask - - // Total is the total count of matching tasks (for pagination metadata). - Total int64 - - // Limit is the limit that was applied. - Limit int - - // Offset is the offset that was applied. - Offset int -} diff --git a/internal/port/worker_registry.go b/internal/port/worker_registry.go new file mode 100644 index 0000000..54d4516 --- /dev/null +++ b/internal/port/worker_registry.go @@ -0,0 +1,51 @@ +// Package port defines interface contracts for external adapters. +package port + +import ( + "context" + "time" + + "github.com/orchard9/rdev/internal/domain" +) + +// WorkerRegistry manages the lifecycle of workers in the pool. +// It handles registration, heartbeats, status updates, and health monitoring. +type WorkerRegistry interface { + // Register adds a worker to the pool. + // If a worker with the same ID already exists, it is re-registered. + Register(ctx context.Context, worker *domain.Worker) error + + // Heartbeat updates the worker's last_heartbeat timestamp. + // Returns ErrWorkerNotFound if the worker does not exist or is offline. + Heartbeat(ctx context.Context, workerID string) error + + // UpdateStatus changes a worker's status and optionally assigns a task. + // Pass empty taskID to clear the current task assignment. + UpdateStatus(ctx context.Context, workerID string, status domain.WorkerStatus, taskID string) error + + // Deregister removes a worker from the pool. + Deregister(ctx context.Context, workerID string) error + + // Get retrieves a specific worker by ID. + // Returns ErrWorkerNotFound if the worker does not exist. + Get(ctx context.Context, workerID string) (*domain.Worker, error) + + // List returns all workers matching the filter. + List(ctx context.Context, filter WorkerFilter) ([]*domain.Worker, error) + + // MarkStaleOffline marks workers without a recent heartbeat as offline. + // Returns the number of workers marked offline. + MarkStaleOffline(ctx context.Context, threshold time.Duration) (int, error) +} + +// WorkerFilter specifies criteria for listing workers. +type WorkerFilter struct { + // Status filters workers by status. Nil means all statuses. + Status *domain.WorkerStatus + + // HasCapability filters workers that have a specific capability. + HasCapability string + + // Limit is the maximum number of workers to return. Zero means no limit. + Limit int +} diff --git a/internal/ratelimit/ratelimit.go b/internal/ratelimit/ratelimit.go index 18fb1f8..d5e7a99 100644 --- a/internal/ratelimit/ratelimit.go +++ b/internal/ratelimit/ratelimit.go @@ -158,7 +158,7 @@ func (l *Limiter) getKey(r *http.Request) string { // Default: use API key ID from context // This requires the auth middleware to run first if apiKey := auth.GetAPIKey(r.Context()); apiKey != nil { - return apiKey.ID + return string(apiKey.ID) } // Fallback: use client IP @@ -258,7 +258,7 @@ func itoa(i int) string { func KeyFromAPIKey() func(*http.Request) string { return func(r *http.Request) string { if apiKey := auth.GetAPIKey(r.Context()); apiKey != nil { - return apiKey.ID + return string(apiKey.ID) } return getClientIP(r) } diff --git a/internal/service/apikey_service.go b/internal/service/apikey_service.go index 0753646..88c2920 100644 --- a/internal/service/apikey_service.go +++ b/internal/service/apikey_service.go @@ -31,6 +31,7 @@ type CreateKeyRequest struct { Name string Scopes []domain.Scope ProjectIDs []domain.ProjectID + AllowedIPs []string // CIDR notation; nil = no restriction ExpiresIn time.Duration CreatedBy string } @@ -43,14 +44,26 @@ type CreateKeyResult struct { // Create generates a new API key. func (s *APIKeyService) Create(ctx context.Context, req CreateKeyRequest) (*CreateKeyResult, error) { - // Generate secret - secret, err := generateSecret() - if err != nil { - return nil, fmt.Errorf("generate secret: %w", err) + // Generate key components using auth-compatible format: + // identifier: 4 random bytes → 8 hex chars + // random: 16 random bytes → 32 hex chars + // Full key: rdev_sk__ + idBytes := make([]byte, 4) + if _, err := rand.Read(idBytes); err != nil { + return nil, fmt.Errorf("generate identifier: %w", err) } + identifier := hex.EncodeToString(idBytes) - // Hash the secret - keyHash := hashKey(secret) + randomBytes := make([]byte, 16) + if _, err := rand.Read(randomBytes); err != nil { + return nil, fmt.Errorf("generate random: %w", err) + } + random := hex.EncodeToString(randomBytes) + + fullKey := fmt.Sprintf("rdev_sk_%s_%s", identifier, random) + + // Hash the full key (what the user receives and sends back for auth) + keyHash := hashKey(fullKey) // Calculate expiration var expiresAt *time.Time @@ -62,9 +75,10 @@ func (s *APIKeyService) Create(ctx context.Context, req CreateKeyRequest) (*Crea // Create key key := &domain.APIKey{ Name: req.Name, - KeyPrefix: secret[:8], + KeyPrefix: identifier, Scopes: req.Scopes, ProjectIDs: req.ProjectIDs, + AllowedIPs: req.AllowedIPs, ExpiresAt: expiresAt, CreatedBy: req.CreatedBy, } @@ -75,7 +89,7 @@ func (s *APIKeyService) Create(ctx context.Context, req CreateKeyRequest) (*Crea return &CreateKeyResult{ Key: key, - Secret: formatSecret(key.KeyPrefix, secret), + Secret: fullKey, }, nil } @@ -105,6 +119,42 @@ func (s *APIKeyService) UpdateLastUsed(ctx context.Context, id domain.APIKeyID) return s.repo.UpdateLastUsed(ctx, id) } +// Validate checks a raw API key and returns the associated APIKey if valid. +// It checks for admin key, looks up by hash, and verifies the key is active. +// On success it asynchronously updates the last-used timestamp. +func (s *APIKeyService) Validate(ctx context.Context, rawKey string) (*domain.APIKey, error) { + // Check admin key first + if s.adminKey != "" && rawKey == s.adminKey { + return &domain.APIKey{ + ID: "admin", + Name: "Super Admin", + KeyPrefix: "admin", + Scopes: []domain.Scope{domain.ScopeAdmin}, + }, nil + } + + keyHash := hashKey(rawKey) + + apiKey, err := s.repo.GetByHash(ctx, keyHash) + if err != nil { + return nil, err + } + + if apiKey.IsRevoked() { + return nil, domain.ErrKeyRevoked + } + if apiKey.IsExpired() { + return nil, domain.ErrKeyExpired + } + + // Update last_used_at asynchronously + go func() { + _ = s.repo.UpdateLastUsed(context.Background(), apiKey.ID) + }() + + return apiKey, nil +} + // ValidateAdminKey checks if the provided key matches the admin key. func (s *APIKeyService) ValidateAdminKey(key string) bool { return s.adminKey != "" && key == s.adminKey @@ -115,26 +165,12 @@ func (s *APIKeyService) AdminKey() string { return s.adminKey } -// generateSecret creates a cryptographically secure random key. -func generateSecret() (string, error) { - bytes := make([]byte, 32) - if _, err := rand.Read(bytes); err != nil { - return "", err - } - return hex.EncodeToString(bytes), nil -} - // hashKey creates a SHA-256 hash of a key. func hashKey(key string) string { hash := sha256.Sum256([]byte(key)) return hex.EncodeToString(hash[:]) } -// formatSecret creates the full secret string with prefix. -func formatSecret(prefix, secret string) string { - return fmt.Sprintf("rdev_sk_%s_%s", prefix, secret[8:]) -} - // ParseExpiration converts a duration string to time.Duration. // Supported formats: "30d", "60d", "90d", "1y", "never" (or empty) func ParseExpiration(s string) (time.Duration, error) { diff --git a/internal/service/apikey_service_test.go b/internal/service/apikey_service_test.go index 5566742..eeacd27 100644 --- a/internal/service/apikey_service_test.go +++ b/internal/service/apikey_service_test.go @@ -319,28 +319,6 @@ func TestParseExpiration(t *testing.T) { } } -func TestGenerateSecret(t *testing.T) { - secrets := make(map[string]bool) - - for i := 0; i < 100; i++ { - secret, err := generateSecret() - if err != nil { - t.Fatalf("generateSecret() error = %v", err) - } - - // Should be 64 hex characters (32 bytes) - if len(secret) != 64 { - t.Errorf("Secret length = %d, want 64", len(secret)) - } - - // Should be unique - if secrets[secret] { - t.Errorf("Duplicate secret generated: %q", secret) - } - secrets[secret] = true - } -} - func TestHashKey(t *testing.T) { // Same input should produce same hash hash1 := hashKey("test-key") @@ -360,12 +338,3 @@ func TestHashKey(t *testing.T) { t.Errorf("Hash length = %d, want 64", len(hash1)) } } - -func TestFormatSecret(t *testing.T) { - result := formatSecret("abcd1234", "abcd12345678rest") - expected := "rdev_sk_abcd1234_5678rest" - - if result != expected { - t.Errorf("formatSecret() = %q, want %q", result, expected) - } -} diff --git a/internal/service/build_service.go b/internal/service/build_service.go new file mode 100644 index 0000000..3151562 --- /dev/null +++ b/internal/service/build_service.go @@ -0,0 +1,132 @@ +// Package service provides business logic services. +package service + +import ( + "context" + "fmt" + "log/slog" + "time" + + "github.com/orchard9/rdev/internal/domain" + "github.com/orchard9/rdev/internal/port" +) + +// BuildService orchestrates build task submission and tracking. +// It coordinates between the work queue (execution) and build audit (history). +type BuildService struct { + queue port.WorkQueue + audit port.BuildAudit + logger *slog.Logger +} + +// NewBuildService creates a new build service. +func NewBuildService( + queue port.WorkQueue, + audit port.BuildAudit, + logger *slog.Logger, +) *BuildService { + if logger == nil { + logger = slog.Default() + } + return &BuildService{ + queue: queue, + audit: audit, + logger: logger.With("service", "build"), + } +} + +// StartBuild enqueues a build task and creates an audit entry. +// Returns the task ID for status tracking. +func (s *BuildService) StartBuild(ctx context.Context, projectID string, spec domain.BuildSpec) (string, error) { + if err := spec.Validate(); err != nil { + return "", err + } + + if projectID == "" { + return "", fmt.Errorf("project_id is required") + } + + // Build work task spec from build spec + taskSpec := map[string]any{ + "prompt": spec.Prompt, + "auto_commit": spec.AutoCommit, + "auto_push": spec.AutoPush, + } + if spec.Template != "" { + taskSpec["template"] = spec.Template + } + if len(spec.Variables) > 0 { + taskSpec["variables"] = spec.Variables + } + + // Create work task + task := &domain.WorkTask{ + ProjectID: projectID, + Type: domain.WorkTaskTypeBuild, + Spec: taskSpec, + CallbackURL: spec.CallbackURL, + MaxRetries: 3, + } + + // Enqueue to work queue + taskID, err := s.queue.Enqueue(ctx, task) + if err != nil { + return "", fmt.Errorf("enqueue build task: %w", err) + } + + // Create audit entry (non-critical - don't fail the build if audit fails) + auditEntry := &domain.BuildAuditEntry{ + TaskID: taskID, + ProjectID: projectID, + Spec: spec, + Status: domain.BuildStatusPending, + StartedAt: time.Now(), + } + if err := s.audit.Record(ctx, auditEntry); err != nil { + s.logger.Warn("failed to record audit entry", + "task_id", taskID, + "error", err, + ) + } + + s.logger.Info("build enqueued", + "task_id", taskID, + "project_id", projectID, + "template", spec.Template, + "auto_push", spec.AutoPush, + ) + + return taskID, nil +} + +// GetBuildStatus returns the current status of a build. +func (s *BuildService) GetBuildStatus(ctx context.Context, taskID string) (*domain.BuildAuditEntry, error) { + return s.audit.Get(ctx, taskID) +} + +// ListBuilds returns build history for a project. +func (s *BuildService) ListBuilds(ctx context.Context, projectID string, limit int) ([]*domain.BuildAuditEntry, error) { + if limit <= 0 { + limit = 50 + } + return s.audit.List(ctx, port.BuildAuditFilter{ + ProjectID: projectID, + Limit: limit, + }) +} + +// CompleteBuild updates the audit entry when a build finishes. +// Called by the work queue processor on task completion. +func (s *BuildService) CompleteBuild(ctx context.Context, taskID string, result *domain.BuildResult) error { + if err := s.audit.Update(ctx, taskID, result); err != nil { + return fmt.Errorf("update audit: %w", err) + } + + s.logger.Info("build completed", + "task_id", taskID, + "success", result.Success, + "duration_ms", result.DurationMs, + ) + + return nil +} diff --git a/internal/service/build_service_test.go b/internal/service/build_service_test.go new file mode 100644 index 0000000..a14a513 --- /dev/null +++ b/internal/service/build_service_test.go @@ -0,0 +1,215 @@ +package service + +import ( + "context" + "fmt" + "testing" + + "github.com/orchard9/rdev/internal/domain" +) + +func TestBuildService_StartBuild(t *testing.T) { + ctx := context.Background() + + t.Run("enqueues build successfully", func(t *testing.T) { + queue := newMockWorkQueue() + audit := newMockBuildAudit() + svc := NewBuildService(queue, audit, nil) + + taskID, err := svc.StartBuild(ctx, "project-1", domain.BuildSpec{ + Prompt: "Build a landing page", + Template: "nextjs", + }) + if err != nil { + t.Fatalf("StartBuild() error = %v", err) + } + if taskID == "" { + t.Error("expected non-empty task ID") + } + + // Verify task was enqueued + if len(queue.tasks) != 1 { + t.Errorf("expected 1 task in queue, got %d", len(queue.tasks)) + } + task := queue.tasks[taskID] + if task.ProjectID != "project-1" { + t.Errorf("got project_id %q, want %q", task.ProjectID, "project-1") + } + if task.Type != domain.WorkTaskTypeBuild { + t.Errorf("got type %q, want %q", task.Type, domain.WorkTaskTypeBuild) + } + + // Verify audit was recorded + if len(audit.entries) != 1 { + t.Errorf("expected 1 audit entry, got %d", len(audit.entries)) + } + }) + + t.Run("validates prompt required", func(t *testing.T) { + queue := newMockWorkQueue() + audit := newMockBuildAudit() + svc := NewBuildService(queue, audit, nil) + + _, err := svc.StartBuild(ctx, "project-1", domain.BuildSpec{}) + if err == nil { + t.Error("expected error for empty prompt") + } + }) + + t.Run("validates project ID required", func(t *testing.T) { + queue := newMockWorkQueue() + audit := newMockBuildAudit() + svc := NewBuildService(queue, audit, nil) + + _, err := svc.StartBuild(ctx, "", domain.BuildSpec{Prompt: "Build"}) + if err == nil { + t.Error("expected error for empty project ID") + } + }) + + t.Run("includes variables in spec", func(t *testing.T) { + queue := newMockWorkQueue() + audit := newMockBuildAudit() + svc := NewBuildService(queue, audit, nil) + + taskID, err := svc.StartBuild(ctx, "project-1", domain.BuildSpec{ + Prompt: "Build", + Variables: map[string]string{ + "name": "My App", + "color": "blue", + }, + }) + if err != nil { + t.Fatalf("StartBuild() error = %v", err) + } + + task := queue.tasks[taskID] + vars, ok := task.Spec["variables"].(map[string]string) + if !ok { + t.Fatal("expected variables in task spec") + } + if vars["name"] != "My App" { + t.Errorf("got variable name %q, want %q", vars["name"], "My App") + } + }) + + t.Run("continues if audit fails", func(t *testing.T) { + queue := newMockWorkQueue() + audit := newMockBuildAudit() + audit.err = fmt.Errorf("db connection failed") + svc := NewBuildService(queue, audit, nil) + + taskID, err := svc.StartBuild(ctx, "project-1", domain.BuildSpec{ + Prompt: "Build", + }) + if err != nil { + t.Fatalf("StartBuild() should succeed even if audit fails, got error = %v", err) + } + if taskID == "" { + t.Error("expected non-empty task ID") + } + }) +} + +func TestBuildService_GetBuildStatus(t *testing.T) { + ctx := context.Background() + + t.Run("returns existing entry", func(t *testing.T) { + audit := newMockBuildAudit() + audit.entries["task-1"] = &domain.BuildAuditEntry{ + TaskID: "task-1", + ProjectID: "project-1", + Status: domain.BuildStatusRunning, + } + svc := NewBuildService(newMockWorkQueue(), audit, nil) + + entry, err := svc.GetBuildStatus(ctx, "task-1") + if err != nil { + t.Fatalf("GetBuildStatus() error = %v", err) + } + if entry.Status != domain.BuildStatusRunning { + t.Errorf("got status %q, want %q", entry.Status, domain.BuildStatusRunning) + } + }) + + t.Run("returns error for nonexistent entry", func(t *testing.T) { + audit := newMockBuildAudit() + svc := NewBuildService(newMockWorkQueue(), audit, nil) + + _, err := svc.GetBuildStatus(ctx, "nonexistent") + if err == nil { + t.Error("expected error for nonexistent entry") + } + }) +} + +func TestBuildService_ListBuilds(t *testing.T) { + ctx := context.Background() + + audit := newMockBuildAudit() + audit.entries["task-1"] = &domain.BuildAuditEntry{ + TaskID: "task-1", ProjectID: "project-a", Status: domain.BuildStatusCompleted, + } + audit.entries["task-2"] = &domain.BuildAuditEntry{ + TaskID: "task-2", ProjectID: "project-a", Status: domain.BuildStatusFailed, + } + audit.entries["task-3"] = &domain.BuildAuditEntry{ + TaskID: "task-3", ProjectID: "project-b", Status: domain.BuildStatusPending, + } + + svc := NewBuildService(newMockWorkQueue(), audit, nil) + + t.Run("lists builds for project", func(t *testing.T) { + entries, err := svc.ListBuilds(ctx, "project-a", 50) + if err != nil { + t.Fatalf("ListBuilds() error = %v", err) + } + if len(entries) != 2 { + t.Errorf("got %d entries, want 2", len(entries)) + } + }) + + t.Run("uses default limit", func(t *testing.T) { + entries, err := svc.ListBuilds(ctx, "project-a", 0) + if err != nil { + t.Fatalf("ListBuilds() error = %v", err) + } + if len(entries) != 2 { + t.Errorf("got %d entries, want 2", len(entries)) + } + }) +} + +func TestBuildService_CompleteBuild(t *testing.T) { + ctx := context.Background() + + t.Run("updates audit on completion", func(t *testing.T) { + audit := newMockBuildAudit() + audit.entries["task-1"] = &domain.BuildAuditEntry{ + TaskID: "task-1", + ProjectID: "project-1", + Status: domain.BuildStatusRunning, + } + svc := NewBuildService(newMockWorkQueue(), audit, nil) + + err := svc.CompleteBuild(ctx, "task-1", &domain.BuildResult{ + Success: true, + CommitSHA: "abc123", + DurationMs: 5000, + }) + if err != nil { + t.Fatalf("CompleteBuild() error = %v", err) + } + + entry := audit.entries["task-1"] + if entry.Status != domain.BuildStatusCompleted { + t.Errorf("got status %q, want %q", entry.Status, domain.BuildStatusCompleted) + } + if entry.Result == nil { + t.Fatal("expected result to be set") + } + if !entry.Result.Success { + t.Error("expected result.Success = true") + } + }) +} diff --git a/internal/service/mock_test.go b/internal/service/mock_test.go new file mode 100644 index 0000000..ce44967 --- /dev/null +++ b/internal/service/mock_test.go @@ -0,0 +1,256 @@ +package service + +import ( + "context" + "fmt" + "time" + + "github.com/orchard9/rdev/internal/domain" + "github.com/orchard9/rdev/internal/port" +) + +// mockWorkQueue implements port.WorkQueue for service tests. +// Configure tasks and err fields to control behavior. +type mockWorkQueue struct { + tasks map[string]*domain.WorkTask + err error +} + +func newMockWorkQueue() *mockWorkQueue { + return &mockWorkQueue{tasks: make(map[string]*domain.WorkTask)} +} + +func (m *mockWorkQueue) Enqueue(ctx context.Context, task *domain.WorkTask) (string, error) { + if m.err != nil { + return "", m.err + } + id := fmt.Sprintf("task-%d", len(m.tasks)+1) + task.ID = id + task.Status = domain.WorkTaskStatusPending + task.CreatedAt = time.Now() + m.tasks[id] = task + return id, nil +} + +func (m *mockWorkQueue) Dequeue(ctx context.Context, workerID string) (*domain.WorkTask, error) { + if m.err != nil { + return nil, m.err + } + for _, task := range m.tasks { + if task.Status == domain.WorkTaskStatusPending { + task.Status = domain.WorkTaskStatusRunning + task.WorkerID = workerID + now := time.Now() + task.StartedAt = &now + return task, nil + } + } + return nil, nil +} + +func (m *mockWorkQueue) Complete(ctx context.Context, taskID string, result *domain.WorkResult) error { + if m.err != nil { + return m.err + } + task, ok := m.tasks[taskID] + if !ok { + return domain.ErrWorkTaskNotFound + } + task.Status = domain.WorkTaskStatusCompleted + task.Result = result + return nil +} + +func (m *mockWorkQueue) Fail(ctx context.Context, taskID string, errMsg string) error { + return nil +} + +func (m *mockWorkQueue) Cancel(ctx context.Context, taskID string) error { + return nil +} + +func (m *mockWorkQueue) GetTask(ctx context.Context, taskID string) (*domain.WorkTask, error) { + task, ok := m.tasks[taskID] + if !ok { + return nil, domain.ErrWorkTaskNotFound + } + return task, nil +} + +func (m *mockWorkQueue) ListByProject(ctx context.Context, projectID string, status *domain.WorkTaskStatus, opts domain.WorkListOptions) (*domain.WorkListResult, error) { + return &domain.WorkListResult{}, nil +} + +func (m *mockWorkQueue) GetStats(ctx context.Context) (*domain.WorkQueueStats, error) { + return &domain.WorkQueueStats{}, nil +} + +func (m *mockWorkQueue) CleanupOld(ctx context.Context, olderThan time.Duration) (int64, error) { + return 0, nil +} + +func (m *mockWorkQueue) RequeueStale(ctx context.Context, timeout time.Duration) (int64, error) { + return 0, nil +} + +// mockBuildAudit implements port.BuildAudit for service tests. +type mockBuildAudit struct { + entries map[string]*domain.BuildAuditEntry + err error +} + +func newMockBuildAudit() *mockBuildAudit { + return &mockBuildAudit{entries: make(map[string]*domain.BuildAuditEntry)} +} + +func (m *mockBuildAudit) Record(ctx context.Context, entry *domain.BuildAuditEntry) error { + if m.err != nil { + return m.err + } + m.entries[entry.TaskID] = entry + return nil +} + +func (m *mockBuildAudit) Update(ctx context.Context, taskID string, result *domain.BuildResult) error { + if m.err != nil { + return m.err + } + entry, ok := m.entries[taskID] + if !ok { + return domain.ErrBuildNotFound + } + entry.Result = result + if result.Success { + entry.Status = domain.BuildStatusCompleted + } else { + entry.Status = domain.BuildStatusFailed + } + now := time.Now() + entry.CompletedAt = &now + return nil +} + +func (m *mockBuildAudit) Get(ctx context.Context, taskID string) (*domain.BuildAuditEntry, error) { + if m.err != nil { + return nil, m.err + } + entry, ok := m.entries[taskID] + if !ok { + return nil, domain.ErrBuildNotFound + } + return entry, nil +} + +func (m *mockBuildAudit) List(ctx context.Context, filter port.BuildAuditFilter) ([]*domain.BuildAuditEntry, error) { + if m.err != nil { + return nil, m.err + } + var result []*domain.BuildAuditEntry + for _, entry := range m.entries { + if filter.ProjectID != "" && entry.ProjectID != filter.ProjectID { + continue + } + result = append(result, entry) + } + if filter.Limit > 0 && len(result) > filter.Limit { + result = result[:filter.Limit] + } + return result, nil +} + +// mockWorkerRegistry implements port.WorkerRegistry for service tests. +type mockWorkerRegistry struct { + workers map[string]*domain.Worker + err error +} + +func newMockWorkerRegistry() *mockWorkerRegistry { + return &mockWorkerRegistry{workers: make(map[string]*domain.Worker)} +} + +func (m *mockWorkerRegistry) Register(ctx context.Context, worker *domain.Worker) error { + if m.err != nil { + return m.err + } + m.workers[worker.ID] = worker + return nil +} + +func (m *mockWorkerRegistry) Heartbeat(ctx context.Context, workerID string) error { + if m.err != nil { + return m.err + } + w, ok := m.workers[workerID] + if !ok { + return domain.ErrWorkerNotFound + } + w.LastHeartbeat = time.Now() + return nil +} + +func (m *mockWorkerRegistry) UpdateStatus(ctx context.Context, workerID string, status domain.WorkerStatus, taskID string) error { + if m.err != nil { + return m.err + } + w, ok := m.workers[workerID] + if !ok { + return domain.ErrWorkerNotFound + } + w.Status = status + w.CurrentTask = taskID + return nil +} + +func (m *mockWorkerRegistry) Deregister(ctx context.Context, workerID string) error { + if m.err != nil { + return m.err + } + if _, ok := m.workers[workerID]; !ok { + return domain.ErrWorkerNotFound + } + delete(m.workers, workerID) + return nil +} + +func (m *mockWorkerRegistry) Get(ctx context.Context, workerID string) (*domain.Worker, error) { + if m.err != nil { + return nil, m.err + } + w, ok := m.workers[workerID] + if !ok { + return nil, domain.ErrWorkerNotFound + } + return w, nil +} + +func (m *mockWorkerRegistry) List(ctx context.Context, filter port.WorkerFilter) ([]*domain.Worker, error) { + if m.err != nil { + return nil, m.err + } + var result []*domain.Worker + for _, w := range m.workers { + if filter.Status != nil && w.Status != *filter.Status { + continue + } + result = append(result, w) + } + if filter.Limit > 0 && len(result) > filter.Limit { + result = result[:filter.Limit] + } + return result, nil +} + +func (m *mockWorkerRegistry) MarkStaleOffline(ctx context.Context, threshold time.Duration) (int, error) { + if m.err != nil { + return 0, m.err + } + count := 0 + for _, w := range m.workers { + if w.Status != domain.WorkerStatusOffline && time.Since(w.LastHeartbeat) > threshold { + w.Status = domain.WorkerStatusOffline + w.CurrentTask = "" + count++ + } + } + return count, nil +} diff --git a/internal/service/project_infra.go b/internal/service/project_infra.go index 0aa175d..37984fa 100644 --- a/internal/service/project_infra.go +++ b/internal/service/project_infra.go @@ -4,49 +4,19 @@ package service import ( "context" "database/sql" - "errors" "fmt" "log/slog" - "regexp" "time" "github.com/orchard9/rdev/internal/domain" "github.com/orchard9/rdev/internal/port" ) -// projectNameRegex validates project names for DNS and K8s compatibility. -// Must be lowercase, start with a letter, contain only letters, numbers, and dashes. -var projectNameRegex = regexp.MustCompile(`^[a-z][a-z0-9-]*$`) - -// reservedProjectNames are names that cannot be used for projects. -var reservedProjectNames = map[string]bool{ - "www": true, - "api": true, - "git": true, - "ci": true, - "registry": true, - "admin": true, - "root": true, - "rdev": true, - "pantheon": true, -} - // ValidateProjectName validates that a project name is safe for use as // a DNS subdomain, K8s resource name, and git repository name. +// Delegates to domain.ValidateProjectName for centralized validation. func ValidateProjectName(name string) error { - if name == "" { - return errors.New("project name cannot be empty") - } - if len(name) > 63 { - return errors.New("project name too long (max 63 characters)") - } - if !projectNameRegex.MatchString(name) { - return errors.New("project name must be lowercase, start with a letter, and contain only letters, numbers, and dashes") - } - if reservedProjectNames[name] { - return fmt.Errorf("'%s' is a reserved name", name) - } - return nil + return domain.ValidateProjectName(name) } // ProjectInfraService orchestrates project infrastructure operations. @@ -136,7 +106,7 @@ type CreateProjectResult struct { func (s *ProjectInfraService) CreateProject(ctx context.Context, req CreateProjectRequest) (*CreateProjectResult, error) { // Validate project name first if err := ValidateProjectName(req.Name); err != nil { - return nil, fmt.Errorf("%w: %v", domain.ErrInvalidProjectName, err) + return nil, fmt.Errorf("%w: %w", domain.ErrInvalidProjectName, err) } s.logger.Info("creating project", "name", req.Name) diff --git a/internal/service/project_service.go b/internal/service/project_service.go index 33eebe0..b8b6b62 100644 --- a/internal/service/project_service.go +++ b/internal/service/project_service.go @@ -142,12 +142,12 @@ func (s *ProjectService) ExecuteClaude(ctx context.Context, req ExecuteClaudeReq return nil, fmt.Errorf("%w: prompt is required", domain.ErrInvalidCommand) } if err := sanitize.ClaudePrompt(req.Prompt); err != nil { - return nil, fmt.Errorf("%w: %v", domain.ErrCommandSanitization, err) + return nil, fmt.Errorf("%w: %w", domain.ErrCommandSanitization, err) } // Validate stream ID if err := sanitize.StreamID(req.StreamID); err != nil { - return nil, fmt.Errorf("%w: %v", domain.ErrInvalidCommand, err) + return nil, fmt.Errorf("%w: %w", domain.ErrInvalidCommand, err) } // Generate command ID diff --git a/internal/service/project_service_agent.go b/internal/service/project_service_agent.go index a38cd86..f0fa86f 100644 --- a/internal/service/project_service_agent.go +++ b/internal/service/project_service_agent.go @@ -65,6 +65,10 @@ func (s *ProjectService) executeAgentCommand(agent port.CodeAgent, req *domain.A "tool": event.ToolName, "input": event.ToolInput, } + // Record tool use metric + if event.ToolName != "" { + metrics.RecordAgentToolUse(string(agent.Provider()), event.ToolName) + } case domain.AgentEventToolResult: eventType = "tool_result" data = map[string]any{ @@ -112,6 +116,9 @@ func (s *ProjectService) executeAgentCommand(agent port.CodeAgent, req *domain.A } metrics.RecordCommand(string(cmd.ProjectID), string(cmd.Type), status, result.DurationMs) + // Record agent-specific metrics + metrics.RecordAgentRequest(string(agent.Provider()), status, result.DurationMs) + // Log audit completion if audit logger is configured if s.auditLogger != nil { var auditStatus domain.AuditStatus diff --git a/internal/service/project_service_commands.go b/internal/service/project_service_commands.go index e27e20b..84f6935 100644 --- a/internal/service/project_service_commands.go +++ b/internal/service/project_service_commands.go @@ -40,12 +40,12 @@ func (s *ProjectService) ExecuteShell(ctx context.Context, req ExecuteShellReque return nil, fmt.Errorf("%w: command is required", domain.ErrInvalidCommand) } if err := sanitize.ShellCommand(req.Command); err != nil { - return nil, fmt.Errorf("%w: %v", domain.ErrCommandSanitization, err) + return nil, fmt.Errorf("%w: %w", domain.ErrCommandSanitization, err) } // Validate stream ID if err := sanitize.StreamID(req.StreamID); err != nil { - return nil, fmt.Errorf("%w: %v", domain.ErrInvalidCommand, err) + return nil, fmt.Errorf("%w: %w", domain.ErrInvalidCommand, err) } // Generate command ID @@ -120,12 +120,12 @@ func (s *ProjectService) ExecuteGit(ctx context.Context, req ExecuteGitRequest) return nil, fmt.Errorf("%w: args is required", domain.ErrInvalidCommand) } if err := sanitize.GitArgs(req.Args); err != nil { - return nil, fmt.Errorf("%w: %v", domain.ErrCommandSanitization, err) + return nil, fmt.Errorf("%w: %w", domain.ErrCommandSanitization, err) } // Validate stream ID if err := sanitize.StreamID(req.StreamID); err != nil { - return nil, fmt.Errorf("%w: %v", domain.ErrInvalidCommand, err) + return nil, fmt.Errorf("%w: %w", domain.ErrInvalidCommand, err) } // Generate command ID diff --git a/internal/service/work_service.go b/internal/service/work_service.go index e539479..e1a31b6 100644 --- a/internal/service/work_service.go +++ b/internal/service/work_service.go @@ -6,6 +6,7 @@ import ( "fmt" "log/slog" + "github.com/orchard9/rdev/internal/domain" "github.com/orchard9/rdev/internal/port" "github.com/orchard9/rdev/internal/webhook" ) @@ -57,7 +58,7 @@ func (s *WorkService) EnqueueTask(ctx context.Context, req EnqueueTaskRequest) ( maxRetries = 3 } - task := &port.WorkTask{ + task := &domain.WorkTask{ ProjectID: req.ProjectID, Type: req.Type, Spec: req.Spec, @@ -85,7 +86,7 @@ func (s *WorkService) EnqueueTask(ctx context.Context, req EnqueueTaskRequest) ( } // DequeueTask claims the next available task for a worker. -func (s *WorkService) DequeueTask(ctx context.Context, workerID string) (*port.WorkTask, error) { +func (s *WorkService) DequeueTask(ctx context.Context, workerID string) (*domain.WorkTask, error) { if workerID == "" { return nil, fmt.Errorf("worker_id is required") } @@ -108,7 +109,7 @@ func (s *WorkService) DequeueTask(ctx context.Context, workerID string) (*port.W } // CompleteTask marks a task as successfully completed. -func (s *WorkService) CompleteTask(ctx context.Context, taskID string, result *port.WorkResult) error { +func (s *WorkService) CompleteTask(ctx context.Context, taskID string, result *domain.WorkResult) error { // Get task for callback URL before completing task, err := s.queue.GetTask(ctx, taskID) if err != nil { @@ -147,7 +148,7 @@ func (s *WorkService) FailTask(ctx context.Context, taskID string, errMsg string // Check if it was requeued or permanently failed updatedTask, _ := s.queue.GetTask(ctx, taskID) - if updatedTask != nil && updatedTask.Status == port.WorkTaskStatusFailed { + if updatedTask != nil && updatedTask.Status == domain.WorkTaskStatusFailed { s.logger.Warn("task failed permanently", "task_id", taskID, "project", task.ProjectID, @@ -199,22 +200,22 @@ func (s *WorkService) CancelTask(ctx context.Context, taskID string) error { } // GetTask retrieves a task by ID. -func (s *WorkService) GetTask(ctx context.Context, taskID string) (*port.WorkTask, error) { +func (s *WorkService) GetTask(ctx context.Context, taskID string) (*domain.WorkTask, error) { return s.queue.GetTask(ctx, taskID) } // ListByProject returns tasks for a project with pagination. -func (s *WorkService) ListByProject(ctx context.Context, projectID string, status *port.WorkTaskStatus, opts port.WorkListOptions) (*port.WorkListResult, error) { +func (s *WorkService) ListByProject(ctx context.Context, projectID string, status *domain.WorkTaskStatus, opts domain.WorkListOptions) (*domain.WorkListResult, error) { return s.queue.ListByProject(ctx, projectID, status, opts) } // GetStats returns queue statistics. -func (s *WorkService) GetStats(ctx context.Context) (*port.WorkQueueStats, error) { +func (s *WorkService) GetStats(ctx context.Context) (*domain.WorkQueueStats, error) { return s.queue.GetStats(ctx) } // notifyCallback sends a webhook notification for task status changes. -func (s *WorkService) notifyCallback(task *port.WorkTask, status string, result *port.WorkResult, errMsg string) { +func (s *WorkService) notifyCallback(task *domain.WorkTask, status string, result *domain.WorkResult, errMsg string) { if s.webhookDispatcher == nil || task.CallbackURL == "" { return } @@ -251,7 +252,7 @@ type EnqueueTaskRequest struct { ProjectID string `json:"project_id"` // Type is the task type (build, test, deploy, custom). - Type port.WorkTaskType `json:"task_type"` + Type domain.WorkTaskType `json:"task_type"` // Spec contains task-specific parameters. Spec map[string]any `json:"task_spec"` diff --git a/internal/service/work_service_test.go b/internal/service/work_service_test.go new file mode 100644 index 0000000..c559df4 --- /dev/null +++ b/internal/service/work_service_test.go @@ -0,0 +1,215 @@ +package service + +import ( + "context" + "testing" + + "github.com/orchard9/rdev/internal/domain" +) + +func newTestWorkService() (*WorkService, *mockWorkQueue) { + q := newMockWorkQueue() + svc := NewWorkService(q, WorkServiceConfig{}) + return svc, q +} + +func TestWorkService_EnqueueTask(t *testing.T) { + t.Run("success", func(t *testing.T) { + svc, q := newTestWorkService() + + result, err := svc.EnqueueTask(context.Background(), EnqueueTaskRequest{ + ProjectID: "myapp", + Type: domain.WorkTaskTypeBuild, + Priority: 1, + Spec: map[string]any{"branch": "main"}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.TaskID == "" { + t.Error("task ID should not be empty") + } + if result.StatusURL == "" { + t.Error("status URL should not be empty") + } + if len(q.tasks) != 1 { + t.Errorf("tasks in queue = %d, want 1", len(q.tasks)) + } + }) + + t.Run("default max retries", func(t *testing.T) { + svc, q := newTestWorkService() + + _, err := svc.EnqueueTask(context.Background(), EnqueueTaskRequest{ + ProjectID: "myapp", + Type: domain.WorkTaskTypeBuild, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + for _, task := range q.tasks { + if task.MaxRetries != 3 { + t.Errorf("max retries = %d, want 3 (default)", task.MaxRetries) + } + } + }) + + t.Run("missing project id", func(t *testing.T) { + svc, _ := newTestWorkService() + + _, err := svc.EnqueueTask(context.Background(), EnqueueTaskRequest{ + Type: domain.WorkTaskTypeBuild, + }) + if err == nil { + t.Error("expected error for missing project_id") + } + }) + + t.Run("missing type", func(t *testing.T) { + svc, _ := newTestWorkService() + + _, err := svc.EnqueueTask(context.Background(), EnqueueTaskRequest{ + ProjectID: "myapp", + }) + if err == nil { + t.Error("expected error for missing type") + } + }) +} + +func TestWorkService_DequeueTask(t *testing.T) { + t.Run("success", func(t *testing.T) { + svc, q := newTestWorkService() + + // Enqueue a task first + q.tasks["task-1"] = &domain.WorkTask{ + ID: "task-1", + ProjectID: "myapp", + Type: domain.WorkTaskTypeBuild, + Status: domain.WorkTaskStatusPending, + } + + task, err := svc.DequeueTask(context.Background(), "worker-1") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if task == nil { + t.Fatal("expected a task, got nil") + } + if task.ID != "task-1" { + t.Errorf("task ID = %q, want %q", task.ID, "task-1") + } + if task.Status != domain.WorkTaskStatusRunning { + t.Errorf("task status = %q, want %q", task.Status, domain.WorkTaskStatusRunning) + } + }) + + t.Run("empty queue", func(t *testing.T) { + svc, _ := newTestWorkService() + + task, err := svc.DequeueTask(context.Background(), "worker-1") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if task != nil { + t.Error("expected nil task for empty queue") + } + }) + + t.Run("missing worker id", func(t *testing.T) { + svc, _ := newTestWorkService() + + _, err := svc.DequeueTask(context.Background(), "") + if err == nil { + t.Error("expected error for missing worker_id") + } + }) +} + +func TestWorkService_CompleteTask(t *testing.T) { + t.Run("success", func(t *testing.T) { + svc, q := newTestWorkService() + + q.tasks["task-1"] = &domain.WorkTask{ + ID: "task-1", + ProjectID: "myapp", + Type: domain.WorkTaskTypeBuild, + Status: domain.WorkTaskStatusRunning, + } + + result := &domain.WorkResult{Output: "ok"} + err := svc.CompleteTask(context.Background(), "task-1", result) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + task := q.tasks["task-1"] + if task.Status != domain.WorkTaskStatusCompleted { + t.Errorf("status = %q, want %q", task.Status, domain.WorkTaskStatusCompleted) + } + }) + + t.Run("task not found", func(t *testing.T) { + svc, _ := newTestWorkService() + + err := svc.CompleteTask(context.Background(), "nonexistent", nil) + if err == nil { + t.Error("expected error for nonexistent task") + } + }) +} + +func TestWorkService_GetTask(t *testing.T) { + t.Run("found", func(t *testing.T) { + svc, q := newTestWorkService() + + q.tasks["task-1"] = &domain.WorkTask{ + ID: "task-1", + ProjectID: "myapp", + } + + task, err := svc.GetTask(context.Background(), "task-1") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if task.ID != "task-1" { + t.Errorf("task ID = %q, want %q", task.ID, "task-1") + } + }) + + t.Run("not found", func(t *testing.T) { + svc, _ := newTestWorkService() + + _, err := svc.GetTask(context.Background(), "missing") + if err == nil { + t.Error("expected error for missing task") + } + }) +} + +func TestWorkService_GetStats(t *testing.T) { + svc, _ := newTestWorkService() + + stats, err := svc.GetStats(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if stats == nil { + t.Error("stats should not be nil") + } +} + +func TestWorkService_CancelTask(t *testing.T) { + svc, q := newTestWorkService() + + q.tasks["task-1"] = &domain.WorkTask{ + ID: "task-1", + ProjectID: "myapp", + Status: domain.WorkTaskStatusPending, + } + + err := svc.CancelTask(context.Background(), "task-1") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} diff --git a/internal/service/worker_service.go b/internal/service/worker_service.go new file mode 100644 index 0000000..450a027 --- /dev/null +++ b/internal/service/worker_service.go @@ -0,0 +1,231 @@ +// Package service provides business logic services. +package service + +import ( + "context" + "log/slog" + "time" + + "github.com/orchard9/rdev/internal/domain" + "github.com/orchard9/rdev/internal/port" +) + +const ( + // DefaultHeartbeatInterval is how often the health checker runs. + DefaultHeartbeatInterval = 30 * time.Second + + // DefaultStaleThreshold is how long since last heartbeat before marking offline. + DefaultStaleThreshold = 90 * time.Second +) + +// WorkerService manages worker lifecycle and task assignment. +// It coordinates between the worker registry (pool management) and +// the work queue (task execution). +type WorkerService struct { + registry port.WorkerRegistry + queue port.WorkQueue + audit port.BuildAudit + logger *slog.Logger +} + +// NewWorkerService creates a new worker service. +func NewWorkerService( + registry port.WorkerRegistry, + queue port.WorkQueue, + logger *slog.Logger, +) *WorkerService { + if logger == nil { + logger = slog.Default() + } + return &WorkerService{ + registry: registry, + queue: queue, + logger: logger.With("service", "worker"), + } +} + +// WithBuildAudit adds a build audit for recording task assignments. +func (s *WorkerService) WithBuildAudit(audit port.BuildAudit) *WorkerService { + s.audit = audit + return s +} + +// Register adds a worker to the pool. +func (s *WorkerService) Register(ctx context.Context, worker *domain.Worker) error { + if err := worker.Validate(); err != nil { + return err + } + + worker.RegisteredAt = time.Now() + worker.LastHeartbeat = time.Now() + worker.Status = domain.WorkerStatusIdle + + if err := s.registry.Register(ctx, worker); err != nil { + return err + } + + s.logger.Info("worker registered", + "worker_id", worker.ID, + "hostname", worker.Hostname, + "version", worker.Version, + "capabilities", worker.Capabilities, + ) + + return nil +} + +// Heartbeat updates worker liveness. +func (s *WorkerService) Heartbeat(ctx context.Context, workerID string) error { + return s.registry.Heartbeat(ctx, workerID) +} + +// Deregister removes a worker from the pool. +func (s *WorkerService) Deregister(ctx context.Context, workerID string) error { + if err := s.registry.Deregister(ctx, workerID); err != nil { + return err + } + + s.logger.Info("worker deregistered", "worker_id", workerID) + return nil +} + +// GetWorker retrieves a specific worker. +func (s *WorkerService) GetWorker(ctx context.Context, workerID string) (*domain.Worker, error) { + return s.registry.Get(ctx, workerID) +} + +// ListWorkers returns all workers matching the optional filter. +func (s *WorkerService) ListWorkers(ctx context.Context, filter port.WorkerFilter) ([]*domain.Worker, error) { + return s.registry.List(ctx, filter) +} + +// ClaimTask atomically dequeues a task and marks worker as busy. +func (s *WorkerService) ClaimTask(ctx context.Context, workerID string) (*domain.WorkTask, error) { + task, err := s.queue.Dequeue(ctx, workerID) + if err != nil { + return nil, err + } + if task == nil { + return nil, nil // No tasks available + } + + // Mark worker as busy with the claimed task + if err := s.registry.UpdateStatus(ctx, workerID, domain.WorkerStatusBusy, task.ID); err != nil { + s.logger.Warn("failed to update worker status after claim", + "worker_id", workerID, + "task_id", task.ID, + "error", err, + ) + } + + // Update audit entry if available + if s.audit != nil { + entry, _ := s.audit.Get(ctx, task.ID) + if entry != nil { + entry.WorkerID = workerID + entry.Status = domain.BuildStatusRunning + } + } + + s.logger.Info("task claimed", + "task_id", task.ID, + "worker_id", workerID, + "project_id", task.ProjectID, + "type", task.Type, + ) + + return task, nil +} + +// CompleteTask marks a task as complete and returns worker to idle. +func (s *WorkerService) CompleteTask(ctx context.Context, workerID, taskID string, result *domain.BuildResult) error { + if result == nil { + result = &domain.BuildResult{} + } + + // Convert domain build result to work result + bwr := result.ToWorkResult() + workResult := &domain.WorkResult{ + Output: bwr.Output, + Artifacts: bwr.Artifacts, + } + + // Update audit record (non-critical) + if s.audit != nil { + if err := s.audit.Update(ctx, taskID, result); err != nil { + s.logger.Warn("failed to update audit", + "task_id", taskID, + "error", err, + ) + } + } + + // Complete in queue + if err := s.queue.Complete(ctx, taskID, workResult); err != nil { + return err + } + + // Return worker to idle + if err := s.registry.UpdateStatus(ctx, workerID, domain.WorkerStatusIdle, ""); err != nil { + s.logger.Warn("failed to return worker to idle", + "worker_id", workerID, + "error", err, + ) + } + + s.logger.Info("task completed", + "task_id", taskID, + "worker_id", workerID, + "success", result.Success, + "duration_ms", result.DurationMs, + ) + + return nil +} + +// DrainWorker sets a worker to draining status so it finishes current work +// but doesn't accept new tasks. +func (s *WorkerService) DrainWorker(ctx context.Context, workerID string) error { + worker, err := s.registry.Get(ctx, workerID) + if err != nil { + return err + } + + if err := s.registry.UpdateStatus(ctx, workerID, domain.WorkerStatusDraining, worker.CurrentTask); err != nil { + return err + } + + s.logger.Info("worker draining", + "worker_id", workerID, + "current_task", worker.CurrentTask, + ) + + return nil +} + +// StartHealthChecker runs a background goroutine that marks stale workers offline. +// It returns when the context is cancelled. +func (s *WorkerService) StartHealthChecker(ctx context.Context) { + ticker := time.NewTicker(DefaultHeartbeatInterval) + defer ticker.Stop() + + s.logger.Info("worker health checker started", + "interval", DefaultHeartbeatInterval, + "stale_threshold", DefaultStaleThreshold, + ) + + for { + select { + case <-ctx.Done(): + s.logger.Info("worker health checker stopped") + return + case <-ticker.C: + count, err := s.registry.MarkStaleOffline(ctx, DefaultStaleThreshold) + if err != nil { + s.logger.Error("failed to mark stale workers", "error", err) + } else if count > 0 { + s.logger.Warn("marked workers offline", "count", count) + } + } + } +} diff --git a/internal/service/worker_service_test.go b/internal/service/worker_service_test.go new file mode 100644 index 0000000..5ae8bad --- /dev/null +++ b/internal/service/worker_service_test.go @@ -0,0 +1,328 @@ +package service + +import ( + "context" + "testing" + "time" + + "github.com/orchard9/rdev/internal/domain" + "github.com/orchard9/rdev/internal/port" +) + +func TestWorkerService_Register(t *testing.T) { + ctx := context.Background() + + t.Run("registers valid worker", func(t *testing.T) { + registry := newMockWorkerRegistry() + queue := newMockWorkQueue() + svc := NewWorkerService(registry, queue, nil) + + err := svc.Register(ctx, &domain.Worker{ + ID: "worker-1", + Hostname: "host-1", + Capabilities: []string{"build"}, + Version: "1.0.0", + }) + if err != nil { + t.Fatalf("Register() error = %v", err) + } + + w := registry.workers["worker-1"] + if w == nil { + t.Fatal("worker not found in registry") + } + if w.Status != domain.WorkerStatusIdle { + t.Errorf("got status %q, want %q", w.Status, domain.WorkerStatusIdle) + } + if w.RegisteredAt.IsZero() { + t.Error("expected registered_at to be set") + } + }) + + t.Run("validates worker", func(t *testing.T) { + registry := newMockWorkerRegistry() + queue := newMockWorkQueue() + svc := NewWorkerService(registry, queue, nil) + + err := svc.Register(ctx, &domain.Worker{}) + if err == nil { + t.Error("expected validation error for empty worker") + } + }) +} + +func TestWorkerService_Heartbeat(t *testing.T) { + ctx := context.Background() + + t.Run("updates heartbeat", func(t *testing.T) { + registry := newMockWorkerRegistry() + registry.workers["worker-1"] = &domain.Worker{ + ID: "worker-1", + Hostname: "host-1", + Status: domain.WorkerStatusIdle, + LastHeartbeat: time.Now().Add(-30 * time.Second), + } + + svc := NewWorkerService(registry, newMockWorkQueue(), nil) + + err := svc.Heartbeat(ctx, "worker-1") + if err != nil { + t.Fatalf("Heartbeat() error = %v", err) + } + + // Heartbeat should be recent + w := registry.workers["worker-1"] + if time.Since(w.LastHeartbeat) > time.Second { + t.Error("expected heartbeat to be updated to now") + } + }) + + t.Run("returns error for nonexistent worker", func(t *testing.T) { + registry := newMockWorkerRegistry() + svc := NewWorkerService(registry, newMockWorkQueue(), nil) + + err := svc.Heartbeat(ctx, "nonexistent") + if err == nil { + t.Error("expected error for nonexistent worker") + } + }) +} + +func TestWorkerService_Deregister(t *testing.T) { + ctx := context.Background() + + t.Run("deregisters worker", func(t *testing.T) { + registry := newMockWorkerRegistry() + registry.workers["worker-1"] = &domain.Worker{ + ID: "worker-1", + Hostname: "host-1", + Status: domain.WorkerStatusIdle, + } + + svc := NewWorkerService(registry, newMockWorkQueue(), nil) + + err := svc.Deregister(ctx, "worker-1") + if err != nil { + t.Fatalf("Deregister() error = %v", err) + } + + if _, ok := registry.workers["worker-1"]; ok { + t.Error("worker should be removed from registry") + } + }) +} + +func TestWorkerService_ClaimTask(t *testing.T) { + ctx := context.Background() + + t.Run("claims task and marks worker busy", func(t *testing.T) { + registry := newMockWorkerRegistry() + registry.workers["worker-1"] = &domain.Worker{ + ID: "worker-1", + Hostname: "host-1", + Status: domain.WorkerStatusIdle, + } + + queue := newMockWorkQueue() + queue.tasks["task-1"] = &domain.WorkTask{ + ID: "task-1", + ProjectID: "project-1", + Type: domain.WorkTaskTypeBuild, + Status: domain.WorkTaskStatusPending, + CreatedAt: time.Now(), + } + + svc := NewWorkerService(registry, queue, nil) + + task, err := svc.ClaimTask(ctx, "worker-1") + if err != nil { + t.Fatalf("ClaimTask() error = %v", err) + } + if task == nil { + t.Fatal("expected task to be returned") + } + if task.ID != "task-1" { + t.Errorf("got task ID %q, want %q", task.ID, "task-1") + } + + // Worker should be busy with the task + w := registry.workers["worker-1"] + if w.Status != domain.WorkerStatusBusy { + t.Errorf("got status %q, want %q", w.Status, domain.WorkerStatusBusy) + } + if w.CurrentTask != "task-1" { + t.Errorf("got current_task %q, want %q", w.CurrentTask, "task-1") + } + }) + + t.Run("returns nil when no tasks available", func(t *testing.T) { + registry := newMockWorkerRegistry() + registry.workers["worker-1"] = &domain.Worker{ + ID: "worker-1", + Hostname: "host-1", + Status: domain.WorkerStatusIdle, + } + + queue := newMockWorkQueue() + svc := NewWorkerService(registry, queue, nil) + + task, err := svc.ClaimTask(ctx, "worker-1") + if err != nil { + t.Fatalf("ClaimTask() error = %v", err) + } + if task != nil { + t.Error("expected nil task when queue is empty") + } + }) +} + +func TestWorkerService_CompleteTask(t *testing.T) { + ctx := context.Background() + + t.Run("completes task and returns worker to idle", func(t *testing.T) { + registry := newMockWorkerRegistry() + registry.workers["worker-1"] = &domain.Worker{ + ID: "worker-1", + Hostname: "host-1", + Status: domain.WorkerStatusBusy, + CurrentTask: "task-1", + } + + queue := newMockWorkQueue() + queue.tasks["task-1"] = &domain.WorkTask{ + ID: "task-1", + ProjectID: "project-1", + Type: domain.WorkTaskTypeBuild, + Status: domain.WorkTaskStatusRunning, + WorkerID: "worker-1", + } + + svc := NewWorkerService(registry, queue, nil) + + err := svc.CompleteTask(ctx, "worker-1", "task-1", &domain.BuildResult{ + Success: true, + CommitSHA: "abc123", + DurationMs: 5000, + }) + if err != nil { + t.Fatalf("CompleteTask() error = %v", err) + } + + // Task should be completed + task := queue.tasks["task-1"] + if task.Status != domain.WorkTaskStatusCompleted { + t.Errorf("got task status %q, want %q", task.Status, domain.WorkTaskStatusCompleted) + } + + // Worker should be idle + w := registry.workers["worker-1"] + if w.Status != domain.WorkerStatusIdle { + t.Errorf("got worker status %q, want %q", w.Status, domain.WorkerStatusIdle) + } + if w.CurrentTask != "" { + t.Errorf("got current_task %q, want empty", w.CurrentTask) + } + }) + + t.Run("handles nil result", func(t *testing.T) { + registry := newMockWorkerRegistry() + registry.workers["worker-1"] = &domain.Worker{ + ID: "worker-1", + Hostname: "host-1", + Status: domain.WorkerStatusBusy, + CurrentTask: "task-1", + } + + queue := newMockWorkQueue() + queue.tasks["task-1"] = &domain.WorkTask{ + ID: "task-1", + Status: domain.WorkTaskStatusRunning, + WorkerID: "worker-1", + } + + svc := NewWorkerService(registry, queue, nil) + + err := svc.CompleteTask(ctx, "worker-1", "task-1", nil) + if err != nil { + t.Fatalf("CompleteTask(nil result) error = %v", err) + } + + // Worker should be idle + w := registry.workers["worker-1"] + if w.Status != domain.WorkerStatusIdle { + t.Errorf("got worker status %q, want %q", w.Status, domain.WorkerStatusIdle) + } + }) +} + +func TestWorkerService_ListWorkers(t *testing.T) { + ctx := context.Background() + + registry := newMockWorkerRegistry() + registry.workers["worker-1"] = &domain.Worker{ID: "worker-1", Status: domain.WorkerStatusIdle} + registry.workers["worker-2"] = &domain.Worker{ID: "worker-2", Status: domain.WorkerStatusBusy} + registry.workers["worker-3"] = &domain.Worker{ID: "worker-3", Status: domain.WorkerStatusIdle} + + svc := NewWorkerService(registry, newMockWorkQueue(), nil) + + t.Run("lists all workers", func(t *testing.T) { + workers, err := svc.ListWorkers(ctx, port.WorkerFilter{}) + if err != nil { + t.Fatalf("ListWorkers() error = %v", err) + } + if len(workers) != 3 { + t.Errorf("got %d workers, want 3", len(workers)) + } + }) + + t.Run("filters by status", func(t *testing.T) { + idle := domain.WorkerStatusIdle + workers, err := svc.ListWorkers(ctx, port.WorkerFilter{Status: &idle}) + if err != nil { + t.Fatalf("ListWorkers() error = %v", err) + } + if len(workers) != 2 { + t.Errorf("got %d idle workers, want 2", len(workers)) + } + }) +} + +func TestWorkerService_DrainWorker(t *testing.T) { + ctx := context.Background() + + t.Run("drains worker", func(t *testing.T) { + registry := newMockWorkerRegistry() + registry.workers["worker-1"] = &domain.Worker{ + ID: "worker-1", + Hostname: "host-1", + Status: domain.WorkerStatusBusy, + CurrentTask: "task-1", + } + + svc := NewWorkerService(registry, newMockWorkQueue(), nil) + + err := svc.DrainWorker(ctx, "worker-1") + if err != nil { + t.Fatalf("DrainWorker() error = %v", err) + } + + w := registry.workers["worker-1"] + if w.Status != domain.WorkerStatusDraining { + t.Errorf("got status %q, want %q", w.Status, domain.WorkerStatusDraining) + } + // Should preserve current task + if w.CurrentTask != "task-1" { + t.Errorf("got current_task %q, want %q", w.CurrentTask, "task-1") + } + }) + + t.Run("returns error for nonexistent worker", func(t *testing.T) { + registry := newMockWorkerRegistry() + svc := NewWorkerService(registry, newMockWorkQueue(), nil) + + err := svc.DrainWorker(ctx, "nonexistent") + if err == nil { + t.Error("expected error for nonexistent worker") + } + }) +} diff --git a/internal/telemetry/telemetry.go b/internal/telemetry/telemetry.go index dd3a26a..3265879 100644 --- a/internal/telemetry/telemetry.go +++ b/internal/telemetry/telemetry.go @@ -13,7 +13,7 @@ package telemetry import ( "context" - "errors" + "fmt" "log/slog" "os" "strings" @@ -111,7 +111,7 @@ func New(ctx context.Context, cfg Config) (*Telemetry, error) { exporter, err := otlptracegrpc.New(ctx, opts...) if err != nil { - return nil, errors.New("failed to create OTLP exporter: " + err.Error()) + return nil, fmt.Errorf("failed to create OTLP exporter: %w", err) } // Create resource with service information @@ -179,7 +179,7 @@ func (t *Telemetry) Shutdown(ctx context.Context) error { } if err := t.tracerProvider.Shutdown(ctx); err != nil { - return errors.New("telemetry shutdown failed: " + err.Error()) + return fmt.Errorf("telemetry shutdown failed: %w", err) } t.logger.Info("telemetry shutdown complete") diff --git a/internal/validate/validate.go b/internal/validate/validate.go index 881dedc..9d947f2 100644 --- a/internal/validate/validate.go +++ b/internal/validate/validate.go @@ -5,6 +5,7 @@ package validate import ( "fmt" + "net/url" "regexp" "strings" ) @@ -239,6 +240,22 @@ var ( // --- Convenience validators for common patterns --- +// HTTPURL validates that a string is a valid HTTP or HTTPS URL. +// Returns nil for empty strings (use Required for that check). +func HTTPURL(value, field string) error { + if value == "" { + return nil + } + parsed, err := url.Parse(value) + if err != nil || (parsed.Scheme != "http" && parsed.Scheme != "https") || parsed.Host == "" { + return ValidationError{ + Field: field, + Message: "must be a valid HTTP or HTTPS URL", + } + } + return nil +} + // Name validates a name field (alphanumeric with dashes/underscores, 1-64 chars). // This matches the existing isValidName pattern in claude_config.go. func Name(value, field string) error { diff --git a/internal/worker/build_executor.go b/internal/worker/build_executor.go new file mode 100644 index 0000000..a0c0178 --- /dev/null +++ b/internal/worker/build_executor.go @@ -0,0 +1,196 @@ +package worker + +import ( + "context" + "fmt" + "log/slog" + "strings" + "time" + + "github.com/orchard9/rdev/internal/domain" + "github.com/orchard9/rdev/internal/port" +) + +// BuildExecutor handles WorkTaskTypeBuild tasks. +// It translates BuildSpec fields from the work task's Spec map into an +// AgentRequest, executes via a CodeAgent, and returns a BuildResult. +type BuildExecutor struct { + agentRegistry port.CodeAgentRegistry + gitOps *GitOperations + logger *slog.Logger +} + +// NewBuildExecutor creates a new build executor. +func NewBuildExecutor( + agentRegistry port.CodeAgentRegistry, + gitOps *GitOperations, + logger *slog.Logger, +) *BuildExecutor { + if logger == nil { + logger = slog.Default() + } + return &BuildExecutor{ + agentRegistry: agentRegistry, + gitOps: gitOps, + logger: logger.With("component", "build-executor"), + } +} + +// Execute runs a build task by translating its spec into an agent call. +func (b *BuildExecutor) Execute(ctx context.Context, task *domain.WorkTask) *domain.BuildResult { + start := time.Now() + + spec, err := b.parseSpec(task.Spec) + if err != nil { + return &domain.BuildResult{ + Success: false, + Error: fmt.Sprintf("invalid build spec: %v", err), + DurationMs: time.Since(start).Milliseconds(), + } + } + + // Determine working directory + workDir := "/workspace" + + // Clone repo if git URL is provided in the spec + gitURL, _ := task.Spec["git_url"].(string) + if gitURL != "" && b.gitOps != nil { + cloneDir, cleanup, err := b.gitOps.CloneToTemp(ctx, gitURL) + if err != nil { + return &domain.BuildResult{ + Success: false, + Error: fmt.Sprintf("git clone failed: %v", err), + DurationMs: time.Since(start).Milliseconds(), + } + } + defer cleanup() + workDir = cloneDir + } + + // Get a code agent + agent := b.agentRegistry.Default() + if agent == nil { + return &domain.BuildResult{ + Success: false, + Error: "no code agent available", + DurationMs: time.Since(start).Milliseconds(), + } + } + + // Build the agent request + agentReq := &domain.AgentRequest{ + Prompt: spec.Prompt, + ProjectID: domain.ProjectID(task.ProjectID), + WorkingDir: workDir, + Timeout: 10 * time.Minute, + } + + // Collect output with a size cap to prevent OOM on verbose builds. + const maxOutputSize = 1 << 20 // 1MB + var outputBuilder strings.Builder + + b.logger.Info("executing build via agent", + "task_id", task.ID, + "project_id", task.ProjectID, + "agent", agent.Name(), + "work_dir", workDir, + ) + + // Execute the agent + agentResult, err := agent.Execute(ctx, agentReq, func(event domain.AgentEvent) { + if event.Type == domain.AgentEventOutput || event.Type == domain.AgentEventError { + if outputBuilder.Len() >= maxOutputSize { + return // Output cap reached, discard further output + } + if outputBuilder.Len() > 0 { + outputBuilder.WriteString("\n") + } + remaining := maxOutputSize - outputBuilder.Len() + if len(event.Content) > remaining { + outputBuilder.WriteString(event.Content[:remaining]) + outputBuilder.WriteString("\n... [output truncated at 1MB]") + } else { + outputBuilder.WriteString(event.Content) + } + } + }) + + if err != nil { + return &domain.BuildResult{ + Success: false, + Error: fmt.Sprintf("agent execution failed: %v", err), + Output: outputBuilder.String(), + DurationMs: time.Since(start).Milliseconds(), + } + } + + result := &domain.BuildResult{ + Success: agentResult.Success(), + Output: outputBuilder.String(), + DurationMs: time.Since(start).Milliseconds(), + } + + if !agentResult.Success() { + errMsg := "agent returned non-zero exit code" + if agentResult.Error != nil { + errMsg = agentResult.Error.Error() + } + result.Error = errMsg + } + + // Handle git commit/push if requested + if result.Success && b.gitOps != nil && gitURL != "" { + if spec.AutoCommit { + commitMsg := fmt.Sprintf("build: %s", truncate(spec.Prompt, 72)) + sha, filesChanged, err := b.gitOps.CommitAndPush(ctx, workDir, commitMsg, spec.AutoPush) + if err != nil { + b.logger.Warn("git commit/push failed", + "task_id", task.ID, + "error", err, + ) + result.Success = false + result.Error = fmt.Sprintf("build succeeded but git operations failed: %v", err) + } else { + result.CommitSHA = sha + result.FilesChanged = filesChanged + } + } + } + + return result +} + +// parsedBuildSpec holds typed fields extracted from the task spec map. +type parsedBuildSpec struct { + Prompt string + AutoCommit bool + AutoPush bool +} + +// parseSpec extracts typed BuildSpec fields from the generic map[string]any. +func (b *BuildExecutor) parseSpec(spec map[string]any) (*parsedBuildSpec, error) { + prompt, _ := spec["prompt"].(string) + if prompt == "" { + return nil, fmt.Errorf("prompt is required") + } + + autoCommit, _ := spec["auto_commit"].(bool) + autoPush, _ := spec["auto_push"].(bool) + + return &parsedBuildSpec{ + Prompt: prompt, + AutoCommit: autoCommit, + AutoPush: autoPush, + }, nil +} + +// truncate shortens a string to maxLen, adding "..." if truncated. +func truncate(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + if maxLen <= 3 { + return s[:maxLen] + } + return s[:maxLen-3] + "..." +} diff --git a/internal/worker/git_operations.go b/internal/worker/git_operations.go new file mode 100644 index 0000000..bdba0bc --- /dev/null +++ b/internal/worker/git_operations.go @@ -0,0 +1,233 @@ +package worker + +import ( + "bytes" + "context" + "fmt" + "log/slog" + "os" + "os/exec" + "path/filepath" + "strings" +) + +// GitOperations provides git clone, commit, and push functionality +// for the build executor. It uses os/exec to run git commands. +type GitOperations struct { + giteaToken string + gitUser string + gitEmail string + logger *slog.Logger +} + +// GitOperationsConfig configures git operations. +type GitOperationsConfig struct { + // GiteaToken is the token for HTTPS clone/push authentication. + GiteaToken string + + // GitUser is the git commit author name. + GitUser string + + // GitEmail is the git commit author email. + GitEmail string + + Logger *slog.Logger +} + +// NewGitOperations creates a new git operations helper. +func NewGitOperations(cfg GitOperationsConfig) *GitOperations { + if cfg.GitUser == "" { + cfg.GitUser = "rdev-worker" + } + if cfg.GitEmail == "" { + cfg.GitEmail = "worker@threesix.ai" + } + if cfg.Logger == nil { + cfg.Logger = slog.Default() + } + return &GitOperations{ + giteaToken: cfg.GiteaToken, + gitUser: cfg.GitUser, + gitEmail: cfg.GitEmail, + logger: cfg.Logger.With("component", "git-ops"), + } +} + +// CloneToTemp clones a repository to a temporary directory. +// Returns the clone directory and a cleanup function. +func (g *GitOperations) CloneToTemp(ctx context.Context, gitURL string) (string, func(), error) { + tmpDir, err := os.MkdirTemp("", "rdev-build-*") + if err != nil { + return "", nil, fmt.Errorf("create temp dir: %w", err) + } + + cleanup := func() { + if err := os.RemoveAll(tmpDir); err != nil { + g.logger.Warn("failed to cleanup temp dir", "dir", tmpDir, "error", err) + } + } + + // Inject token into clone URL for authentication + authURL := g.injectToken(gitURL) + + if err := g.runGit(ctx, tmpDir, "clone", authURL, "."); err != nil { + cleanup() + return "", nil, fmt.Errorf("git clone: %w", err) + } + + // Configure git user for commits + if err := g.runGit(ctx, tmpDir, "config", "user.name", g.gitUser); err != nil { + cleanup() + return "", nil, fmt.Errorf("git config user.name: %w", err) + } + if err := g.runGit(ctx, tmpDir, "config", "user.email", g.gitEmail); err != nil { + cleanup() + return "", nil, fmt.Errorf("git config user.email: %w", err) + } + + g.logger.Info("cloned repository", "url", gitURL, "dir", tmpDir) + return tmpDir, cleanup, nil +} + +// CommitAndPush stages all changes, commits, and optionally pushes. +// Returns the commit SHA and list of changed files. +func (g *GitOperations) CommitAndPush(ctx context.Context, dir, message string, push bool) (string, []string, error) { + // Stage all changes + if err := g.runGit(ctx, dir, "add", "-A"); err != nil { + return "", nil, fmt.Errorf("git add: %w", err) + } + + // Check if there are changes to commit + status, err := g.runGitOutput(ctx, dir, "status", "--porcelain") + if err != nil { + return "", nil, fmt.Errorf("git status: %w", err) + } + if strings.TrimSpace(status) == "" { + g.logger.Info("no changes to commit", "dir", dir) + return "", nil, nil + } + + // Get list of changed files + diffOutput, err := g.runGitOutput(ctx, dir, "diff", "--cached", "--name-only") + if err != nil { + return "", nil, fmt.Errorf("git diff: %w", err) + } + var filesChanged []string + for _, f := range strings.Split(strings.TrimSpace(diffOutput), "\n") { + if f != "" { + filesChanged = append(filesChanged, f) + } + } + + // Commit + if err := g.runGit(ctx, dir, "commit", "-m", message); err != nil { + return "", nil, fmt.Errorf("git commit: %w", err) + } + + // Get commit SHA + sha, err := g.runGitOutput(ctx, dir, "rev-parse", "HEAD") + if err != nil { + return "", nil, fmt.Errorf("git rev-parse: %w", err) + } + sha = strings.TrimSpace(sha) + + g.logger.Info("committed changes", + "sha", sha, + "files", len(filesChanged), + ) + + // Push if requested + if push { + if err := g.runGit(ctx, dir, "push"); err != nil { + return sha, filesChanged, fmt.Errorf("git push: %w", err) + } + g.logger.Info("pushed changes", "sha", sha) + } + + return sha, filesChanged, nil +} + +// injectToken adds the Gitea token to an HTTPS git URL for authentication. +// Converts "https://git.example.com/org/repo.git" to +// "https://token@git.example.com/org/repo.git". +func (g *GitOperations) injectToken(gitURL string) string { + if g.giteaToken == "" { + return gitURL + } + // Handle https:// URLs + if strings.HasPrefix(gitURL, "https://") { + return "https://" + g.giteaToken + "@" + gitURL[len("https://"):] + } + if strings.HasPrefix(gitURL, "http://") { + return "http://" + g.giteaToken + "@" + gitURL[len("http://"):] + } + return gitURL +} + +// gitEnv returns a minimal environment for git subprocesses. +// Only PATH and HOME are inherited; all other host env vars are excluded +// to prevent credential or config leakage. +func gitEnv() []string { + env := []string{"GIT_TERMINAL_PROMPT=0"} + for _, key := range []string{"PATH", "HOME"} { + if v := os.Getenv(key); v != "" { + env = append(env, key+"="+v) + } + } + return env +} + +// runGit executes a git command in the given directory. +func (g *GitOperations) runGit(ctx context.Context, dir string, args ...string) error { + cmd := exec.CommandContext(ctx, "git", args...) + cmd.Dir = dir + cmd.Env = gitEnv() + + var stderr bytes.Buffer + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + // Redact token from error messages + errMsg := g.redactToken(stderr.String()) + return fmt.Errorf("%s: %s", err, errMsg) + } + return nil +} + +// runGitOutput executes a git command and returns its stdout. +func (g *GitOperations) runGitOutput(ctx context.Context, dir string, args ...string) (string, error) { + cmd := exec.CommandContext(ctx, "git", args...) + cmd.Dir = dir + cmd.Env = gitEnv() + + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + errMsg := g.redactToken(stderr.String()) + return "", fmt.Errorf("%s: %s", err, errMsg) + } + return stdout.String(), nil +} + +// redactToken removes the Gitea token from log/error output. +func (g *GitOperations) redactToken(s string) string { + if g.giteaToken == "" { + return s + } + return strings.ReplaceAll(s, g.giteaToken, "[REDACTED]") +} + +// EnsureGitDir verifies that the given path is a valid git repository. +func (g *GitOperations) EnsureGitDir(dir string) error { + gitDir := filepath.Join(dir, ".git") + info, err := os.Stat(gitDir) + if err != nil { + return fmt.Errorf("not a git repository: %w", err) + } + if !info.IsDir() { + return fmt.Errorf("not a git repository: .git is not a directory") + } + return nil +} diff --git a/internal/worker/git_operations_test.go b/internal/worker/git_operations_test.go new file mode 100644 index 0000000..4efb200 --- /dev/null +++ b/internal/worker/git_operations_test.go @@ -0,0 +1,415 @@ +package worker + +import ( + "context" + "log/slog" + "os" + "path/filepath" + "strings" + "testing" +) + +func testGitOps(token string) *GitOperations { + return NewGitOperations(GitOperationsConfig{ + GiteaToken: token, + GitUser: "test-user", + GitEmail: "test@example.com", + Logger: slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelWarn})), + }) +} + +func TestNewGitOperations_Defaults(t *testing.T) { + g := NewGitOperations(GitOperationsConfig{}) + if g.gitUser != "rdev-worker" { + t.Errorf("expected default gitUser 'rdev-worker', got %q", g.gitUser) + } + if g.gitEmail != "worker@threesix.ai" { + t.Errorf("expected default gitEmail 'worker@threesix.ai', got %q", g.gitEmail) + } + if g.logger == nil { + t.Error("expected non-nil logger") + } +} + +func TestNewGitOperations_CustomValues(t *testing.T) { + g := NewGitOperations(GitOperationsConfig{ + GiteaToken: "my-token", + GitUser: "custom-user", + GitEmail: "custom@example.com", + }) + if g.giteaToken != "my-token" { + t.Errorf("expected token 'my-token', got %q", g.giteaToken) + } + if g.gitUser != "custom-user" { + t.Errorf("expected gitUser 'custom-user', got %q", g.gitUser) + } + if g.gitEmail != "custom@example.com" { + t.Errorf("expected gitEmail 'custom@example.com', got %q", g.gitEmail) + } +} + +func TestInjectToken(t *testing.T) { + tests := []struct { + name string + token string + url string + expect string + }{ + { + name: "https URL with token", + token: "ghp_abc123", + url: "https://git.example.com/org/repo.git", + expect: "https://ghp_abc123@git.example.com/org/repo.git", + }, + { + name: "http URL with token", + token: "ghp_abc123", + url: "http://git.example.com/org/repo.git", + expect: "http://ghp_abc123@git.example.com/org/repo.git", + }, + { + name: "no token", + token: "", + url: "https://git.example.com/org/repo.git", + expect: "https://git.example.com/org/repo.git", + }, + { + name: "ssh URL unchanged", + token: "ghp_abc123", + url: "git@git.example.com:org/repo.git", + expect: "git@git.example.com:org/repo.git", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := testGitOps(tt.token) + got := g.injectToken(tt.url) + if got != tt.expect { + t.Errorf("injectToken(%q) = %q, want %q", tt.url, got, tt.expect) + } + }) + } +} + +func TestRedactToken(t *testing.T) { + tests := []struct { + name string + token string + input string + expect string + }{ + { + name: "redacts token from message", + token: "secret123", + input: "fatal: Authentication failed for 'https://secret123@git.example.com/repo.git'", + expect: "fatal: Authentication failed for 'https://[REDACTED]@git.example.com/repo.git'", + }, + { + name: "no token to redact", + token: "", + input: "fatal: repository not found", + expect: "fatal: repository not found", + }, + { + name: "token not present in message", + token: "secret123", + input: "fatal: repository not found", + expect: "fatal: repository not found", + }, + { + name: "multiple occurrences", + token: "tok", + input: "tok appears twice: tok", + expect: "[REDACTED] appears twice: [REDACTED]", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := testGitOps(tt.token) + got := g.redactToken(tt.input) + if got != tt.expect { + t.Errorf("redactToken(%q) = %q, want %q", tt.input, got, tt.expect) + } + }) + } +} + +func TestEnsureGitDir(t *testing.T) { + g := testGitOps("") + + t.Run("valid git directory", func(t *testing.T) { + dir := t.TempDir() + if err := os.MkdirAll(filepath.Join(dir, ".git"), 0o755); err != nil { + t.Fatal(err) + } + if err := g.EnsureGitDir(dir); err != nil { + t.Errorf("expected no error for valid git dir, got: %v", err) + } + }) + + t.Run("no .git directory", func(t *testing.T) { + dir := t.TempDir() + err := g.EnsureGitDir(dir) + if err == nil { + t.Error("expected error for non-git directory") + } + }) + + t.Run(".git is a file not directory", func(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, ".git"), []byte("gitdir: .."), 0o644); err != nil { + t.Fatal(err) + } + err := g.EnsureGitDir(dir) + if err == nil { + t.Error("expected error when .git is a file") + } + }) +} + +// TestCommitAndPush_NoChanges tests that CommitAndPush returns nil when +// there are no staged changes in the repository. +func TestCommitAndPush_NoChanges(t *testing.T) { + g := testGitOps("") + ctx := context.Background() + + // Create a real git repo with an initial commit + dir := t.TempDir() + if err := g.runGit(ctx, dir, "init"); err != nil { + t.Fatal("git init:", err) + } + if err := g.runGit(ctx, dir, "config", "user.name", "test"); err != nil { + t.Fatal("git config user.name:", err) + } + if err := g.runGit(ctx, dir, "config", "user.email", "test@test.com"); err != nil { + t.Fatal("git config user.email:", err) + } + // Create initial commit so HEAD exists + if err := os.WriteFile(filepath.Join(dir, "README.md"), []byte("init"), 0o644); err != nil { + t.Fatal(err) + } + if err := g.runGit(ctx, dir, "add", "-A"); err != nil { + t.Fatal("git add:", err) + } + if err := g.runGit(ctx, dir, "commit", "-m", "initial"); err != nil { + t.Fatal("git commit:", err) + } + + // No new changes — should return empty with no error + sha, files, err := g.CommitAndPush(ctx, dir, "no changes", false) + if err != nil { + t.Errorf("expected no error, got: %v", err) + } + if sha != "" { + t.Errorf("expected empty SHA, got: %q", sha) + } + if len(files) != 0 { + t.Errorf("expected no files, got: %v", files) + } +} + +// TestCommitAndPush_WithChanges tests that CommitAndPush correctly stages, +// commits, and returns SHA and changed file list. +func TestCommitAndPush_WithChanges(t *testing.T) { + g := testGitOps("") + ctx := context.Background() + + // Create a real git repo + dir := t.TempDir() + if err := g.runGit(ctx, dir, "init"); err != nil { + t.Fatal("git init:", err) + } + if err := g.runGit(ctx, dir, "config", "user.name", "test"); err != nil { + t.Fatal("git config user.name:", err) + } + if err := g.runGit(ctx, dir, "config", "user.email", "test@test.com"); err != nil { + t.Fatal("git config user.email:", err) + } + // Initial commit + if err := os.WriteFile(filepath.Join(dir, "README.md"), []byte("init"), 0o644); err != nil { + t.Fatal(err) + } + if err := g.runGit(ctx, dir, "add", "-A"); err != nil { + t.Fatal("git add:", err) + } + if err := g.runGit(ctx, dir, "commit", "-m", "initial"); err != nil { + t.Fatal("git commit:", err) + } + + // Create new files to commit + if err := os.WriteFile(filepath.Join(dir, "main.go"), []byte("package main"), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(dir, "go.mod"), []byte("module test"), 0o644); err != nil { + t.Fatal(err) + } + + // CommitAndPush without push (no remote) + sha, files, err := g.CommitAndPush(ctx, dir, "add go files", false) + if err != nil { + t.Errorf("expected no error, got: %v", err) + } + if sha == "" { + t.Error("expected non-empty SHA") + } + if len(sha) < 7 { + t.Errorf("expected SHA to be at least 7 chars, got: %q", sha) + } + if len(files) != 2 { + t.Errorf("expected 2 changed files, got %d: %v", len(files), files) + } + + // Verify the files are in the list + fileSet := make(map[string]bool) + for _, f := range files { + fileSet[f] = true + } + if !fileSet["main.go"] { + t.Error("expected main.go in changed files") + } + if !fileSet["go.mod"] { + t.Error("expected go.mod in changed files") + } +} + +// TestCommitAndPush_PushWithoutRemote tests that push fails gracefully +// when there's no remote configured. +func TestCommitAndPush_PushWithoutRemote(t *testing.T) { + g := testGitOps("") + ctx := context.Background() + + dir := t.TempDir() + if err := g.runGit(ctx, dir, "init"); err != nil { + t.Fatal("git init:", err) + } + if err := g.runGit(ctx, dir, "config", "user.name", "test"); err != nil { + t.Fatal("git config:", err) + } + if err := g.runGit(ctx, dir, "config", "user.email", "test@test.com"); err != nil { + t.Fatal("git config:", err) + } + if err := os.WriteFile(filepath.Join(dir, "file.txt"), []byte("init"), 0o644); err != nil { + t.Fatal(err) + } + if err := g.runGit(ctx, dir, "add", "-A"); err != nil { + t.Fatal("git add:", err) + } + if err := g.runGit(ctx, dir, "commit", "-m", "initial"); err != nil { + t.Fatal("git commit:", err) + } + + // Add a new file + if err := os.WriteFile(filepath.Join(dir, "new.txt"), []byte("new"), 0o644); err != nil { + t.Fatal(err) + } + + // Push should fail (no remote) but commit succeeds — SHA is returned + sha, files, err := g.CommitAndPush(ctx, dir, "test push", true) + if err == nil { + t.Error("expected push error when no remote configured") + } + // Even though push failed, commit should have succeeded + if sha == "" { + t.Error("expected SHA from successful commit before push failure") + } + if len(files) != 1 || files[0] != "new.txt" { + t.Errorf("expected [new.txt], got: %v", files) + } +} + +// TestCloneToTemp_InvalidURL tests that CloneToTemp fails on a bad URL. +func TestCloneToTemp_InvalidURL(t *testing.T) { + g := testGitOps("") + ctx := context.Background() + + _, _, err := g.CloneToTemp(ctx, "https://invalid.example.com/no-such-repo.git") + if err == nil { + t.Error("expected error cloning invalid URL") + } +} + +// TestCloneToTemp_LocalRepo tests cloning a local bare repository. +func TestCloneToTemp_LocalRepo(t *testing.T) { + g := testGitOps("") + ctx := context.Background() + + // Create a bare repo to clone from + bareDir := t.TempDir() + if err := g.runGit(ctx, bareDir, "init", "--bare"); err != nil { + t.Fatal("git init --bare:", err) + } + + // Create a source repo and push to the bare repo + srcDir := t.TempDir() + if err := g.runGit(ctx, srcDir, "init"); err != nil { + t.Fatal("git init:", err) + } + if err := g.runGit(ctx, srcDir, "config", "user.name", "test"); err != nil { + t.Fatal(err) + } + if err := g.runGit(ctx, srcDir, "config", "user.email", "test@test.com"); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(srcDir, "hello.txt"), []byte("hello"), 0o644); err != nil { + t.Fatal(err) + } + if err := g.runGit(ctx, srcDir, "add", "-A"); err != nil { + t.Fatal(err) + } + if err := g.runGit(ctx, srcDir, "commit", "-m", "initial"); err != nil { + t.Fatal(err) + } + if err := g.runGit(ctx, srcDir, "remote", "add", "origin", bareDir); err != nil { + t.Fatal(err) + } + if err := g.runGit(ctx, srcDir, "push", "origin", "master"); err != nil { + // Some git versions use "main" as default branch + if err2 := g.runGit(ctx, srcDir, "push", "origin", "main"); err2 != nil { + t.Fatalf("push failed for both master and main: master=%v, main=%v", err, err2) + } + } + + // Clone the bare repo using file:// protocol + cloneDir, cleanup, err := g.CloneToTemp(ctx, "file://"+bareDir) + if err != nil { + t.Fatalf("CloneToTemp failed: %v", err) + } + defer cleanup() + + // Verify the cloned file exists + content, err := os.ReadFile(filepath.Join(cloneDir, "hello.txt")) + if err != nil { + t.Fatalf("failed to read cloned file: %v", err) + } + if string(content) != "hello" { + t.Errorf("expected file content 'hello', got %q", string(content)) + } + + // Verify .git dir exists + if err := g.EnsureGitDir(cloneDir); err != nil { + t.Errorf("cloned dir should be a git repo: %v", err) + } + + // Verify git config was set + userName, err := g.runGitOutput(ctx, cloneDir, "config", "user.name") + if err != nil { + t.Fatalf("failed to get user.name: %v", err) + } + if got := strings.TrimSpace(userName); got != "test-user" { + t.Errorf("expected user.name 'test-user', got %q", got) + } +} + +func TestRunGit_ContextCancellation(t *testing.T) { + g := testGitOps("") + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + + dir := t.TempDir() + err := g.runGit(ctx, dir, "status") + if err == nil { + t.Error("expected error when context is cancelled") + } +} diff --git a/internal/worker/mock_test.go b/internal/worker/mock_test.go new file mode 100644 index 0000000..5e074c8 --- /dev/null +++ b/internal/worker/mock_test.go @@ -0,0 +1,322 @@ +package worker + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/orchard9/rdev/internal/domain" + "github.com/orchard9/rdev/internal/port" + "github.com/orchard9/rdev/internal/service" +) + +// ============================================================================= +// Mock implementations for worker package tests +// ============================================================================= + +type mockWorkQueue struct { + mu sync.Mutex + tasks map[string]*domain.WorkTask + err error +} + +func newMockWorkQueue() *mockWorkQueue { + return &mockWorkQueue{tasks: make(map[string]*domain.WorkTask)} +} + +func (m *mockWorkQueue) Enqueue(_ context.Context, task *domain.WorkTask) (string, error) { + m.mu.Lock() + defer m.mu.Unlock() + if m.err != nil { + return "", m.err + } + id := fmt.Sprintf("task-%d", len(m.tasks)+1) + task.ID = id + task.Status = domain.WorkTaskStatusPending + task.CreatedAt = time.Now() + m.tasks[id] = task + return id, nil +} + +func (m *mockWorkQueue) Dequeue(_ context.Context, workerID string) (*domain.WorkTask, error) { + m.mu.Lock() + defer m.mu.Unlock() + if m.err != nil { + return nil, m.err + } + for _, task := range m.tasks { + if task.Status == domain.WorkTaskStatusPending { + task.Status = domain.WorkTaskStatusRunning + task.WorkerID = workerID + now := time.Now() + task.StartedAt = &now + return task, nil + } + } + return nil, nil +} + +func (m *mockWorkQueue) Complete(_ context.Context, taskID string, result *domain.WorkResult) error { + m.mu.Lock() + defer m.mu.Unlock() + task, ok := m.tasks[taskID] + if !ok { + return domain.ErrWorkTaskNotFound + } + task.Status = domain.WorkTaskStatusCompleted + task.Result = result + now := time.Now() + task.CompletedAt = &now + return nil +} + +func (m *mockWorkQueue) Fail(_ context.Context, taskID string, errMsg string) error { + m.mu.Lock() + defer m.mu.Unlock() + task, ok := m.tasks[taskID] + if !ok { + return domain.ErrWorkTaskNotFound + } + task.RetryCount++ + if task.RetryCount >= task.MaxRetries { + task.Status = domain.WorkTaskStatusFailed + task.Error = errMsg + } else { + task.Status = domain.WorkTaskStatusPending + task.WorkerID = "" + } + return nil +} + +func (m *mockWorkQueue) Cancel(_ context.Context, taskID string) error { return nil } +func (m *mockWorkQueue) GetTask(_ context.Context, taskID string) (*domain.WorkTask, error) { + m.mu.Lock() + defer m.mu.Unlock() + task, ok := m.tasks[taskID] + if !ok { + return nil, domain.ErrWorkTaskNotFound + } + return task, nil +} +func (m *mockWorkQueue) ListByProject(_ context.Context, _ string, _ *domain.WorkTaskStatus, _ domain.WorkListOptions) (*domain.WorkListResult, error) { + return &domain.WorkListResult{}, nil +} +func (m *mockWorkQueue) GetStats(_ context.Context) (*domain.WorkQueueStats, error) { + return &domain.WorkQueueStats{}, nil +} +func (m *mockWorkQueue) CleanupOld(_ context.Context, _ time.Duration) (int64, error) { + return 0, nil +} +func (m *mockWorkQueue) RequeueStale(_ context.Context, _ time.Duration) (int64, error) { + return 0, nil +} + +type mockWorkerRegistry struct { + mu sync.Mutex + workers map[string]*domain.Worker + err error +} + +func newMockWorkerRegistry() *mockWorkerRegistry { + return &mockWorkerRegistry{workers: make(map[string]*domain.Worker)} +} + +func (m *mockWorkerRegistry) Register(_ context.Context, worker *domain.Worker) error { + m.mu.Lock() + defer m.mu.Unlock() + if m.err != nil { + return m.err + } + m.workers[worker.ID] = worker + return nil +} + +func (m *mockWorkerRegistry) Heartbeat(_ context.Context, workerID string) error { + m.mu.Lock() + defer m.mu.Unlock() + w, ok := m.workers[workerID] + if !ok { + return domain.ErrWorkerNotFound + } + w.LastHeartbeat = time.Now() + return nil +} + +func (m *mockWorkerRegistry) UpdateStatus(_ context.Context, workerID string, status domain.WorkerStatus, taskID string) error { + m.mu.Lock() + defer m.mu.Unlock() + w, ok := m.workers[workerID] + if !ok { + return domain.ErrWorkerNotFound + } + w.Status = status + w.CurrentTask = taskID + return nil +} + +func (m *mockWorkerRegistry) Deregister(_ context.Context, workerID string) error { + m.mu.Lock() + defer m.mu.Unlock() + delete(m.workers, workerID) + return nil +} + +func (m *mockWorkerRegistry) Get(_ context.Context, workerID string) (*domain.Worker, error) { + m.mu.Lock() + defer m.mu.Unlock() + w, ok := m.workers[workerID] + if !ok { + return nil, domain.ErrWorkerNotFound + } + return w, nil +} + +func (m *mockWorkerRegistry) List(_ context.Context, filter port.WorkerFilter) ([]*domain.Worker, error) { + m.mu.Lock() + defer m.mu.Unlock() + var result []*domain.Worker + for _, w := range m.workers { + if filter.Status != nil && w.Status != *filter.Status { + continue + } + result = append(result, w) + } + return result, nil +} + +func (m *mockWorkerRegistry) MarkStaleOffline(_ context.Context, _ time.Duration) (int, error) { + return 0, nil +} + +type mockBuildAudit struct { + mu sync.Mutex + entries map[string]*domain.BuildAuditEntry +} + +func newMockBuildAudit() *mockBuildAudit { + return &mockBuildAudit{entries: make(map[string]*domain.BuildAuditEntry)} +} + +func (m *mockBuildAudit) Record(_ context.Context, entry *domain.BuildAuditEntry) error { + m.mu.Lock() + defer m.mu.Unlock() + m.entries[entry.TaskID] = entry + return nil +} + +func (m *mockBuildAudit) Update(_ context.Context, taskID string, result *domain.BuildResult) error { + m.mu.Lock() + defer m.mu.Unlock() + entry, ok := m.entries[taskID] + if !ok { + return domain.ErrBuildNotFound + } + entry.Result = result + if result.Success { + entry.Status = domain.BuildStatusCompleted + } else { + entry.Status = domain.BuildStatusFailed + } + return nil +} + +func (m *mockBuildAudit) Get(_ context.Context, taskID string) (*domain.BuildAuditEntry, error) { + m.mu.Lock() + defer m.mu.Unlock() + entry, ok := m.entries[taskID] + if !ok { + return nil, domain.ErrBuildNotFound + } + return entry, nil +} + +func (m *mockBuildAudit) List(_ context.Context, _ port.BuildAuditFilter) ([]*domain.BuildAuditEntry, error) { + return nil, nil +} + +type mockCodeAgent struct { + result *domain.AgentResult + err error +} + +func (m *mockCodeAgent) Name() string { return "mock-agent" } +func (m *mockCodeAgent) Provider() domain.AgentProvider { return "mock" } +func (m *mockCodeAgent) Cancel(_ context.Context, _ string) error { return nil } +func (m *mockCodeAgent) Capabilities() domain.AgentCapabilities { + return domain.AgentCapabilities{Provider: "mock"} +} +func (m *mockCodeAgent) Available(_ context.Context) bool { return true } +func (m *mockCodeAgent) Execute(_ context.Context, req *domain.AgentRequest, handler domain.AgentEventHandler) (*domain.AgentResult, error) { + if handler != nil { + handler(domain.AgentEvent{ + Type: domain.AgentEventOutput, + Content: "mock output for: " + req.Prompt, + }) + } + if m.err != nil { + return nil, m.err + } + return m.result, nil +} + +type mockCodeAgentRegistry struct { + agent port.CodeAgent +} + +func (m *mockCodeAgentRegistry) Register(agent port.CodeAgent) { m.agent = agent } +func (m *mockCodeAgentRegistry) Get(_ domain.AgentProvider) port.CodeAgent { return m.agent } +func (m *mockCodeAgentRegistry) Default() port.CodeAgent { return m.agent } +func (m *mockCodeAgentRegistry) DefaultProvider() domain.AgentProvider { return "mock" } +func (m *mockCodeAgentRegistry) SetDefault(_ domain.AgentProvider) error { return nil } +func (m *mockCodeAgentRegistry) Available() []domain.AgentProvider { + return []domain.AgentProvider{"mock"} +} +func (m *mockCodeAgentRegistry) AvailableAgents(_ context.Context) []port.CodeAgent { + return []port.CodeAgent{m.agent} +} +func (m *mockCodeAgentRegistry) Count() int { return 1 } + +// ============================================================================= +// Helper to build test dependencies +// ============================================================================= + +type testDeps struct { + queue *mockWorkQueue + registry *mockWorkerRegistry + audit *mockBuildAudit + agent *mockCodeAgent + + workerSvc *service.WorkerService + workSvc *service.WorkService + buildExec *BuildExecutor +} + +func newTestDeps() *testDeps { + queue := newMockWorkQueue() + registry := newMockWorkerRegistry() + audit := newMockBuildAudit() + agent := &mockCodeAgent{ + result: &domain.AgentResult{ + ExitCode: 0, + DurationMs: 1000, + }, + } + agentRegistry := &mockCodeAgentRegistry{agent: agent} + + workerSvc := service.NewWorkerService(registry, queue, nil). + WithBuildAudit(audit) + workSvc := service.NewWorkService(queue, service.WorkServiceConfig{}) + + buildExec := NewBuildExecutor(agentRegistry, nil, nil) + + return &testDeps{ + queue: queue, + registry: registry, + audit: audit, + agent: agent, + workerSvc: workerSvc, + workSvc: workSvc, + buildExec: buildExec, + } +} diff --git a/internal/worker/queue_maintenance.go b/internal/worker/queue_maintenance.go new file mode 100644 index 0000000..f84478f --- /dev/null +++ b/internal/worker/queue_maintenance.go @@ -0,0 +1,241 @@ +package worker + +import ( + "context" + "log/slog" + "sync" + "time" + + "github.com/orchard9/rdev/internal/metrics" + "github.com/orchard9/rdev/internal/port" +) + +// QueueMaintenance runs periodic maintenance tasks on the work queue +// and worker registry: stale task recovery, stale worker marking, +// old task cleanup, and queue metrics refresh. +type QueueMaintenance struct { + queue port.WorkQueue + registry port.WorkerRegistry + logger *slog.Logger + + // Intervals + staleTaskTimeout time.Duration + staleWorkerTimeout time.Duration + cleanupAge time.Duration + maintenancePeriod time.Duration + metricsPeriod time.Duration + + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup +} + +// QueueMaintenanceConfig holds configuration for queue maintenance. +type QueueMaintenanceConfig struct { + // StaleTaskTimeout is how long a running task can be silent before requeue. + // Default: 30 minutes. + StaleTaskTimeout time.Duration + + // StaleWorkerTimeout is how long without heartbeat before marking offline. + // Default: 2 minutes. + StaleWorkerTimeout time.Duration + + // CleanupAge is the minimum age for completed/failed/cancelled tasks to be cleaned up. + // Default: 7 days. + CleanupAge time.Duration + + // MaintenancePeriod is how often to run maintenance tasks. + // Default: 1 minute. + MaintenancePeriod time.Duration + + // MetricsPeriod is how often to refresh queue depth metrics. + // Default: 15 seconds. + MetricsPeriod time.Duration + + Logger *slog.Logger +} + +// DefaultQueueMaintenanceConfig returns sensible defaults. +func DefaultQueueMaintenanceConfig() *QueueMaintenanceConfig { + return &QueueMaintenanceConfig{ + StaleTaskTimeout: 30 * time.Minute, + StaleWorkerTimeout: 2 * time.Minute, + CleanupAge: 7 * 24 * time.Hour, + MaintenancePeriod: 1 * time.Minute, + MetricsPeriod: 15 * time.Second, + Logger: slog.Default(), + } +} + +// NewQueueMaintenance creates a new queue maintenance worker. +func NewQueueMaintenance( + queue port.WorkQueue, + registry port.WorkerRegistry, + cfg *QueueMaintenanceConfig, +) *QueueMaintenance { + if cfg == nil { + cfg = DefaultQueueMaintenanceConfig() + } + + ctx, cancel := context.WithCancel(context.Background()) + + return &QueueMaintenance{ + queue: queue, + registry: registry, + logger: cfg.Logger.With("component", "queue-maintenance"), + staleTaskTimeout: cfg.StaleTaskTimeout, + staleWorkerTimeout: cfg.StaleWorkerTimeout, + cleanupAge: cfg.CleanupAge, + maintenancePeriod: cfg.MaintenancePeriod, + metricsPeriod: cfg.MetricsPeriod, + ctx: ctx, + cancel: cancel, + } +} + +// Start begins the maintenance and metrics loops. +func (m *QueueMaintenance) Start() { + m.logger.Info("queue maintenance started", + "maintenance_period", m.maintenancePeriod, + "metrics_period", m.metricsPeriod, + "stale_task_timeout", m.staleTaskTimeout, + "stale_worker_timeout", m.staleWorkerTimeout, + "cleanup_age", m.cleanupAge, + ) + + m.wg.Add(2) + go m.maintenanceLoop() + go m.metricsLoop() +} + +// Stop gracefully shuts down the maintenance worker. +func (m *QueueMaintenance) Stop() { + m.logger.Info("queue maintenance stopping") + m.cancel() + m.wg.Wait() + m.logger.Info("queue maintenance stopped") +} + +// maintenanceLoop runs periodic maintenance: stale recovery, worker health, cleanup. +func (m *QueueMaintenance) maintenanceLoop() { + defer m.wg.Done() + + // Run immediately on start + m.runMaintenance() + + ticker := time.NewTicker(m.maintenancePeriod) + defer ticker.Stop() + + for { + select { + case <-m.ctx.Done(): + return + case <-ticker.C: + m.runMaintenance() + } + } +} + +// metricsLoop refreshes queue depth metrics on a faster cadence. +func (m *QueueMaintenance) metricsLoop() { + defer m.wg.Done() + + // Run immediately on start + m.refreshMetrics() + + ticker := time.NewTicker(m.metricsPeriod) + defer ticker.Stop() + + for { + select { + case <-m.ctx.Done(): + return + case <-ticker.C: + m.refreshMetrics() + } + } +} + +// runMaintenance executes all maintenance tasks. +func (m *QueueMaintenance) runMaintenance() { + ctx, cancel := context.WithTimeout(m.ctx, 30*time.Second) + defer cancel() + + m.requeueStaleTasks(ctx) + m.markStaleWorkers(ctx) + m.cleanupOldTasks(ctx) +} + +// requeueStaleTasks requeues tasks that have been running too long +// (the worker likely crashed without reporting). +func (m *QueueMaintenance) requeueStaleTasks(ctx context.Context) { + count, err := m.queue.RequeueStale(ctx, m.staleTaskTimeout) + if err != nil { + m.logger.Warn("failed to requeue stale tasks", "error", err) + return + } + if count > 0 { + m.logger.Info("requeued stale tasks", "count", count) + } +} + +// markStaleWorkers marks workers without recent heartbeats as offline. +func (m *QueueMaintenance) markStaleWorkers(ctx context.Context) { + count, err := m.registry.MarkStaleOffline(ctx, m.staleWorkerTimeout) + if err != nil { + m.logger.Warn("failed to mark stale workers offline", "error", err) + return + } + if count > 0 { + m.logger.Info("marked stale workers offline", "count", count) + } +} + +// cleanupOldTasks removes completed/failed/cancelled tasks older than cleanup age. +func (m *QueueMaintenance) cleanupOldTasks(ctx context.Context) { + count, err := m.queue.CleanupOld(ctx, m.cleanupAge) + if err != nil { + m.logger.Warn("failed to cleanup old tasks", "error", err) + return + } + if count > 0 { + m.logger.Info("cleaned up old tasks", "count", count) + } +} + +// refreshMetrics fetches queue stats and updates Prometheus gauges. +func (m *QueueMaintenance) refreshMetrics() { + ctx, cancel := context.WithTimeout(m.ctx, 5*time.Second) + defer cancel() + + stats, err := m.queue.GetStats(ctx) + if err != nil { + m.logger.Warn("failed to get queue stats for metrics", "error", err) + return + } + + metrics.SetWorkQueueDepth("pending", stats.Pending) + metrics.SetWorkQueueDepth("running", stats.Running) + metrics.SetWorkQueueDepth("completed", stats.Completed) + metrics.SetWorkQueueDepth("failed", stats.Failed) + metrics.SetWorkQueueDepth("cancelled", stats.Cancelled) + + // Worker counts + workers, err := m.registry.List(ctx, port.WorkerFilter{}) + if err != nil { + m.logger.Warn("failed to list workers for metrics", "error", err) + return + } + + counts := map[string]int{ + "idle": 0, "busy": 0, "draining": 0, "offline": 0, + } + for _, w := range workers { + counts[string(w.Status)]++ + age := time.Since(w.LastHeartbeat).Seconds() + metrics.RecordWorkerHeartbeat(w.ID, age) + } + for status, count := range counts { + metrics.SetWorkerCount(status, count) + } +} diff --git a/internal/worker/queue_maintenance_test.go b/internal/worker/queue_maintenance_test.go new file mode 100644 index 0000000..df20a34 --- /dev/null +++ b/internal/worker/queue_maintenance_test.go @@ -0,0 +1,297 @@ +package worker + +import ( + "context" + "log/slog" + "sync" + "testing" + "time" + + "github.com/orchard9/rdev/internal/domain" + "github.com/orchard9/rdev/internal/port" +) + +// mockMaintenanceQueue implements port.WorkQueue for maintenance tests. +type mockMaintenanceQueue struct { + mu sync.Mutex + requeueCalls int + cleanupCalls int + statsCalls int + requeueCount int64 + cleanupCount int64 + stats *domain.WorkQueueStats + err error +} + +func newMockMaintenanceQueue() *mockMaintenanceQueue { + return &mockMaintenanceQueue{ + stats: &domain.WorkQueueStats{ + Pending: 5, + Running: 2, + Completed: 100, + Failed: 3, + Cancelled: 1, + }, + } +} + +func (m *mockMaintenanceQueue) Enqueue(_ context.Context, _ *domain.WorkTask) (string, error) { + return "", nil +} + +func (m *mockMaintenanceQueue) Dequeue(_ context.Context, _ string) (*domain.WorkTask, error) { + return nil, nil +} + +func (m *mockMaintenanceQueue) Complete(_ context.Context, _ string, _ *domain.WorkResult) error { + return nil +} + +func (m *mockMaintenanceQueue) Fail(_ context.Context, _ string, _ string) error { + return nil +} + +func (m *mockMaintenanceQueue) Cancel(_ context.Context, _ string) error { + return nil +} + +func (m *mockMaintenanceQueue) GetTask(_ context.Context, _ string) (*domain.WorkTask, error) { + return nil, nil +} + +func (m *mockMaintenanceQueue) ListByProject(_ context.Context, _ string, _ *domain.WorkTaskStatus, _ domain.WorkListOptions) (*domain.WorkListResult, error) { + return nil, nil +} + +func (m *mockMaintenanceQueue) GetStats(_ context.Context) (*domain.WorkQueueStats, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.statsCalls++ + if m.err != nil { + return nil, m.err + } + return m.stats, nil +} + +func (m *mockMaintenanceQueue) CleanupOld(_ context.Context, _ time.Duration) (int64, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.cleanupCalls++ + if m.err != nil { + return 0, m.err + } + return m.cleanupCount, nil +} + +func (m *mockMaintenanceQueue) RequeueStale(_ context.Context, _ time.Duration) (int64, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.requeueCalls++ + if m.err != nil { + return 0, m.err + } + return m.requeueCount, nil +} + +// mockMaintenanceRegistry implements port.WorkerRegistry for maintenance tests. +type mockMaintenanceRegistry struct { + mu sync.Mutex + markStaleCalls int + markStaleCount int + workers []*domain.Worker + err error +} + +func newMockMaintenanceRegistry() *mockMaintenanceRegistry { + return &mockMaintenanceRegistry{ + workers: []*domain.Worker{ + { + ID: "worker-1", + Status: domain.WorkerStatusIdle, + LastHeartbeat: time.Now(), + }, + { + ID: "worker-2", + Status: domain.WorkerStatusBusy, + LastHeartbeat: time.Now().Add(-5 * time.Minute), + }, + }, + } +} + +func (m *mockMaintenanceRegistry) Register(_ context.Context, _ *domain.Worker) error { + return nil +} + +func (m *mockMaintenanceRegistry) Heartbeat(_ context.Context, _ string) error { + return nil +} + +func (m *mockMaintenanceRegistry) UpdateStatus(_ context.Context, _ string, _ domain.WorkerStatus, _ string) error { + return nil +} + +func (m *mockMaintenanceRegistry) Deregister(_ context.Context, _ string) error { + return nil +} + +func (m *mockMaintenanceRegistry) Get(_ context.Context, _ string) (*domain.Worker, error) { + return nil, nil +} + +func (m *mockMaintenanceRegistry) List(_ context.Context, _ port.WorkerFilter) ([]*domain.Worker, error) { + m.mu.Lock() + defer m.mu.Unlock() + if m.err != nil { + return nil, m.err + } + return m.workers, nil +} + +func (m *mockMaintenanceRegistry) MarkStaleOffline(_ context.Context, _ time.Duration) (int, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.markStaleCalls++ + if m.err != nil { + return 0, m.err + } + return m.markStaleCount, nil +} + +func TestQueueMaintenance_DefaultConfig(t *testing.T) { + cfg := DefaultQueueMaintenanceConfig() + + if cfg.StaleTaskTimeout != 30*time.Minute { + t.Errorf("got StaleTaskTimeout=%v, want 30m", cfg.StaleTaskTimeout) + } + if cfg.StaleWorkerTimeout != 2*time.Minute { + t.Errorf("got StaleWorkerTimeout=%v, want 2m", cfg.StaleWorkerTimeout) + } + if cfg.CleanupAge != 7*24*time.Hour { + t.Errorf("got CleanupAge=%v, want 7d", cfg.CleanupAge) + } + if cfg.MaintenancePeriod != 1*time.Minute { + t.Errorf("got MaintenancePeriod=%v, want 1m", cfg.MaintenancePeriod) + } + if cfg.MetricsPeriod != 15*time.Second { + t.Errorf("got MetricsPeriod=%v, want 15s", cfg.MetricsPeriod) + } +} + +func TestQueueMaintenance_RunMaintenance(t *testing.T) { + queue := newMockMaintenanceQueue() + queue.requeueCount = 2 + queue.cleanupCount = 5 + + registry := newMockMaintenanceRegistry() + registry.markStaleCount = 1 + + cfg := &QueueMaintenanceConfig{ + StaleTaskTimeout: 30 * time.Minute, + StaleWorkerTimeout: 2 * time.Minute, + CleanupAge: 7 * 24 * time.Hour, + MaintenancePeriod: 1 * time.Hour, // won't fire in test + MetricsPeriod: 1 * time.Hour, // won't fire in test + Logger: slog.Default(), + } + + m := NewQueueMaintenance(queue, registry, cfg) + + // Run maintenance directly + m.runMaintenance() + + queue.mu.Lock() + defer queue.mu.Unlock() + registry.mu.Lock() + defer registry.mu.Unlock() + + if queue.requeueCalls != 1 { + t.Errorf("got requeueCalls=%d, want 1", queue.requeueCalls) + } + if queue.cleanupCalls != 1 { + t.Errorf("got cleanupCalls=%d, want 1", queue.cleanupCalls) + } + if registry.markStaleCalls != 1 { + t.Errorf("got markStaleCalls=%d, want 1", registry.markStaleCalls) + } +} + +func TestQueueMaintenance_RefreshMetrics(t *testing.T) { + queue := newMockMaintenanceQueue() + registry := newMockMaintenanceRegistry() + + cfg := &QueueMaintenanceConfig{ + StaleTaskTimeout: 30 * time.Minute, + StaleWorkerTimeout: 2 * time.Minute, + CleanupAge: 7 * 24 * time.Hour, + MaintenancePeriod: 1 * time.Hour, + MetricsPeriod: 1 * time.Hour, + Logger: slog.Default(), + } + + m := NewQueueMaintenance(queue, registry, cfg) + + // Run metrics refresh directly + m.refreshMetrics() + + queue.mu.Lock() + if queue.statsCalls != 1 { + t.Errorf("got statsCalls=%d, want 1", queue.statsCalls) + } + queue.mu.Unlock() +} + +func TestQueueMaintenance_StartStop(t *testing.T) { + queue := newMockMaintenanceQueue() + registry := newMockMaintenanceRegistry() + + cfg := &QueueMaintenanceConfig{ + StaleTaskTimeout: 30 * time.Minute, + StaleWorkerTimeout: 2 * time.Minute, + CleanupAge: 7 * 24 * time.Hour, + MaintenancePeriod: 50 * time.Millisecond, + MetricsPeriod: 50 * time.Millisecond, + Logger: slog.Default(), + } + + m := NewQueueMaintenance(queue, registry, cfg) + m.Start() + + // Poll until maintenance has run at least once (runs immediately on start) + deadline := time.After(2 * time.Second) + for { + queue.mu.Lock() + rCalls := queue.requeueCalls + sCalls := queue.statsCalls + queue.mu.Unlock() + registry.mu.Lock() + mCalls := registry.markStaleCalls + registry.mu.Unlock() + + if rCalls >= 1 && sCalls >= 1 && mCalls >= 1 { + break + } + + select { + case <-deadline: + t.Fatalf("timed out waiting for maintenance to run: requeue=%d stats=%d markStale=%d", rCalls, sCalls, mCalls) + default: + time.Sleep(10 * time.Millisecond) + } + } + + m.Stop() +} + +func TestQueueMaintenance_NilConfig(t *testing.T) { + queue := newMockMaintenanceQueue() + registry := newMockMaintenanceRegistry() + + m := NewQueueMaintenance(queue, registry, nil) + if m.staleTaskTimeout != 30*time.Minute { + t.Errorf("expected default stale task timeout, got %v", m.staleTaskTimeout) + } + if m.metricsPeriod != 15*time.Second { + t.Errorf("expected default metrics period, got %v", m.metricsPeriod) + } +} diff --git a/internal/worker/work_executor.go b/internal/worker/work_executor.go new file mode 100644 index 0000000..5926d9c --- /dev/null +++ b/internal/worker/work_executor.go @@ -0,0 +1,289 @@ +// Package worker provides background workers for async task processing. +package worker + +import ( + "context" + "fmt" + "log/slog" + "os" + "sync" + "sync/atomic" + "time" + + "github.com/orchard9/rdev/internal/domain" + "github.com/orchard9/rdev/internal/metrics" + "github.com/orchard9/rdev/internal/service" +) + +// WorkExecutor is a background daemon that polls the work queue for tasks +// and executes them via task-type-specific handlers. It self-registers as +// a worker, sends heartbeats, and reports results. +type WorkExecutor struct { + workerSvc *service.WorkerService + workSvc *service.WorkService + buildExec *BuildExecutor + logger *slog.Logger + + workerID string + hostname string + version string + capabilities []string + pollPeriod time.Duration + hbPeriod time.Duration + taskTimeout time.Duration + + started int32 // atomic flag to prevent double-start + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup +} + +// WorkExecutorConfig holds configuration for the work executor. +type WorkExecutorConfig struct { + // WorkerID uniquely identifies this executor instance. + // Defaults to HOSTNAME env var or "rdev-worker-0". + WorkerID string + + // Version reported to the worker registry. + Version string + + // Capabilities reported to the worker registry. + Capabilities []string + + // PollPeriod is how often to check for new tasks. + PollPeriod time.Duration + + // HeartbeatPeriod is how often to send heartbeats. + HeartbeatPeriod time.Duration + + // TaskTimeout is the maximum time a single task may run. + // Default: 15 minutes. + TaskTimeout time.Duration + + Logger *slog.Logger +} + +// DefaultWorkExecutorConfig returns sensible defaults. +func DefaultWorkExecutorConfig() *WorkExecutorConfig { + workerID := os.Getenv("HOSTNAME") + if workerID == "" { + workerID = "rdev-worker-0" + } + return &WorkExecutorConfig{ + WorkerID: workerID, + Capabilities: []string{"build"}, + PollPeriod: 5 * time.Second, + HeartbeatPeriod: 30 * time.Second, + TaskTimeout: 15 * time.Minute, + Logger: slog.Default(), + } +} + +// NewWorkExecutor creates a new work executor daemon. +func NewWorkExecutor( + workerSvc *service.WorkerService, + workSvc *service.WorkService, + buildExec *BuildExecutor, + cfg *WorkExecutorConfig, +) *WorkExecutor { + if cfg == nil { + cfg = DefaultWorkExecutorConfig() + } + + hostname, _ := os.Hostname() + if hostname == "" { + hostname = cfg.WorkerID + } + + ctx, cancel := context.WithCancel(context.Background()) + + capabilities := cfg.Capabilities + if len(capabilities) == 0 { + capabilities = []string{"build"} + } + + taskTimeout := cfg.TaskTimeout + if taskTimeout == 0 { + taskTimeout = 15 * time.Minute + } + + return &WorkExecutor{ + workerSvc: workerSvc, + workSvc: workSvc, + buildExec: buildExec, + logger: cfg.Logger.With("component", "work-executor"), + workerID: cfg.WorkerID, + hostname: hostname, + version: cfg.Version, + capabilities: capabilities, + pollPeriod: cfg.PollPeriod, + hbPeriod: cfg.HeartbeatPeriod, + taskTimeout: taskTimeout, + ctx: ctx, + cancel: cancel, + } +} + +// Start registers the worker and begins the poll and heartbeat loops. +func (e *WorkExecutor) Start() error { + if !atomic.CompareAndSwapInt32(&e.started, 0, 1) { + return fmt.Errorf("executor already started") + } + + // Register this worker in the pool + worker := &domain.Worker{ + ID: e.workerID, + Hostname: e.hostname, + Capabilities: e.capabilities, + Version: e.version, + } + if err := e.workerSvc.Register(e.ctx, worker); err != nil { + return err + } + + e.logger.Info("work executor started", + "worker_id", e.workerID, + "poll_period", e.pollPeriod, + "heartbeat_period", e.hbPeriod, + ) + + // Start heartbeat loop + e.wg.Add(1) + go e.heartbeatLoop() + + // Start poll loop + e.wg.Add(1) + go e.pollLoop() + + return nil +} + +// Stop gracefully shuts down the executor. +func (e *WorkExecutor) Stop() { + e.logger.Info("work executor stopping", "worker_id", e.workerID) + e.cancel() + e.wg.Wait() + + // Deregister (best-effort, context is cancelled so use a fresh one) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := e.workerSvc.Deregister(ctx, e.workerID); err != nil { + e.logger.Warn("failed to deregister worker", "error", err) + } + + e.logger.Info("work executor stopped", "worker_id", e.workerID) +} + +// WorkerID returns the executor's worker ID. +func (e *WorkExecutor) WorkerID() string { + return e.workerID +} + +// Running returns true if the executor context has not been cancelled. +func (e *WorkExecutor) Running() bool { + return e.ctx.Err() == nil +} + +// heartbeatLoop sends periodic heartbeats to the worker registry. +func (e *WorkExecutor) heartbeatLoop() { + defer e.wg.Done() + + ticker := time.NewTicker(e.hbPeriod) + defer ticker.Stop() + + for { + select { + case <-e.ctx.Done(): + return + case <-ticker.C: + if err := e.workerSvc.Heartbeat(e.ctx, e.workerID); err != nil { + e.logger.Warn("heartbeat failed", "error", err) + } + } + } +} + +// pollLoop checks for available tasks on a ticker. +func (e *WorkExecutor) pollLoop() { + defer e.wg.Done() + + ticker := time.NewTicker(e.pollPeriod) + defer ticker.Stop() + + for { + select { + case <-e.ctx.Done(): + return + case <-ticker.C: + e.tryClaimAndExecute() + } + } +} + +// tryClaimAndExecute attempts to claim a task and execute it. +func (e *WorkExecutor) tryClaimAndExecute() { + task, err := e.workerSvc.ClaimTask(e.ctx, e.workerID) + if err != nil { + e.logger.Warn("failed to claim task", "error", err) + return + } + if task == nil { + return // No tasks available + } + + e.logger.Info("executing task", + "task_id", task.ID, + "project_id", task.ProjectID, + "type", task.Type, + ) + + taskCtx, taskCancel := context.WithTimeout(e.ctx, e.taskTimeout) + defer taskCancel() + + result := e.executeTask(taskCtx, task) + + // Record build metrics + status := "success" + if !result.Success { + status = "failed" + } + metrics.RecordBuild(task.ProjectID, status, result.DurationMs) + + if result.Success { + if err := e.workerSvc.CompleteTask(e.ctx, e.workerID, task.ID, result); err != nil { + e.logger.Error("failed to complete task", + "task_id", task.ID, + "error", err, + ) + } + } else { + // Fail the task through work service (handles retry logic) + errMsg := result.Error + if errMsg == "" { + errMsg = "execution failed" + } + if err := e.workSvc.FailTask(e.ctx, task.ID, errMsg); err != nil { + e.logger.Error("failed to record task failure", + "task_id", task.ID, + "error", err, + ) + } + // Return worker to idle regardless + if err := e.workerSvc.Heartbeat(e.ctx, e.workerID); err != nil { + e.logger.Warn("failed to heartbeat after failure", "error", err) + } + } +} + +// executeTask routes a task to the appropriate handler based on its type. +func (e *WorkExecutor) executeTask(ctx context.Context, task *domain.WorkTask) *domain.BuildResult { + switch task.Type { + case domain.WorkTaskTypeBuild: + return e.buildExec.Execute(ctx, task) + default: + return &domain.BuildResult{ + Success: false, + Error: "unsupported task type: " + string(task.Type), + } + } +} diff --git a/internal/worker/work_executor_test.go b/internal/worker/work_executor_test.go new file mode 100644 index 0000000..37b3a09 --- /dev/null +++ b/internal/worker/work_executor_test.go @@ -0,0 +1,346 @@ +package worker + +import ( + "context" + "fmt" + "log/slog" + "testing" + "time" + + "github.com/orchard9/rdev/internal/domain" +) + +// ============================================================================= +// WorkExecutor Tests +// ============================================================================= + +func testLogger() *slog.Logger { + return slog.Default() +} + +func TestWorkExecutor_StartAndStop(t *testing.T) { + deps := newTestDeps() + + executor := NewWorkExecutor(deps.workerSvc, deps.workSvc, deps.buildExec, &WorkExecutorConfig{ + WorkerID: "test-worker-1", + PollPeriod: 100 * time.Millisecond, + HeartbeatPeriod: 100 * time.Millisecond, + Logger: testLogger(), + }) + + if err := executor.Start(); err != nil { + t.Fatalf("Start() error = %v", err) + } + + // Verify worker was registered + deps.registry.mu.Lock() + w, exists := deps.registry.workers["test-worker-1"] + deps.registry.mu.Unlock() + if !exists { + t.Fatal("expected worker to be registered") + } + if w.Status != domain.WorkerStatusIdle { + t.Errorf("got status %q, want %q", w.Status, domain.WorkerStatusIdle) + } + + // Verify double-start returns error + if err := executor.Start(); err == nil { + t.Error("expected error on double-start") + } + + executor.Stop() + + // Verify worker was deregistered + deps.registry.mu.Lock() + _, exists = deps.registry.workers["test-worker-1"] + deps.registry.mu.Unlock() + if exists { + t.Error("expected worker to be deregistered after stop") + } +} + +func TestWorkExecutor_ClaimsAndExecutesTask(t *testing.T) { + deps := newTestDeps() + + // Enqueue a build task + deps.queue.mu.Lock() + deps.queue.tasks["task-1"] = &domain.WorkTask{ + ID: "task-1", + ProjectID: "project-1", + Type: domain.WorkTaskTypeBuild, + Status: domain.WorkTaskStatusPending, + Spec: map[string]any{"prompt": "Build a landing page"}, + MaxRetries: 3, + CreatedAt: time.Now(), + } + deps.queue.mu.Unlock() + + executor := NewWorkExecutor(deps.workerSvc, deps.workSvc, deps.buildExec, &WorkExecutorConfig{ + WorkerID: "test-worker-2", + PollPeriod: 50 * time.Millisecond, + HeartbeatPeriod: 5 * time.Second, + Logger: testLogger(), + }) + + // Register the worker (normally done by Start) then call tryClaimAndExecute directly + if err := executor.Start(); err != nil { + t.Fatalf("Start() error = %v", err) + } + + // Call tryClaimAndExecute directly to avoid timing dependency + executor.tryClaimAndExecute() + executor.Stop() + + // Verify task was completed + deps.queue.mu.Lock() + task := deps.queue.tasks["task-1"] + deps.queue.mu.Unlock() + + if task.Status != domain.WorkTaskStatusCompleted { + t.Errorf("got task status %q, want %q", task.Status, domain.WorkTaskStatusCompleted) + } +} + +func TestWorkExecutor_FailsTaskOnAgentError(t *testing.T) { + deps := newTestDeps() + deps.agent.err = fmt.Errorf("agent crashed") + + // Enqueue a build task + deps.queue.mu.Lock() + deps.queue.tasks["task-1"] = &domain.WorkTask{ + ID: "task-1", + ProjectID: "project-1", + Type: domain.WorkTaskTypeBuild, + Status: domain.WorkTaskStatusPending, + Spec: map[string]any{"prompt": "Build something"}, + MaxRetries: 3, + CreatedAt: time.Now(), + } + deps.queue.mu.Unlock() + + executor := NewWorkExecutor(deps.workerSvc, deps.workSvc, deps.buildExec, &WorkExecutorConfig{ + WorkerID: "test-worker-3", + PollPeriod: 50 * time.Millisecond, + HeartbeatPeriod: 5 * time.Second, + Logger: testLogger(), + }) + + if err := executor.Start(); err != nil { + t.Fatalf("Start() error = %v", err) + } + + // Call tryClaimAndExecute directly for each retry to avoid timing dependency + for i := 0; i < 3; i++ { + executor.tryClaimAndExecute() + } + executor.Stop() + + // Task should be permanently failed after all retries. + deps.queue.mu.Lock() + task := deps.queue.tasks["task-1"] + deps.queue.mu.Unlock() + + if task.Status != domain.WorkTaskStatusFailed { + t.Errorf("got task status %q, want %q (should be permanently failed after retries)", task.Status, domain.WorkTaskStatusFailed) + } + if task.RetryCount < 3 { + t.Errorf("expected retry_count >= 3, got %d", task.RetryCount) + } +} + +func TestWorkExecutor_UnsupportedTaskType(t *testing.T) { + deps := newTestDeps() + + // Enqueue a custom task (not build) + deps.queue.mu.Lock() + deps.queue.tasks["task-1"] = &domain.WorkTask{ + ID: "task-1", + ProjectID: "project-1", + Type: domain.WorkTaskTypeCustom, + Status: domain.WorkTaskStatusPending, + Spec: map[string]any{"prompt": "Do something custom"}, + MaxRetries: 1, + CreatedAt: time.Now(), + } + deps.queue.mu.Unlock() + + executor := NewWorkExecutor(deps.workerSvc, deps.workSvc, deps.buildExec, &WorkExecutorConfig{ + WorkerID: "test-worker-4", + PollPeriod: 50 * time.Millisecond, + HeartbeatPeriod: 5 * time.Second, + Logger: testLogger(), + }) + + if err := executor.Start(); err != nil { + t.Fatalf("Start() error = %v", err) + } + + // Call tryClaimAndExecute directly to avoid timing dependency + executor.tryClaimAndExecute() + executor.Stop() + + // Should fail because custom tasks are unsupported + deps.queue.mu.Lock() + task := deps.queue.tasks["task-1"] + deps.queue.mu.Unlock() + + // With maxRetries=1 and retryCount=1, it should be permanently failed + if task.Status != domain.WorkTaskStatusFailed { + t.Errorf("got task status %q, want %q", task.Status, domain.WorkTaskStatusFailed) + } +} + +// ============================================================================= +// BuildExecutor Tests +// ============================================================================= + +func TestBuildExecutor_Execute(t *testing.T) { + t.Run("successful build", func(t *testing.T) { + agent := &mockCodeAgent{ + result: &domain.AgentResult{ExitCode: 0, DurationMs: 500}, + } + registry := &mockCodeAgentRegistry{agent: agent} + exec := NewBuildExecutor(registry, nil, nil) + + task := &domain.WorkTask{ + ID: "task-1", + ProjectID: "project-1", + Type: domain.WorkTaskTypeBuild, + Spec: map[string]any{"prompt": "Build a landing page"}, + } + + result := exec.Execute(context.Background(), task) + if !result.Success { + t.Errorf("expected success, got error: %s", result.Error) + } + if result.DurationMs < 0 { + t.Errorf("expected non-negative duration, got %d", result.DurationMs) + } + }) + + t.Run("missing prompt", func(t *testing.T) { + registry := &mockCodeAgentRegistry{agent: &mockCodeAgent{}} + exec := NewBuildExecutor(registry, nil, nil) + + task := &domain.WorkTask{ + ID: "task-1", + Type: domain.WorkTaskTypeBuild, + Spec: map[string]any{}, + } + + result := exec.Execute(context.Background(), task) + if result.Success { + t.Error("expected failure for missing prompt") + } + }) + + t.Run("no agent available", func(t *testing.T) { + registry := &mockCodeAgentRegistry{agent: nil} + exec := NewBuildExecutor(registry, nil, nil) + + task := &domain.WorkTask{ + ID: "task-1", + Type: domain.WorkTaskTypeBuild, + Spec: map[string]any{"prompt": "Build something"}, + } + + result := exec.Execute(context.Background(), task) + if result.Success { + t.Error("expected failure when no agent available") + } + }) + + t.Run("agent execution error", func(t *testing.T) { + agent := &mockCodeAgent{err: fmt.Errorf("connection refused")} + registry := &mockCodeAgentRegistry{agent: agent} + exec := NewBuildExecutor(registry, nil, nil) + + task := &domain.WorkTask{ + ID: "task-1", + Type: domain.WorkTaskTypeBuild, + Spec: map[string]any{"prompt": "Build something"}, + } + + result := exec.Execute(context.Background(), task) + if result.Success { + t.Error("expected failure on agent error") + } + if result.Error == "" { + t.Error("expected error message") + } + }) + + t.Run("agent non-zero exit code", func(t *testing.T) { + agent := &mockCodeAgent{ + result: &domain.AgentResult{ExitCode: 1, DurationMs: 500}, + } + registry := &mockCodeAgentRegistry{agent: agent} + exec := NewBuildExecutor(registry, nil, nil) + + task := &domain.WorkTask{ + ID: "task-1", + Type: domain.WorkTaskTypeBuild, + Spec: map[string]any{"prompt": "Build something"}, + } + + result := exec.Execute(context.Background(), task) + if result.Success { + t.Error("expected failure on non-zero exit code") + } + }) +} + +func TestBuildExecutor_ParseSpec(t *testing.T) { + exec := NewBuildExecutor(nil, nil, nil) + + t.Run("valid spec", func(t *testing.T) { + spec, err := exec.parseSpec(map[string]any{ + "prompt": "Build a page", + "template": "astro-landing", + "auto_commit": true, + "auto_push": true, + }) + if err != nil { + t.Fatalf("parseSpec() error = %v", err) + } + if spec.Prompt != "Build a page" { + t.Errorf("got prompt %q", spec.Prompt) + } + if !spec.AutoCommit { + t.Error("expected auto_commit = true") + } + if !spec.AutoPush { + t.Error("expected auto_push = true") + } + }) + + t.Run("missing prompt", func(t *testing.T) { + _, err := exec.parseSpec(map[string]any{ + "template": "astro-landing", + }) + if err == nil { + t.Error("expected error for missing prompt") + } + }) +} + +func TestTruncate(t *testing.T) { + tests := []struct { + input string + maxLen int + want string + }{ + {"short", 10, "short"}, + {"exactly ten", 11, "exactly ten"}, + {"this is a long string", 10, "this is..."}, + {"abc", 3, "abc"}, + {"abcd", 3, "abc"}, + } + + for _, tt := range tests { + got := truncate(tt.input, tt.maxLen) + if got != tt.want { + t.Errorf("truncate(%q, %d) = %q, want %q", tt.input, tt.maxLen, got, tt.want) + } + } +}