feat: Add CI pipeline proxy, DNS alias management, and worker executor system
- Add ListPipelines/GetPipeline to CIProvider port with Woodpecker adapter
- Add DNS alias endpoints: GET/POST/DELETE /projects/{id}/domains
- Implement worker executor daemon, build executor, and git operations
- Add build service, worker service, and build audit tracking
- Add worker registry with PostgreSQL adapter and migration
- Add multi-provider code agent interface (Claude Code + OpenCode)
- Add create-and-build combo endpoint
- Update landing-page cookbook to reflect all gaps closed
- Fix tech debt: unified validation, auth scopes, error wrapping, slog patterns
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
39df51defd
commit
bc47e426b0
@ -1,77 +1,59 @@
|
||||
# Build Orchestration
|
||||
|
||||
**Last Updated:** 2025-01
|
||||
**Confidence:** High (Planned - see address-the-gaps.md)
|
||||
**Last Updated:** 2026-01-27
|
||||
**Confidence:** High
|
||||
|
||||
## Summary
|
||||
|
||||
Build orchestration enables structured build specs for bot-driven development. Bots submit build requests with prompts, workers execute, and callbacks notify completion.
|
||||
Build orchestration enables structured build specs for bot-driven development. Bots submit build requests with prompts and templates via `POST /project/{name}/build`, workers execute Claude Code, and callbacks notify completion. All builds are recorded in the `build_audit` table for observability.
|
||||
|
||||
**Key Facts:**
|
||||
- Build spec includes template, prompt, variables, auto_deploy flag
|
||||
- Enqueues as work task for worker pool
|
||||
- Auto-deploy commits, pushes, triggers Woodpecker CI
|
||||
- Callback URL notified on completion with artifacts
|
||||
- BuildSpec: prompt (required), template, variables, auto_commit, auto_push, callback_url
|
||||
- BuildResult: success, output, error, commit_sha, files_changed, duration_ms, artifacts
|
||||
- Builds enqueued as work tasks for the worker pool
|
||||
- Auto-commit/push triggers Woodpecker CI pipeline
|
||||
- Callback URL receives completion notification with full BuildResult
|
||||
- Complete audit trail in `build_audit` PostgreSQL table
|
||||
|
||||
**File Pointers:**
|
||||
- Domain: `internal/domain/build.go` (BuildSpec, BuildResult, BuildAuditEntry)
|
||||
- Port: `internal/port/build_audit.go` (BuildAudit interface)
|
||||
- Adapter: `internal/adapter/postgres/build_audit.go`
|
||||
- Service: `internal/service/build_service.go`
|
||||
- Handler: `internal/handlers/build.go`
|
||||
- Work queue: `internal/port/work_queue.go`
|
||||
- Handler: `internal/handlers/builds.go` (StartBuild, ListBuilds, GetBuild)
|
||||
- Handler: `internal/handlers/create_and_build.go` (CreateAndBuild)
|
||||
- Executor: `internal/worker/build_executor.go` (BuildSpec→AgentRequest translation)
|
||||
- Git: `internal/worker/git_operations.go` (clone, commit, push with token injection)
|
||||
- Migration: `internal/db/migrations/012_worker_registry.sql` (build_audit table)
|
||||
|
||||
## Build Spec Schema
|
||||
## API Endpoints
|
||||
|
||||
```go
|
||||
type BuildSpec struct {
|
||||
Template string `json:"template"`
|
||||
Prompt string `json:"prompt"`
|
||||
Variables map[string]string `json:"variables"`
|
||||
AutoDeploy bool `json:"auto_deploy"`
|
||||
CallbackURL string `json:"callback_url"`
|
||||
}
|
||||
```
|
||||
|
||||
## API Endpoint
|
||||
|
||||
```
|
||||
POST /project/{name}/build
|
||||
{
|
||||
"template": "astro-landing",
|
||||
"prompt": "Create a coming soon page with dark theme and threesix.ai branding",
|
||||
"auto_deploy": true,
|
||||
"callback_url": "https://pantheon.orchard9.ai/webhooks/build-complete"
|
||||
}
|
||||
```
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| POST | `/projects/{id}/builds` | Start a build, returns task_id |
|
||||
| GET | `/projects/{id}/builds` | List builds for project |
|
||||
| GET | `/builds/{taskId}` | Get build status and result |
|
||||
| POST | `/project/create-and-build` | Create project + start build in one call |
|
||||
|
||||
## Orchestration Flow
|
||||
|
||||
1. Bot calls `POST /project/{name}/build`
|
||||
2. BuildService validates project exists
|
||||
3. Creates WorkTask with build spec
|
||||
4. Enqueues to work queue
|
||||
5. Returns task ID immediately
|
||||
6. Worker picks up task:
|
||||
- Clones repo
|
||||
- Runs Claude with prompt
|
||||
- Commits and pushes (if auto_deploy)
|
||||
7. Woodpecker builds and deploys
|
||||
8. Callback notified with result
|
||||
1. Bot calls `POST /projects/{id}/builds` with BuildSpec (prompt, template, auto_commit, auto_push)
|
||||
2. BuildService validates spec (prompt required), creates WorkTask with build spec, enqueues
|
||||
3. Creates BuildAuditEntry with status "pending"
|
||||
4. Returns task ID immediately
|
||||
5. WorkExecutor poll loop claims task from queue
|
||||
6. BuildExecutor translates spec: clones repo, builds AgentRequest, calls CodeAgent.Execute()
|
||||
7. On success with auto_commit: GitOperations commits and pushes changes
|
||||
8. WorkExecutor reports completion with BuildResult
|
||||
9. Audit entry updated, callback URL notified
|
||||
|
||||
## Callback Payload
|
||||
## Build Audit Statuses
|
||||
|
||||
```json
|
||||
{
|
||||
"task_id": "uuid",
|
||||
"project_id": "myapp",
|
||||
"status": "completed",
|
||||
"result": {
|
||||
"output": "...",
|
||||
"artifacts": {
|
||||
"commit_sha": "abc123",
|
||||
"deploy_url": "https://myapp.threesix.ai"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
- `pending` - enqueued, waiting for worker
|
||||
- `running` - worker executing
|
||||
- `completed` - finished successfully
|
||||
- `failed` - execution failed
|
||||
- `cancelled` - cancelled before completion
|
||||
|
||||
## Related Topics
|
||||
|
||||
|
||||
@ -10,16 +10,16 @@ Quick reference for rdev concepts and facts.
|
||||
| Project Service | [services/project-service.md](./services/project-service.md) | High | 2025-01 | Business logic for project operations |
|
||||
| API Keys | [services/api-keys.md](./services/api-keys.md) | High | 2025-01 | Authentication, scopes, restrictions |
|
||||
| Webhooks | [services/webhooks.md](./services/webhooks.md) | High | 2025-01 | Event subscriptions and delivery |
|
||||
| **Worker Infrastructure** (Planned) |
|
||||
| **Worker Infrastructure** |
|
||||
| Work Queue | [services/work-queue.md](./services/work-queue.md) | High | 2025-01 | Task queue for worker pool |
|
||||
| Worker Pool | [services/worker-pool.md](./services/worker-pool.md) | High | 2025-01 | Shared claudebox workers |
|
||||
| Worker Pool | [services/worker-pool.md](./services/worker-pool.md) | High | 2026-01 | Embedded work executor with queue maintenance and metrics |
|
||||
| CI Provider | [services/ci-provider.md](./services/ci-provider.md) | High | 2025-01 | Woodpecker auto-activation |
|
||||
| Template Provider | [services/template-provider.md](./services/template-provider.md) | High | 2025-01 | Project template seeding |
|
||||
| **Features** |
|
||||
| Command Execution | [features/command-execution.md](./features/command-execution.md) | High | 2025-01 | Claude/shell/git command flow |
|
||||
| SSE Streaming | [features/sse-streaming.md](./features/sse-streaming.md) | High | 2025-01 | Real-time output streaming |
|
||||
| Infrastructure Management | [features/infrastructure.md](./features/infrastructure.md) | High | 2025-01 | Gitea, Cloudflare, deployment |
|
||||
| Build Orchestration | [features/build-orchestration.md](./features/build-orchestration.md) | High | 2025-01 | Bot-driven build specs |
|
||||
| Build Orchestration | [features/build-orchestration.md](./features/build-orchestration.md) | High | 2026-01 | Bot-driven build specs with audit trail |
|
||||
|
||||
## Roadmap Reference
|
||||
|
||||
|
||||
@ -1,59 +1,74 @@
|
||||
# Worker Pool
|
||||
|
||||
**Last Updated:** 2025-01
|
||||
**Confidence:** High (Planned - see address-the-gaps.md)
|
||||
**Last Updated:** 2026-01-27
|
||||
**Confidence:** High
|
||||
|
||||
## Summary
|
||||
|
||||
Shared pool of claudebox workers (3-5 pods) that can build any project. Workers register, send heartbeats, and poll for tasks. Scales horizontally by adding workers, not projects.
|
||||
Shared worker pool that executes build tasks for any project. Currently runs as an embedded WorkExecutor daemon inside rdev-api. Workers register with the worker registry, poll the work queue for tasks, execute Claude Code in cloned repos via GitOperations, and report results with audit trails.
|
||||
|
||||
**Key Facts:**
|
||||
- Workers labeled `rdev.orchard9.ai/role=worker`
|
||||
- StatefulSet: `claudebox-worker` with 3+ replicas
|
||||
- Each worker has dedicated PVC for workspace
|
||||
- Workers poll rdev-api for tasks every 5 seconds
|
||||
- Health tracked via heartbeat endpoint
|
||||
- Embedded WorkExecutor daemon runs inside rdev-api process
|
||||
- Workers poll work queue every 5 seconds, heartbeat every 30 seconds
|
||||
- Stale workers (no heartbeat for 2 minutes) automatically marked offline by QueueMaintenance
|
||||
- Stale tasks (running >30 min without completion) automatically requeued
|
||||
- Old tasks (>7 days) automatically cleaned up
|
||||
- Queue depth and worker counts exported as Prometheus metrics
|
||||
- Future: external worker binary for separate pod deployment
|
||||
|
||||
**File Pointers:**
|
||||
- Port: `internal/port/worker_registry.go`
|
||||
- Domain: `internal/domain/worker.go` (Worker, WorkerStatus)
|
||||
- Domain: `internal/domain/build.go` (BuildSpec, BuildResult)
|
||||
- Port: `internal/port/worker_registry.go` (WorkerRegistry interface)
|
||||
- Port: `internal/port/build_audit.go` (BuildAudit interface)
|
||||
- Adapter: `internal/adapter/postgres/worker_registry.go`
|
||||
- Handler: `internal/handlers/workers.go`
|
||||
- K8s manifest: `deployments/k8s/base/claudebox-worker.yaml`
|
||||
- Adapter: `internal/adapter/postgres/build_audit.go`
|
||||
- Service: `internal/service/worker_service.go`
|
||||
- Service: `internal/service/build_service.go`
|
||||
- Executor: `internal/worker/work_executor.go` (poll loop, heartbeat, task routing)
|
||||
- Executor: `internal/worker/build_executor.go` (BuildSpec→AgentRequest)
|
||||
- Git: `internal/worker/git_operations.go` (clone, commit, push)
|
||||
- Maintenance: `internal/worker/queue_maintenance.go` (stale recovery, cleanup, metrics)
|
||||
- Handler: `internal/handlers/workers.go` (REST API for workers)
|
||||
- Handler: `internal/handlers/builds.go` (REST API for builds)
|
||||
- Handler: `internal/handlers/create_and_build.go` (combined create+build)
|
||||
- Migration: `internal/db/migrations/012_worker_registry.sql`
|
||||
|
||||
## Port Interface
|
||||
## Worker Lifecycle (Embedded)
|
||||
|
||||
```go
|
||||
type WorkerRegistry interface {
|
||||
Register(ctx context.Context, worker WorkerInfo) error
|
||||
Heartbeat(ctx context.Context, workerID string) error
|
||||
Deregister(ctx context.Context, workerID string) error
|
||||
ListActive(ctx context.Context) ([]WorkerInfo, error)
|
||||
}
|
||||
1. rdev-api starts → WorkExecutor registers as worker in registry
|
||||
2. Heartbeat loop: every 30s sends heartbeat via WorkerService
|
||||
3. Poll loop: every 5s dequeues next task from work queue
|
||||
4. BuildExecutor: clones repo, executes CodeAgent, commits/pushes if auto_commit
|
||||
5. Reports completion with BuildResult via WorkerService
|
||||
6. Graceful shutdown: deregisters worker on rdev-api stop
|
||||
|
||||
type WorkerInfo struct {
|
||||
ID string
|
||||
PodName string
|
||||
Namespace string
|
||||
Status string // "idle", "busy", "unhealthy"
|
||||
LastSeen time.Time
|
||||
CurrentTask string
|
||||
}
|
||||
```
|
||||
## Worker Statuses
|
||||
|
||||
## Worker Lifecycle
|
||||
- `idle` - available for new tasks
|
||||
- `busy` - currently executing a task
|
||||
- `draining` - not accepting new tasks (pre-shutdown)
|
||||
- `offline` - missed heartbeat threshold
|
||||
|
||||
1. Pod starts → calls `POST /workers` to register
|
||||
2. Main loop: heartbeat every 5s, poll for tasks
|
||||
3. Task received → clone repo, run Claude, commit, report
|
||||
4. Pod shutdown → `DELETE /workers/{id}` to deregister
|
||||
## API Endpoints
|
||||
|
||||
## Environment Variables
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| GET | `/workers` | List all workers with status summary |
|
||||
| GET | `/workers/{workerId}` | Get worker details |
|
||||
| POST | `/workers/{workerId}/drain` | Set worker to draining |
|
||||
| POST | `/projects/{id}/builds` | Start build for project |
|
||||
| GET | `/projects/{id}/builds` | List builds for project |
|
||||
| GET | `/builds/{taskId}` | Get build status |
|
||||
| POST | `/project/create-and-build` | Create project + start build |
|
||||
|
||||
```
|
||||
WORKER_ID=$(hostname)
|
||||
RDEV_API_URL=http://rdev-api.rdev.svc:8080
|
||||
RDEV_API_KEY=<worker service key>
|
||||
```
|
||||
## Queue Maintenance
|
||||
|
||||
The QueueMaintenance worker runs inside rdev-api alongside the WorkExecutor:
|
||||
- **Stale task recovery** (every 1m): Requeues tasks running >30m without completion
|
||||
- **Stale worker marking** (every 1m): Marks workers offline after 2m without heartbeat
|
||||
- **Old task cleanup** (every 1m): Removes completed/failed/cancelled tasks >7 days old
|
||||
- **Metrics refresh** (every 15s): Updates Prometheus gauges for queue depth and worker counts
|
||||
|
||||
## Related Topics
|
||||
|
||||
|
||||
179
cmd/rdev-api/config.go
Normal file
179
cmd/rdev-api/config.go
Normal file
@ -0,0 +1,179 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
)
|
||||
|
||||
// Config holds application configuration.
|
||||
type Config struct {
|
||||
Port int
|
||||
DBHost string
|
||||
DBPort int
|
||||
DBUser string
|
||||
DBPassword string
|
||||
DBName string
|
||||
DBSSLMode string
|
||||
AdminKey string
|
||||
|
||||
// Credential store encryption key (required for storing secrets in DB)
|
||||
CredentialEncryptionKey string
|
||||
|
||||
// OpenCode configuration (optional - enables OpenCode as alternative code agent)
|
||||
OpenCodeURL string // e.g., "http://opencode:4096"
|
||||
OpenCodeUsername string // Basic auth username (default: "opencode")
|
||||
OpenCodePassword string // Basic auth password
|
||||
|
||||
// Infrastructure adapters (threesix.ai) - fallback values if not in credential store
|
||||
GiteaURL string
|
||||
GiteaToken string
|
||||
GiteaDefaultOrg string
|
||||
CloudflareToken string
|
||||
CloudflareZoneID string
|
||||
DefaultDomain string
|
||||
DeployNamespace string
|
||||
DeployTLSIssuer string
|
||||
ClusterIP string
|
||||
RegistryURL string
|
||||
WoodpeckerURL string
|
||||
WoodpeckerAPIToken string
|
||||
WoodpeckerWebhookSecret string
|
||||
}
|
||||
|
||||
// InfraConfig holds infrastructure adapter configuration.
|
||||
// Loaded from credential store with env var fallback.
|
||||
type InfraConfig struct {
|
||||
GiteaURL string
|
||||
GiteaToken string
|
||||
GiteaDefaultOrg string
|
||||
CloudflareToken string
|
||||
CloudflareZoneID string
|
||||
DefaultDomain string
|
||||
DeployNamespace string
|
||||
DeployTLSIssuer string
|
||||
ClusterIP string
|
||||
RegistryURL string
|
||||
WoodpeckerURL string
|
||||
WoodpeckerAPIToken string
|
||||
WoodpeckerWebhookSecret string
|
||||
}
|
||||
|
||||
func loadConfig() Config {
|
||||
port := 8080
|
||||
if v := os.Getenv("PORT"); v != "" {
|
||||
if p, err := strconv.Atoi(v); err == nil {
|
||||
port = p
|
||||
}
|
||||
}
|
||||
|
||||
dbPort := 5432
|
||||
if v := os.Getenv("DB_PORT"); v != "" {
|
||||
if p, err := strconv.Atoi(v); err == nil {
|
||||
dbPort = p
|
||||
}
|
||||
}
|
||||
|
||||
return Config{
|
||||
Port: port,
|
||||
DBHost: getEnv("DB_HOST", "postgres.databases.svc"),
|
||||
DBPort: dbPort,
|
||||
DBUser: getEnv("DB_USER", "appuser"),
|
||||
DBPassword: os.Getenv("DB_PASSWORD"),
|
||||
DBName: getEnv("DB_NAME", "rdev"),
|
||||
DBSSLMode: getEnv("DB_SSL_MODE", "disable"),
|
||||
AdminKey: os.Getenv("RDEV_ADMIN_KEY"),
|
||||
|
||||
// Encryption key for credential store (generate with: openssl rand -base64 32)
|
||||
// REQUIRED in production - no default to prevent insecure deployments
|
||||
CredentialEncryptionKey: os.Getenv("CREDENTIAL_ENCRYPTION_KEY"),
|
||||
|
||||
// OpenCode (optional alternative code agent)
|
||||
OpenCodeURL: os.Getenv("OPENCODE_URL"), // e.g., "http://opencode:4096"
|
||||
OpenCodeUsername: getEnv("OPENCODE_USERNAME", "opencode"),
|
||||
OpenCodePassword: os.Getenv("OPENCODE_PASSWORD"),
|
||||
|
||||
// Infrastructure adapters (fallback if not in credential store)
|
||||
GiteaURL: getEnv("GITEA_URL", "https://git.threesix.ai"),
|
||||
GiteaToken: os.Getenv("GITEA_TOKEN"),
|
||||
GiteaDefaultOrg: getEnv("GITEA_DEFAULT_ORG", "jordan"),
|
||||
CloudflareToken: os.Getenv("CLOUDFLARE_API_TOKEN"),
|
||||
CloudflareZoneID: os.Getenv("CLOUDFLARE_ZONE_ID"),
|
||||
DefaultDomain: getEnv("DEFAULT_DOMAIN", "threesix.ai"),
|
||||
DeployNamespace: getEnv("DEPLOY_NAMESPACE", "projects"),
|
||||
DeployTLSIssuer: getEnv("DEPLOY_TLS_ISSUER", "letsencrypt-threesix"),
|
||||
ClusterIP: getEnv("CLUSTER_IP", "208.122.204.172"),
|
||||
RegistryURL: getEnv("REGISTRY_URL", "zot.threesix.svc.cluster.local:5000"),
|
||||
WoodpeckerURL: getEnv("WOODPECKER_URL", "https://ci.threesix.ai"),
|
||||
WoodpeckerAPIToken: os.Getenv("WOODPECKER_API_TOKEN"),
|
||||
WoodpeckerWebhookSecret: os.Getenv("WOODPECKER_WEBHOOK_SECRET"),
|
||||
}
|
||||
}
|
||||
|
||||
func getEnv(key, defaultVal string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
|
||||
// loadInfraConfig loads infrastructure configuration from credential store,
|
||||
// falling back to environment variables if not found in the store.
|
||||
func loadInfraConfig(ctx context.Context, store port.CredentialStore, cfg Config, logger *slog.Logger) InfraConfig {
|
||||
// Try to load from credential store
|
||||
creds, err := store.GetMultiple(ctx, []string{
|
||||
domain.CredKeyGiteaToken,
|
||||
domain.CredKeyGiteaURL,
|
||||
domain.CredKeyCloudflareAPIToken,
|
||||
domain.CredKeyCloudflareZoneID,
|
||||
domain.CredKeyWoodpeckerURL,
|
||||
domain.CredKeyWoodpeckerAPIToken,
|
||||
domain.CredKeyWoodpeckerWebhookSecret,
|
||||
domain.CredKeyRegistryURL,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Warn("failed to load credentials from store, using env vars", "error", err)
|
||||
creds = make(map[string]string)
|
||||
}
|
||||
|
||||
// Helper to get from store or fall back to env var
|
||||
getOrFallback := func(key, envFallback string) string {
|
||||
if v, ok := creds[key]; ok && v != "" {
|
||||
return v
|
||||
}
|
||||
return envFallback
|
||||
}
|
||||
|
||||
infraCfg := InfraConfig{
|
||||
GiteaURL: getOrFallback(domain.CredKeyGiteaURL, cfg.GiteaURL),
|
||||
GiteaToken: getOrFallback(domain.CredKeyGiteaToken, cfg.GiteaToken),
|
||||
GiteaDefaultOrg: cfg.GiteaDefaultOrg, // Not a secret, use env
|
||||
CloudflareToken: getOrFallback(domain.CredKeyCloudflareAPIToken, cfg.CloudflareToken),
|
||||
CloudflareZoneID: getOrFallback(domain.CredKeyCloudflareZoneID, cfg.CloudflareZoneID),
|
||||
DefaultDomain: cfg.DefaultDomain, // Not a secret, use env
|
||||
DeployNamespace: cfg.DeployNamespace, // Not a secret, use env
|
||||
DeployTLSIssuer: cfg.DeployTLSIssuer, // Not a secret, use env
|
||||
ClusterIP: cfg.ClusterIP, // Not a secret, use env
|
||||
RegistryURL: getOrFallback(domain.CredKeyRegistryURL, cfg.RegistryURL),
|
||||
WoodpeckerURL: getOrFallback(domain.CredKeyWoodpeckerURL, cfg.WoodpeckerURL),
|
||||
WoodpeckerAPIToken: getOrFallback(domain.CredKeyWoodpeckerAPIToken, cfg.WoodpeckerAPIToken),
|
||||
WoodpeckerWebhookSecret: getOrFallback(domain.CredKeyWoodpeckerWebhookSecret, cfg.WoodpeckerWebhookSecret),
|
||||
}
|
||||
|
||||
// Log which credentials were loaded from store vs env
|
||||
fromStore := 0
|
||||
for k := range creds {
|
||||
if creds[k] != "" {
|
||||
fromStore++
|
||||
}
|
||||
}
|
||||
if fromStore > 0 {
|
||||
logger.Info("loaded credentials from store", "count", fromStore)
|
||||
}
|
||||
|
||||
return infraCfg
|
||||
}
|
||||
@ -37,10 +37,12 @@ import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/orchard9/rdev/internal/adapter/cloudflare"
|
||||
"github.com/orchard9/rdev/internal/adapter/codeagent"
|
||||
"github.com/orchard9/rdev/internal/adapter/codeagent/claudecode"
|
||||
"github.com/orchard9/rdev/internal/adapter/codeagent/opencode"
|
||||
"github.com/orchard9/rdev/internal/adapter/deployer"
|
||||
"github.com/orchard9/rdev/internal/adapter/gitea"
|
||||
"github.com/orchard9/rdev/internal/adapter/kubernetes"
|
||||
@ -50,11 +52,9 @@ import (
|
||||
"github.com/orchard9/rdev/internal/adapter/woodpecker"
|
||||
"github.com/orchard9/rdev/internal/auth"
|
||||
"github.com/orchard9/rdev/internal/db"
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/handlers"
|
||||
"github.com/orchard9/rdev/internal/metrics"
|
||||
"github.com/orchard9/rdev/internal/middleware"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
"github.com/orchard9/rdev/internal/service"
|
||||
"github.com/orchard9/rdev/internal/telemetry"
|
||||
"github.com/orchard9/rdev/internal/webhook"
|
||||
@ -102,8 +102,10 @@ func main() {
|
||||
}
|
||||
defer func() { _ = database.Close() }()
|
||||
|
||||
// Initialize auth service
|
||||
authService := auth.NewService(database.DB, cfg.AdminKey)
|
||||
// Initialize auth service (hexagonal: repo → service → auth wrapper)
|
||||
apiKeyRepo := postgres.NewAPIKeyRepository(database.DB)
|
||||
apiKeySvc := service.NewAPIKeyService(apiKeyRepo, cfg.AdminKey)
|
||||
authService := auth.NewService(apiKeySvc, cfg.AdminKey)
|
||||
|
||||
// Initialize credential store (for infrastructure secrets)
|
||||
credentialStore := postgres.NewCredentialStore(database.DB, cfg.CredentialEncryptionKey)
|
||||
@ -218,17 +220,49 @@ func main() {
|
||||
logger.Info("template provider initialized")
|
||||
}
|
||||
|
||||
// Initialize CodeAgent registry (multi-provider support)
|
||||
agentRegistry := codeagent.NewRegistry()
|
||||
|
||||
// Register Claude Code adapter (default - always available)
|
||||
claudeCodeAdapter := claudecode.NewAdapter(namespace)
|
||||
agentRegistry.Register(claudeCodeAdapter)
|
||||
logger.Info("registered Claude Code agent", "provider", claudeCodeAdapter.Provider())
|
||||
|
||||
// Register OpenCode adapter (optional - only if configured)
|
||||
if cfg.OpenCodeURL != "" {
|
||||
openCodeAdapter := opencode.NewAdapter(opencode.ClientConfig{
|
||||
BaseURL: cfg.OpenCodeURL,
|
||||
Username: cfg.OpenCodeUsername,
|
||||
Password: cfg.OpenCodePassword,
|
||||
Timeout: 30 * time.Second,
|
||||
})
|
||||
agentRegistry.Register(openCodeAdapter)
|
||||
logger.Info("registered OpenCode agent", "provider", openCodeAdapter.Provider(), "url", cfg.OpenCodeURL)
|
||||
}
|
||||
|
||||
// Create services
|
||||
projectService := service.NewProjectService(projectRepo, k8sExecutor, streamPub).
|
||||
WithAuditLogger(auditLogger).
|
||||
WithCommandQueue(commandQueue).
|
||||
WithWebhookDispatcher(webhookDispatcher)
|
||||
WithWebhookDispatcher(webhookDispatcher).
|
||||
WithCodeAgentRegistry(agentRegistry)
|
||||
|
||||
// Create work service (for worker pool task management)
|
||||
workService := service.NewWorkService(workQueueRepo, service.WorkServiceConfig{
|
||||
Logger: logger,
|
||||
}).WithWebhookDispatcher(webhookDispatcher)
|
||||
|
||||
// Initialize worker pool infrastructure
|
||||
workerRegistryRepo := postgres.NewWorkerRegistryRepository(database.DB)
|
||||
buildAuditRepo := postgres.NewBuildAuditRepository(database.DB)
|
||||
|
||||
// Create worker service (manages worker lifecycle and task assignment)
|
||||
workerService := service.NewWorkerService(workerRegistryRepo, workQueueRepo, logger).
|
||||
WithBuildAudit(buildAuditRepo)
|
||||
|
||||
// Create build service (orchestrates build submission and tracking)
|
||||
buildService := service.NewBuildService(workQueueRepo, buildAuditRepo, logger)
|
||||
|
||||
// Create app
|
||||
app := api.New("rdev-api",
|
||||
api.WithPort(cfg.Port),
|
||||
@ -261,12 +295,13 @@ func main() {
|
||||
webhookHandler := handlers.NewWebhookHandler(webhookRepo, projectRepo)
|
||||
workHandler := handlers.NewWorkHandler(workService)
|
||||
|
||||
// Initialize infrastructure handler (for threesix.ai git/deploy/dns)
|
||||
// Initialize infrastructure handler (for threesix.ai git/deploy/dns/ci)
|
||||
infraHandler := handlers.NewInfrastructureHandler(
|
||||
giteaClient,
|
||||
dnsClient,
|
||||
deployerAdapter,
|
||||
projectRepo,
|
||||
woodpeckerClient,
|
||||
handlers.InfrastructureConfig{
|
||||
DefaultGitOwner: infraCfg.GiteaDefaultOrg,
|
||||
DefaultDomain: infraCfg.DefaultDomain,
|
||||
@ -290,7 +325,7 @@ func main() {
|
||||
)
|
||||
|
||||
// Initialize project management handler
|
||||
projectMgmtHandler := handlers.NewProjectManagementHandler(projectInfraService)
|
||||
projectMgmtHandler := handlers.NewProjectManagementHandler(projectInfraService, logger)
|
||||
|
||||
// Initialize Woodpecker webhook handler (for CI/CD auto-deploy)
|
||||
woodpeckerHandler := handlers.NewWoodpeckerWebhookHandler(
|
||||
@ -308,6 +343,21 @@ func main() {
|
||||
// Initialize credentials handler (superadmin only)
|
||||
credentialsHandler := handlers.NewCredentialsHandler(credentialStore)
|
||||
|
||||
// Initialize agents handler (for code agent management)
|
||||
agentsHandler := handlers.NewAgentsHandler(agentRegistry)
|
||||
|
||||
// Initialize worker pool handlers
|
||||
workersHandler := handlers.NewWorkersHandler(workerService)
|
||||
buildsHandler := handlers.NewBuildsHandler(buildService)
|
||||
createAndBuildHandler := handlers.NewCreateAndBuildHandler(projectInfraService, buildService, logger)
|
||||
|
||||
// Override default health/ready endpoints with full dependency checks
|
||||
healthHandler := handlers.NewHealthHandler("rdev-api", database.DB, nil).
|
||||
WithAgentRegistry(agentRegistry)
|
||||
|
||||
app.Router().Get("/health", healthHandler.Health)
|
||||
app.Router().Get("/ready", healthHandler.Ready)
|
||||
|
||||
// Register routes
|
||||
projectsHandler.Mount(app.Router())
|
||||
keysHandler.Mount(app.Router())
|
||||
@ -320,8 +370,12 @@ func main() {
|
||||
projectMgmtHandler.Mount(app.Router())
|
||||
woodpeckerHandler.Mount(app.Router())
|
||||
credentialsHandler.Mount(app.Router())
|
||||
agentsHandler.Mount(app.Router())
|
||||
workersHandler.Mount(app.Router())
|
||||
buildsHandler.Mount(app.Router())
|
||||
createAndBuildHandler.Mount(app.Router())
|
||||
|
||||
// Start queue processor worker
|
||||
// Start queue processor worker (per-project command queue)
|
||||
queueProcessor := worker.NewQueueProcessor(
|
||||
commandQueue,
|
||||
k8sExecutor,
|
||||
@ -337,11 +391,57 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Start work executor (cross-project worker pool)
|
||||
var gitOps *worker.GitOperations
|
||||
if infraCfg.GiteaToken != "" {
|
||||
gitOps = worker.NewGitOperations(worker.GitOperationsConfig{
|
||||
GiteaToken: infraCfg.GiteaToken,
|
||||
Logger: logger,
|
||||
})
|
||||
}
|
||||
buildExecutor := worker.NewBuildExecutor(agentRegistry, gitOps, logger)
|
||||
workExecutor := worker.NewWorkExecutor(
|
||||
workerService,
|
||||
workService,
|
||||
buildExecutor,
|
||||
&worker.WorkExecutorConfig{
|
||||
PollPeriod: 5 * time.Second,
|
||||
HeartbeatPeriod: 30 * time.Second,
|
||||
Logger: logger,
|
||||
},
|
||||
)
|
||||
if err := workExecutor.Start(); err != nil {
|
||||
logger.Error("failed to start work executor", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
healthHandler.WithWorkExecutor(workExecutor)
|
||||
|
||||
// Start queue maintenance worker (stale task recovery, worker health, cleanup, metrics)
|
||||
queueMaintenance := worker.NewQueueMaintenance(
|
||||
workQueueRepo,
|
||||
workerRegistryRepo,
|
||||
&worker.QueueMaintenanceConfig{
|
||||
StaleTaskTimeout: 30 * time.Minute,
|
||||
StaleWorkerTimeout: 2 * time.Minute,
|
||||
CleanupAge: 7 * 24 * time.Hour,
|
||||
MaintenancePeriod: 1 * time.Minute,
|
||||
MetricsPeriod: 15 * time.Second,
|
||||
Logger: logger,
|
||||
},
|
||||
)
|
||||
queueMaintenance.Start()
|
||||
|
||||
// Enable API documentation
|
||||
app.EnableDocs(buildOpenAPISpec())
|
||||
|
||||
// Cleanup on shutdown
|
||||
app.OnShutdown(func(ctx context.Context) error {
|
||||
// Stop work executor (deregisters worker)
|
||||
workExecutor.Stop()
|
||||
|
||||
// Stop queue maintenance worker
|
||||
queueMaintenance.Stop()
|
||||
|
||||
// Stop queue processor
|
||||
queueProcessor.Stop()
|
||||
|
||||
@ -371,160 +471,5 @@ func main() {
|
||||
app.Run()
|
||||
}
|
||||
|
||||
// Config holds application configuration.
|
||||
type Config struct {
|
||||
Port int
|
||||
DBHost string
|
||||
DBPort int
|
||||
DBUser string
|
||||
DBPassword string
|
||||
DBName string
|
||||
DBSSLMode string
|
||||
AdminKey string
|
||||
|
||||
// Credential store encryption key (required for storing secrets in DB)
|
||||
CredentialEncryptionKey string
|
||||
|
||||
// Infrastructure adapters (threesix.ai) - fallback values if not in credential store
|
||||
GiteaURL string
|
||||
GiteaToken string
|
||||
GiteaDefaultOrg string
|
||||
CloudflareToken string
|
||||
CloudflareZoneID string
|
||||
DefaultDomain string
|
||||
DeployNamespace string
|
||||
DeployTLSIssuer string
|
||||
ClusterIP string
|
||||
RegistryURL string
|
||||
WoodpeckerURL string
|
||||
WoodpeckerAPIToken string
|
||||
WoodpeckerWebhookSecret string
|
||||
}
|
||||
|
||||
// InfraConfig holds infrastructure adapter configuration.
|
||||
// Loaded from credential store with env var fallback.
|
||||
type InfraConfig struct {
|
||||
GiteaURL string
|
||||
GiteaToken string
|
||||
GiteaDefaultOrg string
|
||||
CloudflareToken string
|
||||
CloudflareZoneID string
|
||||
DefaultDomain string
|
||||
DeployNamespace string
|
||||
DeployTLSIssuer string
|
||||
ClusterIP string
|
||||
RegistryURL string
|
||||
WoodpeckerURL string
|
||||
WoodpeckerAPIToken string
|
||||
WoodpeckerWebhookSecret string
|
||||
}
|
||||
|
||||
func loadConfig() Config {
|
||||
port := 8080
|
||||
if v := os.Getenv("PORT"); v != "" {
|
||||
if p, err := strconv.Atoi(v); err == nil {
|
||||
port = p
|
||||
}
|
||||
}
|
||||
|
||||
dbPort := 5432
|
||||
if v := os.Getenv("DB_PORT"); v != "" {
|
||||
if p, err := strconv.Atoi(v); err == nil {
|
||||
dbPort = p
|
||||
}
|
||||
}
|
||||
|
||||
return Config{
|
||||
Port: port,
|
||||
DBHost: getEnv("DB_HOST", "postgres.databases.svc"),
|
||||
DBPort: dbPort,
|
||||
DBUser: getEnv("DB_USER", "appuser"),
|
||||
DBPassword: os.Getenv("DB_PASSWORD"),
|
||||
DBName: getEnv("DB_NAME", "rdev"),
|
||||
DBSSLMode: getEnv("DB_SSL_MODE", "disable"),
|
||||
AdminKey: os.Getenv("RDEV_ADMIN_KEY"),
|
||||
|
||||
// Encryption key for credential store (generate with: openssl rand -base64 32)
|
||||
// REQUIRED in production - no default to prevent insecure deployments
|
||||
CredentialEncryptionKey: os.Getenv("CREDENTIAL_ENCRYPTION_KEY"),
|
||||
|
||||
// Infrastructure adapters (fallback if not in credential store)
|
||||
GiteaURL: getEnv("GITEA_URL", "https://git.threesix.ai"),
|
||||
GiteaToken: os.Getenv("GITEA_TOKEN"),
|
||||
GiteaDefaultOrg: getEnv("GITEA_DEFAULT_ORG", "jordan"),
|
||||
CloudflareToken: os.Getenv("CLOUDFLARE_API_TOKEN"),
|
||||
CloudflareZoneID: os.Getenv("CLOUDFLARE_ZONE_ID"),
|
||||
DefaultDomain: getEnv("DEFAULT_DOMAIN", "threesix.ai"),
|
||||
DeployNamespace: getEnv("DEPLOY_NAMESPACE", "projects"),
|
||||
DeployTLSIssuer: getEnv("DEPLOY_TLS_ISSUER", "letsencrypt-threesix"),
|
||||
ClusterIP: getEnv("CLUSTER_IP", "208.122.204.172"),
|
||||
RegistryURL: getEnv("REGISTRY_URL", "zot.threesix.svc.cluster.local:5000"),
|
||||
WoodpeckerURL: getEnv("WOODPECKER_URL", "https://ci.threesix.ai"),
|
||||
WoodpeckerAPIToken: os.Getenv("WOODPECKER_API_TOKEN"),
|
||||
WoodpeckerWebhookSecret: os.Getenv("WOODPECKER_WEBHOOK_SECRET"),
|
||||
}
|
||||
}
|
||||
|
||||
func getEnv(key, defaultVal string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
|
||||
// loadInfraConfig loads infrastructure configuration from credential store,
|
||||
// falling back to environment variables if not found in the store.
|
||||
func loadInfraConfig(ctx context.Context, store port.CredentialStore, cfg Config, logger *slog.Logger) InfraConfig {
|
||||
// Try to load from credential store
|
||||
creds, err := store.GetMultiple(ctx, []string{
|
||||
domain.CredKeyGiteaToken,
|
||||
domain.CredKeyGiteaURL,
|
||||
domain.CredKeyCloudflareAPIToken,
|
||||
domain.CredKeyCloudflareZoneID,
|
||||
domain.CredKeyWoodpeckerURL,
|
||||
domain.CredKeyWoodpeckerAPIToken,
|
||||
domain.CredKeyWoodpeckerWebhookSecret,
|
||||
domain.CredKeyRegistryURL,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Warn("failed to load credentials from store, using env vars", "error", err)
|
||||
creds = make(map[string]string)
|
||||
}
|
||||
|
||||
// Helper to get from store or fall back to env var
|
||||
getOrFallback := func(key, envFallback string) string {
|
||||
if v, ok := creds[key]; ok && v != "" {
|
||||
return v
|
||||
}
|
||||
return envFallback
|
||||
}
|
||||
|
||||
infraCfg := InfraConfig{
|
||||
GiteaURL: getOrFallback(domain.CredKeyGiteaURL, cfg.GiteaURL),
|
||||
GiteaToken: getOrFallback(domain.CredKeyGiteaToken, cfg.GiteaToken),
|
||||
GiteaDefaultOrg: cfg.GiteaDefaultOrg, // Not a secret, use env
|
||||
CloudflareToken: getOrFallback(domain.CredKeyCloudflareAPIToken, cfg.CloudflareToken),
|
||||
CloudflareZoneID: getOrFallback(domain.CredKeyCloudflareZoneID, cfg.CloudflareZoneID),
|
||||
DefaultDomain: cfg.DefaultDomain, // Not a secret, use env
|
||||
DeployNamespace: cfg.DeployNamespace, // Not a secret, use env
|
||||
DeployTLSIssuer: cfg.DeployTLSIssuer, // Not a secret, use env
|
||||
ClusterIP: cfg.ClusterIP, // Not a secret, use env
|
||||
RegistryURL: getOrFallback(domain.CredKeyRegistryURL, cfg.RegistryURL),
|
||||
WoodpeckerURL: getOrFallback(domain.CredKeyWoodpeckerURL, cfg.WoodpeckerURL),
|
||||
WoodpeckerAPIToken: getOrFallback(domain.CredKeyWoodpeckerAPIToken, cfg.WoodpeckerAPIToken),
|
||||
WoodpeckerWebhookSecret: getOrFallback(domain.CredKeyWoodpeckerWebhookSecret, cfg.WoodpeckerWebhookSecret),
|
||||
}
|
||||
|
||||
// Log which credentials were loaded from store vs env
|
||||
fromStore := 0
|
||||
for k := range creds {
|
||||
if creds[k] != "" {
|
||||
fromStore++
|
||||
}
|
||||
}
|
||||
if fromStore > 0 {
|
||||
logger.Info("loaded credentials from store", "count", fromStore)
|
||||
}
|
||||
|
||||
return infraCfg
|
||||
}
|
||||
// Config, InfraConfig, loadConfig, loadInfraConfig, and getEnv
|
||||
// are defined in config.go.
|
||||
|
||||
@ -55,6 +55,9 @@ Command output is streamed via Server-Sent Events (SSE) at /projects/{id}/events
|
||||
spec.WithTag("Events", "Server-Sent Events for real-time output")
|
||||
spec.WithTag("Claude Config", "Manage commands, skills, and agents in /workspace/.claude/")
|
||||
spec.WithTag("Audit", "Command execution audit logs")
|
||||
spec.WithTag("Code Agents", "Multi-provider code agent management")
|
||||
spec.WithTag("Workers", "Worker pool management")
|
||||
spec.WithTag("Builds", "Build orchestration and history")
|
||||
spec.WithTag("System", "Health and readiness endpoints")
|
||||
|
||||
// Register all path operations
|
||||
@ -65,6 +68,9 @@ Command output is streamed via Server-Sent Events (SSE) at /projects/{id}/events
|
||||
registerEventPaths(spec)
|
||||
registerClaudeConfigPaths(spec)
|
||||
registerAuditPaths(spec)
|
||||
registerAgentPaths(spec)
|
||||
registerWorkerPaths(spec)
|
||||
registerBuildPaths(spec)
|
||||
|
||||
return spec
|
||||
}
|
||||
@ -456,141 +462,3 @@ func registerAuditPaths(spec *api.OpenAPISpec) {
|
||||
[]param{{Name: "command_id", In: "path", Description: "Command ID", Required: true}},
|
||||
))
|
||||
}
|
||||
|
||||
// param represents an OpenAPI parameter.
|
||||
type param struct {
|
||||
Name string
|
||||
In string
|
||||
Description string
|
||||
Required bool
|
||||
}
|
||||
|
||||
// withAuth creates an operation that requires authentication.
|
||||
func withAuth(summary, description, tag, scope, example string) map[string]any {
|
||||
return map[string]any{
|
||||
"summary": summary,
|
||||
"description": description + "\n\n**Required scope**: `" + scope + "`",
|
||||
"tags": []string{tag},
|
||||
"security": []map[string]any{
|
||||
{"ApiKeyAuth": []string{}},
|
||||
},
|
||||
"responses": map[string]any{
|
||||
"200": map[string]any{
|
||||
"description": "Success",
|
||||
"content": map[string]any{
|
||||
"application/json": map[string]any{
|
||||
"example": example,
|
||||
},
|
||||
},
|
||||
},
|
||||
"401": map[string]any{"description": "Unauthorized - Missing or invalid API key"},
|
||||
"403": map[string]any{"description": "Forbidden - Insufficient permissions"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// withAuthAndBody creates an operation with auth and request body.
|
||||
func withAuthAndBody(summary, description, tag, scope, requestExample, responseExample string) map[string]any {
|
||||
return map[string]any{
|
||||
"summary": summary,
|
||||
"description": description + "\n\n**Required scope**: `" + scope + "`",
|
||||
"tags": []string{tag},
|
||||
"security": []map[string]any{
|
||||
{"ApiKeyAuth": []string{}},
|
||||
},
|
||||
"requestBody": map[string]any{
|
||||
"required": true,
|
||||
"content": map[string]any{
|
||||
"application/json": map[string]any{
|
||||
"example": requestExample,
|
||||
},
|
||||
},
|
||||
},
|
||||
"responses": map[string]any{
|
||||
"201": map[string]any{
|
||||
"description": "Created",
|
||||
"content": map[string]any{
|
||||
"application/json": map[string]any{
|
||||
"example": responseExample,
|
||||
},
|
||||
},
|
||||
},
|
||||
"400": map[string]any{"description": "Bad Request - Invalid input"},
|
||||
"401": map[string]any{"description": "Unauthorized - Missing or invalid API key"},
|
||||
"403": map[string]any{"description": "Forbidden - Insufficient permissions"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// withAuthAndParams creates an operation with auth and path parameters.
|
||||
func withAuthAndParams(summary, description, tag, scope string, params []param) map[string]any {
|
||||
parameters := make([]map[string]any, len(params))
|
||||
for i, p := range params {
|
||||
parameters[i] = map[string]any{
|
||||
"name": p.Name,
|
||||
"in": p.In,
|
||||
"description": p.Description,
|
||||
"required": p.Required,
|
||||
"schema": map[string]any{"type": "string"},
|
||||
}
|
||||
}
|
||||
return map[string]any{
|
||||
"summary": summary,
|
||||
"description": description + "\n\n**Required scope**: `" + scope + "`",
|
||||
"tags": []string{tag},
|
||||
"security": []map[string]any{
|
||||
{"ApiKeyAuth": []string{}},
|
||||
},
|
||||
"parameters": parameters,
|
||||
"responses": map[string]any{
|
||||
"200": map[string]any{"description": "Success"},
|
||||
"401": map[string]any{"description": "Unauthorized - Missing or invalid API key"},
|
||||
"403": map[string]any{"description": "Forbidden - Insufficient permissions"},
|
||||
"404": map[string]any{"description": "Not Found"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// withAuthBodyAndParams creates an operation with auth, body, and params.
|
||||
func withAuthBodyAndParams(summary, description, tag, scope string, params []param, requestExample, responseExample string) map[string]any {
|
||||
parameters := make([]map[string]any, len(params))
|
||||
for i, p := range params {
|
||||
parameters[i] = map[string]any{
|
||||
"name": p.Name,
|
||||
"in": p.In,
|
||||
"description": p.Description,
|
||||
"required": p.Required,
|
||||
"schema": map[string]any{"type": "string"},
|
||||
}
|
||||
}
|
||||
return map[string]any{
|
||||
"summary": summary,
|
||||
"description": description + "\n\n**Required scope**: `" + scope + "`",
|
||||
"tags": []string{tag},
|
||||
"security": []map[string]any{
|
||||
{"ApiKeyAuth": []string{}},
|
||||
},
|
||||
"parameters": parameters,
|
||||
"requestBody": map[string]any{
|
||||
"required": true,
|
||||
"content": map[string]any{
|
||||
"application/json": map[string]any{
|
||||
"example": requestExample,
|
||||
},
|
||||
},
|
||||
},
|
||||
"responses": map[string]any{
|
||||
"201": map[string]any{
|
||||
"description": "Created",
|
||||
"content": map[string]any{
|
||||
"application/json": map[string]any{
|
||||
"example": responseExample,
|
||||
},
|
||||
},
|
||||
},
|
||||
"400": map[string]any{"description": "Bad Request - Invalid input"},
|
||||
"401": map[string]any{"description": "Unauthorized - Missing or invalid API key"},
|
||||
"403": map[string]any{"description": "Forbidden - Insufficient permissions"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
342
cmd/rdev-api/openapi_ext.go
Normal file
342
cmd/rdev-api/openapi_ext.go
Normal file
@ -0,0 +1,342 @@
|
||||
package main
|
||||
|
||||
import "github.com/orchard9/rdev/pkg/api"
|
||||
|
||||
func registerAgentPaths(spec *api.OpenAPISpec) {
|
||||
spec.AddPath("/agents", "get", withAuth(
|
||||
"List code agents",
|
||||
`Returns all registered code agent providers and their status.
|
||||
|
||||
Shows which agents are available, their supported models, and the current default.`,
|
||||
"Code Agents",
|
||||
"projects:read",
|
||||
`{
|
||||
"agents": [
|
||||
{
|
||||
"provider": "claudecode",
|
||||
"name": "Claude Code",
|
||||
"available": true,
|
||||
"default": true,
|
||||
"supported_models": ["claude-sonnet-4-20250514"],
|
||||
"default_model": "claude-sonnet-4-20250514"
|
||||
},
|
||||
{
|
||||
"provider": "opencode",
|
||||
"name": "OpenCode",
|
||||
"available": false,
|
||||
"default": false,
|
||||
"supported_models": ["gpt-4o", "claude-sonnet-4-20250514"],
|
||||
"default_model": "claude-sonnet-4-20250514"
|
||||
}
|
||||
],
|
||||
"default_agent": "claudecode",
|
||||
"total_agents": 2,
|
||||
"available_count": 1
|
||||
}`,
|
||||
))
|
||||
|
||||
spec.AddPath("/agents/health", "get", withAuth(
|
||||
"Get agent health status",
|
||||
`Returns the health status of all registered code agents.
|
||||
|
||||
Checks connectivity to each agent backend and reports availability.`,
|
||||
"Code Agents",
|
||||
"projects:read",
|
||||
`{
|
||||
"agents": [
|
||||
{
|
||||
"provider": "claudecode",
|
||||
"name": "Claude Code",
|
||||
"healthy": true,
|
||||
"message": "available",
|
||||
"latency": "125ms",
|
||||
"checked_at": "2026-01-27T12:00:00Z"
|
||||
},
|
||||
{
|
||||
"provider": "opencode",
|
||||
"name": "OpenCode",
|
||||
"healthy": false,
|
||||
"message": "unavailable",
|
||||
"latency": "5.002s",
|
||||
"checked_at": "2026-01-27T12:00:00Z"
|
||||
}
|
||||
],
|
||||
"healthy_count": 1,
|
||||
"total_count": 2,
|
||||
"default_agent": "claudecode",
|
||||
"default_healthy": true
|
||||
}`,
|
||||
))
|
||||
|
||||
spec.AddPath("/agents/{provider}", "get", withAuthAndParams(
|
||||
"Get agent capabilities",
|
||||
`Returns detailed capabilities for a specific code agent provider.
|
||||
|
||||
Includes supported features, models, and configuration options.`,
|
||||
"Code Agents",
|
||||
"projects:read",
|
||||
[]param{{Name: "provider", In: "path", Description: "Agent provider ID (e.g., 'claudecode', 'opencode')", Required: true}},
|
||||
))
|
||||
|
||||
spec.AddPath("/agents/default", "post", withAuthAndBody(
|
||||
"Set default agent",
|
||||
`Changes the default code agent used for command execution.
|
||||
|
||||
The specified provider must be registered and ideally available.`,
|
||||
"Code Agents",
|
||||
"admin",
|
||||
`{"provider": "opencode"}`,
|
||||
`{
|
||||
"default_agent": "opencode",
|
||||
"message": "default agent updated"
|
||||
}`,
|
||||
))
|
||||
}
|
||||
|
||||
// param represents an OpenAPI parameter.
|
||||
type param struct {
|
||||
Name string
|
||||
In string
|
||||
Description string
|
||||
Required bool
|
||||
}
|
||||
|
||||
// withAuth creates an operation that requires authentication.
|
||||
func withAuth(summary, description, tag, scope, example string) map[string]any {
|
||||
return map[string]any{
|
||||
"summary": summary,
|
||||
"description": description + "\n\n**Required scope**: `" + scope + "`",
|
||||
"tags": []string{tag},
|
||||
"security": []map[string]any{
|
||||
{"ApiKeyAuth": []string{}},
|
||||
},
|
||||
"responses": map[string]any{
|
||||
"200": map[string]any{
|
||||
"description": "Success",
|
||||
"content": map[string]any{
|
||||
"application/json": map[string]any{
|
||||
"example": example,
|
||||
},
|
||||
},
|
||||
},
|
||||
"401": map[string]any{"description": "Unauthorized - Missing or invalid API key"},
|
||||
"403": map[string]any{"description": "Forbidden - Insufficient permissions"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// withAuthAndBody creates an operation with auth and request body.
|
||||
func withAuthAndBody(summary, description, tag, scope, requestExample, responseExample string) map[string]any {
|
||||
return map[string]any{
|
||||
"summary": summary,
|
||||
"description": description + "\n\n**Required scope**: `" + scope + "`",
|
||||
"tags": []string{tag},
|
||||
"security": []map[string]any{
|
||||
{"ApiKeyAuth": []string{}},
|
||||
},
|
||||
"requestBody": map[string]any{
|
||||
"required": true,
|
||||
"content": map[string]any{
|
||||
"application/json": map[string]any{
|
||||
"example": requestExample,
|
||||
},
|
||||
},
|
||||
},
|
||||
"responses": map[string]any{
|
||||
"201": map[string]any{
|
||||
"description": "Created",
|
||||
"content": map[string]any{
|
||||
"application/json": map[string]any{
|
||||
"example": responseExample,
|
||||
},
|
||||
},
|
||||
},
|
||||
"400": map[string]any{"description": "Bad Request - Invalid input"},
|
||||
"401": map[string]any{"description": "Unauthorized - Missing or invalid API key"},
|
||||
"403": map[string]any{"description": "Forbidden - Insufficient permissions"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// withAuthAndParams creates an operation with auth and path parameters.
|
||||
func withAuthAndParams(summary, description, tag, scope string, params []param) map[string]any {
|
||||
parameters := make([]map[string]any, len(params))
|
||||
for i, p := range params {
|
||||
parameters[i] = map[string]any{
|
||||
"name": p.Name,
|
||||
"in": p.In,
|
||||
"description": p.Description,
|
||||
"required": p.Required,
|
||||
"schema": map[string]any{"type": "string"},
|
||||
}
|
||||
}
|
||||
return map[string]any{
|
||||
"summary": summary,
|
||||
"description": description + "\n\n**Required scope**: `" + scope + "`",
|
||||
"tags": []string{tag},
|
||||
"security": []map[string]any{
|
||||
{"ApiKeyAuth": []string{}},
|
||||
},
|
||||
"parameters": parameters,
|
||||
"responses": map[string]any{
|
||||
"200": map[string]any{"description": "Success"},
|
||||
"401": map[string]any{"description": "Unauthorized - Missing or invalid API key"},
|
||||
"403": map[string]any{"description": "Forbidden - Insufficient permissions"},
|
||||
"404": map[string]any{"description": "Not Found"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// withAuthBodyAndParams creates an operation with auth, body, and params.
|
||||
func withAuthBodyAndParams(summary, description, tag, scope string, params []param, requestExample, responseExample string) map[string]any {
|
||||
parameters := make([]map[string]any, len(params))
|
||||
for i, p := range params {
|
||||
parameters[i] = map[string]any{
|
||||
"name": p.Name,
|
||||
"in": p.In,
|
||||
"description": p.Description,
|
||||
"required": p.Required,
|
||||
"schema": map[string]any{"type": "string"},
|
||||
}
|
||||
}
|
||||
return map[string]any{
|
||||
"summary": summary,
|
||||
"description": description + "\n\n**Required scope**: `" + scope + "`",
|
||||
"tags": []string{tag},
|
||||
"security": []map[string]any{
|
||||
{"ApiKeyAuth": []string{}},
|
||||
},
|
||||
"parameters": parameters,
|
||||
"requestBody": map[string]any{
|
||||
"required": true,
|
||||
"content": map[string]any{
|
||||
"application/json": map[string]any{
|
||||
"example": requestExample,
|
||||
},
|
||||
},
|
||||
},
|
||||
"responses": map[string]any{
|
||||
"201": map[string]any{
|
||||
"description": "Created",
|
||||
"content": map[string]any{
|
||||
"application/json": map[string]any{
|
||||
"example": responseExample,
|
||||
},
|
||||
},
|
||||
},
|
||||
"400": map[string]any{"description": "Bad Request - Invalid input"},
|
||||
"401": map[string]any{"description": "Unauthorized - Missing or invalid API key"},
|
||||
"403": map[string]any{"description": "Forbidden - Insufficient permissions"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func registerWorkerPaths(spec *api.OpenAPISpec) {
|
||||
spec.AddPath("/workers", "get", withAuth(
|
||||
"List workers",
|
||||
"Returns all registered workers in the pool with status summary.",
|
||||
"Workers",
|
||||
"admin",
|
||||
`{
|
||||
"workers": [
|
||||
{
|
||||
"id": "rdev-worker-0",
|
||||
"hostname": "rdev-worker-0.rdev.svc",
|
||||
"status": "idle",
|
||||
"capabilities": ["build", "test", "deploy"],
|
||||
"registered_at": "2026-01-27T12:00:00Z",
|
||||
"last_heartbeat": "2026-01-27T12:05:00Z",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
],
|
||||
"total": 1,
|
||||
"summary": {"idle": 1, "busy": 0, "draining": 0, "offline": 0}
|
||||
}`,
|
||||
))
|
||||
|
||||
spec.AddPath("/workers/{workerId}", "get", withAuthAndParams(
|
||||
"Get worker",
|
||||
"Returns details for a specific worker.",
|
||||
"Workers",
|
||||
"admin",
|
||||
[]param{{Name: "workerId", In: "path", Description: "Worker ID", Required: true}},
|
||||
))
|
||||
|
||||
spec.AddPath("/workers/{workerId}/drain", "post", withAuthAndParams(
|
||||
"Drain worker",
|
||||
"Sets a worker to draining status. It will finish its current task but stop accepting new work.",
|
||||
"Workers",
|
||||
"admin",
|
||||
[]param{{Name: "workerId", In: "path", Description: "Worker ID", Required: true}},
|
||||
))
|
||||
}
|
||||
|
||||
func registerBuildPaths(spec *api.OpenAPISpec) {
|
||||
spec.AddPath("/projects/{id}/builds", "post", withAuthBodyAndParams(
|
||||
"Start build",
|
||||
"Enqueues a build task for a project. The build will be picked up by an available worker.",
|
||||
"Builds",
|
||||
"projects:execute",
|
||||
[]param{{Name: "id", In: "path", Description: "Project ID", Required: true}},
|
||||
`{
|
||||
"prompt": "Build a landing page with Next.js and Tailwind CSS",
|
||||
"template": "nextjs-landing",
|
||||
"auto_commit": true,
|
||||
"auto_push": true,
|
||||
"callback_url": "https://example.com/webhook"
|
||||
}`,
|
||||
`{
|
||||
"task_id": "task-abc123",
|
||||
"project_id": "my-project",
|
||||
"status": "pending",
|
||||
"status_url": "/builds/task-abc123"
|
||||
}`,
|
||||
))
|
||||
|
||||
spec.AddPath("/projects/{id}/builds", "get", withAuthAndParams(
|
||||
"List builds",
|
||||
"Returns build history for a project.",
|
||||
"Builds",
|
||||
"projects:read",
|
||||
[]param{{Name: "id", In: "path", Description: "Project ID", Required: true}},
|
||||
))
|
||||
|
||||
spec.AddPath("/builds/{taskId}", "get", withAuthAndParams(
|
||||
"Get build status",
|
||||
"Returns the status and result of a specific build.",
|
||||
"Builds",
|
||||
"projects:read",
|
||||
[]param{{Name: "taskId", In: "path", Description: "Build task ID", Required: true}},
|
||||
))
|
||||
|
||||
spec.AddPath("/project/create-and-build", "post", withAuthAndBody(
|
||||
"Create project and build",
|
||||
`Creates a new project and immediately enqueues a build task.
|
||||
|
||||
Combines project creation (git repo, DNS, CI activation) with build submission in a single call.`,
|
||||
"Builds",
|
||||
"admin",
|
||||
`{
|
||||
"name": "my-landing-page",
|
||||
"description": "Landing page for product launch",
|
||||
"template": "nextjs-landing",
|
||||
"prompt": "Build a modern landing page with hero, features, and CTA sections",
|
||||
"auto_commit": true,
|
||||
"auto_push": true
|
||||
}`,
|
||||
`{
|
||||
"project_id": "my-landing-page",
|
||||
"name": "my-landing-page",
|
||||
"domain": "my-landing-page.threesix.ai",
|
||||
"url": "https://my-landing-page.threesix.ai",
|
||||
"git": {
|
||||
"owner": "jordan",
|
||||
"name": "my-landing-page",
|
||||
"clone_http": "https://git.threesix.ai/jordan/my-landing-page.git"
|
||||
},
|
||||
"task_id": "task-abc123",
|
||||
"status": "pending",
|
||||
"status_url": "/builds/task-abc123"
|
||||
}`,
|
||||
))
|
||||
}
|
||||
@ -7,27 +7,38 @@
|
||||
This cookbook creates and deploys a simple landing page using the full threesix.ai autonomous infrastructure:
|
||||
|
||||
```
|
||||
rdev-api → Gitea repo → Claude agent → push → Woodpecker CI → K8s deployment
|
||||
POST /project → Gitea repo + DNS + Woodpecker CI + template seed → Claude agent → git push → CI build → K8s deployment
|
||||
```
|
||||
|
||||
**Target:** `landing.threesix.ai` (with future DNS aliases for www/root)
|
||||
**Stack:** Astro (static site generator)
|
||||
**Stack:** Astro (static site generator) via `astro-landing` template
|
||||
**Status:** Coming Soon page
|
||||
|
||||
---
|
||||
|
||||
## Current Architecture Gap
|
||||
## What's Automated Today
|
||||
|
||||
**Two separate systems that need bridging:**
|
||||
`POST /project` orchestrates the full infrastructure setup in a single call:
|
||||
|
||||
| System | Endpoint | What it manages |
|
||||
|--------|----------|-----------------|
|
||||
| Project Management | `POST /project` | Gitea repos, DNS records, K8s deployments |
|
||||
| Claudebox Execution | `POST /projects/{id}/claude` | Code generation in existing claudebox pods |
|
||||
| Step | Status | How |
|
||||
|------|--------|-----|
|
||||
| Gitea repo creation | Automated | `port.GitRepository` adapter |
|
||||
| DNS A record | Automated | `port.DNSProvider` (Cloudflare) adapter |
|
||||
| Woodpecker CI activation | Automated | `port.CIProvider` adapter, called during project creation |
|
||||
| Template seeding | Automated | `port.TemplateProvider` with `astro-landing` template |
|
||||
| K8s deployment | Automated | `port.Deployer` adapter (triggered by CI webhook) |
|
||||
|
||||
**The problem:** Creating a project via `POST /project` creates a Gitea repo, but there's no claudebox to generate code for it. The claudebox system only knows about pre-existing pods (pantheon, aeries).
|
||||
## Full Pipeline Status
|
||||
|
||||
**The solution:** Use an existing claudebox as a "worker" to clone, build, and push to any project repo.
|
||||
All infrastructure gaps have been closed. The full pipeline from project creation through code generation, CI monitoring, and multi-domain DNS is operational:
|
||||
|
||||
| Capability | Endpoint | Status |
|
||||
|------------|----------|--------|
|
||||
| Project creation | `POST /project` | Operational |
|
||||
| Code generation (worker) | `POST /projects/{id}/builds` | Operational |
|
||||
| Create + build combo | `POST /project/create-and-build` | Operational |
|
||||
| CI pipeline monitoring | `GET /projects/{id}/pipelines` | Operational |
|
||||
| DNS alias management | `POST /projects/{id}/domains` | Operational |
|
||||
|
||||
---
|
||||
|
||||
@ -39,12 +50,12 @@ rdev-api → Gitea repo → Claude agent → push → Woodpecker CI → K8s depl
|
||||
|--------|----------|---------|
|
||||
| RDEV_ADMIN_KEY | `rdev-credentials` secret | rdev-api authentication |
|
||||
| GITEA_TOKEN | `rdev-credentials` secret | Gitea API access |
|
||||
| WOODPECKER_API_TOKEN | `.secrets` file | Woodpecker repo activation |
|
||||
| WOODPECKER_API_TOKEN | `rdev-credentials` secret | Woodpecker repo activation |
|
||||
| CLOUDFLARE_API_TOKEN | `rdev-credentials` secret | DNS management |
|
||||
|
||||
### Infrastructure Required
|
||||
|
||||
- [x] rdev-api running with infrastructure handlers (v0.7.1+)
|
||||
- [x] rdev-api running with infrastructure handlers
|
||||
- [x] Gitea at https://git.threesix.ai
|
||||
- [x] Woodpecker CI at https://ci.threesix.ai
|
||||
- [x] Zot registry at zot.threesix.svc.cluster.local:5000
|
||||
@ -56,38 +67,35 @@ rdev-api → Gitea repo → Claude agent → push → Woodpecker CI → K8s depl
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────────────────┐
|
||||
│ Landing Page Flow │
|
||||
│ │
|
||||
│ 1. Create Project │
|
||||
│ POST /project {"name": "www"} │
|
||||
│ │ │
|
||||
│ ├──▶ Creates Gitea repo: jordan/www │
|
||||
│ └──▶ Creates DNS: www.threesix.ai → 208.122.204.172 │
|
||||
│ │
|
||||
│ 2. Activate Woodpecker │
|
||||
│ POST /api/repos?forge_remote_id={id} │
|
||||
│ │ │
|
||||
│ └──▶ Creates webhook in Gitea │
|
||||
│ │
|
||||
│ 3. Generate Code (Claude Agent) │
|
||||
│ claudebox or local Claude Code │
|
||||
│ │ │
|
||||
│ ├──▶ Creates Astro project │
|
||||
│ ├──▶ Creates Dockerfile │
|
||||
│ ├──▶ Creates .woodpecker.yml │
|
||||
│ └──▶ Pushes to Gitea │
|
||||
│ │
|
||||
│ 4. CI/CD Pipeline (automatic) │
|
||||
│ Woodpecker triggered by push │
|
||||
│ │ │
|
||||
│ ├──▶ Kaniko builds Docker image │
|
||||
│ ├──▶ Pushes to Zot registry │
|
||||
│ └──▶ Webhook triggers rdev-api deploy │
|
||||
│ │
|
||||
│ 5. Live at https://www.threesix.ai │
|
||||
│ │
|
||||
└─────────────────────────────────────────────────────────────────────────────┘
|
||||
┌──────────────────────────────────────────────────────────────────────┐
|
||||
│ Landing Page Flow │
|
||||
│ │
|
||||
│ 1. Create Project (single API call) │
|
||||
│ POST /project {"name": "landing", "template": "astro-landing"} │
|
||||
│ │ │
|
||||
│ ├──▶ Creates Gitea repo: threesix/landing │
|
||||
│ ├──▶ Creates DNS: landing.threesix.ai → cluster IP │
|
||||
│ ├──▶ Activates Woodpecker CI (auto) │
|
||||
│ └──▶ Seeds repo with astro-landing template │
|
||||
│ │
|
||||
│ 2. Generate Code (3 options) │
|
||||
│ Via worker pool, claudebox, or local Claude Code │
|
||||
│ │ │
|
||||
│ ├──▶ Customizes Astro landing page │
|
||||
│ ├──▶ Commits and pushes to Gitea │
|
||||
│ └──▶ Worker executor polls queue, dispatches to agent │
|
||||
│ │
|
||||
│ 3. CI/CD Pipeline (automatic on push) │
|
||||
│ Woodpecker triggered by git push │
|
||||
│ │ │
|
||||
│ ├──▶ npm install + npm build │
|
||||
│ ├──▶ Docker build (nginx) │
|
||||
│ ├──▶ Push to Zot registry │
|
||||
│ └──▶ kubectl set image (deploy) │
|
||||
│ │
|
||||
│ 4. Live at https://landing.threesix.ai │
|
||||
│ │
|
||||
└──────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
@ -96,213 +104,193 @@ rdev-api → Gitea repo → Claude agent → push → Woodpecker CI → K8s depl
|
||||
|
||||
### Step 1: Create Project via rdev-api
|
||||
|
||||
```bash
|
||||
RDEV_KEY="rdev_sk_prod_7f3a9c2e1d8b4a6f0e5c9d2b7a1f8e4c"
|
||||
This single call creates the Gitea repo, DNS record, activates Woodpecker CI, and seeds the repo with the `astro-landing` template.
|
||||
|
||||
```bash
|
||||
curl -X POST https://rdev.masq-ops.orchard9.ai/project \
|
||||
-H "Authorization: Bearer $RDEV_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"name": "landing", "description": "threesix.ai landing page"}'
|
||||
-d '{
|
||||
"name": "landing",
|
||||
"description": "threesix.ai landing page",
|
||||
"template": "astro-landing"
|
||||
}'
|
||||
```
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"data": {
|
||||
"project_id": "landing",
|
||||
"name": "landing",
|
||||
"description": "threesix.ai landing page",
|
||||
"git": {
|
||||
"owner": "threesix",
|
||||
"name": "landing",
|
||||
"domain": "landing.threesix.ai",
|
||||
"git": {
|
||||
"clone_ssh": "git@git.threesix.ai:jordan/landing.git",
|
||||
"clone_http": "https://git.threesix.ai/jordan/landing.git"
|
||||
}
|
||||
}
|
||||
"clone_ssh": "git@git.threesix.ai:threesix/landing.git",
|
||||
"clone_http": "https://git.threesix.ai/threesix/landing.git",
|
||||
"html_url": "https://git.threesix.ai/threesix/landing"
|
||||
},
|
||||
"domain": "landing.threesix.ai",
|
||||
"url": "https://landing.threesix.ai",
|
||||
"next_steps": []
|
||||
}
|
||||
```
|
||||
|
||||
### Step 2: Activate Woodpecker CI
|
||||
If any infrastructure step fails, `next_steps` will contain manual instructions for that step. The remaining steps still execute.
|
||||
|
||||
### Step 2: Generate Code
|
||||
|
||||
The `astro-landing` template seeds the repo with a working Astro project, Dockerfile, nginx.conf, and `.woodpecker.yml`. You can customize it via Claude.
|
||||
|
||||
**Option A: Local Claude Code (recommended for now)**
|
||||
```bash
|
||||
GITEA_TOKEN="5508ff241943e84aad0ced3559f5fbd311a2fb81"
|
||||
WOODPECKER_TOKEN="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ0eXBlIjoidXNlciIsInVzZXItaWQiOiIxIn0.LcyVHcZ_gSvVH1w3y6TUCp_Jg9ubfsebOAVo-MtiNP8"
|
||||
|
||||
# Get Gitea repo ID
|
||||
REPO_ID=$(curl -s https://git.threesix.ai/api/v1/repos/jordan/landing \
|
||||
-H "Authorization: token $GITEA_TOKEN" | jq '.id')
|
||||
|
||||
# Activate in Woodpecker (creates webhook automatically)
|
||||
curl -X POST "https://ci.threesix.ai/api/repos?forge_remote_id=$REPO_ID" \
|
||||
-H "Authorization: Bearer $WOODPECKER_TOKEN"
|
||||
git clone https://git.threesix.ai/threesix/landing.git
|
||||
cd landing
|
||||
# Use Claude Code to customize the landing page
|
||||
# Then commit and push
|
||||
```
|
||||
|
||||
### Step 3: Generate Code via Claudebox
|
||||
|
||||
Use the `pantheon` claudebox as a worker to generate code for the landing project:
|
||||
|
||||
**Option B: Via existing claudebox**
|
||||
```bash
|
||||
RDEV_KEY="rdev_sk_prod_7f3a9c2e1d8b4a6f0e5c9d2b7a1f8e4c"
|
||||
|
||||
# Tell Claude to build the landing page in /tmp/landing
|
||||
curl -X POST "https://rdev.masq-ops.orchard9.ai/projects/pantheon/claude" \
|
||||
-H "Authorization: Bearer $RDEV_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"prompt": "Clone https://git.threesix.ai/jordan/landing.git to /tmp/landing, then create a simple Astro landing page with: Coming Soon message, threesix.ai branding (dark theme), responsive layout, Dockerfile (nginx), and .woodpecker.yml for CI/CD. Commit and push when done."
|
||||
"prompt": "Clone https://git.threesix.ai/threesix/landing.git to /tmp/landing, then customize the Astro landing page with: Coming Soon message, threesix.ai branding (dark theme, gradient background), responsive layout. Commit and push when done."
|
||||
}'
|
||||
```
|
||||
|
||||
**What happens:**
|
||||
1. Claude receives the prompt in the claudebox
|
||||
2. Claude clones the repo to `/tmp/landing`
|
||||
3. Claude generates the Astro project files
|
||||
4. Claude commits and pushes to Gitea
|
||||
|
||||
### Step 4: Monitor Build
|
||||
|
||||
Watch Woodpecker for the build:
|
||||
- https://ci.threesix.ai/jordan/landing
|
||||
|
||||
Or via API:
|
||||
**Option C: Via work queue (build endpoint)**
|
||||
```bash
|
||||
curl -s "https://ci.threesix.ai/api/repos/jordan/landing/pipelines" \
|
||||
-H "Authorization: Bearer $WOODPECKER_TOKEN" | jq '.[0] | {number, status, started}'
|
||||
```
|
||||
|
||||
### Step 5: Verify Deployment
|
||||
|
||||
```bash
|
||||
curl -s "https://rdev.masq-ops.orchard9.ai/project/landing" \
|
||||
-H "Authorization: Bearer $RDEV_KEY" | jq '.data.deployment'
|
||||
```
|
||||
|
||||
Site live at: https://landing.threesix.ai
|
||||
|
||||
### Step 6: Configure DNS Aliases (Optional)
|
||||
|
||||
Point `www.threesix.ai` and `threesix.ai` to the landing page:
|
||||
|
||||
```bash
|
||||
CF_TOKEN="nGoDhG6Za66XsKMl6W7LNXuowc5EM00glHxkq1KK"
|
||||
CF_ZONE="e0bc8d510f62807b360db0c5994964c5"
|
||||
|
||||
# Update root A record to point to k3s cluster
|
||||
curl -X PATCH "https://api.cloudflare.com/client/v4/zones/$CF_ZONE/dns_records/{record_id}" \
|
||||
-H "Authorization: Bearer $CF_TOKEN" \
|
||||
# Enqueue a build task for a worker to execute
|
||||
curl -X POST "https://rdev.masq-ops.orchard9.ai/projects/landing/builds" \
|
||||
-H "Authorization: Bearer $RDEV_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"type": "A",
|
||||
"name": "threesix.ai",
|
||||
"content": "208.122.204.172",
|
||||
"proxied": false
|
||||
"prompt": "Customize the landing page with Coming Soon message and threesix.ai branding",
|
||||
"auto_commit": true,
|
||||
"auto_push": true
|
||||
}'
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## File Templates
|
||||
|
||||
### .woodpecker.yml
|
||||
|
||||
```yaml
|
||||
steps:
|
||||
build:
|
||||
image: gcr.io/kaniko-project/executor:latest
|
||||
settings:
|
||||
registry: zot.threesix.svc.cluster.local:5000
|
||||
tags:
|
||||
- ${CI_COMMIT_SHA:0:8}
|
||||
- latest
|
||||
repo: zot.threesix.svc.cluster.local:5000/${CI_REPO_NAME}
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
insecure: true
|
||||
when:
|
||||
branch: main
|
||||
|
||||
notify:
|
||||
image: alpine/curl:latest
|
||||
commands:
|
||||
- echo "Build complete, webhook will trigger deployment"
|
||||
when:
|
||||
branch: main
|
||||
status: success
|
||||
**Option D: Create + build in one call**
|
||||
```bash
|
||||
curl -X POST "https://rdev.masq-ops.orchard9.ai/project/create-and-build" \
|
||||
-H "Authorization: Bearer $RDEV_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"name": "landing",
|
||||
"description": "threesix.ai landing page",
|
||||
"template": "astro-landing",
|
||||
"build": {
|
||||
"prompt": "Customize with Coming Soon message and threesix.ai branding",
|
||||
"auto_commit": true,
|
||||
"auto_push": true
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
### Dockerfile (Astro + Nginx)
|
||||
### Step 3: Monitor Build
|
||||
|
||||
```dockerfile
|
||||
# Build stage
|
||||
FROM node:20-alpine AS builder
|
||||
WORKDIR /app
|
||||
COPY package*.json ./
|
||||
RUN npm ci
|
||||
COPY . .
|
||||
RUN npm run build
|
||||
The git push triggers Woodpecker CI automatically.
|
||||
|
||||
# Production stage
|
||||
FROM nginx:alpine
|
||||
COPY --from=builder /app/dist /usr/share/nginx/html
|
||||
COPY nginx.conf /etc/nginx/conf.d/default.conf
|
||||
EXPOSE 80
|
||||
CMD ["nginx", "-g", "daemon off;"]
|
||||
**Via rdev-api (recommended):**
|
||||
```bash
|
||||
# List recent pipelines
|
||||
curl -s "https://rdev.masq-ops.orchard9.ai/projects/landing/pipelines" \
|
||||
-H "Authorization: Bearer $RDEV_KEY" | jq '.data.pipelines'
|
||||
|
||||
# Get specific pipeline
|
||||
curl -s "https://rdev.masq-ops.orchard9.ai/projects/landing/pipelines/1" \
|
||||
-H "Authorization: Bearer $RDEV_KEY" | jq '.data'
|
||||
```
|
||||
|
||||
### nginx.conf
|
||||
**Via Woodpecker UI:**
|
||||
- https://ci.threesix.ai/threesix/landing
|
||||
|
||||
```nginx
|
||||
server {
|
||||
listen 80;
|
||||
server_name _;
|
||||
root /usr/share/nginx/html;
|
||||
index index.html;
|
||||
### Step 4: Verify Deployment
|
||||
|
||||
location / {
|
||||
try_files $uri $uri/ /index.html;
|
||||
}
|
||||
```bash
|
||||
# Check project status via rdev-api
|
||||
curl -s "https://rdev.masq-ops.orchard9.ai/project/landing" \
|
||||
-H "Authorization: Bearer $RDEV_KEY" | jq '.data.deployment'
|
||||
|
||||
# Cache static assets
|
||||
location ~* \.(js|css|png|jpg|jpeg|gif|ico|svg|woff|woff2)$ {
|
||||
expires 1y;
|
||||
add_header Cache-Control "public, immutable";
|
||||
}
|
||||
}
|
||||
# Check deployment via K8s
|
||||
export KUBECONFIG=~/.kube/orchard9-k3sf.yaml
|
||||
kubectl get deploy -n projects landing
|
||||
|
||||
# Check the site
|
||||
curl -I https://landing.threesix.ai
|
||||
```
|
||||
|
||||
### Step 5: Configure DNS Aliases (Optional)
|
||||
|
||||
Point `www.threesix.ai` and `threesix.ai` to the landing page via the rdev-api domain alias endpoints.
|
||||
|
||||
```bash
|
||||
# Add www.threesix.ai as a CNAME alias
|
||||
curl -X POST "https://rdev.masq-ops.orchard9.ai/projects/landing/domains" \
|
||||
-H "Authorization: Bearer $RDEV_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"domain": "www.threesix.ai", "type": "CNAME"}'
|
||||
|
||||
# Add root A record
|
||||
curl -X POST "https://rdev.masq-ops.orchard9.ai/projects/landing/domains" \
|
||||
-H "Authorization: Bearer $RDEV_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"domain": "threesix.ai"}'
|
||||
|
||||
# List all domains for the project
|
||||
curl -s "https://rdev.masq-ops.orchard9.ai/projects/landing/domains" \
|
||||
-H "Authorization: Bearer $RDEV_KEY" | jq '.data.domains'
|
||||
|
||||
# Remove an alias
|
||||
curl -X DELETE "https://rdev.masq-ops.orchard9.ai/projects/landing/domains/www.threesix.ai" \
|
||||
-H "Authorization: Bearer $RDEV_KEY"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Current Gaps & Future Automation
|
||||
## Template: astro-landing
|
||||
|
||||
### What's Manual Today
|
||||
The `astro-landing` template (`deployments/k8s/base/templates/astro-landing/`) seeds the repo with:
|
||||
|
||||
| Step | Status | Automation Path |
|
||||
|------|--------|-----------------|
|
||||
| Create project | ✅ API | Already automated |
|
||||
| Activate Woodpecker | 🔧 API call needed | Add to rdev-api |
|
||||
| Generate code | ❌ Manual Claude | Claudebox integration |
|
||||
| Push to Gitea | ❌ Manual git | Claudebox with SSH key |
|
||||
| Deploy | ✅ Webhook | Already automated |
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `.woodpecker.yml` | CI pipeline: npm build, Docker build, push, deploy |
|
||||
| `.claude/CLAUDE.md` | Project instructions for Claude Code |
|
||||
| `Dockerfile` | Multi-stage build (Node 20 build, nginx serve) |
|
||||
| `nginx.conf` | Production config with gzip, caching, SPA fallback |
|
||||
| `package.json` | Astro 4.0+ with Tailwind CSS |
|
||||
| `astro.config.mjs` | Astro configuration |
|
||||
| `tailwind.config.mjs` | Tailwind configuration |
|
||||
| `src/pages/index.astro` | Landing page (dark theme with gradient) |
|
||||
| `src/layouts/Layout.astro` | Base HTML layout |
|
||||
| `README.md` | Development and deployment docs |
|
||||
|
||||
### To Fully Automate (Future Work)
|
||||
Variables substituted during seeding: `{{PROJECT_NAME}}`, `{{DOMAIN}}`, `{{GIT_URL}}`
|
||||
|
||||
1. **Add Woodpecker activation to rdev-api**
|
||||
- Store WOODPECKER_API_TOKEN in secrets
|
||||
- Call Woodpecker API after creating Gitea repo
|
||||
- Create webhook automatically
|
||||
---
|
||||
|
||||
2. **Claudebox code generation**
|
||||
- Spawn claudebox with project context
|
||||
- Claudebox has Gitea SSH key
|
||||
- Claude Code generates code based on prompt
|
||||
- Auto-push to Gitea
|
||||
## Implementation Status
|
||||
|
||||
3. **Single API call**
|
||||
```
|
||||
POST /project/create-and-build
|
||||
{
|
||||
"name": "www",
|
||||
"prompt": "Create an Astro landing page with coming soon message",
|
||||
"stack": "astro"
|
||||
}
|
||||
```
|
||||
All components for the full landing page pipeline are implemented:
|
||||
|
||||
| Component | Location | Status |
|
||||
|-----------|----------|--------|
|
||||
| Work queue (enqueue/dequeue) | `internal/adapter/postgres/work_queue.go` | Implemented |
|
||||
| Worker registry | `internal/adapter/postgres/worker_registry.go` | Implemented |
|
||||
| Build audit tracking | `internal/adapter/postgres/build_audit.go` | Implemented |
|
||||
| Build service | `internal/service/build_service.go` | Implemented |
|
||||
| Worker service | `internal/service/worker_service.go` | Implemented |
|
||||
| Work handlers (REST) | `internal/handlers/work.go` | Implemented |
|
||||
| Code agent interface | `internal/port/code_agent.go` | Implemented |
|
||||
| Worker executor daemon | `internal/worker/work_executor.go` | Implemented |
|
||||
| BuildSpec-to-agent bridge | `internal/worker/build_executor.go` | Implemented |
|
||||
| Git credential resolution | `internal/service/credential_service.go` | Implemented |
|
||||
| DNS alias endpoints | `internal/handlers/infrastructure_domains.go` | Implemented |
|
||||
| CI pipeline proxy | `internal/handlers/infrastructure_pipelines.go` | Implemented |
|
||||
| Create-and-build endpoint | `internal/handlers/create_and_build.go` | Implemented |
|
||||
|
||||
---
|
||||
|
||||
@ -312,13 +300,13 @@ After deployment, verify:
|
||||
|
||||
```bash
|
||||
# Check DNS
|
||||
dig www.threesix.ai
|
||||
dig landing.threesix.ai
|
||||
|
||||
# Check site
|
||||
curl -I https://www.threesix.ai
|
||||
curl -I https://landing.threesix.ai
|
||||
|
||||
# Check deployment status
|
||||
curl https://rdev.masq-ops.orchard9.ai/project/www \
|
||||
curl https://rdev.masq-ops.orchard9.ai/project/landing \
|
||||
-H "Authorization: Bearer $RDEV_KEY"
|
||||
```
|
||||
|
||||
@ -329,8 +317,8 @@ curl https://rdev.masq-ops.orchard9.ai/project/www \
|
||||
To remove the landing page:
|
||||
|
||||
```bash
|
||||
# Delete via rdev-api (removes Gitea repo, DNS, K8s deployment)
|
||||
curl -X DELETE https://rdev.masq-ops.orchard9.ai/project/www \
|
||||
# Delete via rdev-api (removes DNS, K8s deployment; Gitea repo preserved for safety)
|
||||
curl -X DELETE https://rdev.masq-ops.orchard9.ai/project/landing \
|
||||
-H "Authorization: Bearer $RDEV_KEY"
|
||||
```
|
||||
|
||||
@ -338,5 +326,7 @@ curl -X DELETE https://rdev.masq-ops.orchard9.ai/project/www \
|
||||
|
||||
## Related
|
||||
|
||||
- [THREESIX_INFRASTRUCTURE.md](/Users/jordanwashburn/Workspace/orchard9/rdev/docs/plans/THREESIX_INFRASTRUCTURE.md) - Infrastructure plan
|
||||
- [woodpecker-pipeline-template.yaml](../deployments/k8s/base/threesix/woodpecker-pipeline-template.yaml) - CI template
|
||||
- [Build Orchestration](../ai-lookup/features/build-orchestration.md) - Build system documentation
|
||||
- [Worker Pool](../ai-lookup/services/worker-pool.md) - Worker pool management
|
||||
- [Work Queue](../.claude/guides/services/work-queue.md) - Work queue guide
|
||||
- [Templates](../.claude/guides/services/templates.md) - Project template guide
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# Multi-Provider Code Agent Interface
|
||||
|
||||
> **Status:** In Progress (Weeks 1-4 Complete)
|
||||
> **Status:** Complete (Weeks 1-5)
|
||||
> **Feature:** Unified interface supporting Claude Code and OpenCode providers
|
||||
|
||||
## Overview
|
||||
@ -15,7 +15,7 @@ This document describes the architecture for supporting multiple code agent prov
|
||||
| Week 2: Claude Code Adapter | ✅ Complete | kubectl exec wrapper, stream-json parser |
|
||||
| Week 3: OpenCode Adapter | ✅ Complete | HTTP/SSE client, session management |
|
||||
| Week 4: Service Integration | ✅ Complete | ProjectService integration, event streaming |
|
||||
| Week 5: Polish | ⬜ Pending | Model selection API, health monitoring, metrics, docs |
|
||||
| Week 5: Polish | ✅ Complete | Agent HTTP endpoints, health monitoring, metrics, DI wiring |
|
||||
|
||||
## Architecture
|
||||
|
||||
@ -448,7 +448,78 @@ func (s *ProjectService) GetDefaultAgent() domain.AgentProvider
|
||||
func (s *ProjectService) SetDefaultAgent(provider domain.AgentProvider) error
|
||||
```
|
||||
|
||||
## API Changes (⬜ Pending - Week 5)
|
||||
## API Changes (✅ Complete - Week 5)
|
||||
|
||||
### Agent Management Endpoints
|
||||
|
||||
```http
|
||||
# List all registered agents
|
||||
GET /agents
|
||||
|
||||
{
|
||||
"data": {
|
||||
"agents": [
|
||||
{
|
||||
"provider": "claudecode",
|
||||
"name": "Claude Code",
|
||||
"available": true,
|
||||
"default": true,
|
||||
"supported_models": ["claude-sonnet-4-20250514"],
|
||||
"default_model": "claude-sonnet-4-20250514"
|
||||
}
|
||||
],
|
||||
"default_agent": "claudecode",
|
||||
"total_agents": 1,
|
||||
"available_count": 1
|
||||
}
|
||||
}
|
||||
|
||||
# Get agent capabilities
|
||||
GET /agents/{provider}
|
||||
|
||||
{
|
||||
"data": {
|
||||
"provider": "claudecode",
|
||||
"supports_session_continuation": true,
|
||||
"supports_model_selection": false,
|
||||
"supports_tool_control": true,
|
||||
"supports_streaming": true,
|
||||
"supported_models": ["claude-sonnet-4-20250514"],
|
||||
"default_model": "claude-sonnet-4-20250514",
|
||||
"max_prompt_length": 100000
|
||||
}
|
||||
}
|
||||
|
||||
# Set default agent
|
||||
POST /agents/default
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"provider": "opencode"
|
||||
}
|
||||
|
||||
# Agent health status
|
||||
GET /agents/health
|
||||
|
||||
{
|
||||
"data": {
|
||||
"agents": [
|
||||
{
|
||||
"provider": "claudecode",
|
||||
"name": "Claude Code",
|
||||
"healthy": true,
|
||||
"message": "available",
|
||||
"latency": "1.234ms",
|
||||
"checked_at": "2025-01-27T10:00:00Z"
|
||||
}
|
||||
],
|
||||
"healthy_count": 1,
|
||||
"total_count": 1,
|
||||
"default_agent": "claudecode",
|
||||
"default_healthy": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Project Response
|
||||
|
||||
@ -531,35 +602,56 @@ internal/
|
||||
│ ├── adapter.go ✅ CodeAgent implementation
|
||||
│ ├── client.go ✅ HTTP/SSE client
|
||||
│ └── adapter_test.go ✅ Mock server tests
|
||||
├── handlers/
|
||||
│ ├── agents.go ✅ Week 5: Agent management endpoints
|
||||
│ └── agents_test.go ✅ Week 5: Handler tests
|
||||
├── service/
|
||||
│ ├── project_service.go ✅ Week 4: Agent registry integration
|
||||
│ ├── project_service_agent.go ✅ Week 4: Agent execution methods
|
||||
│ ├── project_service_agent.go ✅ Week 4: Agent execution methods + metrics
|
||||
│ ├── project_service_commands.go ✅ Extracted shell/git commands
|
||||
│ └── project_service_queue.go ✅ Extracted queue operations
|
||||
├── metrics/
|
||||
│ └── metrics.go ✅ Week 5: Agent metrics (requests, tool use, availability)
|
||||
└── worker/
|
||||
└── queue_processor.go ⬜ Week 5: Use CodeAgent for queue
|
||||
└── queue_processor.go ⬜ Future: Use CodeAgent for queue
|
||||
cmd/
|
||||
└── rdev-api/
|
||||
├── main.go ✅ Week 5: Agent registry DI wiring
|
||||
└── openapi.go ✅ Week 5: Agent API documentation
|
||||
```
|
||||
|
||||
## Observability (⬜ Pending - Week 5)
|
||||
## Observability (✅ Complete - Week 5)
|
||||
|
||||
### Metrics
|
||||
### Prometheus Metrics
|
||||
|
||||
| Metric | Labels | Description |
|
||||
|--------|--------|-------------|
|
||||
| `code_agent_requests_total` | provider, project, status | Total requests |
|
||||
| `code_agent_duration_seconds` | provider, project | Execution duration |
|
||||
| `code_agent_events_total` | provider, event_type | Streaming events |
|
||||
| `rdev_agent_requests_total` | provider, status | Total code agent requests |
|
||||
| `rdev_agent_request_duration_seconds` | provider | Execution duration histogram |
|
||||
| `rdev_agent_tool_use_total` | provider, tool | Tool invocations by agents |
|
||||
| `rdev_agent_available` | provider | Availability gauge (1=available, 0=unavailable) |
|
||||
|
||||
### Health Check
|
||||
|
||||
```http
|
||||
GET /health
|
||||
GET /agents/health
|
||||
|
||||
{
|
||||
"status": "healthy",
|
||||
"agents": {
|
||||
"claudecode": "available",
|
||||
"opencode": "unavailable"
|
||||
"data": {
|
||||
"agents": [
|
||||
{
|
||||
"provider": "claudecode",
|
||||
"name": "Claude Code",
|
||||
"healthy": true,
|
||||
"message": "available",
|
||||
"latency": "1.234ms",
|
||||
"checked_at": "2025-01-27T10:00:00Z"
|
||||
}
|
||||
],
|
||||
"healthy_count": 1,
|
||||
"total_count": 1,
|
||||
"default_agent": "claudecode",
|
||||
"default_healthy": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
320
docs/plans/worker-executor-breakdown.md
Normal file
320
docs/plans/worker-executor-breakdown.md
Normal file
@ -0,0 +1,320 @@
|
||||
# Worker Executor Implementation Plan
|
||||
|
||||
> Close the last gap in the landing page cookbook: automated code generation via the worker pool.
|
||||
|
||||
## Context
|
||||
|
||||
The work queue, worker registry, build audit, and code agent systems are **all implemented**. The single missing piece is a **work executor** — a background loop that consumes queued tasks and executes them via a code agent. This is analogous to the existing `QueueProcessor` (which processes per-project command queue tasks), but for the generic `WorkQueue` (cross-project worker pool tasks).
|
||||
|
||||
### What Already Exists
|
||||
|
||||
| Component | File | Status |
|
||||
|-----------|------|--------|
|
||||
| Work queue (PostgreSQL) | `internal/adapter/postgres/work_queue.go` | Done |
|
||||
| Worker registry (PostgreSQL) | `internal/adapter/postgres/worker_registry.go` | Done |
|
||||
| Build audit (PostgreSQL) | `internal/adapter/postgres/build_audit.go` | Done |
|
||||
| WorkService (enqueue/dequeue/complete/fail) | `internal/service/work_service.go` | Done |
|
||||
| WorkerService (claim/complete/health) | `internal/service/worker_service.go` | Done |
|
||||
| BuildService (start/status/complete) | `internal/service/build_service.go` | Done |
|
||||
| WorkHandler (REST API) | `internal/handlers/work.go` | Done |
|
||||
| AgentsHandler (REST API) | `internal/handlers/agents.go` | Done |
|
||||
| CodeAgent interface | `internal/port/code_agent.go` | Done |
|
||||
| Domain models (WorkTask, Worker, BuildSpec) | `internal/domain/` | Done |
|
||||
| Command QueueProcessor (reference pattern) | `internal/worker/queue_processor.go` | Done |
|
||||
|
||||
### What's Missing
|
||||
|
||||
| Gap | Priority |
|
||||
|-----|----------|
|
||||
| Work executor daemon (poll loop) | Critical |
|
||||
| BuildSpec → AgentRequest translation | Critical |
|
||||
| Git clone/commit/push in executor | Critical |
|
||||
| Git credential resolution for cross-project | High |
|
||||
| Worker management REST endpoints | Medium |
|
||||
| DNS alias endpoint | Medium |
|
||||
| Create-and-build endpoint | Medium |
|
||||
| Woodpecker build status proxy | Low |
|
||||
|
||||
---
|
||||
|
||||
## Week 1: Work Executor Core
|
||||
|
||||
**Goal:** A background loop that claims tasks from the work queue and executes them via a code agent. By end of week, `POST /work/enqueue` → task claimed → agent executes → result recorded.
|
||||
|
||||
### Tasks
|
||||
|
||||
1. **Create `internal/worker/work_executor.go`**
|
||||
- Follow the `QueueProcessor` pattern from `queue_processor.go`
|
||||
- Poll loop: calls `WorkerService.ClaimTask(workerID)` on a ticker
|
||||
- On task claim: route to appropriate handler based on `task.Type`
|
||||
- On completion: call `WorkerService.CompleteTask(workerID, taskID, result)`
|
||||
- On failure: call `WorkService.FailTask(taskID, errMsg)` (handles retry logic)
|
||||
- Graceful shutdown via context cancellation
|
||||
- Self-registers as a worker via `WorkerService.Register()` on start
|
||||
- Sends heartbeats via `WorkerService.Heartbeat()` on a 30s ticker
|
||||
|
||||
2. **Create `internal/worker/build_executor.go`**
|
||||
- Handles `WorkTaskTypeBuild` tasks specifically
|
||||
- Extracts `BuildSpec` fields from `WorkTask.Spec` (map[string]any → typed fields)
|
||||
- Translates `BuildSpec.Prompt` into `domain.AgentRequest`
|
||||
- Calls `CodeAgent.Execute()` with event streaming
|
||||
- Collects output, files changed, duration into `domain.BuildResult`
|
||||
- Returns `BuildResult` to the work executor
|
||||
|
||||
3. **Wire into `cmd/rdev-api/main.go`**
|
||||
- Create `WorkExecutor` alongside existing `QueueProcessor`
|
||||
- Inject: `WorkerService`, `BuildService`, `CodeAgentRegistry`
|
||||
- Start on boot, stop on shutdown
|
||||
- Worker ID: hostname or pod name (from `HOSTNAME` env var)
|
||||
|
||||
4. **Create `internal/worker/work_executor_test.go`**
|
||||
- Test: executor starts and registers as a worker
|
||||
- Test: executor claims a task and routes to build handler
|
||||
- Test: build handler translates spec and calls code agent
|
||||
- Test: results are recorded via CompleteTask
|
||||
- Test: failures trigger FailTask with retry
|
||||
- Test: graceful shutdown stops the poll loop
|
||||
- Use mock implementations of ports
|
||||
|
||||
### Deliverables
|
||||
|
||||
- `POST /work/enqueue` with a build task → executor picks it up → agent runs → result in `GET /work/{taskId}`
|
||||
- Worker visible in registry during execution
|
||||
- Build audit entry created with spec and result
|
||||
|
||||
### Files Created/Modified
|
||||
|
||||
| File | Action |
|
||||
|------|--------|
|
||||
| `internal/worker/work_executor.go` | Create |
|
||||
| `internal/worker/build_executor.go` | Create |
|
||||
| `internal/worker/work_executor_test.go` | Create |
|
||||
| `cmd/rdev-api/main.go` | Modify (wire executor) |
|
||||
|
||||
---
|
||||
|
||||
## Week 2: Git Operations & Cross-Project Execution
|
||||
|
||||
**Goal:** The executor can clone any project's repo, run the agent in that directory, and push results back. By end of week, the full build cycle works: enqueue → clone → agent generates code → commit → push → CI triggers.
|
||||
|
||||
### Tasks
|
||||
|
||||
1. **Create `internal/worker/git_operations.go`**
|
||||
- `CloneRepo(ctx, gitURL, dir, token) error` — clone via HTTPS with token auth
|
||||
- `CommitAndPush(ctx, dir, message) (commitSHA string, filesChanged []string, err error)`
|
||||
- `ConfigureGit(dir, name, email)` — set git user for commits
|
||||
- Uses `os/exec` for git commands (same pattern as `kubernetes.Executor` uses for kubectl)
|
||||
- Workspace management: creates temp dir per task, cleans up after
|
||||
|
||||
2. **Add git credential resolution to `BuildExecutor`**
|
||||
- Option A (simplest): Use the Gitea token already in `InfraConfig.GiteaToken`
|
||||
- All project repos are in Gitea, so one token covers all repos
|
||||
- Pass token via HTTPS clone URL: `https://token@git.threesix.ai/org/repo.git`
|
||||
- Option B (per-project): Look up project's git URL from database, resolve credentials
|
||||
- **Recommendation:** Option A — the Gitea token is already loaded and available
|
||||
|
||||
3. **Integrate git ops into `BuildExecutor`**
|
||||
- Before agent execution: clone the project's repo to a temp directory
|
||||
- Look up project git URL from database (add `ProjectStore` port or query directly)
|
||||
- After agent execution: if `auto_commit` is true, commit changes
|
||||
- After commit: if `auto_push` is true, push to remote
|
||||
- Capture `commit_sha` and `files_changed` in `BuildResult`
|
||||
|
||||
4. **Add project git URL lookup**
|
||||
- The `ProjectInfraService` stores git URLs in the database during `CreateProject`
|
||||
- Add a method to retrieve git info by project ID
|
||||
- Or: include `git_url` in the `WorkTask.Spec` at enqueue time (simpler, no extra lookup)
|
||||
|
||||
5. **Create `internal/worker/git_operations_test.go`**
|
||||
- Test: clone with token auth
|
||||
- Test: commit and push
|
||||
- Test: workspace cleanup on success and failure
|
||||
- Test: git URL construction with token
|
||||
|
||||
6. **Integration test**
|
||||
- Enqueue a build task with a real prompt
|
||||
- Verify agent executes in cloned repo
|
||||
- Verify commit is created (if auto_commit)
|
||||
- Verify push succeeds (if auto_push)
|
||||
- Verify BuildResult has correct fields
|
||||
|
||||
### Deliverables
|
||||
|
||||
- Full build cycle: enqueue → clone → execute → commit → push
|
||||
- Git credentials resolved from infrastructure config
|
||||
- Temp workspace created and cleaned per task
|
||||
- Build audit shows commit SHA and files changed
|
||||
|
||||
### Files Created/Modified
|
||||
|
||||
| File | Action |
|
||||
|------|--------|
|
||||
| `internal/worker/git_operations.go` | Create |
|
||||
| `internal/worker/git_operations_test.go` | Create |
|
||||
| `internal/worker/build_executor.go` | Modify (add git integration) |
|
||||
| `internal/worker/work_executor.go` | Modify (pass git config) |
|
||||
| `cmd/rdev-api/main.go` | Modify (pass gitea token to executor) |
|
||||
|
||||
---
|
||||
|
||||
## Week 3: API Enhancements
|
||||
|
||||
**Goal:** Add the REST endpoints that complete the platform experience. By end of week, users can create a project, enqueue a build, monitor CI status, and manage DNS — all through rdev-api.
|
||||
|
||||
### Tasks
|
||||
|
||||
1. **Worker management endpoints — `internal/handlers/workers.go`**
|
||||
- `GET /workers` — list all workers with status
|
||||
- `GET /workers/{id}` — get worker details
|
||||
- `POST /workers/{id}/drain` — drain a worker
|
||||
- Wire `WorkerService` into handler
|
||||
- Register in `cmd/rdev-api/main.go` and `openapi.go`
|
||||
|
||||
2. **Build management endpoints — `internal/handlers/builds.go`**
|
||||
- `POST /projects/{id}/builds` — enqueue a build (wraps `BuildService.StartBuild()`)
|
||||
- `GET /projects/{id}/builds` — list build history
|
||||
- `GET /projects/{id}/builds/{taskId}` — get build status
|
||||
- Simpler API than raw `/work/enqueue` — project-scoped, build-specific
|
||||
- Register in `cmd/rdev-api/main.go` and `openapi.go`
|
||||
|
||||
3. **DNS alias endpoint — `internal/handlers/infrastructure.go`**
|
||||
- `POST /projects/{id}/domains` — add DNS alias (A or CNAME record)
|
||||
- `GET /projects/{id}/domains` — list domains for project
|
||||
- `DELETE /projects/{id}/domains/{domain}` — remove alias
|
||||
- Uses existing Cloudflare adapter's `CreateRecord()` and `DeleteRecordByName()`
|
||||
- The adapter already supports full CRUD — just needs a handler
|
||||
|
||||
4. **Woodpecker build status proxy — `internal/handlers/ci.go`**
|
||||
- `GET /projects/{id}/ci/pipelines` — list recent Woodpecker pipelines
|
||||
- `GET /projects/{id}/ci/pipelines/{number}` — get pipeline details
|
||||
- Add `ListPipelines()` and `GetPipeline()` to `port.CIProvider`
|
||||
- Implement in `internal/adapter/woodpecker/client.go` using Woodpecker SDK
|
||||
- Low priority — can defer if time is tight
|
||||
|
||||
5. **Create-and-build endpoint — `internal/handlers/project_management.go`**
|
||||
- `POST /project/create-and-build`
|
||||
- Request: `{ name, description, template, prompt, auto_push }`
|
||||
- Calls `ProjectInfraService.CreateProject()` then `BuildService.StartBuild()`
|
||||
- Returns project info + task ID
|
||||
- Trivial once executor is working
|
||||
|
||||
6. **Tests for all new handlers**
|
||||
- Follow existing patterns in `handlers/*_test.go`
|
||||
- Test request validation, success paths, error handling
|
||||
|
||||
### Deliverables
|
||||
|
||||
- `POST /projects/{id}/builds` as the clean API for code generation
|
||||
- `GET /workers` for monitoring the worker pool
|
||||
- `POST /projects/{id}/domains` for DNS aliases
|
||||
- `POST /project/create-and-build` for the single-call flow
|
||||
- All endpoints documented in `openapi.go`
|
||||
|
||||
### Files Created/Modified
|
||||
|
||||
| File | Action |
|
||||
|------|--------|
|
||||
| `internal/handlers/workers.go` | Create |
|
||||
| `internal/handlers/workers_test.go` | Create |
|
||||
| `internal/handlers/builds.go` | Create |
|
||||
| `internal/handlers/builds_test.go` | Create |
|
||||
| `internal/handlers/infrastructure.go` | Modify (add domain endpoints) |
|
||||
| `internal/handlers/ci.go` | Create (if time) |
|
||||
| `internal/handlers/project_management.go` | Modify (add create-and-build) |
|
||||
| `internal/adapter/woodpecker/client.go` | Modify (add pipeline methods, if time) |
|
||||
| `internal/port/ci.go` or port updates | Modify (add pipeline interface, if time) |
|
||||
| `cmd/rdev-api/main.go` | Modify (wire new handlers) |
|
||||
| `cmd/rdev-api/openapi.go` | Modify (add routes to spec) |
|
||||
|
||||
---
|
||||
|
||||
## Week 4: Polish, Validation & Observability
|
||||
|
||||
**Goal:** End-to-end validation of the cookbook flow. Observability for production operation. Documentation updated.
|
||||
|
||||
### Tasks
|
||||
|
||||
1. **End-to-end cookbook validation**
|
||||
- Run the landing page cookbook flow from start to finish
|
||||
- `POST /project` with `astro-landing` template
|
||||
- `POST /projects/landing/builds` with customization prompt
|
||||
- Monitor via `GET /work/{taskId}/status`
|
||||
- Verify CI triggers on push
|
||||
- Verify site is live at `https://landing.threesix.ai`
|
||||
- Fix any issues found during validation
|
||||
|
||||
2. **Stale task recovery**
|
||||
- Add periodic `RequeueStale()` call to the work executor
|
||||
- Requeue tasks where the worker crashed mid-execution
|
||||
- Add periodic `CleanupOld()` call to remove ancient completed tasks
|
||||
- These methods exist on `WorkQueue` but nothing calls them
|
||||
|
||||
3. **Observability additions**
|
||||
- Add metrics to work executor: tasks_claimed, tasks_completed, tasks_failed, execution_duration
|
||||
- Add metrics to worker service: workers_registered, workers_idle, workers_busy
|
||||
- Follow existing pattern in `internal/metrics/metrics.go`
|
||||
- Add work executor health to readiness check (`GET /ready`)
|
||||
|
||||
4. **Queue maintenance worker**
|
||||
- Create `internal/worker/queue_maintenance.go`
|
||||
- Runs on a slower ticker (every 5 minutes)
|
||||
- Calls `RequeueStale(ctx, 10*time.Minute)` — requeue tasks running > 10min with no heartbeat
|
||||
- Calls `CleanupOld(ctx, 7*24*time.Hour)` — prune tasks older than 7 days
|
||||
- Wire into main.go
|
||||
|
||||
5. **Update documentation**
|
||||
- Update `cookbooks/landing-page.md` with final validated flow
|
||||
- Update `ai-lookup/features/build-orchestration.md`
|
||||
- Update `ai-lookup/services/worker-pool.md`
|
||||
- Add `.claude/guides/services/build-orchestration.md` if needed
|
||||
|
||||
6. **Update CLAUDE.md roadmap**
|
||||
- Mark "Work Queue" as implemented
|
||||
- Mark "Worker Pool" as implemented
|
||||
- Mark "Build Orchestration" as implemented
|
||||
- Update "Bot Communication" status
|
||||
|
||||
### Deliverables
|
||||
|
||||
- Cookbook flow works end-to-end without manual intervention (except code generation prompt)
|
||||
- Stale task recovery running in production
|
||||
- Metrics visible in `/metrics` endpoint
|
||||
- All documentation reflects actual capabilities
|
||||
|
||||
### Files Created/Modified
|
||||
|
||||
| File | Action |
|
||||
|------|--------|
|
||||
| `internal/worker/queue_maintenance.go` | Create |
|
||||
| `internal/metrics/metrics.go` | Modify (add work executor metrics) |
|
||||
| `internal/handlers/health.go` | Modify (add executor health) |
|
||||
| `cookbooks/landing-page.md` | Modify (final validation) |
|
||||
| `ai-lookup/features/build-orchestration.md` | Modify |
|
||||
| `ai-lookup/services/worker-pool.md` | Modify |
|
||||
| `CLAUDE.md` | Modify (update roadmap) |
|
||||
| `cmd/rdev-api/main.go` | Modify (wire maintenance worker) |
|
||||
|
||||
---
|
||||
|
||||
## Risk & Dependencies
|
||||
|
||||
| Risk | Mitigation |
|
||||
|------|-----------|
|
||||
| CodeAgent execution in a temp directory (not a K8s pod) may not work the same as in-pod execution | Test early in Week 1; fallback is to kubectl exec into a worker pod |
|
||||
| Gitea token may lack permissions for new repos created by different users | Test with actual token; all repos should be in the same org |
|
||||
| Agent execution may take longer than expected (10+ minutes for complex prompts) | Make timeout configurable; increase default |
|
||||
| Worker process crash loses in-flight task | Stale requeue (Week 4) handles this automatically |
|
||||
| 500-line file limit may require splitting new files | Plan for split from the start; `work_executor.go` + `build_executor.go` + `git_operations.go` keeps things modular |
|
||||
|
||||
## Architecture Decision: In-Process vs External Worker
|
||||
|
||||
The plan above implements the executor **in-process** (running inside the rdev-api binary). This is simpler and matches the existing `QueueProcessor` pattern. The alternative would be a separate worker binary, which would allow independent scaling. The in-process approach is the right starting point — it can be extracted into a separate binary later if scaling requires it.
|
||||
|
||||
## Summary
|
||||
|
||||
| Week | Focus | Key Deliverable |
|
||||
|------|-------|----------------|
|
||||
| 1 | Work executor core | Tasks flow from queue → agent → result |
|
||||
| 2 | Git operations | Clone → execute → commit → push cycle |
|
||||
| 3 | API enhancements | Build, worker, DNS, create-and-build endpoints |
|
||||
| 4 | Polish & validation | E2E cookbook flow, observability, docs |
|
||||
79
internal/adapter/postgres/apikey_helpers_test.go
Normal file
79
internal/adapter/postgres/apikey_helpers_test.go
Normal file
@ -0,0 +1,79 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
)
|
||||
|
||||
// Helper function conversion tests
|
||||
|
||||
func TestScopesToStrings(t *testing.T) {
|
||||
scopes := []domain.Scope{domain.ScopeProjectsRead, domain.ScopeAdmin}
|
||||
strings := scopesToStrings(scopes)
|
||||
|
||||
if len(strings) != 2 {
|
||||
t.Fatalf("Length = %d, want 2", len(strings))
|
||||
}
|
||||
if strings[0] != "projects:read" {
|
||||
t.Errorf("strings[0] = %q, want %q", strings[0], "projects:read")
|
||||
}
|
||||
if strings[1] != "admin" {
|
||||
t.Errorf("strings[1] = %q, want %q", strings[1], "admin")
|
||||
}
|
||||
}
|
||||
|
||||
func TestScopesFromStrings(t *testing.T) {
|
||||
strings := []string{"projects:read", "keys:write"}
|
||||
scopes := scopesFromStrings(strings)
|
||||
|
||||
if len(scopes) != 2 {
|
||||
t.Fatalf("Length = %d, want 2", len(scopes))
|
||||
}
|
||||
if scopes[0] != domain.ScopeProjectsRead {
|
||||
t.Errorf("scopes[0] = %q, want %q", scopes[0], domain.ScopeProjectsRead)
|
||||
}
|
||||
if scopes[1] != domain.ScopeKeysWrite {
|
||||
t.Errorf("scopes[1] = %q, want %q", scopes[1], domain.ScopeKeysWrite)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProjectIDsToStrings(t *testing.T) {
|
||||
t.Run("nil input", func(t *testing.T) {
|
||||
result := projectIDsToStrings(nil)
|
||||
if result != nil {
|
||||
t.Errorf("Expected nil, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-nil input", func(t *testing.T) {
|
||||
ids := []domain.ProjectID{"proj-a", "proj-b"}
|
||||
result := projectIDsToStrings(ids)
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("Length = %d, want 2", len(result))
|
||||
}
|
||||
if result[0] != "proj-a" || result[1] != "proj-b" {
|
||||
t.Errorf("Unexpected result: %v", result)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestProjectIDsFromStrings(t *testing.T) {
|
||||
t.Run("nil input", func(t *testing.T) {
|
||||
result := projectIDsFromStrings(nil)
|
||||
if result != nil {
|
||||
t.Errorf("Expected nil, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-nil input", func(t *testing.T) {
|
||||
strings := []string{"proj-x", "proj-y"}
|
||||
result := projectIDsFromStrings(strings)
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("Length = %d, want 2", len(result))
|
||||
}
|
||||
if result[0] != "proj-x" || result[1] != "proj-y" {
|
||||
t.Errorf("Unexpected result: %v", result)
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -28,7 +28,7 @@ func TestAPIKeyRepository_Create(t *testing.T) {
|
||||
key := &domain.APIKey{
|
||||
Name: "test-repo-create",
|
||||
KeyPrefix: "abc12345",
|
||||
Scopes: []domain.Scope{domain.ScopeProjectsRead, domain.ScopeKeysManage},
|
||||
Scopes: []domain.Scope{domain.ScopeProjectsRead, domain.ScopeKeysWrite},
|
||||
ProjectIDs: []domain.ProjectID{"proj-a", "proj-b"},
|
||||
AllowedIPs: []string{"192.168.1.0/24", "10.0.0.1"},
|
||||
ExpiresAt: &expires,
|
||||
@ -302,9 +302,9 @@ func TestAPIKeyRepository_ScopeArrayHandling(t *testing.T) {
|
||||
scopes []domain.Scope
|
||||
}{
|
||||
{"single scope", []domain.Scope{domain.ScopeProjectsRead}},
|
||||
{"multiple scopes", []domain.Scope{domain.ScopeProjectsRead, domain.ScopeProjectsExecute, domain.ScopeKeysManage}},
|
||||
{"multiple scopes", []domain.Scope{domain.ScopeProjectsRead, domain.ScopeProjectsExecute, domain.ScopeKeysWrite}},
|
||||
{"admin scope", []domain.Scope{domain.ScopeAdmin}},
|
||||
{"all scopes", []domain.Scope{domain.ScopeProjectsRead, domain.ScopeProjectsExecute, domain.ScopeKeysManage, domain.ScopeKeysManage, domain.ScopeAdmin}},
|
||||
{"all scopes", []domain.Scope{domain.ScopeProjectsRead, domain.ScopeProjectsExecute, domain.ScopeKeysRead, domain.ScopeKeysWrite, domain.ScopeAdmin}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@ -435,74 +435,3 @@ func TestAPIKeyRepository_AllowedIPsArrayHandling(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function conversion tests
|
||||
func TestScopesToStrings(t *testing.T) {
|
||||
scopes := []domain.Scope{domain.ScopeProjectsRead, domain.ScopeAdmin}
|
||||
strings := scopesToStrings(scopes)
|
||||
|
||||
if len(strings) != 2 {
|
||||
t.Fatalf("Length = %d, want 2", len(strings))
|
||||
}
|
||||
if strings[0] != "projects:read" {
|
||||
t.Errorf("strings[0] = %q, want %q", strings[0], "projects:read")
|
||||
}
|
||||
if strings[1] != "admin" {
|
||||
t.Errorf("strings[1] = %q, want %q", strings[1], "admin")
|
||||
}
|
||||
}
|
||||
|
||||
func TestScopesFromStrings(t *testing.T) {
|
||||
strings := []string{"projects:read", "keys:manage"}
|
||||
scopes := scopesFromStrings(strings)
|
||||
|
||||
if len(scopes) != 2 {
|
||||
t.Fatalf("Length = %d, want 2", len(scopes))
|
||||
}
|
||||
if scopes[0] != domain.ScopeProjectsRead {
|
||||
t.Errorf("scopes[0] = %q, want %q", scopes[0], domain.ScopeProjectsRead)
|
||||
}
|
||||
if scopes[1] != domain.ScopeKeysManage {
|
||||
t.Errorf("scopes[1] = %q, want %q", scopes[1], domain.ScopeKeysManage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProjectIDsToStrings(t *testing.T) {
|
||||
t.Run("nil input", func(t *testing.T) {
|
||||
result := projectIDsToStrings(nil)
|
||||
if result != nil {
|
||||
t.Errorf("Expected nil, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-nil input", func(t *testing.T) {
|
||||
ids := []domain.ProjectID{"proj-a", "proj-b"}
|
||||
result := projectIDsToStrings(ids)
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("Length = %d, want 2", len(result))
|
||||
}
|
||||
if result[0] != "proj-a" || result[1] != "proj-b" {
|
||||
t.Errorf("Unexpected result: %v", result)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestProjectIDsFromStrings(t *testing.T) {
|
||||
t.Run("nil input", func(t *testing.T) {
|
||||
result := projectIDsFromStrings(nil)
|
||||
if result != nil {
|
||||
t.Errorf("Expected nil, got %v", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-nil input", func(t *testing.T) {
|
||||
strings := []string{"proj-x", "proj-y"}
|
||||
result := projectIDsFromStrings(strings)
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("Length = %d, want 2", len(result))
|
||||
}
|
||||
if result[0] != "proj-x" || result[1] != "proj-y" {
|
||||
t.Errorf("Unexpected result: %v", result)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
210
internal/adapter/postgres/build_audit.go
Normal file
210
internal/adapter/postgres/build_audit.go
Normal file
@ -0,0 +1,210 @@
|
||||
// Package postgres provides PostgreSQL-based implementations of port interfaces.
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
)
|
||||
|
||||
// BuildAuditRepository implements port.BuildAudit using PostgreSQL.
|
||||
type BuildAuditRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewBuildAuditRepository creates a new PostgreSQL build audit repository.
|
||||
func NewBuildAuditRepository(db *sql.DB) *BuildAuditRepository {
|
||||
return &BuildAuditRepository{db: db}
|
||||
}
|
||||
|
||||
// Ensure BuildAuditRepository implements port.BuildAudit at compile time.
|
||||
var _ port.BuildAudit = (*BuildAuditRepository)(nil)
|
||||
|
||||
// Record creates a new audit entry when a build starts.
|
||||
func (r *BuildAuditRepository) Record(ctx context.Context, entry *domain.BuildAuditEntry) error {
|
||||
specJSON, err := json.Marshal(entry.Spec)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal build spec: %w", err)
|
||||
}
|
||||
|
||||
_, err = r.db.ExecContext(ctx, `
|
||||
INSERT INTO build_audit (task_id, project_id, worker_id, spec, status, started_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
`, entry.TaskID, entry.ProjectID, nullString(entry.WorkerID),
|
||||
specJSON, entry.Status, entry.StartedAt)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("record build audit: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update modifies an existing entry when a build completes.
|
||||
func (r *BuildAuditRepository) Update(ctx context.Context, taskID string, result *domain.BuildResult) error {
|
||||
var resultJSON []byte
|
||||
var err error
|
||||
|
||||
if result != nil {
|
||||
resultJSON, err = json.Marshal(result)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal build result: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
status := domain.BuildStatusCompleted
|
||||
if result != nil && !result.Success {
|
||||
status = domain.BuildStatusFailed
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
res, err := r.db.ExecContext(ctx, `
|
||||
UPDATE build_audit
|
||||
SET result = $2, status = $3, completed_at = $4
|
||||
WHERE task_id = $1
|
||||
`, taskID, resultJSON, status, now)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update build audit: %w", err)
|
||||
}
|
||||
|
||||
rows, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("rows affected: %w", err)
|
||||
}
|
||||
if rows == 0 {
|
||||
return domain.ErrBuildNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves a specific audit entry by task ID.
|
||||
func (r *BuildAuditRepository) Get(ctx context.Context, taskID string) (*domain.BuildAuditEntry, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT task_id, project_id, worker_id, spec, result, status,
|
||||
started_at, completed_at
|
||||
FROM build_audit
|
||||
WHERE task_id = $1
|
||||
`, taskID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get build audit: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
if !rows.Next() {
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("get build audit: %w", err)
|
||||
}
|
||||
return nil, domain.ErrBuildNotFound
|
||||
}
|
||||
return r.scanEntry(rows)
|
||||
}
|
||||
|
||||
// List returns audit entries matching the filter.
|
||||
func (r *BuildAuditRepository) List(ctx context.Context, filter port.BuildAuditFilter) ([]*domain.BuildAuditEntry, error) {
|
||||
query := `
|
||||
SELECT task_id, project_id, worker_id, spec, result, status,
|
||||
started_at, completed_at
|
||||
FROM build_audit
|
||||
WHERE 1=1`
|
||||
args := []any{}
|
||||
argNum := 1
|
||||
|
||||
if filter.ProjectID != "" {
|
||||
query += fmt.Sprintf(" AND project_id = $%d", argNum)
|
||||
args = append(args, filter.ProjectID)
|
||||
argNum++
|
||||
}
|
||||
|
||||
if filter.WorkerID != "" {
|
||||
query += fmt.Sprintf(" AND worker_id = $%d", argNum)
|
||||
args = append(args, filter.WorkerID)
|
||||
argNum++
|
||||
}
|
||||
|
||||
if filter.Status != nil {
|
||||
query += fmt.Sprintf(" AND status = $%d", argNum)
|
||||
args = append(args, string(*filter.Status))
|
||||
argNum++
|
||||
}
|
||||
|
||||
if !filter.Since.IsZero() {
|
||||
query += fmt.Sprintf(" AND started_at >= $%d", argNum)
|
||||
args = append(args, filter.Since)
|
||||
argNum++
|
||||
}
|
||||
|
||||
query += " ORDER BY started_at DESC"
|
||||
|
||||
if filter.Limit > 0 {
|
||||
query += fmt.Sprintf(" LIMIT $%d", argNum)
|
||||
args = append(args, filter.Limit)
|
||||
}
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list build audit: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var entries []*domain.BuildAuditEntry
|
||||
for rows.Next() {
|
||||
entry, err := r.scanEntry(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
|
||||
return entries, rows.Err()
|
||||
}
|
||||
|
||||
// scanEntry scans a single build audit row from a query result.
|
||||
func (r *BuildAuditRepository) scanEntry(rows *sql.Rows) (*domain.BuildAuditEntry, error) {
|
||||
var entry domain.BuildAuditEntry
|
||||
var workerID sql.NullString
|
||||
var specJSON []byte
|
||||
var resultJSON []byte
|
||||
var completedAt sql.NullTime
|
||||
|
||||
err := rows.Scan(
|
||||
&entry.TaskID,
|
||||
&entry.ProjectID,
|
||||
&workerID,
|
||||
&specJSON,
|
||||
&resultJSON,
|
||||
&entry.Status,
|
||||
&entry.StartedAt,
|
||||
&completedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("scan build audit: %w", err)
|
||||
}
|
||||
|
||||
if workerID.Valid {
|
||||
entry.WorkerID = workerID.String
|
||||
}
|
||||
if completedAt.Valid {
|
||||
entry.CompletedAt = &completedAt.Time
|
||||
}
|
||||
|
||||
if len(specJSON) > 0 {
|
||||
if err := json.Unmarshal(specJSON, &entry.Spec); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal build spec: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(resultJSON) > 0 {
|
||||
entry.Result = &domain.BuildResult{}
|
||||
if err := json.Unmarshal(resultJSON, entry.Result); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal build result: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &entry, nil
|
||||
}
|
||||
256
internal/adapter/postgres/build_audit_test.go
Normal file
256
internal/adapter/postgres/build_audit_test.go
Normal file
@ -0,0 +1,256 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
"github.com/orchard9/rdev/internal/testutil"
|
||||
)
|
||||
|
||||
func cleanupTestBuildAudit(t *testing.T, db *sql.DB) {
|
||||
t.Helper()
|
||||
_, err := db.Exec("DELETE FROM build_audit WHERE project_id LIKE 'test-%'")
|
||||
if err != nil {
|
||||
t.Logf("cleanup test build_audit: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuditRepository_Record(t *testing.T) {
|
||||
db := testutil.TestDB(t)
|
||||
t.Cleanup(func() { cleanupTestBuildAudit(t, db) })
|
||||
|
||||
repo := NewBuildAuditRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("records new audit entry", func(t *testing.T) {
|
||||
entry := &domain.BuildAuditEntry{
|
||||
TaskID: "test-task-audit-1",
|
||||
ProjectID: "test-project-1",
|
||||
Spec: domain.BuildSpec{
|
||||
Prompt: "Build a landing page",
|
||||
Template: "nextjs",
|
||||
},
|
||||
Status: domain.BuildStatusPending,
|
||||
StartedAt: time.Now(),
|
||||
}
|
||||
|
||||
err := repo.Record(ctx, entry)
|
||||
if err != nil {
|
||||
t.Fatalf("Record() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify it was stored
|
||||
got, err := repo.Get(ctx, "test-task-audit-1")
|
||||
if err != nil {
|
||||
t.Fatalf("Get() error = %v", err)
|
||||
}
|
||||
if got.ProjectID != "test-project-1" {
|
||||
t.Errorf("got project_id %q, want %q", got.ProjectID, "test-project-1")
|
||||
}
|
||||
if got.Spec.Prompt != "Build a landing page" {
|
||||
t.Errorf("got prompt %q, want %q", got.Spec.Prompt, "Build a landing page")
|
||||
}
|
||||
if got.Status != domain.BuildStatusPending {
|
||||
t.Errorf("got status %q, want %q", got.Status, domain.BuildStatusPending)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("records entry with worker ID", func(t *testing.T) {
|
||||
entry := &domain.BuildAuditEntry{
|
||||
TaskID: "test-task-audit-2",
|
||||
ProjectID: "test-project-1",
|
||||
WorkerID: "worker-1",
|
||||
Spec: domain.BuildSpec{
|
||||
Prompt: "Run tests",
|
||||
},
|
||||
Status: domain.BuildStatusRunning,
|
||||
StartedAt: time.Now(),
|
||||
}
|
||||
|
||||
err := repo.Record(ctx, entry)
|
||||
if err != nil {
|
||||
t.Fatalf("Record() error = %v", err)
|
||||
}
|
||||
|
||||
got, err := repo.Get(ctx, "test-task-audit-2")
|
||||
if err != nil {
|
||||
t.Fatalf("Get() error = %v", err)
|
||||
}
|
||||
if got.WorkerID != "worker-1" {
|
||||
t.Errorf("got worker_id %q, want %q", got.WorkerID, "worker-1")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildAuditRepository_Update(t *testing.T) {
|
||||
db := testutil.TestDB(t)
|
||||
t.Cleanup(func() { cleanupTestBuildAudit(t, db) })
|
||||
|
||||
repo := NewBuildAuditRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create initial entry
|
||||
entry := &domain.BuildAuditEntry{
|
||||
TaskID: "test-task-update-1",
|
||||
ProjectID: "test-project-1",
|
||||
Spec: domain.BuildSpec{Prompt: "Build"},
|
||||
Status: domain.BuildStatusPending,
|
||||
StartedAt: time.Now(),
|
||||
}
|
||||
if err := repo.Record(ctx, entry); err != nil {
|
||||
t.Fatalf("Record() error = %v", err)
|
||||
}
|
||||
|
||||
t.Run("updates with success result", func(t *testing.T) {
|
||||
result := &domain.BuildResult{
|
||||
Success: true,
|
||||
Output: "Build successful",
|
||||
CommitSHA: "abc123",
|
||||
FilesChanged: []string{"index.html", "style.css"},
|
||||
DurationMs: 5000,
|
||||
}
|
||||
|
||||
err := repo.Update(ctx, "test-task-update-1", result)
|
||||
if err != nil {
|
||||
t.Fatalf("Update() error = %v", err)
|
||||
}
|
||||
|
||||
got, err := repo.Get(ctx, "test-task-update-1")
|
||||
if err != nil {
|
||||
t.Fatalf("Get() error = %v", err)
|
||||
}
|
||||
if got.Status != domain.BuildStatusCompleted {
|
||||
t.Errorf("got status %q, want %q", got.Status, domain.BuildStatusCompleted)
|
||||
}
|
||||
if got.Result == nil {
|
||||
t.Fatal("expected result to be set")
|
||||
}
|
||||
if !got.Result.Success {
|
||||
t.Error("expected result.Success = true")
|
||||
}
|
||||
if got.Result.CommitSHA != "abc123" {
|
||||
t.Errorf("got commit_sha %q, want %q", got.Result.CommitSHA, "abc123")
|
||||
}
|
||||
if got.CompletedAt == nil {
|
||||
t.Error("expected completed_at to be set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("updates with failure result", func(t *testing.T) {
|
||||
// Create a new entry
|
||||
entry := &domain.BuildAuditEntry{
|
||||
TaskID: "test-task-update-2",
|
||||
ProjectID: "test-project-1",
|
||||
Spec: domain.BuildSpec{Prompt: "Build"},
|
||||
Status: domain.BuildStatusPending,
|
||||
StartedAt: time.Now(),
|
||||
}
|
||||
if err := repo.Record(ctx, entry); err != nil {
|
||||
t.Fatalf("Record() error = %v", err)
|
||||
}
|
||||
|
||||
result := &domain.BuildResult{
|
||||
Success: false,
|
||||
Error: "compilation error",
|
||||
}
|
||||
|
||||
err := repo.Update(ctx, "test-task-update-2", result)
|
||||
if err != nil {
|
||||
t.Fatalf("Update() error = %v", err)
|
||||
}
|
||||
|
||||
got, err := repo.Get(ctx, "test-task-update-2")
|
||||
if err != nil {
|
||||
t.Fatalf("Get() error = %v", err)
|
||||
}
|
||||
if got.Status != domain.BuildStatusFailed {
|
||||
t.Errorf("got status %q, want %q", got.Status, domain.BuildStatusFailed)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns error for nonexistent task", func(t *testing.T) {
|
||||
result := &domain.BuildResult{Success: true}
|
||||
err := repo.Update(ctx, "test-task-nonexistent", result)
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent task")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildAuditRepository_Get(t *testing.T) {
|
||||
db := testutil.TestDB(t)
|
||||
t.Cleanup(func() { cleanupTestBuildAudit(t, db) })
|
||||
|
||||
repo := NewBuildAuditRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("returns error for nonexistent entry", func(t *testing.T) {
|
||||
_, err := repo.Get(ctx, "test-task-nonexistent")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent entry")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildAuditRepository_List(t *testing.T) {
|
||||
db := testutil.TestDB(t)
|
||||
t.Cleanup(func() { cleanupTestBuildAudit(t, db) })
|
||||
|
||||
repo := NewBuildAuditRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create entries for different projects
|
||||
entries := []*domain.BuildAuditEntry{
|
||||
{TaskID: "test-task-list-1", ProjectID: "test-project-a", Spec: domain.BuildSpec{Prompt: "Build 1"}, Status: domain.BuildStatusCompleted, StartedAt: time.Now()},
|
||||
{TaskID: "test-task-list-2", ProjectID: "test-project-a", Spec: domain.BuildSpec{Prompt: "Build 2"}, Status: domain.BuildStatusFailed, StartedAt: time.Now()},
|
||||
{TaskID: "test-task-list-3", ProjectID: "test-project-b", Spec: domain.BuildSpec{Prompt: "Build 3"}, Status: domain.BuildStatusPending, StartedAt: time.Now()},
|
||||
}
|
||||
for _, e := range entries {
|
||||
if err := repo.Record(ctx, e); err != nil {
|
||||
t.Fatalf("Record() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("filters by project", func(t *testing.T) {
|
||||
got, err := repo.List(ctx, port.BuildAuditFilter{
|
||||
ProjectID: "test-project-a",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("List() error = %v", err)
|
||||
}
|
||||
if len(got) != 2 {
|
||||
t.Errorf("got %d entries, want 2", len(got))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filters by status", func(t *testing.T) {
|
||||
completed := domain.BuildStatusCompleted
|
||||
got, err := repo.List(ctx, port.BuildAuditFilter{
|
||||
Status: &completed,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("List() error = %v", err)
|
||||
}
|
||||
for _, e := range got {
|
||||
if e.Status != domain.BuildStatusCompleted {
|
||||
t.Errorf("got status %q, want only completed", e.Status)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("respects limit", func(t *testing.T) {
|
||||
got, err := repo.List(ctx, port.BuildAuditFilter{
|
||||
Limit: 1,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("List() error = %v", err)
|
||||
}
|
||||
if len(got) > 1 {
|
||||
t.Errorf("got %d entries, want at most 1", len(got))
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -56,7 +56,7 @@ func (s *CredentialStore) GetRequired(ctx context.Context, key string) (string,
|
||||
return "", err
|
||||
}
|
||||
if value == "" {
|
||||
return "", fmt.Errorf("credential %s not found", key)
|
||||
return "", domain.ErrCredentialNotFound
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
@ -88,7 +88,7 @@ func (s *CredentialStore) Delete(ctx context.Context, key string) error {
|
||||
|
||||
rows, _ := result.RowsAffected()
|
||||
if rows == 0 {
|
||||
return fmt.Errorf("credential %s not found", key)
|
||||
return domain.ErrCredentialNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -15,12 +15,21 @@ import (
|
||||
|
||||
// RateLimiter implements port.RateLimiter using PostgreSQL.
|
||||
type RateLimiter struct {
|
||||
db *sql.DB
|
||||
db *sql.DB
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a new PostgreSQL rate limiter.
|
||||
func NewRateLimiter(db *sql.DB) *RateLimiter {
|
||||
return &RateLimiter{db: db}
|
||||
return &RateLimiter{db: db, logger: slog.Default()}
|
||||
}
|
||||
|
||||
// WithLogger sets a custom logger for the rate limiter.
|
||||
func (r *RateLimiter) WithLogger(logger *slog.Logger) *RateLimiter {
|
||||
if logger != nil {
|
||||
r.logger = logger
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// Ensure RateLimiter implements port.RateLimiter at compile time.
|
||||
@ -224,7 +233,7 @@ func (r *RateLimiter) StartCleanupWorker(ctx context.Context, interval time.Dura
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := r.Cleanup(ctx); err != nil {
|
||||
slog.Error("rate limit cleanup failed", "error", err)
|
||||
r.logger.Error("rate limit cleanup failed", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -7,7 +7,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
@ -27,7 +26,7 @@ func NewWorkQueueRepository(db *sql.DB) *WorkQueueRepository {
|
||||
var _ port.WorkQueue = (*WorkQueueRepository)(nil)
|
||||
|
||||
// Enqueue adds a task to the queue.
|
||||
func (r *WorkQueueRepository) Enqueue(ctx context.Context, task *port.WorkTask) (string, error) {
|
||||
func (r *WorkQueueRepository) Enqueue(ctx context.Context, task *domain.WorkTask) (string, error) {
|
||||
specJSON, err := json.Marshal(task.Spec)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal task spec: %w", err)
|
||||
@ -48,10 +47,10 @@ func (r *WorkQueueRepository) Enqueue(ctx context.Context, task *port.WorkTask)
|
||||
}
|
||||
|
||||
// Dequeue atomically claims the next available task for a worker.
|
||||
func (r *WorkQueueRepository) Dequeue(ctx context.Context, workerID string) (*port.WorkTask, error) {
|
||||
func (r *WorkQueueRepository) Dequeue(ctx context.Context, workerID string) (*domain.WorkTask, error) {
|
||||
// Use a single UPDATE ... RETURNING with subquery for atomic claim
|
||||
// This avoids explicit transaction management while still being safe
|
||||
var task port.WorkTask
|
||||
var task domain.WorkTask
|
||||
var taskType string
|
||||
var specJSON []byte
|
||||
var status string
|
||||
@ -99,8 +98,8 @@ func (r *WorkQueueRepository) Dequeue(ctx context.Context, workerID string) (*po
|
||||
return nil, fmt.Errorf("dequeue work task: %w", err)
|
||||
}
|
||||
|
||||
task.Type = port.WorkTaskType(taskType)
|
||||
task.Status = port.WorkTaskStatus(status)
|
||||
task.Type = domain.WorkTaskType(taskType)
|
||||
task.Status = domain.WorkTaskStatus(status)
|
||||
|
||||
if callbackURL.Valid {
|
||||
task.CallbackURL = callbackURL.String
|
||||
@ -124,7 +123,7 @@ func (r *WorkQueueRepository) Dequeue(ctx context.Context, workerID string) (*po
|
||||
|
||||
// Parse result
|
||||
if len(resultJSON) > 0 {
|
||||
task.Result = &port.WorkResult{}
|
||||
task.Result = &domain.WorkResult{}
|
||||
if err := json.Unmarshal(resultJSON, task.Result); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal task result: %w", err)
|
||||
}
|
||||
@ -134,7 +133,7 @@ func (r *WorkQueueRepository) Dequeue(ctx context.Context, workerID string) (*po
|
||||
}
|
||||
|
||||
// Complete marks a task as successfully completed with results.
|
||||
func (r *WorkQueueRepository) Complete(ctx context.Context, taskID string, result *port.WorkResult) error {
|
||||
func (r *WorkQueueRepository) Complete(ctx context.Context, taskID string, result *domain.WorkResult) error {
|
||||
var resultJSON []byte
|
||||
var err error
|
||||
|
||||
@ -242,285 +241,3 @@ func (r *WorkQueueRepository) Cancel(ctx context.Context, taskID string) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTask retrieves a task by ID.
|
||||
func (r *WorkQueueRepository) GetTask(ctx context.Context, taskID string) (*port.WorkTask, error) {
|
||||
var task port.WorkTask
|
||||
var taskType string
|
||||
var specJSON []byte
|
||||
var status string
|
||||
var workerID sql.NullString
|
||||
var callbackURL sql.NullString
|
||||
var startedAt sql.NullTime
|
||||
var completedAt sql.NullTime
|
||||
var resultJSON []byte
|
||||
var errorMsg sql.NullString
|
||||
|
||||
err := r.db.QueryRowContext(ctx, `
|
||||
SELECT id, project_id, task_type, task_spec, status, priority, worker_id,
|
||||
callback_url, created_at, started_at, completed_at, result, error,
|
||||
retry_count, max_retries
|
||||
FROM work_queue
|
||||
WHERE id = $1
|
||||
`, taskID).Scan(
|
||||
&task.ID,
|
||||
&task.ProjectID,
|
||||
&taskType,
|
||||
&specJSON,
|
||||
&status,
|
||||
&task.Priority,
|
||||
&workerID,
|
||||
&callbackURL,
|
||||
&task.CreatedAt,
|
||||
&startedAt,
|
||||
&completedAt,
|
||||
&resultJSON,
|
||||
&errorMsg,
|
||||
&task.RetryCount,
|
||||
&task.MaxRetries,
|
||||
)
|
||||
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, domain.ErrWorkTaskNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get work task: %w", err)
|
||||
}
|
||||
|
||||
task.Type = port.WorkTaskType(taskType)
|
||||
task.Status = port.WorkTaskStatus(status)
|
||||
|
||||
if workerID.Valid {
|
||||
task.WorkerID = workerID.String
|
||||
}
|
||||
if callbackURL.Valid {
|
||||
task.CallbackURL = callbackURL.String
|
||||
}
|
||||
if startedAt.Valid {
|
||||
task.StartedAt = &startedAt.Time
|
||||
}
|
||||
if completedAt.Valid {
|
||||
task.CompletedAt = &completedAt.Time
|
||||
}
|
||||
if errorMsg.Valid {
|
||||
task.Error = errorMsg.String
|
||||
}
|
||||
|
||||
// Parse task spec
|
||||
if len(specJSON) > 0 {
|
||||
if err := json.Unmarshal(specJSON, &task.Spec); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal task spec: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse result
|
||||
if len(resultJSON) > 0 {
|
||||
task.Result = &port.WorkResult{}
|
||||
if err := json.Unmarshal(resultJSON, task.Result); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal task result: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &task, nil
|
||||
}
|
||||
|
||||
// ListByProject returns tasks for a project with optional status filter and pagination.
|
||||
func (r *WorkQueueRepository) ListByProject(ctx context.Context, projectID string, status *port.WorkTaskStatus, opts port.WorkListOptions) (*port.WorkListResult, error) {
|
||||
// Normalize pagination options
|
||||
opts.Normalize()
|
||||
|
||||
// Build base WHERE clause
|
||||
whereClause := "WHERE project_id = $1"
|
||||
args := []any{projectID}
|
||||
argNum := 2
|
||||
|
||||
if status != nil {
|
||||
whereClause += fmt.Sprintf(" AND status = $%d", argNum)
|
||||
args = append(args, string(*status))
|
||||
argNum++
|
||||
}
|
||||
|
||||
// Get total count for pagination metadata
|
||||
countQuery := "SELECT COUNT(*) FROM work_queue " + whereClause
|
||||
var total int64
|
||||
if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||
return nil, fmt.Errorf("count work tasks: %w", err)
|
||||
}
|
||||
|
||||
// Build paginated query
|
||||
query := fmt.Sprintf(`
|
||||
SELECT id, project_id, task_type, task_spec, status, priority, worker_id,
|
||||
callback_url, created_at, started_at, completed_at, result, error,
|
||||
retry_count, max_retries
|
||||
FROM work_queue
|
||||
%s
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $%d OFFSET $%d
|
||||
`, whereClause, argNum, argNum+1)
|
||||
args = append(args, opts.Limit, opts.Offset)
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list work tasks: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var tasks []*port.WorkTask
|
||||
for rows.Next() {
|
||||
task, err := r.scanTask(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tasks = append(tasks, task)
|
||||
}
|
||||
|
||||
return &port.WorkListResult{
|
||||
Tasks: tasks,
|
||||
Total: total,
|
||||
Limit: opts.Limit,
|
||||
Offset: opts.Offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetStats returns queue statistics.
|
||||
func (r *WorkQueueRepository) GetStats(ctx context.Context) (*port.WorkQueueStats, error) {
|
||||
var stats port.WorkQueueStats
|
||||
|
||||
err := r.db.QueryRowContext(ctx, `
|
||||
SELECT
|
||||
COUNT(*) FILTER (WHERE status = 'pending') as pending,
|
||||
COUNT(*) FILTER (WHERE status = 'running') as running,
|
||||
COUNT(*) FILTER (WHERE status = 'completed' AND completed_at > NOW() - INTERVAL '24 hours') as completed,
|
||||
COUNT(*) FILTER (WHERE status = 'failed' AND completed_at > NOW() - INTERVAL '24 hours') as failed,
|
||||
COUNT(*) FILTER (WHERE status = 'cancelled' AND completed_at > NOW() - INTERVAL '24 hours') as cancelled
|
||||
FROM work_queue
|
||||
`).Scan(
|
||||
&stats.Pending,
|
||||
&stats.Running,
|
||||
&stats.Completed,
|
||||
&stats.Failed,
|
||||
&stats.Cancelled,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get stats: %w", err)
|
||||
}
|
||||
|
||||
// Get oldest pending task age
|
||||
var oldestCreatedAt sql.NullTime
|
||||
err = r.db.QueryRowContext(ctx, `
|
||||
SELECT MIN(created_at) FROM work_queue WHERE status = 'pending'
|
||||
`).Scan(&oldestCreatedAt)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, fmt.Errorf("get oldest pending: %w", err)
|
||||
}
|
||||
if oldestCreatedAt.Valid {
|
||||
age := time.Since(oldestCreatedAt.Time)
|
||||
stats.OldestPending = &age
|
||||
}
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// CleanupOld removes completed/failed/cancelled tasks older than the specified duration.
|
||||
func (r *WorkQueueRepository) CleanupOld(ctx context.Context, olderThan time.Duration) (int64, error) {
|
||||
cutoff := time.Now().Add(-olderThan)
|
||||
result, err := r.db.ExecContext(ctx, `
|
||||
DELETE FROM work_queue
|
||||
WHERE status IN ('completed', 'failed', 'cancelled')
|
||||
AND completed_at < $1
|
||||
`, cutoff)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("cleanup old tasks: %w", err)
|
||||
}
|
||||
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
// RequeueStale re-queues tasks that have been running longer than the timeout.
|
||||
func (r *WorkQueueRepository) RequeueStale(ctx context.Context, timeout time.Duration) (int64, error) {
|
||||
cutoff := time.Now().Add(-timeout)
|
||||
result, err := r.db.ExecContext(ctx, `
|
||||
UPDATE work_queue
|
||||
SET status = 'pending', worker_id = NULL, started_at = NULL,
|
||||
retry_count = retry_count + 1, error = 'Worker timeout - task requeued'
|
||||
WHERE status = 'running'
|
||||
AND started_at < $1
|
||||
AND retry_count < max_retries
|
||||
`, cutoff)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("requeue stale tasks: %w", err)
|
||||
}
|
||||
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
// scanTask scans a single task row.
|
||||
func (r *WorkQueueRepository) scanTask(rows *sql.Rows) (*port.WorkTask, error) {
|
||||
var task port.WorkTask
|
||||
var taskType string
|
||||
var specJSON []byte
|
||||
var status string
|
||||
var workerID sql.NullString
|
||||
var callbackURL sql.NullString
|
||||
var startedAt sql.NullTime
|
||||
var completedAt sql.NullTime
|
||||
var resultJSON []byte
|
||||
var errorMsg sql.NullString
|
||||
|
||||
err := rows.Scan(
|
||||
&task.ID,
|
||||
&task.ProjectID,
|
||||
&taskType,
|
||||
&specJSON,
|
||||
&status,
|
||||
&task.Priority,
|
||||
&workerID,
|
||||
&callbackURL,
|
||||
&task.CreatedAt,
|
||||
&startedAt,
|
||||
&completedAt,
|
||||
&resultJSON,
|
||||
&errorMsg,
|
||||
&task.RetryCount,
|
||||
&task.MaxRetries,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("scan task: %w", err)
|
||||
}
|
||||
|
||||
task.Type = port.WorkTaskType(taskType)
|
||||
task.Status = port.WorkTaskStatus(status)
|
||||
|
||||
if workerID.Valid {
|
||||
task.WorkerID = workerID.String
|
||||
}
|
||||
if callbackURL.Valid {
|
||||
task.CallbackURL = callbackURL.String
|
||||
}
|
||||
if startedAt.Valid {
|
||||
task.StartedAt = &startedAt.Time
|
||||
}
|
||||
if completedAt.Valid {
|
||||
task.CompletedAt = &completedAt.Time
|
||||
}
|
||||
if errorMsg.Valid {
|
||||
task.Error = errorMsg.String
|
||||
}
|
||||
|
||||
// Parse task spec
|
||||
if len(specJSON) > 0 {
|
||||
if err := json.Unmarshal(specJSON, &task.Spec); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal task spec: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse result
|
||||
if len(resultJSON) > 0 {
|
||||
task.Result = &port.WorkResult{}
|
||||
if err := json.Unmarshal(resultJSON, task.Result); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal task result: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &task, nil
|
||||
}
|
||||
|
||||
294
internal/adapter/postgres/work_queue_queries.go
Normal file
294
internal/adapter/postgres/work_queue_queries.go
Normal file
@ -0,0 +1,294 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
)
|
||||
|
||||
// GetTask retrieves a task by ID.
|
||||
func (r *WorkQueueRepository) GetTask(ctx context.Context, taskID string) (*domain.WorkTask, error) {
|
||||
var task domain.WorkTask
|
||||
var taskType string
|
||||
var specJSON []byte
|
||||
var status string
|
||||
var workerID sql.NullString
|
||||
var callbackURL sql.NullString
|
||||
var startedAt sql.NullTime
|
||||
var completedAt sql.NullTime
|
||||
var resultJSON []byte
|
||||
var errorMsg sql.NullString
|
||||
|
||||
err := r.db.QueryRowContext(ctx, `
|
||||
SELECT id, project_id, task_type, task_spec, status, priority, worker_id,
|
||||
callback_url, created_at, started_at, completed_at, result, error,
|
||||
retry_count, max_retries
|
||||
FROM work_queue
|
||||
WHERE id = $1
|
||||
`, taskID).Scan(
|
||||
&task.ID,
|
||||
&task.ProjectID,
|
||||
&taskType,
|
||||
&specJSON,
|
||||
&status,
|
||||
&task.Priority,
|
||||
&workerID,
|
||||
&callbackURL,
|
||||
&task.CreatedAt,
|
||||
&startedAt,
|
||||
&completedAt,
|
||||
&resultJSON,
|
||||
&errorMsg,
|
||||
&task.RetryCount,
|
||||
&task.MaxRetries,
|
||||
)
|
||||
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, domain.ErrWorkTaskNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get work task: %w", err)
|
||||
}
|
||||
|
||||
task.Type = domain.WorkTaskType(taskType)
|
||||
task.Status = domain.WorkTaskStatus(status)
|
||||
|
||||
if workerID.Valid {
|
||||
task.WorkerID = workerID.String
|
||||
}
|
||||
if callbackURL.Valid {
|
||||
task.CallbackURL = callbackURL.String
|
||||
}
|
||||
if startedAt.Valid {
|
||||
task.StartedAt = &startedAt.Time
|
||||
}
|
||||
if completedAt.Valid {
|
||||
task.CompletedAt = &completedAt.Time
|
||||
}
|
||||
if errorMsg.Valid {
|
||||
task.Error = errorMsg.String
|
||||
}
|
||||
|
||||
// Parse task spec
|
||||
if len(specJSON) > 0 {
|
||||
if err := json.Unmarshal(specJSON, &task.Spec); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal task spec: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse result
|
||||
if len(resultJSON) > 0 {
|
||||
task.Result = &domain.WorkResult{}
|
||||
if err := json.Unmarshal(resultJSON, task.Result); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal task result: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &task, nil
|
||||
}
|
||||
|
||||
// ListByProject returns tasks for a project with optional status filter and pagination.
|
||||
func (r *WorkQueueRepository) ListByProject(ctx context.Context, projectID string, status *domain.WorkTaskStatus, opts domain.WorkListOptions) (*domain.WorkListResult, error) {
|
||||
// Normalize pagination options
|
||||
opts.Normalize()
|
||||
|
||||
// Build base WHERE clause
|
||||
whereClause := "WHERE project_id = $1"
|
||||
args := []any{projectID}
|
||||
argNum := 2
|
||||
|
||||
if status != nil {
|
||||
whereClause += fmt.Sprintf(" AND status = $%d", argNum)
|
||||
args = append(args, string(*status))
|
||||
argNum++
|
||||
}
|
||||
|
||||
// Get total count for pagination metadata
|
||||
countQuery := "SELECT COUNT(*) FROM work_queue " + whereClause
|
||||
var total int64
|
||||
if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||
return nil, fmt.Errorf("count work tasks: %w", err)
|
||||
}
|
||||
|
||||
// Build paginated query
|
||||
query := fmt.Sprintf(`
|
||||
SELECT id, project_id, task_type, task_spec, status, priority, worker_id,
|
||||
callback_url, created_at, started_at, completed_at, result, error,
|
||||
retry_count, max_retries
|
||||
FROM work_queue
|
||||
%s
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $%d OFFSET $%d
|
||||
`, whereClause, argNum, argNum+1)
|
||||
args = append(args, opts.Limit, opts.Offset)
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list work tasks: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var tasks []*domain.WorkTask
|
||||
for rows.Next() {
|
||||
task, err := r.scanTask(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tasks = append(tasks, task)
|
||||
}
|
||||
|
||||
return &domain.WorkListResult{
|
||||
Tasks: tasks,
|
||||
Total: total,
|
||||
Limit: opts.Limit,
|
||||
Offset: opts.Offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetStats returns queue statistics.
|
||||
func (r *WorkQueueRepository) GetStats(ctx context.Context) (*domain.WorkQueueStats, error) {
|
||||
var stats domain.WorkQueueStats
|
||||
|
||||
err := r.db.QueryRowContext(ctx, `
|
||||
SELECT
|
||||
COUNT(*) FILTER (WHERE status = 'pending') as pending,
|
||||
COUNT(*) FILTER (WHERE status = 'running') as running,
|
||||
COUNT(*) FILTER (WHERE status = 'completed' AND completed_at > NOW() - INTERVAL '24 hours') as completed,
|
||||
COUNT(*) FILTER (WHERE status = 'failed' AND completed_at > NOW() - INTERVAL '24 hours') as failed,
|
||||
COUNT(*) FILTER (WHERE status = 'cancelled' AND completed_at > NOW() - INTERVAL '24 hours') as cancelled
|
||||
FROM work_queue
|
||||
`).Scan(
|
||||
&stats.Pending,
|
||||
&stats.Running,
|
||||
&stats.Completed,
|
||||
&stats.Failed,
|
||||
&stats.Cancelled,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get stats: %w", err)
|
||||
}
|
||||
|
||||
// Get oldest pending task age
|
||||
var oldestCreatedAt sql.NullTime
|
||||
err = r.db.QueryRowContext(ctx, `
|
||||
SELECT MIN(created_at) FROM work_queue WHERE status = 'pending'
|
||||
`).Scan(&oldestCreatedAt)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, fmt.Errorf("get oldest pending: %w", err)
|
||||
}
|
||||
if oldestCreatedAt.Valid {
|
||||
age := time.Since(oldestCreatedAt.Time)
|
||||
stats.OldestPending = &age
|
||||
}
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// CleanupOld removes completed/failed/cancelled tasks older than the specified duration.
|
||||
func (r *WorkQueueRepository) CleanupOld(ctx context.Context, olderThan time.Duration) (int64, error) {
|
||||
cutoff := time.Now().Add(-olderThan)
|
||||
result, err := r.db.ExecContext(ctx, `
|
||||
DELETE FROM work_queue
|
||||
WHERE status IN ('completed', 'failed', 'cancelled')
|
||||
AND completed_at < $1
|
||||
`, cutoff)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("cleanup old tasks: %w", err)
|
||||
}
|
||||
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
// RequeueStale re-queues tasks that have been running longer than the timeout.
|
||||
func (r *WorkQueueRepository) RequeueStale(ctx context.Context, timeout time.Duration) (int64, error) {
|
||||
cutoff := time.Now().Add(-timeout)
|
||||
result, err := r.db.ExecContext(ctx, `
|
||||
UPDATE work_queue
|
||||
SET status = 'pending', worker_id = NULL, started_at = NULL,
|
||||
retry_count = retry_count + 1, error = 'Worker timeout - task requeued'
|
||||
WHERE status = 'running'
|
||||
AND started_at < $1
|
||||
AND retry_count < max_retries
|
||||
`, cutoff)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("requeue stale tasks: %w", err)
|
||||
}
|
||||
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
// scanTask scans a single task row.
|
||||
func (r *WorkQueueRepository) scanTask(rows *sql.Rows) (*domain.WorkTask, error) {
|
||||
var task domain.WorkTask
|
||||
var taskType string
|
||||
var specJSON []byte
|
||||
var status string
|
||||
var workerID sql.NullString
|
||||
var callbackURL sql.NullString
|
||||
var startedAt sql.NullTime
|
||||
var completedAt sql.NullTime
|
||||
var resultJSON []byte
|
||||
var errorMsg sql.NullString
|
||||
|
||||
err := rows.Scan(
|
||||
&task.ID,
|
||||
&task.ProjectID,
|
||||
&taskType,
|
||||
&specJSON,
|
||||
&status,
|
||||
&task.Priority,
|
||||
&workerID,
|
||||
&callbackURL,
|
||||
&task.CreatedAt,
|
||||
&startedAt,
|
||||
&completedAt,
|
||||
&resultJSON,
|
||||
&errorMsg,
|
||||
&task.RetryCount,
|
||||
&task.MaxRetries,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("scan task: %w", err)
|
||||
}
|
||||
|
||||
task.Type = domain.WorkTaskType(taskType)
|
||||
task.Status = domain.WorkTaskStatus(status)
|
||||
|
||||
if workerID.Valid {
|
||||
task.WorkerID = workerID.String
|
||||
}
|
||||
if callbackURL.Valid {
|
||||
task.CallbackURL = callbackURL.String
|
||||
}
|
||||
if startedAt.Valid {
|
||||
task.StartedAt = &startedAt.Time
|
||||
}
|
||||
if completedAt.Valid {
|
||||
task.CompletedAt = &completedAt.Time
|
||||
}
|
||||
if errorMsg.Valid {
|
||||
task.Error = errorMsg.String
|
||||
}
|
||||
|
||||
// Parse task spec
|
||||
if len(specJSON) > 0 {
|
||||
if err := json.Unmarshal(specJSON, &task.Spec); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal task spec: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse result
|
||||
if len(resultJSON) > 0 {
|
||||
task.Result = &domain.WorkResult{}
|
||||
if err := json.Unmarshal(resultJSON, task.Result); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal task result: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &task, nil
|
||||
}
|
||||
244
internal/adapter/postgres/worker_registry.go
Normal file
244
internal/adapter/postgres/worker_registry.go
Normal file
@ -0,0 +1,244 @@
|
||||
// Package postgres provides PostgreSQL-based implementations of port interfaces.
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
)
|
||||
|
||||
// WorkerRegistryRepository implements port.WorkerRegistry using PostgreSQL.
|
||||
type WorkerRegistryRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewWorkerRegistryRepository creates a new PostgreSQL worker registry.
|
||||
func NewWorkerRegistryRepository(db *sql.DB) *WorkerRegistryRepository {
|
||||
return &WorkerRegistryRepository{db: db}
|
||||
}
|
||||
|
||||
// Ensure WorkerRegistryRepository implements port.WorkerRegistry at compile time.
|
||||
var _ port.WorkerRegistry = (*WorkerRegistryRepository)(nil)
|
||||
|
||||
// Register adds a worker to the pool.
|
||||
// If a worker with the same ID already exists, it is re-registered as idle.
|
||||
func (r *WorkerRegistryRepository) Register(ctx context.Context, worker *domain.Worker) error {
|
||||
capsJSON, err := json.Marshal(worker.Capabilities)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal capabilities: %w", err)
|
||||
}
|
||||
|
||||
_, err = r.db.ExecContext(ctx, `
|
||||
INSERT INTO workers (id, hostname, status, capabilities, version, registered_at, last_heartbeat)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $6)
|
||||
ON CONFLICT (id) DO UPDATE SET
|
||||
hostname = EXCLUDED.hostname,
|
||||
status = 'idle',
|
||||
current_task = NULL,
|
||||
capabilities = EXCLUDED.capabilities,
|
||||
version = EXCLUDED.version,
|
||||
last_heartbeat = EXCLUDED.last_heartbeat
|
||||
`, worker.ID, worker.Hostname, domain.WorkerStatusIdle, capsJSON,
|
||||
nullString(worker.Version), time.Now())
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("register worker: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Heartbeat updates the worker's last_heartbeat timestamp.
|
||||
func (r *WorkerRegistryRepository) Heartbeat(ctx context.Context, workerID string) error {
|
||||
result, err := r.db.ExecContext(ctx, `
|
||||
UPDATE workers SET last_heartbeat = NOW()
|
||||
WHERE id = $1 AND status != 'offline'
|
||||
`, workerID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("heartbeat worker: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("rows affected: %w", err)
|
||||
}
|
||||
if rows == 0 {
|
||||
return domain.ErrWorkerNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateStatus changes a worker's status and optionally assigns a task.
|
||||
func (r *WorkerRegistryRepository) UpdateStatus(ctx context.Context, workerID string, status domain.WorkerStatus, taskID string) error {
|
||||
var currentTask sql.NullString
|
||||
if taskID != "" {
|
||||
currentTask = sql.NullString{String: taskID, Valid: true}
|
||||
}
|
||||
|
||||
result, err := r.db.ExecContext(ctx, `
|
||||
UPDATE workers SET status = $2, current_task = $3, last_heartbeat = NOW()
|
||||
WHERE id = $1
|
||||
`, workerID, status, currentTask)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update worker status: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("rows affected: %w", err)
|
||||
}
|
||||
if rows == 0 {
|
||||
return domain.ErrWorkerNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Deregister removes a worker from the pool.
|
||||
func (r *WorkerRegistryRepository) Deregister(ctx context.Context, workerID string) error {
|
||||
result, err := r.db.ExecContext(ctx, `DELETE FROM workers WHERE id = $1`, workerID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("deregister worker: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("rows affected: %w", err)
|
||||
}
|
||||
if rows == 0 {
|
||||
return domain.ErrWorkerNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves a specific worker by ID.
|
||||
func (r *WorkerRegistryRepository) Get(ctx context.Context, workerID string) (*domain.Worker, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, hostname, status, current_task, capabilities, version,
|
||||
registered_at, last_heartbeat
|
||||
FROM workers
|
||||
WHERE id = $1
|
||||
`, workerID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get worker: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
if !rows.Next() {
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("get worker: %w", err)
|
||||
}
|
||||
return nil, domain.ErrWorkerNotFound
|
||||
}
|
||||
return r.scanWorker(rows)
|
||||
}
|
||||
|
||||
// List returns all workers matching the filter.
|
||||
func (r *WorkerRegistryRepository) List(ctx context.Context, filter port.WorkerFilter) ([]*domain.Worker, error) {
|
||||
query := `
|
||||
SELECT id, hostname, status, current_task, capabilities, version,
|
||||
registered_at, last_heartbeat
|
||||
FROM workers
|
||||
WHERE 1=1`
|
||||
args := []any{}
|
||||
argNum := 1
|
||||
|
||||
if filter.Status != nil {
|
||||
query += fmt.Sprintf(" AND status = $%d", argNum)
|
||||
args = append(args, string(*filter.Status))
|
||||
argNum++
|
||||
}
|
||||
|
||||
if filter.HasCapability != "" {
|
||||
query += fmt.Sprintf(" AND capabilities @> $%d::jsonb", argNum)
|
||||
capJSON, _ := json.Marshal([]string{filter.HasCapability})
|
||||
args = append(args, string(capJSON))
|
||||
argNum++
|
||||
}
|
||||
|
||||
query += " ORDER BY registered_at ASC"
|
||||
|
||||
if filter.Limit > 0 {
|
||||
query += fmt.Sprintf(" LIMIT $%d", argNum)
|
||||
args = append(args, filter.Limit)
|
||||
}
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list workers: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var workers []*domain.Worker
|
||||
for rows.Next() {
|
||||
w, err := r.scanWorker(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
workers = append(workers, w)
|
||||
}
|
||||
|
||||
return workers, rows.Err()
|
||||
}
|
||||
|
||||
// MarkStaleOffline marks workers without a recent heartbeat as offline.
|
||||
func (r *WorkerRegistryRepository) MarkStaleOffline(ctx context.Context, threshold time.Duration) (int, error) {
|
||||
cutoff := time.Now().Add(-threshold)
|
||||
result, err := r.db.ExecContext(ctx, `
|
||||
UPDATE workers SET status = 'offline', current_task = NULL
|
||||
WHERE status != 'offline' AND last_heartbeat < $1
|
||||
`, cutoff)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("mark stale workers offline: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("rows affected: %w", err)
|
||||
}
|
||||
|
||||
return int(rows), nil
|
||||
}
|
||||
|
||||
// scanWorker scans a single worker row from a query result.
|
||||
func (r *WorkerRegistryRepository) scanWorker(rows *sql.Rows) (*domain.Worker, error) {
|
||||
var w domain.Worker
|
||||
var currentTask sql.NullString
|
||||
var capsJSON []byte
|
||||
var version sql.NullString
|
||||
|
||||
err := rows.Scan(
|
||||
&w.ID,
|
||||
&w.Hostname,
|
||||
&w.Status,
|
||||
¤tTask,
|
||||
&capsJSON,
|
||||
&version,
|
||||
&w.RegisteredAt,
|
||||
&w.LastHeartbeat,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("scan worker: %w", err)
|
||||
}
|
||||
|
||||
if currentTask.Valid {
|
||||
w.CurrentTask = currentTask.String
|
||||
}
|
||||
if version.Valid {
|
||||
w.Version = version.String
|
||||
}
|
||||
if len(capsJSON) > 0 {
|
||||
if err := json.Unmarshal(capsJSON, &w.Capabilities); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal capabilities: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &w, nil
|
||||
}
|
||||
321
internal/adapter/postgres/worker_registry_test.go
Normal file
321
internal/adapter/postgres/worker_registry_test.go
Normal file
@ -0,0 +1,321 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
"github.com/orchard9/rdev/internal/testutil"
|
||||
)
|
||||
|
||||
func cleanupTestWorkers(t *testing.T, db *sql.DB) {
|
||||
t.Helper()
|
||||
_, err := db.Exec("DELETE FROM workers WHERE id LIKE 'test-%'")
|
||||
if err != nil {
|
||||
t.Logf("cleanup test workers: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorkerRegistryRepository_Register(t *testing.T) {
|
||||
db := testutil.TestDB(t)
|
||||
t.Cleanup(func() { cleanupTestWorkers(t, db) })
|
||||
|
||||
repo := NewWorkerRegistryRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("registers new worker", func(t *testing.T) {
|
||||
worker := &domain.Worker{
|
||||
ID: "test-worker-reg-1",
|
||||
Hostname: "host-1",
|
||||
Capabilities: []string{"build", "test"},
|
||||
Version: "1.0.0",
|
||||
}
|
||||
|
||||
err := repo.Register(ctx, worker)
|
||||
if err != nil {
|
||||
t.Fatalf("Register() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify worker was stored
|
||||
got, err := repo.Get(ctx, "test-worker-reg-1")
|
||||
if err != nil {
|
||||
t.Fatalf("Get() error = %v", err)
|
||||
}
|
||||
if got.Hostname != "host-1" {
|
||||
t.Errorf("got hostname %q, want %q", got.Hostname, "host-1")
|
||||
}
|
||||
if got.Status != domain.WorkerStatusIdle {
|
||||
t.Errorf("got status %q, want %q", got.Status, domain.WorkerStatusIdle)
|
||||
}
|
||||
if got.Version != "1.0.0" {
|
||||
t.Errorf("got version %q, want %q", got.Version, "1.0.0")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("re-registers existing worker", func(t *testing.T) {
|
||||
worker := &domain.Worker{
|
||||
ID: "test-worker-reg-2",
|
||||
Hostname: "host-2-old",
|
||||
Version: "1.0.0",
|
||||
}
|
||||
if err := repo.Register(ctx, worker); err != nil {
|
||||
t.Fatalf("Register() error = %v", err)
|
||||
}
|
||||
|
||||
// Update hostname via re-registration
|
||||
worker.Hostname = "host-2-new"
|
||||
worker.Version = "2.0.0"
|
||||
if err := repo.Register(ctx, worker); err != nil {
|
||||
t.Fatalf("Register() re-register error = %v", err)
|
||||
}
|
||||
|
||||
got, err := repo.Get(ctx, "test-worker-reg-2")
|
||||
if err != nil {
|
||||
t.Fatalf("Get() error = %v", err)
|
||||
}
|
||||
if got.Hostname != "host-2-new" {
|
||||
t.Errorf("got hostname %q, want %q", got.Hostname, "host-2-new")
|
||||
}
|
||||
if got.Status != domain.WorkerStatusIdle {
|
||||
t.Errorf("got status %q, want %q (should reset on re-register)", got.Status, domain.WorkerStatusIdle)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWorkerRegistryRepository_Heartbeat(t *testing.T) {
|
||||
db := testutil.TestDB(t)
|
||||
t.Cleanup(func() { cleanupTestWorkers(t, db) })
|
||||
|
||||
repo := NewWorkerRegistryRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("updates heartbeat", func(t *testing.T) {
|
||||
worker := &domain.Worker{
|
||||
ID: "test-worker-hb-1",
|
||||
Hostname: "host-1",
|
||||
}
|
||||
if err := repo.Register(ctx, worker); err != nil {
|
||||
t.Fatalf("Register() error = %v", err)
|
||||
}
|
||||
|
||||
// Wait a moment so heartbeat time differs
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
if err := repo.Heartbeat(ctx, "test-worker-hb-1"); err != nil {
|
||||
t.Fatalf("Heartbeat() error = %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns error for nonexistent worker", func(t *testing.T) {
|
||||
err := repo.Heartbeat(ctx, "test-worker-nonexistent")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent worker")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWorkerRegistryRepository_UpdateStatus(t *testing.T) {
|
||||
db := testutil.TestDB(t)
|
||||
t.Cleanup(func() { cleanupTestWorkers(t, db) })
|
||||
|
||||
repo := NewWorkerRegistryRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Register a worker
|
||||
worker := &domain.Worker{
|
||||
ID: "test-worker-status-1",
|
||||
Hostname: "host-1",
|
||||
}
|
||||
if err := repo.Register(ctx, worker); err != nil {
|
||||
t.Fatalf("Register() error = %v", err)
|
||||
}
|
||||
|
||||
t.Run("updates to busy with task", func(t *testing.T) {
|
||||
err := repo.UpdateStatus(ctx, "test-worker-status-1", domain.WorkerStatusBusy, "task-123")
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateStatus() error = %v", err)
|
||||
}
|
||||
|
||||
got, err := repo.Get(ctx, "test-worker-status-1")
|
||||
if err != nil {
|
||||
t.Fatalf("Get() error = %v", err)
|
||||
}
|
||||
if got.Status != domain.WorkerStatusBusy {
|
||||
t.Errorf("got status %q, want %q", got.Status, domain.WorkerStatusBusy)
|
||||
}
|
||||
if got.CurrentTask != "task-123" {
|
||||
t.Errorf("got current_task %q, want %q", got.CurrentTask, "task-123")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("updates to idle clearing task", func(t *testing.T) {
|
||||
err := repo.UpdateStatus(ctx, "test-worker-status-1", domain.WorkerStatusIdle, "")
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateStatus() error = %v", err)
|
||||
}
|
||||
|
||||
got, err := repo.Get(ctx, "test-worker-status-1")
|
||||
if err != nil {
|
||||
t.Fatalf("Get() error = %v", err)
|
||||
}
|
||||
if got.Status != domain.WorkerStatusIdle {
|
||||
t.Errorf("got status %q, want %q", got.Status, domain.WorkerStatusIdle)
|
||||
}
|
||||
if got.CurrentTask != "" {
|
||||
t.Errorf("got current_task %q, want empty", got.CurrentTask)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns error for nonexistent worker", func(t *testing.T) {
|
||||
err := repo.UpdateStatus(ctx, "test-worker-nonexistent", domain.WorkerStatusBusy, "")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent worker")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWorkerRegistryRepository_Deregister(t *testing.T) {
|
||||
db := testutil.TestDB(t)
|
||||
t.Cleanup(func() { cleanupTestWorkers(t, db) })
|
||||
|
||||
repo := NewWorkerRegistryRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("deregisters existing worker", func(t *testing.T) {
|
||||
worker := &domain.Worker{
|
||||
ID: "test-worker-dereg-1",
|
||||
Hostname: "host-1",
|
||||
}
|
||||
if err := repo.Register(ctx, worker); err != nil {
|
||||
t.Fatalf("Register() error = %v", err)
|
||||
}
|
||||
|
||||
err := repo.Deregister(ctx, "test-worker-dereg-1")
|
||||
if err != nil {
|
||||
t.Fatalf("Deregister() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify worker was removed
|
||||
_, err = repo.Get(ctx, "test-worker-dereg-1")
|
||||
if err == nil {
|
||||
t.Error("expected error for deregistered worker")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns error for nonexistent worker", func(t *testing.T) {
|
||||
err := repo.Deregister(ctx, "test-worker-nonexistent")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent worker")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWorkerRegistryRepository_List(t *testing.T) {
|
||||
db := testutil.TestDB(t)
|
||||
t.Cleanup(func() { cleanupTestWorkers(t, db) })
|
||||
|
||||
repo := NewWorkerRegistryRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Register workers
|
||||
workers := []*domain.Worker{
|
||||
{ID: "test-worker-list-1", Hostname: "host-1"},
|
||||
{ID: "test-worker-list-2", Hostname: "host-2"},
|
||||
{ID: "test-worker-list-3", Hostname: "host-3"},
|
||||
}
|
||||
for _, w := range workers {
|
||||
if err := repo.Register(ctx, w); err != nil {
|
||||
t.Fatalf("Register() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Make one busy
|
||||
if err := repo.UpdateStatus(ctx, "test-worker-list-2", domain.WorkerStatusBusy, "task-1"); err != nil {
|
||||
t.Fatalf("UpdateStatus() error = %v", err)
|
||||
}
|
||||
|
||||
t.Run("lists all workers", func(t *testing.T) {
|
||||
got, err := repo.List(ctx, port.WorkerFilter{})
|
||||
if err != nil {
|
||||
t.Fatalf("List() error = %v", err)
|
||||
}
|
||||
// Filter to just our test workers
|
||||
count := 0
|
||||
for _, w := range got {
|
||||
if w.ID == "test-worker-list-1" || w.ID == "test-worker-list-2" || w.ID == "test-worker-list-3" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
if count < 3 {
|
||||
t.Errorf("expected at least 3 test workers, got %d", count)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filters by status", func(t *testing.T) {
|
||||
idle := domain.WorkerStatusIdle
|
||||
got, err := repo.List(ctx, port.WorkerFilter{Status: &idle})
|
||||
if err != nil {
|
||||
t.Fatalf("List() error = %v", err)
|
||||
}
|
||||
for _, w := range got {
|
||||
if w.Status != domain.WorkerStatusIdle {
|
||||
t.Errorf("got worker with status %q, want only idle", w.Status)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("respects limit", func(t *testing.T) {
|
||||
got, err := repo.List(ctx, port.WorkerFilter{Limit: 1})
|
||||
if err != nil {
|
||||
t.Fatalf("List() error = %v", err)
|
||||
}
|
||||
if len(got) > 1 {
|
||||
t.Errorf("got %d workers, want at most 1", len(got))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWorkerRegistryRepository_MarkStaleOffline(t *testing.T) {
|
||||
db := testutil.TestDB(t)
|
||||
t.Cleanup(func() { cleanupTestWorkers(t, db) })
|
||||
|
||||
repo := NewWorkerRegistryRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Register a worker
|
||||
worker := &domain.Worker{
|
||||
ID: "test-worker-stale-1",
|
||||
Hostname: "host-1",
|
||||
}
|
||||
if err := repo.Register(ctx, worker); err != nil {
|
||||
t.Fatalf("Register() error = %v", err)
|
||||
}
|
||||
|
||||
t.Run("marks stale worker offline", func(t *testing.T) {
|
||||
// Set heartbeat to past
|
||||
_, err := db.Exec("UPDATE workers SET last_heartbeat = $1 WHERE id = $2",
|
||||
time.Now().Add(-5*time.Minute), "test-worker-stale-1")
|
||||
if err != nil {
|
||||
t.Fatalf("set heartbeat: %v", err)
|
||||
}
|
||||
|
||||
count, err := repo.MarkStaleOffline(ctx, 90*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("MarkStaleOffline() error = %v", err)
|
||||
}
|
||||
if count < 1 {
|
||||
t.Errorf("expected at least 1 worker marked offline, got %d", count)
|
||||
}
|
||||
|
||||
got, err := repo.Get(ctx, "test-worker-stale-1")
|
||||
if err != nil {
|
||||
t.Fatalf("Get() error = %v", err)
|
||||
}
|
||||
if got.Status != domain.WorkerStatusOffline {
|
||||
t.Errorf("got status %q, want %q", got.Status, domain.WorkerStatusOffline)
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -287,6 +287,79 @@ func (c *Client) DeleteSecret(ctx context.Context, owner, repo, secretName strin
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListPipelines returns recent CI pipeline executions for a repository.
|
||||
func (c *Client) ListPipelines(ctx context.Context, owner, repo string) ([]*domain.CIPipeline, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
fullName := owner + "/" + repo
|
||||
|
||||
r, err := c.client.RepoLookup(fullName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("repo not found: %s", fullName)
|
||||
}
|
||||
|
||||
pipelines, err := c.client.PipelineList(r.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list pipelines: %w", err)
|
||||
}
|
||||
|
||||
result := make([]*domain.CIPipeline, len(pipelines))
|
||||
for i, p := range pipelines {
|
||||
result[i] = pipelineFromWoodpecker(p)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetPipeline returns a specific pipeline execution by number.
|
||||
func (c *Client) GetPipeline(ctx context.Context, owner, repo string, number int64) (*domain.CIPipeline, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
fullName := owner + "/" + repo
|
||||
|
||||
r, err := c.client.RepoLookup(fullName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("repo not found: %s", fullName)
|
||||
}
|
||||
|
||||
p, err := c.client.Pipeline(r.ID, number)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pipeline %d not found: %w", number, err)
|
||||
}
|
||||
|
||||
return pipelineFromWoodpecker(p), nil
|
||||
}
|
||||
|
||||
// pipelineFromWoodpecker converts a woodpecker.Pipeline to domain.CIPipeline.
|
||||
func pipelineFromWoodpecker(p *woodpecker.Pipeline) *domain.CIPipeline {
|
||||
var started, finished time.Time
|
||||
if p.Started > 0 {
|
||||
started = time.Unix(p.Started, 0)
|
||||
}
|
||||
if p.Finished > 0 {
|
||||
finished = time.Unix(p.Finished, 0)
|
||||
}
|
||||
return &domain.CIPipeline{
|
||||
ID: p.ID,
|
||||
Number: p.Number,
|
||||
Status: p.Status,
|
||||
Event: p.Event,
|
||||
Branch: p.Branch,
|
||||
Commit: p.Commit,
|
||||
Message: p.Message,
|
||||
Author: p.Author,
|
||||
Started: started,
|
||||
Finished: finished,
|
||||
}
|
||||
}
|
||||
|
||||
// repoFromWoodpecker converts a woodpecker.Repo to domain.CIRepo.
|
||||
func repoFromWoodpecker(r *woodpecker.Repo) *domain.CIRepo {
|
||||
// Parse forge remote ID (string in SDK, int64 in our domain)
|
||||
|
||||
@ -1,116 +1,71 @@
|
||||
package auth
|
||||
|
||||
import "slices"
|
||||
import "github.com/orchard9/rdev/internal/domain"
|
||||
|
||||
// Scope represents an API permission scope.
|
||||
type Scope string
|
||||
// Scope is an alias for domain.Scope.
|
||||
// All scope constants, helpers, and validation live in domain/apikey.go.
|
||||
type Scope = domain.Scope
|
||||
|
||||
// Available scopes.
|
||||
// Re-exported scope constants for backward compatibility.
|
||||
// Consumers should migrate to domain.ScopeXxx over time.
|
||||
const (
|
||||
ScopeProjectsRead Scope = "projects:read"
|
||||
ScopeProjectsExecute Scope = "projects:execute"
|
||||
ScopeKeysRead Scope = "keys:read"
|
||||
ScopeKeysWrite Scope = "keys:write"
|
||||
ScopeAuditRead Scope = "audit:read"
|
||||
ScopeQueueRead Scope = "queue:read"
|
||||
ScopeQueueWrite Scope = "queue:write"
|
||||
ScopeWebhookRead Scope = "webhook:read"
|
||||
ScopeWebhookWrite Scope = "webhook:write"
|
||||
ScopeAdmin Scope = "admin"
|
||||
ScopeProjectsRead = domain.ScopeProjectsRead
|
||||
ScopeProjectsExecute = domain.ScopeProjectsExecute
|
||||
ScopeKeysRead = domain.ScopeKeysRead
|
||||
ScopeKeysWrite = domain.ScopeKeysWrite
|
||||
ScopeAuditRead = domain.ScopeAuditRead
|
||||
ScopeQueueRead = domain.ScopeQueueRead
|
||||
ScopeQueueWrite = domain.ScopeQueueWrite
|
||||
ScopeWebhookRead = domain.ScopeWebhookRead
|
||||
ScopeWebhookWrite = domain.ScopeWebhookWrite
|
||||
ScopeWorkersRead = domain.ScopeWorkersRead
|
||||
ScopeWorkersWrite = domain.ScopeWorkersWrite
|
||||
ScopeBuildRead = domain.ScopeBuildRead
|
||||
ScopeBuildWrite = domain.ScopeBuildWrite
|
||||
ScopeAdmin = domain.ScopeAdmin
|
||||
)
|
||||
|
||||
// AllScopes is the list of all valid scopes.
|
||||
var AllScopes = []Scope{
|
||||
ScopeProjectsRead,
|
||||
ScopeProjectsExecute,
|
||||
ScopeKeysRead,
|
||||
ScopeKeysWrite,
|
||||
ScopeAuditRead,
|
||||
ScopeQueueRead,
|
||||
ScopeQueueWrite,
|
||||
ScopeWebhookRead,
|
||||
ScopeWebhookWrite,
|
||||
ScopeAdmin,
|
||||
}
|
||||
|
||||
// ScopeDescriptions provides human-readable descriptions.
|
||||
var ScopeDescriptions = map[Scope]string{
|
||||
ScopeProjectsRead: "List and view project details",
|
||||
ScopeProjectsExecute: "Execute commands (claude, shell, git) on projects",
|
||||
ScopeKeysRead: "List API keys (metadata only, not secrets)",
|
||||
ScopeKeysWrite: "Create and revoke API keys",
|
||||
ScopeAuditRead: "View audit logs for command executions",
|
||||
ScopeQueueRead: "View queued commands and queue status",
|
||||
ScopeQueueWrite: "Enqueue and cancel queued commands",
|
||||
ScopeWebhookRead: "View webhooks and delivery history",
|
||||
ScopeWebhookWrite: "Create, update, and delete webhooks",
|
||||
ScopeAdmin: "Full administrative access (includes all scopes)",
|
||||
}
|
||||
|
||||
// IsValid checks if a scope is valid.
|
||||
func (s Scope) IsValid() bool {
|
||||
return slices.Contains(AllScopes, s)
|
||||
}
|
||||
|
||||
// String returns the scope as a string.
|
||||
func (s Scope) String() string {
|
||||
return string(s)
|
||||
}
|
||||
// Re-exported scope helpers for backward compatibility.
|
||||
var (
|
||||
AllScopes = domain.AllScopes
|
||||
ScopeDescriptions = domain.ScopeDescriptions
|
||||
)
|
||||
|
||||
// ScopesFromStrings converts string slice to Scope slice.
|
||||
func ScopesFromStrings(ss []string) []Scope {
|
||||
scopes := make([]Scope, len(ss))
|
||||
for i, s := range ss {
|
||||
scopes[i] = Scope(s)
|
||||
}
|
||||
return scopes
|
||||
return domain.ScopesFromStrings(ss)
|
||||
}
|
||||
|
||||
// ScopesToStrings converts Scope slice to string slice.
|
||||
func ScopesToStrings(scopes []Scope) []string {
|
||||
ss := make([]string, len(scopes))
|
||||
for i, s := range scopes {
|
||||
ss[i] = string(s)
|
||||
}
|
||||
return ss
|
||||
return domain.ScopesToStrings(scopes)
|
||||
}
|
||||
|
||||
// ValidateScopes checks if all scopes are valid.
|
||||
func ValidateScopes(scopes []Scope) bool {
|
||||
for _, s := range scopes {
|
||||
if !s.IsValid() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
return domain.ValidateScopes(scopes)
|
||||
}
|
||||
|
||||
// HasScope checks if a scope list contains a required scope.
|
||||
// Admin scope grants access to everything.
|
||||
func HasScope(scopes []Scope, required Scope) bool {
|
||||
for _, s := range scopes {
|
||||
if s == ScopeAdmin || s == required {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
return domain.HasScope(scopes, required)
|
||||
}
|
||||
|
||||
// HasAnyScope checks if a scope list contains any of the required scopes.
|
||||
func HasAnyScope(scopes []Scope, required ...Scope) bool {
|
||||
for _, r := range required {
|
||||
if HasScope(scopes, r) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
return domain.HasAnyScope(scopes, required...)
|
||||
}
|
||||
|
||||
// HasProjectAccess checks if the key has access to a specific project.
|
||||
// projectIDs nil means access to all projects.
|
||||
func HasProjectAccess(allowedProjects []string, projectID string) bool {
|
||||
if allowedProjects == nil {
|
||||
return true // nil = all projects
|
||||
return true
|
||||
}
|
||||
return slices.Contains(allowedProjects, projectID)
|
||||
for _, p := range allowedProjects {
|
||||
if p == projectID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@ -2,86 +2,26 @@ package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/service"
|
||||
)
|
||||
|
||||
// Common errors.
|
||||
// APIKey is an alias for domain.APIKey.
|
||||
// All API key behavior (IsExpired, IsRevoked, etc.) lives in domain/apikey.go.
|
||||
type APIKey = domain.APIKey
|
||||
|
||||
// Error sentinels — delegate to domain errors.
|
||||
// Consumers should migrate to domain.ErrXxx over time.
|
||||
var (
|
||||
ErrKeyNotFound = errors.New("api key not found")
|
||||
ErrKeyRevoked = errors.New("api key has been revoked")
|
||||
ErrKeyExpired = errors.New("api key has expired")
|
||||
ErrIPNotAllowed = errors.New("ip address not allowed")
|
||||
ErrKeyNotFound = domain.ErrKeyNotFound
|
||||
ErrKeyRevoked = domain.ErrKeyRevoked
|
||||
ErrKeyExpired = domain.ErrKeyExpired
|
||||
ErrIPNotAllowed = domain.ErrIPNotAllowed
|
||||
)
|
||||
|
||||
// APIKey represents a stored API key.
|
||||
type APIKey struct {
|
||||
ID string
|
||||
Name string
|
||||
KeyPrefix string
|
||||
Scopes []Scope
|
||||
ProjectIDs []string // nil = all projects
|
||||
AllowedIPs []string // CIDR notation, e.g., ["192.168.1.0/24"]; nil = no restriction
|
||||
CreatedAt time.Time
|
||||
ExpiresAt *time.Time
|
||||
LastUsedAt *time.Time
|
||||
RevokedAt *time.Time
|
||||
CreatedBy string
|
||||
}
|
||||
|
||||
// IsExpired checks if the key has expired.
|
||||
func (k *APIKey) IsExpired() bool {
|
||||
if k.ExpiresAt == nil {
|
||||
return false
|
||||
}
|
||||
return time.Now().After(*k.ExpiresAt)
|
||||
}
|
||||
|
||||
// IsRevoked checks if the key has been revoked.
|
||||
func (k *APIKey) IsRevoked() bool {
|
||||
return k.RevokedAt != nil
|
||||
}
|
||||
|
||||
// IsActive checks if the key is valid for use.
|
||||
func (k *APIKey) IsActive() bool {
|
||||
return !k.IsRevoked() && !k.IsExpired()
|
||||
}
|
||||
|
||||
// IsIPAllowed checks if the given IP address is allowed by the key's IP restrictions.
|
||||
// Returns true if no IP restrictions are set or if the IP matches any allowed CIDR.
|
||||
func (k *APIKey) IsIPAllowed(clientIP string) bool {
|
||||
// No restrictions means all IPs are allowed
|
||||
if len(k.AllowedIPs) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
ip := net.ParseIP(clientIP)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, cidr := range k.AllowedIPs {
|
||||
_, network, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
// If not a CIDR, try parsing as single IP
|
||||
allowedIP := net.ParseIP(cidr)
|
||||
if allowedIP != nil && allowedIP.Equal(ip) {
|
||||
return true
|
||||
}
|
||||
continue
|
||||
}
|
||||
if network.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// CreateKeyRequest is the input for creating a new key.
|
||||
type CreateKeyRequest struct {
|
||||
Name string
|
||||
@ -99,15 +39,18 @@ type CreateKeyResponse struct {
|
||||
}
|
||||
|
||||
// Service handles API key operations.
|
||||
// It wraps service.APIKeyService to provide the same interface as before
|
||||
// while delegating to the hexagonal service layer.
|
||||
type Service struct {
|
||||
db *sql.DB
|
||||
adminKey string // Super admin key from environment
|
||||
svc *service.APIKeyService
|
||||
adminKey string
|
||||
}
|
||||
|
||||
// NewService creates a new auth service.
|
||||
func NewService(db *sql.DB, adminKey string) *Service {
|
||||
// Accepts a service.APIKeyService (hexagonal) instead of raw *sql.DB.
|
||||
func NewService(svc *service.APIKeyService, adminKey string) *Service {
|
||||
return &Service{
|
||||
db: db,
|
||||
svc: svc,
|
||||
adminKey: adminKey,
|
||||
}
|
||||
}
|
||||
@ -124,211 +67,49 @@ func (s *Service) Create(ctx context.Context, req CreateKeyRequest) (*CreateKeyR
|
||||
return nil, fmt.Errorf("invalid scopes")
|
||||
}
|
||||
|
||||
// Generate key
|
||||
fullKey, prefix, err := GenerateKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate key: %w", err)
|
||||
// Convert []string ProjectIDs to []domain.ProjectID
|
||||
var projectIDs []domain.ProjectID
|
||||
if req.ProjectIDs != nil {
|
||||
projectIDs = make([]domain.ProjectID, len(req.ProjectIDs))
|
||||
for i, p := range req.ProjectIDs {
|
||||
projectIDs[i] = domain.ProjectID(p)
|
||||
}
|
||||
}
|
||||
|
||||
keyHash := HashKey(fullKey)
|
||||
expiresAt := ExpiresAt(req.ExpiresIn)
|
||||
|
||||
// Convert scopes to strings for postgres
|
||||
scopeStrings := ScopesToStrings(req.Scopes)
|
||||
|
||||
var id string
|
||||
err = s.db.QueryRowContext(ctx, `
|
||||
INSERT INTO api_keys (name, key_hash, key_prefix, scopes, project_ids, allowed_ips, expires_at, created_by)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
RETURNING id
|
||||
`, req.Name, keyHash, prefix, pq.Array(scopeStrings), pq.Array(req.ProjectIDs), pq.Array(req.AllowedIPs), expiresAt, req.CreatedBy).Scan(&id)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("insert key: %w", err)
|
||||
}
|
||||
|
||||
key := &APIKey{
|
||||
ID: id,
|
||||
result, err := s.svc.Create(ctx, service.CreateKeyRequest{
|
||||
Name: req.Name,
|
||||
KeyPrefix: prefix,
|
||||
Scopes: req.Scopes,
|
||||
ProjectIDs: req.ProjectIDs,
|
||||
ProjectIDs: projectIDs,
|
||||
AllowedIPs: req.AllowedIPs,
|
||||
CreatedAt: time.Now(),
|
||||
ExpiresAt: expiresAt,
|
||||
ExpiresIn: req.ExpiresIn,
|
||||
CreatedBy: req.CreatedBy,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &CreateKeyResponse{
|
||||
Key: key,
|
||||
Secret: fullKey,
|
||||
Key: result.Key,
|
||||
Secret: result.Secret,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Validate checks if a key is valid and returns the key details.
|
||||
func (s *Service) Validate(ctx context.Context, key string) (*APIKey, error) {
|
||||
// Check admin key first
|
||||
if s.IsAdminKey(key) {
|
||||
return &APIKey{
|
||||
ID: "admin",
|
||||
Name: "Super Admin",
|
||||
KeyPrefix: "admin",
|
||||
Scopes: []Scope{ScopeAdmin},
|
||||
CreatedAt: time.Time{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Validate format
|
||||
if !ValidateKeyFormat(key) {
|
||||
return nil, ErrKeyNotFound
|
||||
}
|
||||
|
||||
keyHash := HashKey(key)
|
||||
|
||||
var (
|
||||
apiKey APIKey
|
||||
scopeStrings []string
|
||||
)
|
||||
|
||||
err := s.db.QueryRowContext(ctx, `
|
||||
SELECT id, name, key_prefix, scopes, project_ids, allowed_ips, created_at, expires_at, last_used_at, revoked_at, created_by
|
||||
FROM api_keys
|
||||
WHERE key_hash = $1
|
||||
`, keyHash).Scan(
|
||||
&apiKey.ID,
|
||||
&apiKey.Name,
|
||||
&apiKey.KeyPrefix,
|
||||
pq.Array(&scopeStrings),
|
||||
pq.Array(&apiKey.ProjectIDs),
|
||||
pq.Array(&apiKey.AllowedIPs),
|
||||
&apiKey.CreatedAt,
|
||||
&apiKey.ExpiresAt,
|
||||
&apiKey.LastUsedAt,
|
||||
&apiKey.RevokedAt,
|
||||
&apiKey.CreatedBy,
|
||||
)
|
||||
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrKeyNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query key: %w", err)
|
||||
}
|
||||
|
||||
apiKey.Scopes = ScopesFromStrings(scopeStrings)
|
||||
|
||||
if apiKey.IsRevoked() {
|
||||
return nil, ErrKeyRevoked
|
||||
}
|
||||
|
||||
if apiKey.IsExpired() {
|
||||
return nil, ErrKeyExpired
|
||||
}
|
||||
|
||||
// Update last_used_at asynchronously
|
||||
go func() {
|
||||
_, _ = s.db.ExecContext(context.Background(), `
|
||||
UPDATE api_keys SET last_used_at = NOW() WHERE id = $1
|
||||
`, apiKey.ID)
|
||||
}()
|
||||
|
||||
return &apiKey, nil
|
||||
return s.svc.Validate(ctx, key)
|
||||
}
|
||||
|
||||
// List returns all API keys (without secrets).
|
||||
func (s *Service) List(ctx context.Context) ([]*APIKey, error) {
|
||||
rows, err := s.db.QueryContext(ctx, `
|
||||
SELECT id, name, key_prefix, scopes, project_ids, allowed_ips, created_at, expires_at, last_used_at, revoked_at, created_by
|
||||
FROM api_keys
|
||||
ORDER BY created_at DESC
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query keys: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var keys []*APIKey
|
||||
for rows.Next() {
|
||||
var (
|
||||
key APIKey
|
||||
scopeStrings []string
|
||||
)
|
||||
if err := rows.Scan(
|
||||
&key.ID,
|
||||
&key.Name,
|
||||
&key.KeyPrefix,
|
||||
pq.Array(&scopeStrings),
|
||||
pq.Array(&key.ProjectIDs),
|
||||
pq.Array(&key.AllowedIPs),
|
||||
&key.CreatedAt,
|
||||
&key.ExpiresAt,
|
||||
&key.LastUsedAt,
|
||||
&key.RevokedAt,
|
||||
&key.CreatedBy,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("scan key: %w", err)
|
||||
}
|
||||
key.Scopes = ScopesFromStrings(scopeStrings)
|
||||
keys = append(keys, &key)
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
return s.svc.List(ctx)
|
||||
}
|
||||
|
||||
// Get returns a single API key by ID.
|
||||
func (s *Service) Get(ctx context.Context, id string) (*APIKey, error) {
|
||||
var (
|
||||
key APIKey
|
||||
scopeStrings []string
|
||||
)
|
||||
|
||||
err := s.db.QueryRowContext(ctx, `
|
||||
SELECT id, name, key_prefix, scopes, project_ids, allowed_ips, created_at, expires_at, last_used_at, revoked_at, created_by
|
||||
FROM api_keys
|
||||
WHERE id = $1
|
||||
`, id).Scan(
|
||||
&key.ID,
|
||||
&key.Name,
|
||||
&key.KeyPrefix,
|
||||
pq.Array(&scopeStrings),
|
||||
pq.Array(&key.ProjectIDs),
|
||||
pq.Array(&key.AllowedIPs),
|
||||
&key.CreatedAt,
|
||||
&key.ExpiresAt,
|
||||
&key.LastUsedAt,
|
||||
&key.RevokedAt,
|
||||
&key.CreatedBy,
|
||||
)
|
||||
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrKeyNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query key: %w", err)
|
||||
}
|
||||
|
||||
key.Scopes = ScopesFromStrings(scopeStrings)
|
||||
return &key, nil
|
||||
return s.svc.Get(ctx, domain.APIKeyID(id))
|
||||
}
|
||||
|
||||
// Revoke marks an API key as revoked.
|
||||
func (s *Service) Revoke(ctx context.Context, id string) error {
|
||||
result, err := s.db.ExecContext(ctx, `
|
||||
UPDATE api_keys SET revoked_at = NOW()
|
||||
WHERE id = $1 AND revoked_at IS NULL
|
||||
`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("revoke key: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return ErrKeyNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
return s.svc.Revoke(ctx, domain.APIKeyID(id))
|
||||
}
|
||||
|
||||
190
internal/auth/service_ip_test.go
Normal file
190
internal/auth/service_ip_test.go
Normal file
@ -0,0 +1,190 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestAPIKey_IsIPAllowed(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
allowedIPs []string
|
||||
clientIP string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "no restrictions - any IP allowed",
|
||||
allowedIPs: nil,
|
||||
clientIP: "192.168.1.100",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "empty restrictions - any IP allowed",
|
||||
allowedIPs: []string{},
|
||||
clientIP: "10.0.0.5",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "single IP match",
|
||||
allowedIPs: []string{"192.168.1.100"},
|
||||
clientIP: "192.168.1.100",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "single IP no match",
|
||||
allowedIPs: []string{"192.168.1.100"},
|
||||
clientIP: "192.168.1.101",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "CIDR match",
|
||||
allowedIPs: []string{"192.168.1.0/24"},
|
||||
clientIP: "192.168.1.55",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "CIDR no match",
|
||||
allowedIPs: []string{"192.168.1.0/24"},
|
||||
clientIP: "192.168.2.1",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "multiple CIDRs - first matches",
|
||||
allowedIPs: []string{"10.0.0.0/8", "192.168.0.0/16"},
|
||||
clientIP: "10.50.25.100",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "multiple CIDRs - second matches",
|
||||
allowedIPs: []string{"10.0.0.0/8", "192.168.0.0/16"},
|
||||
clientIP: "192.168.50.1",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "multiple CIDRs - none match",
|
||||
allowedIPs: []string{"10.0.0.0/8", "192.168.0.0/16"},
|
||||
clientIP: "172.16.0.1",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "mixed IP and CIDR - IP matches",
|
||||
allowedIPs: []string{"10.0.0.0/8", "172.16.0.1"},
|
||||
clientIP: "172.16.0.1",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "mixed IP and CIDR - CIDR matches",
|
||||
allowedIPs: []string{"10.0.0.0/8", "172.16.0.1"},
|
||||
clientIP: "10.1.2.3",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 CIDR",
|
||||
allowedIPs: []string{"2001:db8::/32"},
|
||||
clientIP: "2001:db8:1234:5678::1",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 no match",
|
||||
allowedIPs: []string{"2001:db8::/32"},
|
||||
clientIP: "2001:db9::1",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "invalid client IP",
|
||||
allowedIPs: []string{"192.168.1.0/24"},
|
||||
clientIP: "not-an-ip",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "invalid CIDR in allowlist (fallback to IP parse)",
|
||||
allowedIPs: []string{"invalid/cidr", "192.168.1.100"},
|
||||
clientIP: "192.168.1.100",
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
key := &APIKey{AllowedIPs: tt.allowedIPs}
|
||||
if got := key.IsIPAllowed(tt.clientIP); got != tt.want {
|
||||
t.Errorf("IsIPAllowed(%q) = %v, want %v", tt.clientIP, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestService_CreateWithAllowedIPs(t *testing.T) {
|
||||
svc := newTestService(t, "admin-key")
|
||||
|
||||
t.Run("creates key with IP restrictions", func(t *testing.T) {
|
||||
resp, err := svc.Create(context.Background(), CreateKeyRequest{
|
||||
Name: "test-ip-key",
|
||||
Scopes: []Scope{ScopeProjectsRead},
|
||||
AllowedIPs: []string{"192.168.1.0/24", "10.0.0.1"},
|
||||
ExpiresIn: 24 * time.Hour,
|
||||
CreatedBy: "test",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
|
||||
if len(resp.Key.AllowedIPs) != 2 {
|
||||
t.Errorf("Key.AllowedIPs length = %d, want 2", len(resp.Key.AllowedIPs))
|
||||
}
|
||||
|
||||
key, err := svc.Get(context.Background(), string(resp.Key.ID))
|
||||
if err != nil {
|
||||
t.Fatalf("Get() error = %v", err)
|
||||
}
|
||||
|
||||
if len(key.AllowedIPs) != 2 {
|
||||
t.Errorf("Retrieved Key.AllowedIPs length = %d, want 2", len(key.AllowedIPs))
|
||||
}
|
||||
|
||||
validatedKey, err := svc.Validate(context.Background(), resp.Secret)
|
||||
if err != nil {
|
||||
t.Fatalf("Validate() error = %v", err)
|
||||
}
|
||||
|
||||
if len(validatedKey.AllowedIPs) != 2 {
|
||||
t.Errorf("Validated Key.AllowedIPs length = %d, want 2", len(validatedKey.AllowedIPs))
|
||||
}
|
||||
|
||||
if !validatedKey.IsIPAllowed("192.168.1.50") {
|
||||
t.Error("IsIPAllowed should return true for IP in allowed CIDR")
|
||||
}
|
||||
if !validatedKey.IsIPAllowed("10.0.0.1") {
|
||||
t.Error("IsIPAllowed should return true for explicitly allowed IP")
|
||||
}
|
||||
if validatedKey.IsIPAllowed("172.16.0.1") {
|
||||
t.Error("IsIPAllowed should return false for IP not in allowed list")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("creates key with no IP restrictions", func(t *testing.T) {
|
||||
resp, err := svc.Create(context.Background(), CreateKeyRequest{
|
||||
Name: "test-no-ip-key",
|
||||
Scopes: []Scope{ScopeProjectsRead},
|
||||
ExpiresIn: 24 * time.Hour,
|
||||
CreatedBy: "test",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
|
||||
if len(resp.Key.AllowedIPs) != 0 {
|
||||
t.Errorf("Key.AllowedIPs should be empty, got %v", resp.Key.AllowedIPs)
|
||||
}
|
||||
|
||||
validatedKey, err := svc.Validate(context.Background(), resp.Secret)
|
||||
if err != nil {
|
||||
t.Fatalf("Validate() error = %v", err)
|
||||
}
|
||||
|
||||
if !validatedKey.IsIPAllowed("1.2.3.4") {
|
||||
t.Error("IsIPAllowed should return true when no restrictions set")
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -2,10 +2,13 @@ package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/orchard9/rdev/internal/adapter/postgres"
|
||||
"github.com/orchard9/rdev/internal/service"
|
||||
"github.com/orchard9/rdev/internal/testutil"
|
||||
)
|
||||
|
||||
@ -79,6 +82,17 @@ func TestAPIKey_IsActive(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// newTestService creates an auth.Service backed by the real postgres repo for integration tests.
|
||||
func newTestService(t *testing.T, adminKey string) *Service {
|
||||
t.Helper()
|
||||
db := testutil.TestDB(t)
|
||||
t.Cleanup(func() { testutil.CleanupTestKeys(t, db) })
|
||||
|
||||
repo := postgres.NewAPIKeyRepository(db)
|
||||
apiKeySvc := service.NewAPIKeyService(repo, adminKey)
|
||||
return NewService(apiKeySvc, adminKey)
|
||||
}
|
||||
|
||||
func TestService_IsAdminKey(t *testing.T) {
|
||||
svc := NewService(nil, "admin-secret")
|
||||
|
||||
@ -111,10 +125,7 @@ func TestService_IsAdminKey_NoAdminKey(t *testing.T) {
|
||||
|
||||
// Integration tests - require database
|
||||
func TestService_Create(t *testing.T) {
|
||||
db := testutil.TestDB(t)
|
||||
t.Cleanup(func() { testutil.CleanupTestKeys(t, db) })
|
||||
|
||||
svc := NewService(db, "admin-key")
|
||||
svc := newTestService(t, "admin-key")
|
||||
|
||||
t.Run("creates key with valid scopes", func(t *testing.T) {
|
||||
resp, err := svc.Create(context.Background(), CreateKeyRequest{
|
||||
@ -198,10 +209,7 @@ func TestService_Create(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestService_Validate(t *testing.T) {
|
||||
db := testutil.TestDB(t)
|
||||
t.Cleanup(func() { testutil.CleanupTestKeys(t, db) })
|
||||
|
||||
svc := NewService(db, "admin-key-test")
|
||||
svc := newTestService(t, "admin-key-test")
|
||||
|
||||
t.Run("validates admin key", func(t *testing.T) {
|
||||
key, err := svc.Validate(context.Background(), "admin-key-test")
|
||||
@ -209,7 +217,7 @@ func TestService_Validate(t *testing.T) {
|
||||
t.Fatalf("Validate() error = %v", err)
|
||||
}
|
||||
|
||||
if key.ID != "admin" {
|
||||
if string(key.ID) != "admin" {
|
||||
t.Errorf("Key.ID = %q, want %q", key.ID, "admin")
|
||||
}
|
||||
|
||||
@ -219,7 +227,6 @@ func TestService_Validate(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("validates created key", func(t *testing.T) {
|
||||
// Create a key first
|
||||
resp, err := svc.Create(context.Background(), CreateKeyRequest{
|
||||
Name: "test-validate-key",
|
||||
Scopes: []Scope{ScopeProjectsRead, ScopeKeysRead},
|
||||
@ -230,7 +237,6 @@ func TestService_Validate(t *testing.T) {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
|
||||
// Validate it
|
||||
key, err := svc.Validate(context.Background(), resp.Secret)
|
||||
if err != nil {
|
||||
t.Fatalf("Validate() error = %v", err)
|
||||
@ -245,29 +251,17 @@ func TestService_Validate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rejects invalid format", func(t *testing.T) {
|
||||
_, err := svc.Validate(context.Background(), "not-a-valid-key")
|
||||
if err != ErrKeyNotFound {
|
||||
t.Errorf("Validate() error = %v, want %v", err, ErrKeyNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rejects unknown key", func(t *testing.T) {
|
||||
// Valid format but not in database
|
||||
_, err := svc.Validate(context.Background(), "rdev_sk_abc12345_0123456789abcdef0123456789abcdef")
|
||||
if err != ErrKeyNotFound {
|
||||
if !errors.Is(err, ErrKeyNotFound) {
|
||||
t.Errorf("Validate() error = %v, want %v", err, ErrKeyNotFound)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestService_List(t *testing.T) {
|
||||
db := testutil.TestDB(t)
|
||||
t.Cleanup(func() { testutil.CleanupTestKeys(t, db) })
|
||||
svc := newTestService(t, "admin-key")
|
||||
|
||||
svc := NewService(db, "admin-key")
|
||||
|
||||
// Create some test keys
|
||||
for i := 0; i < 3; i++ {
|
||||
_, err := svc.Create(context.Background(), CreateKeyRequest{
|
||||
Name: fmt.Sprintf("test-list-key-%d", i),
|
||||
@ -285,10 +279,9 @@ func TestService_List(t *testing.T) {
|
||||
t.Fatalf("List() error = %v", err)
|
||||
}
|
||||
|
||||
// Should have at least our 3 test keys
|
||||
testKeyCount := 0
|
||||
for _, k := range keys {
|
||||
if k.Name[:10] == "test-list-" {
|
||||
if len(k.Name) >= 10 && k.Name[:10] == "test-list-" {
|
||||
testKeyCount++
|
||||
}
|
||||
}
|
||||
@ -299,10 +292,7 @@ func TestService_List(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestService_Get(t *testing.T) {
|
||||
db := testutil.TestDB(t)
|
||||
t.Cleanup(func() { testutil.CleanupTestKeys(t, db) })
|
||||
|
||||
svc := NewService(db, "admin-key")
|
||||
svc := newTestService(t, "admin-key")
|
||||
|
||||
t.Run("gets existing key", func(t *testing.T) {
|
||||
resp, err := svc.Create(context.Background(), CreateKeyRequest{
|
||||
@ -315,7 +305,7 @@ func TestService_Get(t *testing.T) {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
|
||||
key, err := svc.Get(context.Background(), resp.Key.ID)
|
||||
key, err := svc.Get(context.Background(), string(resp.Key.ID))
|
||||
if err != nil {
|
||||
t.Fatalf("Get() error = %v", err)
|
||||
}
|
||||
@ -327,17 +317,14 @@ func TestService_Get(t *testing.T) {
|
||||
|
||||
t.Run("returns error for unknown key", func(t *testing.T) {
|
||||
_, err := svc.Get(context.Background(), "00000000-0000-0000-0000-000000000000")
|
||||
if err != ErrKeyNotFound {
|
||||
if !errors.Is(err, ErrKeyNotFound) {
|
||||
t.Errorf("Get() error = %v, want %v", err, ErrKeyNotFound)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestService_Revoke(t *testing.T) {
|
||||
db := testutil.TestDB(t)
|
||||
t.Cleanup(func() { testutil.CleanupTestKeys(t, db) })
|
||||
|
||||
svc := NewService(db, "admin-key")
|
||||
svc := newTestService(t, "admin-key")
|
||||
|
||||
t.Run("revokes existing key", func(t *testing.T) {
|
||||
resp, err := svc.Create(context.Background(), CreateKeyRequest{
|
||||
@ -350,21 +337,20 @@ func TestService_Revoke(t *testing.T) {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
|
||||
err = svc.Revoke(context.Background(), resp.Key.ID)
|
||||
err = svc.Revoke(context.Background(), string(resp.Key.ID))
|
||||
if err != nil {
|
||||
t.Fatalf("Revoke() error = %v", err)
|
||||
}
|
||||
|
||||
// Validate should fail
|
||||
_, err = svc.Validate(context.Background(), resp.Secret)
|
||||
if err != ErrKeyRevoked {
|
||||
if !errors.Is(err, ErrKeyRevoked) {
|
||||
t.Errorf("Validate() after revoke error = %v, want %v", err, ErrKeyRevoked)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns error for unknown key", func(t *testing.T) {
|
||||
err := svc.Revoke(context.Background(), "00000000-0000-0000-0000-000000000000")
|
||||
if err != ErrKeyNotFound {
|
||||
if !errors.Is(err, ErrKeyNotFound) {
|
||||
t.Errorf("Revoke() error = %v, want %v", err, ErrKeyNotFound)
|
||||
}
|
||||
})
|
||||
@ -380,206 +366,13 @@ func TestService_Revoke(t *testing.T) {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
|
||||
// Revoke once
|
||||
if err := svc.Revoke(context.Background(), resp.Key.ID); err != nil {
|
||||
if err := svc.Revoke(context.Background(), string(resp.Key.ID)); err != nil {
|
||||
t.Fatalf("First Revoke() error = %v", err)
|
||||
}
|
||||
|
||||
// Revoke again - should return not found (no rows affected)
|
||||
err = svc.Revoke(context.Background(), resp.Key.ID)
|
||||
if err != ErrKeyNotFound {
|
||||
err = svc.Revoke(context.Background(), string(resp.Key.ID))
|
||||
if !errors.Is(err, ErrKeyNotFound) {
|
||||
t.Errorf("Second Revoke() error = %v, want %v", err, ErrKeyNotFound)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestAPIKey_IsIPAllowed(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
allowedIPs []string
|
||||
clientIP string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "no restrictions - any IP allowed",
|
||||
allowedIPs: nil,
|
||||
clientIP: "192.168.1.100",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "empty restrictions - any IP allowed",
|
||||
allowedIPs: []string{},
|
||||
clientIP: "10.0.0.5",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "single IP match",
|
||||
allowedIPs: []string{"192.168.1.100"},
|
||||
clientIP: "192.168.1.100",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "single IP no match",
|
||||
allowedIPs: []string{"192.168.1.100"},
|
||||
clientIP: "192.168.1.101",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "CIDR match",
|
||||
allowedIPs: []string{"192.168.1.0/24"},
|
||||
clientIP: "192.168.1.55",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "CIDR no match",
|
||||
allowedIPs: []string{"192.168.1.0/24"},
|
||||
clientIP: "192.168.2.1",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "multiple CIDRs - first matches",
|
||||
allowedIPs: []string{"10.0.0.0/8", "192.168.0.0/16"},
|
||||
clientIP: "10.50.25.100",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "multiple CIDRs - second matches",
|
||||
allowedIPs: []string{"10.0.0.0/8", "192.168.0.0/16"},
|
||||
clientIP: "192.168.50.1",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "multiple CIDRs - none match",
|
||||
allowedIPs: []string{"10.0.0.0/8", "192.168.0.0/16"},
|
||||
clientIP: "172.16.0.1",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "mixed IP and CIDR - IP matches",
|
||||
allowedIPs: []string{"10.0.0.0/8", "172.16.0.1"},
|
||||
clientIP: "172.16.0.1",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "mixed IP and CIDR - CIDR matches",
|
||||
allowedIPs: []string{"10.0.0.0/8", "172.16.0.1"},
|
||||
clientIP: "10.1.2.3",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 CIDR",
|
||||
allowedIPs: []string{"2001:db8::/32"},
|
||||
clientIP: "2001:db8:1234:5678::1",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 no match",
|
||||
allowedIPs: []string{"2001:db8::/32"},
|
||||
clientIP: "2001:db9::1",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "invalid client IP",
|
||||
allowedIPs: []string{"192.168.1.0/24"},
|
||||
clientIP: "not-an-ip",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "invalid CIDR in allowlist (fallback to IP parse)",
|
||||
allowedIPs: []string{"invalid/cidr", "192.168.1.100"},
|
||||
clientIP: "192.168.1.100",
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
key := &APIKey{AllowedIPs: tt.allowedIPs}
|
||||
if got := key.IsIPAllowed(tt.clientIP); got != tt.want {
|
||||
t.Errorf("IsIPAllowed(%q) = %v, want %v", tt.clientIP, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestService_CreateWithAllowedIPs(t *testing.T) {
|
||||
db := testutil.TestDB(t)
|
||||
t.Cleanup(func() { testutil.CleanupTestKeys(t, db) })
|
||||
|
||||
svc := NewService(db, "admin-key")
|
||||
|
||||
t.Run("creates key with IP restrictions", func(t *testing.T) {
|
||||
resp, err := svc.Create(context.Background(), CreateKeyRequest{
|
||||
Name: "test-ip-key",
|
||||
Scopes: []Scope{ScopeProjectsRead},
|
||||
AllowedIPs: []string{"192.168.1.0/24", "10.0.0.1"},
|
||||
ExpiresIn: 24 * time.Hour,
|
||||
CreatedBy: "test",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
|
||||
if len(resp.Key.AllowedIPs) != 2 {
|
||||
t.Errorf("Key.AllowedIPs length = %d, want 2", len(resp.Key.AllowedIPs))
|
||||
}
|
||||
|
||||
// Verify via Get
|
||||
key, err := svc.Get(context.Background(), resp.Key.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Get() error = %v", err)
|
||||
}
|
||||
|
||||
if len(key.AllowedIPs) != 2 {
|
||||
t.Errorf("Retrieved Key.AllowedIPs length = %d, want 2", len(key.AllowedIPs))
|
||||
}
|
||||
|
||||
// Verify via Validate
|
||||
validatedKey, err := svc.Validate(context.Background(), resp.Secret)
|
||||
if err != nil {
|
||||
t.Fatalf("Validate() error = %v", err)
|
||||
}
|
||||
|
||||
if len(validatedKey.AllowedIPs) != 2 {
|
||||
t.Errorf("Validated Key.AllowedIPs length = %d, want 2", len(validatedKey.AllowedIPs))
|
||||
}
|
||||
|
||||
// Verify IP checking works
|
||||
if !validatedKey.IsIPAllowed("192.168.1.50") {
|
||||
t.Error("IsIPAllowed should return true for IP in allowed CIDR")
|
||||
}
|
||||
if !validatedKey.IsIPAllowed("10.0.0.1") {
|
||||
t.Error("IsIPAllowed should return true for explicitly allowed IP")
|
||||
}
|
||||
if validatedKey.IsIPAllowed("172.16.0.1") {
|
||||
t.Error("IsIPAllowed should return false for IP not in allowed list")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("creates key with no IP restrictions", func(t *testing.T) {
|
||||
resp, err := svc.Create(context.Background(), CreateKeyRequest{
|
||||
Name: "test-no-ip-key",
|
||||
Scopes: []Scope{ScopeProjectsRead},
|
||||
ExpiresIn: 24 * time.Hour,
|
||||
CreatedBy: "test",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
|
||||
if len(resp.Key.AllowedIPs) != 0 {
|
||||
t.Errorf("Key.AllowedIPs should be empty, got %v", resp.Key.AllowedIPs)
|
||||
}
|
||||
|
||||
// Verify via Validate
|
||||
validatedKey, err := svc.Validate(context.Background(), resp.Secret)
|
||||
if err != nil {
|
||||
t.Fatalf("Validate() error = %v", err)
|
||||
}
|
||||
|
||||
// Any IP should be allowed
|
||||
if !validatedKey.IsIPAllowed("1.2.3.4") {
|
||||
t.Error("IsIPAllowed should return true when no restrictions set")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@ -3,13 +3,15 @@ package cmdlimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
)
|
||||
|
||||
// ErrLimitExceeded is returned when the concurrent command limit is reached.
|
||||
var ErrLimitExceeded = errors.New("concurrent command limit exceeded")
|
||||
// ErrLimitExceeded aliases domain.ErrLimitExceeded for backward compatibility.
|
||||
// Consumers should migrate to domain.ErrLimitExceeded over time.
|
||||
var ErrLimitExceeded = domain.ErrLimitExceeded
|
||||
|
||||
// Config defines the limiter configuration.
|
||||
type Config struct {
|
||||
@ -143,11 +145,11 @@ func (l *Limiter) Stats() Stats {
|
||||
}
|
||||
|
||||
return Stats{
|
||||
TotalActive: l.totalCount,
|
||||
MaxTotal: l.cfg.MaxConcurrentTotal,
|
||||
ProjectCounts: projectStats,
|
||||
MaxPerProject: l.cfg.MaxConcurrentPerProject,
|
||||
ActiveCommandIDs: l.getActiveCommandIDs(),
|
||||
TotalActive: l.totalCount,
|
||||
MaxTotal: l.cfg.MaxConcurrentTotal,
|
||||
ProjectCounts: projectStats,
|
||||
MaxPerProject: l.cfg.MaxConcurrentPerProject,
|
||||
ActiveCommandIDs: l.getActiveCommandIDs(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
71
internal/db/migrations/012_worker_registry.sql
Normal file
71
internal/db/migrations/012_worker_registry.sql
Normal file
@ -0,0 +1,71 @@
|
||||
-- Workers table for worker pool registration and health tracking.
|
||||
-- Workers register on startup, send heartbeats, and are marked offline
|
||||
-- if they miss heartbeats beyond the configured threshold.
|
||||
|
||||
CREATE TABLE IF NOT EXISTS workers (
|
||||
id VARCHAR(255) PRIMARY KEY,
|
||||
hostname VARCHAR(255) NOT NULL,
|
||||
status VARCHAR(50) NOT NULL DEFAULT 'idle',
|
||||
current_task VARCHAR(255),
|
||||
capabilities JSONB DEFAULT '[]',
|
||||
version VARCHAR(50),
|
||||
registered_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
last_heartbeat TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
|
||||
CONSTRAINT workers_valid_status CHECK (status IN ('idle', 'busy', 'draining', 'offline'))
|
||||
);
|
||||
|
||||
-- Index for finding idle workers (used during task assignment)
|
||||
CREATE INDEX IF NOT EXISTS idx_workers_status
|
||||
ON workers(status)
|
||||
WHERE status = 'idle';
|
||||
|
||||
-- Index for health checker (finding stale heartbeats)
|
||||
CREATE INDEX IF NOT EXISTS idx_workers_heartbeat
|
||||
ON workers(last_heartbeat);
|
||||
|
||||
COMMENT ON TABLE workers IS 'Worker pool registry for tracking worker lifecycle and health';
|
||||
COMMENT ON COLUMN workers.id IS 'Unique worker identifier (typically Kubernetes pod name)';
|
||||
COMMENT ON COLUMN workers.hostname IS 'Worker hostname for identification';
|
||||
COMMENT ON COLUMN workers.status IS 'Worker status: idle, busy, draining, offline';
|
||||
COMMENT ON COLUMN workers.current_task IS 'ID of the work queue task currently being executed';
|
||||
COMMENT ON COLUMN workers.capabilities IS 'JSON array of worker capabilities (e.g., build, deploy)';
|
||||
COMMENT ON COLUMN workers.version IS 'Worker binary version string';
|
||||
COMMENT ON COLUMN workers.last_heartbeat IS 'Most recent heartbeat timestamp for health monitoring';
|
||||
|
||||
-- Build audit table for tracking build execution history.
|
||||
-- Every build request creates an audit entry that is updated
|
||||
-- when the build completes or fails.
|
||||
|
||||
CREATE TABLE IF NOT EXISTS build_audit (
|
||||
task_id VARCHAR(255) PRIMARY KEY,
|
||||
project_id VARCHAR(255) NOT NULL,
|
||||
worker_id VARCHAR(255),
|
||||
spec JSONB NOT NULL,
|
||||
result JSONB,
|
||||
status VARCHAR(50) NOT NULL DEFAULT 'pending',
|
||||
started_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
completed_at TIMESTAMPTZ,
|
||||
|
||||
CONSTRAINT build_audit_valid_status CHECK (status IN ('pending', 'running', 'completed', 'failed', 'cancelled'))
|
||||
);
|
||||
|
||||
-- Index for project build history queries
|
||||
CREATE INDEX IF NOT EXISTS idx_build_audit_project
|
||||
ON build_audit(project_id);
|
||||
|
||||
-- Index for status-based queries (e.g., "show all running builds")
|
||||
CREATE INDEX IF NOT EXISTS idx_build_audit_status
|
||||
ON build_audit(status);
|
||||
|
||||
-- Index for time-ordered queries (most recent first)
|
||||
CREATE INDEX IF NOT EXISTS idx_build_audit_started
|
||||
ON build_audit(started_at DESC);
|
||||
|
||||
COMMENT ON TABLE build_audit IS 'Audit trail for build executions with full spec and result history';
|
||||
COMMENT ON COLUMN build_audit.task_id IS 'Work queue task ID (links to work_queue.id)';
|
||||
COMMENT ON COLUMN build_audit.project_id IS 'Project this build belongs to';
|
||||
COMMENT ON COLUMN build_audit.worker_id IS 'Worker that executed the build';
|
||||
COMMENT ON COLUMN build_audit.spec IS 'JSON build specification (prompt, template, variables, etc.)';
|
||||
COMMENT ON COLUMN build_audit.result IS 'JSON build result (output, commit_sha, files_changed, etc.)';
|
||||
COMMENT ON COLUMN build_audit.status IS 'Build status: pending, running, completed, failed, cancelled';
|
||||
@ -12,12 +12,122 @@ type APIKeyID string
|
||||
type Scope string
|
||||
|
||||
const (
|
||||
ScopeAdmin Scope = "admin"
|
||||
ScopeProjectsRead Scope = "projects:read"
|
||||
ScopeProjectsExecute Scope = "projects:execute"
|
||||
ScopeKeysManage Scope = "keys:manage"
|
||||
ScopeKeysRead Scope = "keys:read"
|
||||
ScopeKeysWrite Scope = "keys:write"
|
||||
ScopeAuditRead Scope = "audit:read"
|
||||
ScopeQueueRead Scope = "queue:read"
|
||||
ScopeQueueWrite Scope = "queue:write"
|
||||
ScopeWebhookRead Scope = "webhook:read"
|
||||
ScopeWebhookWrite Scope = "webhook:write"
|
||||
ScopeWorkersRead Scope = "workers:read"
|
||||
ScopeWorkersWrite Scope = "workers:write"
|
||||
ScopeBuildRead Scope = "build:read"
|
||||
ScopeBuildWrite Scope = "build:write"
|
||||
ScopeAdmin Scope = "admin"
|
||||
)
|
||||
|
||||
// AllScopes is the list of all valid scopes.
|
||||
var AllScopes = []Scope{
|
||||
ScopeProjectsRead,
|
||||
ScopeProjectsExecute,
|
||||
ScopeKeysRead,
|
||||
ScopeKeysWrite,
|
||||
ScopeAuditRead,
|
||||
ScopeQueueRead,
|
||||
ScopeQueueWrite,
|
||||
ScopeWebhookRead,
|
||||
ScopeWebhookWrite,
|
||||
ScopeWorkersRead,
|
||||
ScopeWorkersWrite,
|
||||
ScopeBuildRead,
|
||||
ScopeBuildWrite,
|
||||
ScopeAdmin,
|
||||
}
|
||||
|
||||
// ScopeDescriptions provides human-readable descriptions.
|
||||
var ScopeDescriptions = map[Scope]string{
|
||||
ScopeProjectsRead: "List and view project details",
|
||||
ScopeProjectsExecute: "Execute commands (claude, shell, git) on projects",
|
||||
ScopeKeysRead: "List API keys (metadata only, not secrets)",
|
||||
ScopeKeysWrite: "Create and revoke API keys",
|
||||
ScopeAuditRead: "View audit logs for command executions",
|
||||
ScopeQueueRead: "View queued commands and queue status",
|
||||
ScopeQueueWrite: "Enqueue and cancel queued commands",
|
||||
ScopeWebhookRead: "View webhooks and delivery history",
|
||||
ScopeWebhookWrite: "Create, update, and delete webhooks",
|
||||
ScopeWorkersRead: "View workers and worker status",
|
||||
ScopeWorkersWrite: "Manage workers (drain, register)",
|
||||
ScopeBuildRead: "View build status and history",
|
||||
ScopeBuildWrite: "Start and manage builds",
|
||||
ScopeAdmin: "Full administrative access (includes all scopes)",
|
||||
}
|
||||
|
||||
// IsValid checks if a scope is valid.
|
||||
func (s Scope) IsValid() bool {
|
||||
for _, scope := range AllScopes {
|
||||
if scope == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// String returns the scope as a string.
|
||||
func (s Scope) String() string {
|
||||
return string(s)
|
||||
}
|
||||
|
||||
// ScopesFromStrings converts string slice to Scope slice.
|
||||
func ScopesFromStrings(ss []string) []Scope {
|
||||
scopes := make([]Scope, len(ss))
|
||||
for i, s := range ss {
|
||||
scopes[i] = Scope(s)
|
||||
}
|
||||
return scopes
|
||||
}
|
||||
|
||||
// ScopesToStrings converts Scope slice to string slice.
|
||||
func ScopesToStrings(scopes []Scope) []string {
|
||||
ss := make([]string, len(scopes))
|
||||
for i, s := range scopes {
|
||||
ss[i] = string(s)
|
||||
}
|
||||
return ss
|
||||
}
|
||||
|
||||
// ValidateScopes checks if all scopes are valid.
|
||||
func ValidateScopes(scopes []Scope) bool {
|
||||
for _, s := range scopes {
|
||||
if !s.IsValid() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// HasScope checks if a scope list contains a required scope.
|
||||
// Admin scope grants access to everything.
|
||||
func HasScope(scopes []Scope, required Scope) bool {
|
||||
for _, s := range scopes {
|
||||
if s == ScopeAdmin || s == required {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HasAnyScope checks if a scope list contains any of the required scopes.
|
||||
func HasAnyScope(scopes []Scope, required ...Scope) bool {
|
||||
for _, r := range required {
|
||||
if HasScope(scopes, r) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// APIKey represents an API key for authentication.
|
||||
type APIKey struct {
|
||||
ID APIKeyID
|
||||
|
||||
202
internal/domain/build.go
Normal file
202
internal/domain/build.go
Normal file
@ -0,0 +1,202 @@
|
||||
// Package domain contains pure domain models with no external dependencies.
|
||||
package domain
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// BuildSpec defines what a build task should accomplish.
|
||||
// It captures the user's intent and parameters for code generation.
|
||||
type BuildSpec struct {
|
||||
// Template is the project template to use (e.g., "nextjs-landing").
|
||||
Template string `json:"template,omitempty"`
|
||||
|
||||
// Prompt is the user's instruction for the build.
|
||||
Prompt string `json:"prompt"`
|
||||
|
||||
// Variables contains template-specific substitution values.
|
||||
Variables map[string]string `json:"variables,omitempty"`
|
||||
|
||||
// AutoCommit controls whether changes are automatically committed.
|
||||
AutoCommit bool `json:"auto_commit"`
|
||||
|
||||
// AutoPush controls whether commits are automatically pushed.
|
||||
AutoPush bool `json:"auto_push"`
|
||||
|
||||
// CallbackURL is the webhook URL for completion notification.
|
||||
CallbackURL string `json:"callback_url,omitempty"`
|
||||
}
|
||||
|
||||
// Validate checks that the BuildSpec has required fields.
|
||||
func (s *BuildSpec) Validate() error {
|
||||
if s.Prompt == "" {
|
||||
return ErrPromptRequired
|
||||
}
|
||||
if s.CallbackURL != "" {
|
||||
if err := ValidateCallbackURL(s.CallbackURL); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateCallbackURL checks that a callback URL is safe to use.
|
||||
// It rejects non-HTTPS URLs and private/internal network addresses.
|
||||
func ValidateCallbackURL(rawURL string) error {
|
||||
u, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid callback URL: %w", err)
|
||||
}
|
||||
if u.Scheme != "https" {
|
||||
return fmt.Errorf("callback URL must use HTTPS scheme, got %q", u.Scheme)
|
||||
}
|
||||
if u.Host == "" {
|
||||
return fmt.Errorf("callback URL must have a host")
|
||||
}
|
||||
|
||||
host := u.Hostname()
|
||||
lower := strings.ToLower(host)
|
||||
|
||||
// Block localhost and loopback
|
||||
if lower == "localhost" || lower == "127.0.0.1" || lower == "::1" || lower == "[::1]" {
|
||||
return fmt.Errorf("callback URL must not point to localhost")
|
||||
}
|
||||
// Block common metadata endpoints
|
||||
if lower == "metadata.google.internal" || lower == "169.254.169.254" {
|
||||
return fmt.Errorf("callback URL must not point to cloud metadata service")
|
||||
}
|
||||
// Block private network ranges
|
||||
if strings.HasPrefix(lower, "10.") || strings.HasPrefix(lower, "192.168.") {
|
||||
return fmt.Errorf("callback URL must not point to private network addresses")
|
||||
}
|
||||
if strings.HasPrefix(lower, "172.") {
|
||||
// 172.16.0.0 - 172.31.255.255
|
||||
parts := strings.SplitN(lower, ".", 3)
|
||||
if len(parts) >= 2 {
|
||||
if octet, err := strconv.Atoi(parts[1]); err == nil && octet >= 16 && octet <= 31 {
|
||||
return fmt.Errorf("callback URL must not point to private network addresses")
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildResult captures the outcome of a build execution.
|
||||
type BuildResult struct {
|
||||
// Success indicates whether the build completed successfully.
|
||||
Success bool `json:"success"`
|
||||
|
||||
// Output is the agent's text output during the build.
|
||||
Output string `json:"output,omitempty"`
|
||||
|
||||
// Error contains the error message if the build failed.
|
||||
Error string `json:"error,omitempty"`
|
||||
|
||||
// CommitSHA is the git commit hash if auto-commit was enabled.
|
||||
CommitSHA string `json:"commit_sha,omitempty"`
|
||||
|
||||
// FilesChanged lists files modified during the build.
|
||||
FilesChanged []string `json:"files_changed,omitempty"`
|
||||
|
||||
// DurationMs is the total execution time in milliseconds.
|
||||
DurationMs int64 `json:"duration_ms"`
|
||||
|
||||
// Artifacts contains named outputs from the build (e.g., deploy URLs).
|
||||
Artifacts map[string]string `json:"artifacts,omitempty"`
|
||||
}
|
||||
|
||||
// ToWorkResult converts a BuildResult to a WorkResult.
|
||||
// Build-specific fields (commit_sha, files_changed, duration_ms) are
|
||||
// promoted into the artifacts map, overwriting any existing keys with
|
||||
// the same names.
|
||||
func (r *BuildResult) ToWorkResult() *WorkResult {
|
||||
if r == nil {
|
||||
return &WorkResult{}
|
||||
}
|
||||
|
||||
artifacts := make(map[string]string)
|
||||
maps.Copy(artifacts, r.Artifacts)
|
||||
|
||||
// Promote build-specific fields into artifacts
|
||||
if r.CommitSHA != "" {
|
||||
artifacts["commit_sha"] = r.CommitSHA
|
||||
}
|
||||
if r.DurationMs > 0 {
|
||||
artifacts["duration_ms"] = strconv.FormatInt(r.DurationMs, 10)
|
||||
}
|
||||
if len(r.FilesChanged) > 0 {
|
||||
artifacts["files_changed_count"] = strconv.Itoa(len(r.FilesChanged))
|
||||
}
|
||||
|
||||
output := r.Output
|
||||
if !r.Success && r.Error != "" {
|
||||
output = r.Error
|
||||
}
|
||||
|
||||
return &WorkResult{
|
||||
Output: output,
|
||||
Artifacts: artifacts,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildStatus represents the lifecycle state of a build.
|
||||
type BuildStatus string
|
||||
|
||||
const (
|
||||
BuildStatusPending BuildStatus = "pending"
|
||||
BuildStatusRunning BuildStatus = "running"
|
||||
BuildStatusCompleted BuildStatus = "completed"
|
||||
BuildStatusFailed BuildStatus = "failed"
|
||||
BuildStatusCancelled BuildStatus = "cancelled"
|
||||
)
|
||||
|
||||
// IsValid returns true if the status is a known valid status.
|
||||
func (s BuildStatus) IsValid() bool {
|
||||
switch s {
|
||||
case BuildStatusPending, BuildStatusRunning, BuildStatusCompleted,
|
||||
BuildStatusFailed, BuildStatusCancelled:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsTerminal returns true if the build is in a final state.
|
||||
func (s BuildStatus) IsTerminal() bool {
|
||||
switch s {
|
||||
case BuildStatusCompleted, BuildStatusFailed, BuildStatusCancelled:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// BuildAuditEntry represents a single build's audit record.
|
||||
type BuildAuditEntry struct {
|
||||
// TaskID is the work queue task ID.
|
||||
TaskID string `json:"task_id"`
|
||||
|
||||
// ProjectID is the project this build belongs to.
|
||||
ProjectID string `json:"project_id"`
|
||||
|
||||
// WorkerID is the worker that executed the build.
|
||||
WorkerID string `json:"worker_id,omitempty"`
|
||||
|
||||
// Spec is the original build specification.
|
||||
Spec BuildSpec `json:"spec"`
|
||||
|
||||
// Result is the build outcome (nil if not yet complete).
|
||||
Result *BuildResult `json:"result,omitempty"`
|
||||
|
||||
// Status is the current build status.
|
||||
Status BuildStatus `json:"status"`
|
||||
|
||||
// StartedAt is when the build was created/enqueued.
|
||||
StartedAt time.Time `json:"started_at"`
|
||||
|
||||
// CompletedAt is when the build finished (nil if still running).
|
||||
CompletedAt *time.Time `json:"completed_at,omitempty"`
|
||||
}
|
||||
207
internal/domain/build_test.go
Normal file
207
internal/domain/build_test.go
Normal file
@ -0,0 +1,207 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBuildSpec_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
spec BuildSpec
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "valid spec with prompt",
|
||||
spec: BuildSpec{Prompt: "Build a landing page"},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "valid spec with all fields",
|
||||
spec: BuildSpec{Prompt: "Build it", Template: "nextjs", AutoCommit: true, AutoPush: true},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "empty prompt",
|
||||
spec: BuildSpec{},
|
||||
wantErr: ErrPromptRequired,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.spec.Validate()
|
||||
if tt.wantErr != nil {
|
||||
if !errors.Is(err, tt.wantErr) {
|
||||
t.Errorf("Validate() error = %v, want %v", err, tt.wantErr)
|
||||
}
|
||||
} else if err != nil {
|
||||
t.Errorf("Validate() unexpected error = %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildStatus_IsValid(t *testing.T) {
|
||||
tests := []struct {
|
||||
status BuildStatus
|
||||
want bool
|
||||
}{
|
||||
{BuildStatusPending, true},
|
||||
{BuildStatusRunning, true},
|
||||
{BuildStatusCompleted, true},
|
||||
{BuildStatusFailed, true},
|
||||
{BuildStatusCancelled, true},
|
||||
{"unknown", false},
|
||||
{"", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(string(tt.status), func(t *testing.T) {
|
||||
if got := tt.status.IsValid(); got != tt.want {
|
||||
t.Errorf("IsValid() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildResult_ToWorkResult(t *testing.T) {
|
||||
t.Run("success with all fields", func(t *testing.T) {
|
||||
result := &BuildResult{
|
||||
Success: true,
|
||||
Output: "Build completed",
|
||||
CommitSHA: "abc123",
|
||||
FilesChanged: []string{"main.go", "go.mod"},
|
||||
DurationMs: 1500,
|
||||
Artifacts: map[string]string{"deploy_url": "https://example.com"},
|
||||
}
|
||||
|
||||
wr := result.ToWorkResult()
|
||||
|
||||
if wr.Output != "Build completed" {
|
||||
t.Errorf("Output = %q, want %q", wr.Output, "Build completed")
|
||||
}
|
||||
if wr.Artifacts["commit_sha"] != "abc123" {
|
||||
t.Errorf("commit_sha = %q, want %q", wr.Artifacts["commit_sha"], "abc123")
|
||||
}
|
||||
if wr.Artifacts["duration_ms"] != "1500" {
|
||||
t.Errorf("duration_ms = %q, want %q", wr.Artifacts["duration_ms"], "1500")
|
||||
}
|
||||
if wr.Artifacts["files_changed_count"] != "2" {
|
||||
t.Errorf("files_changed_count = %q, want %q", wr.Artifacts["files_changed_count"], "2")
|
||||
}
|
||||
if wr.Artifacts["deploy_url"] != "https://example.com" {
|
||||
t.Errorf("deploy_url = %q, want %q", wr.Artifacts["deploy_url"], "https://example.com")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("failure uses error as output", func(t *testing.T) {
|
||||
result := &BuildResult{
|
||||
Success: false,
|
||||
Output: "partial output",
|
||||
Error: "build failed: missing deps",
|
||||
}
|
||||
|
||||
wr := result.ToWorkResult()
|
||||
|
||||
if wr.Output != "build failed: missing deps" {
|
||||
t.Errorf("Output = %q, want error message", wr.Output)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nil artifacts map is safe", func(t *testing.T) {
|
||||
result := &BuildResult{
|
||||
Success: true,
|
||||
Output: "done",
|
||||
}
|
||||
|
||||
wr := result.ToWorkResult()
|
||||
|
||||
if wr.Output != "done" {
|
||||
t.Errorf("Output = %q, want %q", wr.Output, "done")
|
||||
}
|
||||
if len(wr.Artifacts) != 0 {
|
||||
t.Errorf("Artifacts = %v, want empty", wr.Artifacts)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("zero duration not included", func(t *testing.T) {
|
||||
result := &BuildResult{
|
||||
Success: true,
|
||||
DurationMs: 0,
|
||||
}
|
||||
|
||||
wr := result.ToWorkResult()
|
||||
|
||||
if _, ok := wr.Artifacts["duration_ms"]; ok {
|
||||
t.Error("duration_ms should not be in artifacts when zero")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nil receiver returns empty result", func(t *testing.T) {
|
||||
var result *BuildResult
|
||||
wr := result.ToWorkResult()
|
||||
|
||||
if wr.Output != "" {
|
||||
t.Errorf("Output = %q, want empty", wr.Output)
|
||||
}
|
||||
if wr.Artifacts != nil {
|
||||
t.Errorf("Artifacts = %v, want nil", wr.Artifacts)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("promoted fields overwrite existing artifacts", func(t *testing.T) {
|
||||
result := &BuildResult{
|
||||
Success: true,
|
||||
CommitSHA: "new-sha",
|
||||
Artifacts: map[string]string{
|
||||
"commit_sha": "old-sha",
|
||||
"custom_key": "kept",
|
||||
},
|
||||
}
|
||||
|
||||
wr := result.ToWorkResult()
|
||||
|
||||
if wr.Artifacts["commit_sha"] != "new-sha" {
|
||||
t.Errorf("commit_sha = %q, want %q (promoted field should overwrite)", wr.Artifacts["commit_sha"], "new-sha")
|
||||
}
|
||||
if wr.Artifacts["custom_key"] != "kept" {
|
||||
t.Errorf("custom_key = %q, want %q", wr.Artifacts["custom_key"], "kept")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("failed with empty error keeps output", func(t *testing.T) {
|
||||
result := &BuildResult{
|
||||
Success: false,
|
||||
Output: "partial output before crash",
|
||||
Error: "",
|
||||
}
|
||||
|
||||
wr := result.ToWorkResult()
|
||||
|
||||
if wr.Output != "partial output before crash" {
|
||||
t.Errorf("Output = %q, want original output when error is empty", wr.Output)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildStatus_IsTerminal(t *testing.T) {
|
||||
tests := []struct {
|
||||
status BuildStatus
|
||||
want bool
|
||||
}{
|
||||
{BuildStatusPending, false},
|
||||
{BuildStatusRunning, false},
|
||||
{BuildStatusCompleted, true},
|
||||
{BuildStatusFailed, true},
|
||||
{BuildStatusCancelled, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(string(tt.status), func(t *testing.T) {
|
||||
if got := tt.status.IsTerminal(); got != tt.want {
|
||||
t.Errorf("IsTerminal() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -1,7 +1,6 @@
|
||||
package domain_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -202,7 +201,7 @@ func TestAPIKey_HasAnyScope(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "no match",
|
||||
scopes: []domain.Scope{domain.ScopeKeysManage},
|
||||
scopes: []domain.Scope{domain.ScopeKeysWrite},
|
||||
check: []domain.Scope{domain.ScopeProjectsRead, domain.ScopeProjectsExecute},
|
||||
want: false,
|
||||
},
|
||||
@ -403,257 +402,6 @@ func TestAPIKey_IsIPAllowed(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// CommandResult Tests
|
||||
// =============================================================================
|
||||
|
||||
func TestCommandResult_Success(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
exitCode int
|
||||
err error
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "success with zero exit code and no error",
|
||||
exitCode: 0,
|
||||
err: nil,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "failure with non-zero exit code",
|
||||
exitCode: 1,
|
||||
err: nil,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "failure with error",
|
||||
exitCode: 0,
|
||||
err: errors.New("execution failed"),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "failure with both error and non-zero exit",
|
||||
exitCode: 127,
|
||||
err: errors.New("command not found"),
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := &domain.CommandResult{
|
||||
ExitCode: tt.exitCode,
|
||||
Error: tt.err,
|
||||
}
|
||||
if got := result.Success(); got != tt.want {
|
||||
t.Errorf("Success() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// ProjectStatus Tests
|
||||
// =============================================================================
|
||||
|
||||
func TestProjectStatus_IsAvailable(t *testing.T) {
|
||||
tests := []struct {
|
||||
status domain.ProjectStatus
|
||||
want bool
|
||||
}{
|
||||
{domain.ProjectStatusRunning, true},
|
||||
{domain.ProjectStatusPending, false},
|
||||
{domain.ProjectStatusFailed, false},
|
||||
{domain.ProjectStatusNotFound, false},
|
||||
{domain.ProjectStatusUnknown, false},
|
||||
{domain.ProjectStatusError, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(string(tt.status), func(t *testing.T) {
|
||||
if got := tt.status.IsAvailable(); got != tt.want {
|
||||
t.Errorf("ProjectStatus(%q).IsAvailable() = %v, want %v", tt.status, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProjectStatus_IsTerminal(t *testing.T) {
|
||||
tests := []struct {
|
||||
status domain.ProjectStatus
|
||||
want bool
|
||||
}{
|
||||
{domain.ProjectStatusRunning, false},
|
||||
{domain.ProjectStatusPending, false},
|
||||
{domain.ProjectStatusFailed, true},
|
||||
{domain.ProjectStatusNotFound, true},
|
||||
{domain.ProjectStatusUnknown, false},
|
||||
{domain.ProjectStatusError, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(string(tt.status), func(t *testing.T) {
|
||||
if got := tt.status.IsTerminal(); got != tt.want {
|
||||
t.Errorf("ProjectStatus(%q).IsTerminal() = %v, want %v", tt.status, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Error Variables Tests
|
||||
// =============================================================================
|
||||
|
||||
func TestErrorVariables_AreDistinct(t *testing.T) {
|
||||
// Verify all domain errors are distinct and can be matched with errors.Is
|
||||
allErrors := []error{
|
||||
domain.ErrProjectNotFound,
|
||||
domain.ErrProjectNotRunning,
|
||||
domain.ErrCommandNotFound,
|
||||
domain.ErrCommandTimeout,
|
||||
domain.ErrCommandCancelled,
|
||||
domain.ErrLimitExceeded,
|
||||
domain.ErrInvalidCommand,
|
||||
domain.ErrCommandSanitization,
|
||||
domain.ErrKeyNotFound,
|
||||
domain.ErrKeyRevoked,
|
||||
domain.ErrKeyExpired,
|
||||
domain.ErrKeyInvalid,
|
||||
domain.ErrUnauthorized,
|
||||
domain.ErrForbidden,
|
||||
domain.ErrInsufficientScope,
|
||||
domain.ErrRateLimited,
|
||||
domain.ErrDatabaseConnection,
|
||||
domain.ErrKubernetesError,
|
||||
}
|
||||
|
||||
// Each error should only match itself
|
||||
for i, err1 := range allErrors {
|
||||
for j, err2 := range allErrors {
|
||||
if i == j {
|
||||
if !errors.Is(err1, err2) {
|
||||
t.Errorf("error %v should match itself", err1)
|
||||
}
|
||||
} else {
|
||||
if errors.Is(err1, err2) {
|
||||
t.Errorf("error %v should not match %v", err1, err2)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorVariables_CanBeWrapped(t *testing.T) {
|
||||
// Domain errors should be usable as base errors for wrapping
|
||||
wrapped := errors.Join(domain.ErrProjectNotFound, errors.New("additional context"))
|
||||
|
||||
if !errors.Is(wrapped, domain.ErrProjectNotFound) {
|
||||
t.Error("wrapped error should match base domain error")
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Type Constants Tests
|
||||
// =============================================================================
|
||||
|
||||
func TestScopeConstants(t *testing.T) {
|
||||
// Verify scope constants have expected values (for documentation)
|
||||
expectedScopes := map[domain.Scope]string{
|
||||
domain.ScopeAdmin: "admin",
|
||||
domain.ScopeProjectsRead: "projects:read",
|
||||
domain.ScopeProjectsExecute: "projects:execute",
|
||||
domain.ScopeKeysManage: "keys:manage",
|
||||
}
|
||||
|
||||
for scope, expected := range expectedScopes {
|
||||
if string(scope) != expected {
|
||||
t.Errorf("Scope %v = %q, want %q", scope, string(scope), expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommandTypeConstants(t *testing.T) {
|
||||
expectedTypes := map[domain.CommandType]string{
|
||||
domain.CommandTypeClaude: "claude",
|
||||
domain.CommandTypeShell: "shell",
|
||||
domain.CommandTypeGit: "git",
|
||||
}
|
||||
|
||||
for cmdType, expected := range expectedTypes {
|
||||
if string(cmdType) != expected {
|
||||
t.Errorf("CommandType %v = %q, want %q", cmdType, string(cmdType), expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestProjectStatusConstants(t *testing.T) {
|
||||
expectedStatuses := map[domain.ProjectStatus]string{
|
||||
domain.ProjectStatusRunning: "running",
|
||||
domain.ProjectStatusPending: "pending",
|
||||
domain.ProjectStatusFailed: "failed",
|
||||
domain.ProjectStatusNotFound: "not_found",
|
||||
domain.ProjectStatusUnknown: "unknown",
|
||||
domain.ProjectStatusError: "error",
|
||||
}
|
||||
|
||||
for status, expected := range expectedStatuses {
|
||||
if string(status) != expected {
|
||||
t.Errorf("ProjectStatus %v = %q, want %q", status, string(status), expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Type Instantiation Tests
|
||||
// =============================================================================
|
||||
|
||||
func TestProject_CanBeInstantiated(t *testing.T) {
|
||||
proj := domain.Project{
|
||||
ID: "test-project",
|
||||
Name: "Test Project",
|
||||
Description: "A test project",
|
||||
PodName: "test-pod",
|
||||
Status: domain.ProjectStatusRunning,
|
||||
Workspace: "/workspace",
|
||||
}
|
||||
|
||||
if proj.ID != "test-project" {
|
||||
t.Errorf("Project.ID = %q, want %q", proj.ID, "test-project")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommand_CanBeInstantiated(t *testing.T) {
|
||||
now := time.Now()
|
||||
cmd := domain.Command{
|
||||
ID: "cmd-1",
|
||||
ProjectID: "proj-1",
|
||||
Type: domain.CommandTypeClaude,
|
||||
Args: []string{"--version"},
|
||||
StartedAt: now,
|
||||
}
|
||||
|
||||
if cmd.ID != "cmd-1" {
|
||||
t.Errorf("Command.ID = %q, want %q", cmd.ID, "cmd-1")
|
||||
}
|
||||
if len(cmd.Args) != 1 {
|
||||
t.Errorf("Command.Args length = %d, want 1", len(cmd.Args))
|
||||
}
|
||||
}
|
||||
|
||||
func TestOutputLine_CanBeInstantiated(t *testing.T) {
|
||||
now := time.Now()
|
||||
line := domain.OutputLine{
|
||||
Stream: "stdout",
|
||||
Line: "Hello, world!",
|
||||
Timestamp: now,
|
||||
}
|
||||
|
||||
if line.Stream != "stdout" {
|
||||
t.Errorf("OutputLine.Stream = %q, want %q", line.Stream, "stdout")
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Helpers
|
||||
// =============================================================================
|
||||
|
||||
266
internal/domain/domain_types_test.go
Normal file
266
internal/domain/domain_types_test.go
Normal file
@ -0,0 +1,266 @@
|
||||
package domain_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// CommandResult Tests
|
||||
// =============================================================================
|
||||
|
||||
func TestCommandResult_Success(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
exitCode int
|
||||
err error
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "success with zero exit code and no error",
|
||||
exitCode: 0,
|
||||
err: nil,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "failure with non-zero exit code",
|
||||
exitCode: 1,
|
||||
err: nil,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "failure with error",
|
||||
exitCode: 0,
|
||||
err: errors.New("execution failed"),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "failure with both error and non-zero exit",
|
||||
exitCode: 127,
|
||||
err: errors.New("command not found"),
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := &domain.CommandResult{
|
||||
ExitCode: tt.exitCode,
|
||||
Error: tt.err,
|
||||
}
|
||||
if got := result.Success(); got != tt.want {
|
||||
t.Errorf("Success() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// ProjectStatus Tests
|
||||
// =============================================================================
|
||||
|
||||
func TestProjectStatus_IsAvailable(t *testing.T) {
|
||||
tests := []struct {
|
||||
status domain.ProjectStatus
|
||||
want bool
|
||||
}{
|
||||
{domain.ProjectStatusRunning, true},
|
||||
{domain.ProjectStatusPending, false},
|
||||
{domain.ProjectStatusFailed, false},
|
||||
{domain.ProjectStatusNotFound, false},
|
||||
{domain.ProjectStatusUnknown, false},
|
||||
{domain.ProjectStatusError, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(string(tt.status), func(t *testing.T) {
|
||||
if got := tt.status.IsAvailable(); got != tt.want {
|
||||
t.Errorf("ProjectStatus(%q).IsAvailable() = %v, want %v", tt.status, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProjectStatus_IsTerminal(t *testing.T) {
|
||||
tests := []struct {
|
||||
status domain.ProjectStatus
|
||||
want bool
|
||||
}{
|
||||
{domain.ProjectStatusRunning, false},
|
||||
{domain.ProjectStatusPending, false},
|
||||
{domain.ProjectStatusFailed, true},
|
||||
{domain.ProjectStatusNotFound, true},
|
||||
{domain.ProjectStatusUnknown, false},
|
||||
{domain.ProjectStatusError, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(string(tt.status), func(t *testing.T) {
|
||||
if got := tt.status.IsTerminal(); got != tt.want {
|
||||
t.Errorf("ProjectStatus(%q).IsTerminal() = %v, want %v", tt.status, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Error Variables Tests
|
||||
// =============================================================================
|
||||
|
||||
func TestErrorVariables_AreDistinct(t *testing.T) {
|
||||
// Verify all domain errors are distinct and can be matched with errors.Is
|
||||
allErrors := []error{
|
||||
domain.ErrProjectNotFound,
|
||||
domain.ErrProjectNotRunning,
|
||||
domain.ErrCommandNotFound,
|
||||
domain.ErrCommandTimeout,
|
||||
domain.ErrCommandCancelled,
|
||||
domain.ErrLimitExceeded,
|
||||
domain.ErrInvalidCommand,
|
||||
domain.ErrCommandSanitization,
|
||||
domain.ErrKeyNotFound,
|
||||
domain.ErrKeyRevoked,
|
||||
domain.ErrKeyExpired,
|
||||
domain.ErrKeyInvalid,
|
||||
domain.ErrUnauthorized,
|
||||
domain.ErrForbidden,
|
||||
domain.ErrInsufficientScope,
|
||||
domain.ErrRateLimited,
|
||||
domain.ErrDatabaseConnection,
|
||||
domain.ErrKubernetesError,
|
||||
}
|
||||
|
||||
// Each error should only match itself
|
||||
for i, err1 := range allErrors {
|
||||
for j, err2 := range allErrors {
|
||||
if i == j {
|
||||
if !errors.Is(err1, err2) {
|
||||
t.Errorf("error %v should match itself", err1)
|
||||
}
|
||||
} else {
|
||||
if errors.Is(err1, err2) {
|
||||
t.Errorf("error %v should not match %v", err1, err2)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorVariables_CanBeWrapped(t *testing.T) {
|
||||
// Domain errors should be usable as base errors for wrapping
|
||||
wrapped := errors.Join(domain.ErrProjectNotFound, errors.New("additional context"))
|
||||
|
||||
if !errors.Is(wrapped, domain.ErrProjectNotFound) {
|
||||
t.Error("wrapped error should match base domain error")
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Type Constants Tests
|
||||
// =============================================================================
|
||||
|
||||
func TestScopeConstants(t *testing.T) {
|
||||
// Verify scope constants have expected values (for documentation)
|
||||
expectedScopes := map[domain.Scope]string{
|
||||
domain.ScopeAdmin: "admin",
|
||||
domain.ScopeProjectsRead: "projects:read",
|
||||
domain.ScopeProjectsExecute: "projects:execute",
|
||||
domain.ScopeKeysRead: "keys:read",
|
||||
domain.ScopeKeysWrite: "keys:write",
|
||||
domain.ScopeAuditRead: "audit:read",
|
||||
domain.ScopeQueueRead: "queue:read",
|
||||
domain.ScopeQueueWrite: "queue:write",
|
||||
domain.ScopeWebhookRead: "webhook:read",
|
||||
domain.ScopeWebhookWrite: "webhook:write",
|
||||
}
|
||||
|
||||
for scope, expected := range expectedScopes {
|
||||
if string(scope) != expected {
|
||||
t.Errorf("Scope %v = %q, want %q", scope, string(scope), expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommandTypeConstants(t *testing.T) {
|
||||
expectedTypes := map[domain.CommandType]string{
|
||||
domain.CommandTypeClaude: "claude",
|
||||
domain.CommandTypeShell: "shell",
|
||||
domain.CommandTypeGit: "git",
|
||||
}
|
||||
|
||||
for cmdType, expected := range expectedTypes {
|
||||
if string(cmdType) != expected {
|
||||
t.Errorf("CommandType %v = %q, want %q", cmdType, string(cmdType), expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestProjectStatusConstants(t *testing.T) {
|
||||
expectedStatuses := map[domain.ProjectStatus]string{
|
||||
domain.ProjectStatusRunning: "running",
|
||||
domain.ProjectStatusPending: "pending",
|
||||
domain.ProjectStatusFailed: "failed",
|
||||
domain.ProjectStatusNotFound: "not_found",
|
||||
domain.ProjectStatusUnknown: "unknown",
|
||||
domain.ProjectStatusError: "error",
|
||||
}
|
||||
|
||||
for status, expected := range expectedStatuses {
|
||||
if string(status) != expected {
|
||||
t.Errorf("ProjectStatus %v = %q, want %q", status, string(status), expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Type Instantiation Tests
|
||||
// =============================================================================
|
||||
|
||||
func TestProject_CanBeInstantiated(t *testing.T) {
|
||||
proj := domain.Project{
|
||||
ID: "test-project",
|
||||
Name: "Test Project",
|
||||
Description: "A test project",
|
||||
PodName: "test-pod",
|
||||
Status: domain.ProjectStatusRunning,
|
||||
Workspace: "/workspace",
|
||||
}
|
||||
|
||||
if proj.ID != "test-project" {
|
||||
t.Errorf("Project.ID = %q, want %q", proj.ID, "test-project")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommand_CanBeInstantiated(t *testing.T) {
|
||||
now := time.Now()
|
||||
cmd := domain.Command{
|
||||
ID: "cmd-1",
|
||||
ProjectID: "proj-1",
|
||||
Type: domain.CommandTypeClaude,
|
||||
Args: []string{"--version"},
|
||||
StartedAt: now,
|
||||
}
|
||||
|
||||
if cmd.ID != "cmd-1" {
|
||||
t.Errorf("Command.ID = %q, want %q", cmd.ID, "cmd-1")
|
||||
}
|
||||
if len(cmd.Args) != 1 {
|
||||
t.Errorf("Command.Args length = %d, want 1", len(cmd.Args))
|
||||
}
|
||||
}
|
||||
|
||||
func TestOutputLine_CanBeInstantiated(t *testing.T) {
|
||||
now := time.Now()
|
||||
line := domain.OutputLine{
|
||||
Stream: "stdout",
|
||||
Line: "Hello, world!",
|
||||
Timestamp: now,
|
||||
}
|
||||
|
||||
if line.Stream != "stdout" {
|
||||
t.Errorf("OutputLine.Stream = %q, want %q", line.Stream, "stdout")
|
||||
}
|
||||
}
|
||||
@ -5,6 +5,9 @@ import "errors"
|
||||
// Domain errors - these are business-level errors that should be translated
|
||||
// to appropriate HTTP status codes or gRPC error codes by the presentation layer.
|
||||
var (
|
||||
// Generic errors
|
||||
ErrNotFound = errors.New("not found")
|
||||
|
||||
// Project errors
|
||||
ErrProjectNotFound = errors.New("project not found")
|
||||
ErrProjectNotRunning = errors.New("project is not running")
|
||||
@ -26,9 +29,20 @@ var (
|
||||
ErrPromptRequired = errors.New("prompt is required")
|
||||
ErrInvalidTimeout = errors.New("timeout cannot be negative")
|
||||
|
||||
// Credential errors
|
||||
ErrCredentialNotFound = errors.New("credential not found")
|
||||
|
||||
// Work queue errors
|
||||
ErrWorkTaskNotFound = errors.New("work task not found")
|
||||
|
||||
// Worker pool errors
|
||||
ErrWorkerNotFound = errors.New("worker not found")
|
||||
ErrWorkerIDRequired = errors.New("worker ID is required")
|
||||
ErrWorkerHostnameRequired = errors.New("worker hostname is required")
|
||||
|
||||
// Build errors
|
||||
ErrBuildNotFound = errors.New("build not found")
|
||||
|
||||
// API Key errors
|
||||
ErrKeyNotFound = errors.New("api key not found")
|
||||
ErrKeyRevoked = errors.New("api key has been revoked")
|
||||
@ -39,10 +53,15 @@ var (
|
||||
ErrUnauthorized = errors.New("unauthorized")
|
||||
ErrForbidden = errors.New("forbidden")
|
||||
ErrInsufficientScope = errors.New("insufficient scope")
|
||||
ErrIPNotAllowed = errors.New("ip address not allowed")
|
||||
|
||||
// Rate limiting errors
|
||||
ErrRateLimited = errors.New("rate limit exceeded")
|
||||
|
||||
// Webhook errors
|
||||
ErrWebhookNotFound = errors.New("webhook not found")
|
||||
ErrInvalidWebhook = errors.New("invalid webhook configuration")
|
||||
|
||||
// Audit errors
|
||||
ErrAuditNotFound = errors.New("audit log entry not found")
|
||||
|
||||
|
||||
@ -2,9 +2,66 @@
|
||||
// These types represent the core business concepts of the application.
|
||||
package domain
|
||||
|
||||
import "regexp"
|
||||
|
||||
// ProjectID is a strongly-typed identifier for projects.
|
||||
type ProjectID string
|
||||
|
||||
// Project name/ID constraints.
|
||||
const (
|
||||
// MaxProjectNameLen is the maximum length for project names (K8s name limit).
|
||||
MaxProjectNameLen = 63
|
||||
)
|
||||
|
||||
// projectIDRegex validates project IDs used for referencing existing projects.
|
||||
// Allows uppercase, lowercase, digits, dashes, and underscores. Must start with a letter.
|
||||
var projectIDRegex = regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9_-]*$`)
|
||||
|
||||
// projectNameRegex validates project names for DNS/K8s resource creation.
|
||||
// Lowercase only, digits, dashes. Must start with a lowercase letter.
|
||||
var projectNameRegex = regexp.MustCompile(`^[a-z][a-z0-9-]*$`)
|
||||
|
||||
// reservedProjectNames are names that cannot be used for new projects.
|
||||
var reservedProjectNames = map[string]bool{
|
||||
"www": true, "api": true, "git": true, "ci": true,
|
||||
"registry": true, "admin": true, "root": true,
|
||||
"rdev": true, "pantheon": true,
|
||||
}
|
||||
|
||||
// ValidateProjectID validates a project ID for referencing existing projects.
|
||||
// Allows letters, digits, dashes, underscores. Must start with a letter. Max 63 chars.
|
||||
func ValidateProjectID(id string) error {
|
||||
if id == "" {
|
||||
return ErrInvalidProjectName
|
||||
}
|
||||
if len(id) > MaxProjectNameLen {
|
||||
return ErrInvalidProjectName
|
||||
}
|
||||
if !projectIDRegex.MatchString(id) {
|
||||
return ErrInvalidProjectName
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateProjectName validates a project name for DNS/K8s resource creation.
|
||||
// Lowercase letters, digits, dashes only. Must start with a lowercase letter.
|
||||
// Max 63 chars. Rejects reserved names.
|
||||
func ValidateProjectName(name string) error {
|
||||
if name == "" {
|
||||
return ErrInvalidProjectName
|
||||
}
|
||||
if len(name) > MaxProjectNameLen {
|
||||
return ErrInvalidProjectName
|
||||
}
|
||||
if !projectNameRegex.MatchString(name) {
|
||||
return ErrInvalidProjectName
|
||||
}
|
||||
if reservedProjectNames[name] {
|
||||
return ErrInvalidProjectName
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Project represents a claudebox project that can execute commands.
|
||||
type Project struct {
|
||||
ID ProjectID
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
@ -153,8 +152,4 @@ func DefaultWebhookDeliveryFilters() *WebhookDeliveryFilters {
|
||||
}
|
||||
}
|
||||
|
||||
// Webhook-related errors.
|
||||
var (
|
||||
ErrWebhookNotFound = errors.New("webhook not found")
|
||||
ErrInvalidWebhook = errors.New("invalid webhook configuration")
|
||||
)
|
||||
// Webhook-related errors are defined in errors.go for centralized error definitions.
|
||||
|
||||
170
internal/domain/work.go
Normal file
170
internal/domain/work.go
Normal file
@ -0,0 +1,170 @@
|
||||
package domain
|
||||
|
||||
import "time"
|
||||
|
||||
// WorkTaskStatus represents the status of a work task.
|
||||
type WorkTaskStatus string
|
||||
|
||||
const (
|
||||
WorkTaskStatusPending WorkTaskStatus = "pending"
|
||||
WorkTaskStatusRunning WorkTaskStatus = "running"
|
||||
WorkTaskStatusCompleted WorkTaskStatus = "completed"
|
||||
WorkTaskStatusFailed WorkTaskStatus = "failed"
|
||||
WorkTaskStatusCancelled WorkTaskStatus = "cancelled"
|
||||
)
|
||||
|
||||
// IsValid returns true if the status is a known valid status.
|
||||
func (s WorkTaskStatus) IsValid() bool {
|
||||
switch s {
|
||||
case WorkTaskStatusPending, WorkTaskStatusRunning, WorkTaskStatusCompleted,
|
||||
WorkTaskStatusFailed, WorkTaskStatusCancelled:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// WorkTaskType represents the type of work task.
|
||||
type WorkTaskType string
|
||||
|
||||
const (
|
||||
WorkTaskTypeBuild WorkTaskType = "build"
|
||||
WorkTaskTypeTest WorkTaskType = "test"
|
||||
WorkTaskTypeDeploy WorkTaskType = "deploy"
|
||||
WorkTaskTypeCustom WorkTaskType = "custom"
|
||||
)
|
||||
|
||||
// IsValid returns true if the task type is a known valid type.
|
||||
func (t WorkTaskType) IsValid() bool {
|
||||
switch t {
|
||||
case WorkTaskTypeBuild, WorkTaskTypeTest, WorkTaskTypeDeploy, WorkTaskTypeCustom:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// WorkTask represents a task in the work queue.
|
||||
type WorkTask struct {
|
||||
// ID is the unique task identifier.
|
||||
ID string
|
||||
|
||||
// ProjectID is the project this task belongs to.
|
||||
ProjectID string
|
||||
|
||||
// Type is the task type (build, test, deploy, custom).
|
||||
Type WorkTaskType
|
||||
|
||||
// Spec contains task-specific parameters.
|
||||
// For build tasks: template, prompt, variables, auto_deploy, git_url
|
||||
// For test tasks: test_command, git_url
|
||||
// For deploy tasks: image, replicas, env
|
||||
Spec map[string]any
|
||||
|
||||
// Status is the current task status.
|
||||
Status WorkTaskStatus
|
||||
|
||||
// Priority determines execution order (higher = more urgent).
|
||||
Priority int
|
||||
|
||||
// WorkerID is the ID of the worker that claimed this task.
|
||||
WorkerID string
|
||||
|
||||
// CallbackURL is the webhook URL for completion notification.
|
||||
CallbackURL string
|
||||
|
||||
// CreatedAt is when the task was created.
|
||||
CreatedAt time.Time
|
||||
|
||||
// StartedAt is when a worker started executing the task.
|
||||
StartedAt *time.Time
|
||||
|
||||
// CompletedAt is when the task finished (success or failure).
|
||||
CompletedAt *time.Time
|
||||
|
||||
// Result contains the task output (if completed).
|
||||
Result *WorkResult
|
||||
|
||||
// Error contains the error message (if failed).
|
||||
Error string
|
||||
|
||||
// RetryCount is the number of retry attempts.
|
||||
RetryCount int
|
||||
|
||||
// MaxRetries is the maximum allowed retry attempts.
|
||||
MaxRetries int
|
||||
}
|
||||
|
||||
// WorkResult contains the result of a completed task.
|
||||
type WorkResult struct {
|
||||
// Output is the main output from task execution.
|
||||
Output string `json:"output,omitempty"`
|
||||
|
||||
// Artifacts contains named artifacts from the task.
|
||||
// For build tasks: commit_sha, deploy_url, etc.
|
||||
Artifacts map[string]string `json:"artifacts,omitempty"`
|
||||
}
|
||||
|
||||
// WorkQueueStats contains queue statistics.
|
||||
type WorkQueueStats struct {
|
||||
// Pending is the count of pending tasks.
|
||||
Pending int64 `json:"pending"`
|
||||
|
||||
// Running is the count of running tasks.
|
||||
Running int64 `json:"running"`
|
||||
|
||||
// Completed is the count of completed tasks (last 24h).
|
||||
Completed int64 `json:"completed"`
|
||||
|
||||
// Failed is the count of failed tasks (last 24h).
|
||||
Failed int64 `json:"failed"`
|
||||
|
||||
// Cancelled is the count of cancelled tasks (last 24h).
|
||||
Cancelled int64 `json:"cancelled"`
|
||||
|
||||
// OldestPending is the age of the oldest pending task.
|
||||
OldestPending *time.Duration `json:"oldest_pending,omitempty"`
|
||||
}
|
||||
|
||||
// WorkListOptions contains pagination options for listing tasks.
|
||||
type WorkListOptions struct {
|
||||
// Limit is the maximum number of tasks to return (default: 50, max: 100).
|
||||
Limit int
|
||||
|
||||
// Offset is the number of tasks to skip (for pagination).
|
||||
Offset int
|
||||
}
|
||||
|
||||
// DefaultWorkListOptions returns options with default values.
|
||||
func DefaultWorkListOptions() WorkListOptions {
|
||||
return WorkListOptions{
|
||||
Limit: 50,
|
||||
Offset: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize applies defaults and limits to the options.
|
||||
func (o *WorkListOptions) Normalize() {
|
||||
if o.Limit <= 0 {
|
||||
o.Limit = 50
|
||||
}
|
||||
if o.Limit > 100 {
|
||||
o.Limit = 100
|
||||
}
|
||||
if o.Offset < 0 {
|
||||
o.Offset = 0
|
||||
}
|
||||
}
|
||||
|
||||
// WorkListResult contains paginated task results.
|
||||
type WorkListResult struct {
|
||||
// Tasks is the list of tasks.
|
||||
Tasks []*WorkTask
|
||||
|
||||
// Total is the total count of matching tasks (for pagination metadata).
|
||||
Total int64
|
||||
|
||||
// Limit is the limit that was applied.
|
||||
Limit int
|
||||
|
||||
// Offset is the offset that was applied.
|
||||
Offset int
|
||||
}
|
||||
76
internal/domain/worker.go
Normal file
76
internal/domain/worker.go
Normal file
@ -0,0 +1,76 @@
|
||||
// Package domain contains pure domain models with no external dependencies.
|
||||
package domain
|
||||
|
||||
import "time"
|
||||
|
||||
// Worker represents a registered worker in the pool.
|
||||
type Worker struct {
|
||||
// ID is the unique worker identifier (typically pod name).
|
||||
ID string `json:"id"`
|
||||
|
||||
// Hostname is the worker's hostname.
|
||||
Hostname string `json:"hostname"`
|
||||
|
||||
// Status is the current worker state.
|
||||
Status WorkerStatus `json:"status"`
|
||||
|
||||
// CurrentTask is the ID of the task currently being executed.
|
||||
CurrentTask string `json:"current_task,omitempty"`
|
||||
|
||||
// Capabilities lists what this worker can do (e.g., "build", "deploy").
|
||||
Capabilities []string `json:"capabilities,omitempty"`
|
||||
|
||||
// RegisteredAt is when the worker first joined the pool.
|
||||
RegisteredAt time.Time `json:"registered_at"`
|
||||
|
||||
// LastHeartbeat is the most recent heartbeat timestamp.
|
||||
LastHeartbeat time.Time `json:"last_heartbeat"`
|
||||
|
||||
// Version is the worker binary version.
|
||||
Version string `json:"version,omitempty"`
|
||||
}
|
||||
|
||||
// IsAvailable returns true if the worker can accept new tasks.
|
||||
func (w *Worker) IsAvailable() bool {
|
||||
return w.Status == WorkerStatusIdle
|
||||
}
|
||||
|
||||
// IsHealthy returns true if the worker has a recent heartbeat.
|
||||
func (w *Worker) IsHealthy(threshold time.Duration) bool {
|
||||
return time.Since(w.LastHeartbeat) < threshold
|
||||
}
|
||||
|
||||
// WorkerStatus represents the current state of a worker.
|
||||
type WorkerStatus string
|
||||
|
||||
const (
|
||||
WorkerStatusIdle WorkerStatus = "idle"
|
||||
WorkerStatusBusy WorkerStatus = "busy"
|
||||
WorkerStatusDraining WorkerStatus = "draining"
|
||||
WorkerStatusOffline WorkerStatus = "offline"
|
||||
)
|
||||
|
||||
// IsValid returns true if the status is a known valid status.
|
||||
func (s WorkerStatus) IsValid() bool {
|
||||
switch s {
|
||||
case WorkerStatusIdle, WorkerStatusBusy, WorkerStatusDraining, WorkerStatusOffline:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ValidWorkerStatuses returns all valid worker status values.
|
||||
func ValidWorkerStatuses() []WorkerStatus {
|
||||
return []WorkerStatus{WorkerStatusIdle, WorkerStatusBusy, WorkerStatusDraining, WorkerStatusOffline}
|
||||
}
|
||||
|
||||
// Validate checks that the Worker has required fields for registration.
|
||||
func (w *Worker) Validate() error {
|
||||
if w.ID == "" {
|
||||
return ErrWorkerIDRequired
|
||||
}
|
||||
if w.Hostname == "" {
|
||||
return ErrWorkerHostnameRequired
|
||||
}
|
||||
return nil
|
||||
}
|
||||
146
internal/domain/worker_test.go
Normal file
146
internal/domain/worker_test.go
Normal file
@ -0,0 +1,146 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestWorker_IsAvailable(t *testing.T) {
|
||||
tests := []struct {
|
||||
status WorkerStatus
|
||||
want bool
|
||||
}{
|
||||
{WorkerStatusIdle, true},
|
||||
{WorkerStatusBusy, false},
|
||||
{WorkerStatusDraining, false},
|
||||
{WorkerStatusOffline, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(string(tt.status), func(t *testing.T) {
|
||||
w := &Worker{Status: tt.status}
|
||||
if got := w.IsAvailable(); got != tt.want {
|
||||
t.Errorf("IsAvailable() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorker_IsHealthy(t *testing.T) {
|
||||
threshold := 90 * time.Second
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
lastHeartbeat time.Time
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "recent heartbeat",
|
||||
lastHeartbeat: time.Now().Add(-30 * time.Second),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "stale heartbeat",
|
||||
lastHeartbeat: time.Now().Add(-2 * time.Minute),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "exactly at threshold",
|
||||
lastHeartbeat: time.Now().Add(-90 * time.Second),
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := &Worker{LastHeartbeat: tt.lastHeartbeat}
|
||||
if got := w.IsHealthy(threshold); got != tt.want {
|
||||
t.Errorf("IsHealthy() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorkerStatus_IsValid(t *testing.T) {
|
||||
tests := []struct {
|
||||
status WorkerStatus
|
||||
want bool
|
||||
}{
|
||||
{WorkerStatusIdle, true},
|
||||
{WorkerStatusBusy, true},
|
||||
{WorkerStatusDraining, true},
|
||||
{WorkerStatusOffline, true},
|
||||
{"unknown", false},
|
||||
{"", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(string(tt.status), func(t *testing.T) {
|
||||
if got := tt.status.IsValid(); got != tt.want {
|
||||
t.Errorf("IsValid() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorker_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
worker Worker
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "valid worker",
|
||||
worker: Worker{ID: "worker-1", Hostname: "host-1"},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "missing ID",
|
||||
worker: Worker{Hostname: "host-1"},
|
||||
wantErr: ErrWorkerIDRequired,
|
||||
},
|
||||
{
|
||||
name: "missing hostname",
|
||||
worker: Worker{ID: "worker-1"},
|
||||
wantErr: ErrWorkerHostnameRequired,
|
||||
},
|
||||
{
|
||||
name: "both missing",
|
||||
worker: Worker{},
|
||||
wantErr: ErrWorkerIDRequired,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.worker.Validate()
|
||||
if tt.wantErr != nil {
|
||||
if !errors.Is(err, tt.wantErr) {
|
||||
t.Errorf("Validate() error = %v, want %v", err, tt.wantErr)
|
||||
}
|
||||
} else if err != nil {
|
||||
t.Errorf("Validate() unexpected error = %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidWorkerStatuses(t *testing.T) {
|
||||
statuses := ValidWorkerStatuses()
|
||||
if len(statuses) != 4 {
|
||||
t.Errorf("expected 4 statuses, got %d", len(statuses))
|
||||
}
|
||||
|
||||
found := make(map[WorkerStatus]bool)
|
||||
for _, s := range statuses {
|
||||
found[s] = true
|
||||
}
|
||||
|
||||
expected := []WorkerStatus{WorkerStatusIdle, WorkerStatusBusy, WorkerStatusDraining, WorkerStatusOffline}
|
||||
for _, s := range expected {
|
||||
if !found[s] {
|
||||
t.Errorf("expected %s in valid statuses", s)
|
||||
}
|
||||
}
|
||||
}
|
||||
269
internal/handlers/agents.go
Normal file
269
internal/handlers/agents.go
Normal file
@ -0,0 +1,269 @@
|
||||
// Package handlers provides HTTP handlers for the rdev API.
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
"github.com/orchard9/rdev/pkg/api"
|
||||
)
|
||||
|
||||
// AgentsHandler handles code agent management endpoints.
|
||||
type AgentsHandler struct {
|
||||
registry port.CodeAgentRegistry
|
||||
}
|
||||
|
||||
// NewAgentsHandler creates a new agents handler.
|
||||
func NewAgentsHandler(registry port.CodeAgentRegistry) *AgentsHandler {
|
||||
return &AgentsHandler{
|
||||
registry: registry,
|
||||
}
|
||||
}
|
||||
|
||||
// Mount registers the agent routes.
|
||||
func (h *AgentsHandler) Mount(r api.Router) {
|
||||
r.Route("/agents", func(r chi.Router) {
|
||||
r.Get("/", h.List)
|
||||
r.Get("/health", h.Health)
|
||||
r.Get("/{provider}", h.GetCapabilities)
|
||||
r.Post("/default", h.SetDefault)
|
||||
})
|
||||
}
|
||||
|
||||
// AgentDTO is the data transfer object for code agents.
|
||||
type AgentDTO struct {
|
||||
Provider string `json:"provider"`
|
||||
Name string `json:"name"`
|
||||
Available bool `json:"available"`
|
||||
Default bool `json:"default"`
|
||||
Models []string `json:"supported_models,omitempty"`
|
||||
DefaultModel string `json:"default_model,omitempty"`
|
||||
}
|
||||
|
||||
// AgentCapabilitiesDTO is the DTO for agent capabilities.
|
||||
type AgentCapabilitiesDTO struct {
|
||||
Provider string `json:"provider"`
|
||||
SupportsSessionContinuation bool `json:"supports_session_continuation"`
|
||||
SupportsModelSelection bool `json:"supports_model_selection"`
|
||||
SupportsToolControl bool `json:"supports_tool_control"`
|
||||
SupportsStreaming bool `json:"supports_streaming"`
|
||||
SupportedModels []string `json:"supported_models"`
|
||||
DefaultModel string `json:"default_model"`
|
||||
MaxPromptLength int `json:"max_prompt_length,omitempty"`
|
||||
}
|
||||
|
||||
// ListAgentsResponse is the response for GET /agents.
|
||||
type ListAgentsResponse struct {
|
||||
Agents []AgentDTO `json:"agents"`
|
||||
DefaultAgent string `json:"default_agent"`
|
||||
TotalAgents int `json:"total_agents"`
|
||||
AvailableCount int `json:"available_count"`
|
||||
}
|
||||
|
||||
// List returns all registered code agents and their status.
|
||||
// GET /agents
|
||||
func (h *AgentsHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
if h.registry == nil {
|
||||
api.WriteSuccess(w, r, ListAgentsResponse{
|
||||
Agents: []AgentDTO{},
|
||||
DefaultAgent: "",
|
||||
TotalAgents: 0,
|
||||
AvailableCount: 0,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
providers := h.registry.Available()
|
||||
defaultProvider := h.registry.DefaultProvider()
|
||||
|
||||
// Check availability for each agent
|
||||
availableAgents := h.registry.AvailableAgents(r.Context())
|
||||
availableSet := make(map[domain.AgentProvider]bool)
|
||||
for _, agent := range availableAgents {
|
||||
availableSet[agent.Provider()] = true
|
||||
}
|
||||
|
||||
agents := make([]AgentDTO, 0, len(providers))
|
||||
availableCount := 0
|
||||
|
||||
for _, provider := range providers {
|
||||
agent := h.registry.Get(provider)
|
||||
if agent == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
caps := agent.Capabilities()
|
||||
isAvailable := availableSet[provider]
|
||||
isDefault := provider == defaultProvider
|
||||
|
||||
if isAvailable {
|
||||
availableCount++
|
||||
}
|
||||
|
||||
agents = append(agents, AgentDTO{
|
||||
Provider: string(provider),
|
||||
Name: agent.Name(),
|
||||
Available: isAvailable,
|
||||
Default: isDefault,
|
||||
Models: caps.SupportedModels,
|
||||
DefaultModel: caps.DefaultModel,
|
||||
})
|
||||
}
|
||||
|
||||
api.WriteSuccess(w, r, ListAgentsResponse{
|
||||
Agents: agents,
|
||||
DefaultAgent: string(defaultProvider),
|
||||
TotalAgents: len(agents),
|
||||
AvailableCount: availableCount,
|
||||
})
|
||||
}
|
||||
|
||||
// GetCapabilities returns the capabilities of a specific agent.
|
||||
// GET /agents/{provider}
|
||||
func (h *AgentsHandler) GetCapabilities(w http.ResponseWriter, r *http.Request) {
|
||||
providerStr := chi.URLParam(r, "provider")
|
||||
provider := domain.AgentProvider(providerStr)
|
||||
|
||||
if h.registry == nil {
|
||||
api.WriteNotFound(w, r, "no agents registered")
|
||||
return
|
||||
}
|
||||
|
||||
agent := h.registry.Get(provider)
|
||||
if agent == nil {
|
||||
api.WriteNotFound(w, r, "agent not found: "+providerStr)
|
||||
return
|
||||
}
|
||||
|
||||
caps := agent.Capabilities()
|
||||
|
||||
api.WriteSuccess(w, r, AgentCapabilitiesDTO{
|
||||
Provider: string(caps.Provider),
|
||||
SupportsSessionContinuation: caps.SupportsSessionContinuation,
|
||||
SupportsModelSelection: caps.SupportsModelSelection,
|
||||
SupportsToolControl: caps.SupportsToolControl,
|
||||
SupportsStreaming: caps.SupportsStreaming,
|
||||
SupportedModels: caps.SupportedModels,
|
||||
DefaultModel: caps.DefaultModel,
|
||||
MaxPromptLength: caps.MaxPromptLength,
|
||||
})
|
||||
}
|
||||
|
||||
// AgentHealthDTO represents the health status of a single agent.
|
||||
type AgentHealthDTO struct {
|
||||
Provider string `json:"provider"`
|
||||
Name string `json:"name"`
|
||||
Healthy bool `json:"healthy"`
|
||||
Message string `json:"message"`
|
||||
Latency string `json:"latency"`
|
||||
CheckedAt string `json:"checked_at"`
|
||||
}
|
||||
|
||||
// AgentHealthResponse is the response for GET /agents/health.
|
||||
type AgentHealthResponse struct {
|
||||
Agents []AgentHealthDTO `json:"agents"`
|
||||
HealthyCount int `json:"healthy_count"`
|
||||
TotalCount int `json:"total_count"`
|
||||
DefaultAgent string `json:"default_agent"`
|
||||
DefaultHealth bool `json:"default_healthy"`
|
||||
}
|
||||
|
||||
// Health returns the health status of all registered code agents.
|
||||
// GET /agents/health
|
||||
func (h *AgentsHandler) Health(w http.ResponseWriter, r *http.Request) {
|
||||
if h.registry == nil {
|
||||
api.WriteSuccess(w, r, AgentHealthResponse{
|
||||
Agents: []AgentHealthDTO{},
|
||||
HealthyCount: 0,
|
||||
TotalCount: 0,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
providers := h.registry.Available()
|
||||
defaultProvider := h.registry.DefaultProvider()
|
||||
|
||||
agents := make([]AgentHealthDTO, 0, len(providers))
|
||||
healthyCount := 0
|
||||
defaultHealthy := false
|
||||
|
||||
for _, provider := range providers {
|
||||
agent := h.registry.Get(provider)
|
||||
if agent == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
healthy := agent.Available(r.Context())
|
||||
latency := time.Since(start)
|
||||
|
||||
msg := "available"
|
||||
if !healthy {
|
||||
msg = "unavailable"
|
||||
}
|
||||
|
||||
if healthy {
|
||||
healthyCount++
|
||||
}
|
||||
if provider == defaultProvider {
|
||||
defaultHealthy = healthy
|
||||
}
|
||||
|
||||
agents = append(agents, AgentHealthDTO{
|
||||
Provider: string(provider),
|
||||
Name: agent.Name(),
|
||||
Healthy: healthy,
|
||||
Message: msg,
|
||||
Latency: latency.String(),
|
||||
CheckedAt: time.Now().UTC().Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
|
||||
api.WriteSuccess(w, r, AgentHealthResponse{
|
||||
Agents: agents,
|
||||
HealthyCount: healthyCount,
|
||||
TotalCount: len(agents),
|
||||
DefaultAgent: string(defaultProvider),
|
||||
DefaultHealth: defaultHealthy,
|
||||
})
|
||||
}
|
||||
|
||||
// SetDefaultRequest is the request body for POST /agents/default.
|
||||
type SetDefaultRequest struct {
|
||||
Provider string `json:"provider"`
|
||||
}
|
||||
|
||||
// SetDefault changes the default code agent.
|
||||
// POST /agents/default
|
||||
func (h *AgentsHandler) SetDefault(w http.ResponseWriter, r *http.Request) {
|
||||
if h.registry == nil {
|
||||
api.WriteBadRequest(w, r, "no agents registered")
|
||||
return
|
||||
}
|
||||
|
||||
var req SetDefaultRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
api.WriteBadRequest(w, r, "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Provider == "" {
|
||||
api.WriteBadRequest(w, r, "provider is required")
|
||||
return
|
||||
}
|
||||
|
||||
provider := domain.AgentProvider(req.Provider)
|
||||
if err := h.registry.SetDefault(provider); err != nil {
|
||||
api.WriteBadRequest(w, r, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
api.WriteSuccess(w, r, map[string]any{
|
||||
"default_agent": req.Provider,
|
||||
"message": "default agent updated",
|
||||
})
|
||||
}
|
||||
372
internal/handlers/agents_test.go
Normal file
372
internal/handlers/agents_test.go
Normal file
@ -0,0 +1,372 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
)
|
||||
|
||||
// mockCodeAgent implements port.CodeAgent for testing.
|
||||
type mockCodeAgent struct {
|
||||
name string
|
||||
provider domain.AgentProvider
|
||||
available bool
|
||||
capabilities domain.AgentCapabilities
|
||||
}
|
||||
|
||||
func (m *mockCodeAgent) Name() string { return m.name }
|
||||
func (m *mockCodeAgent) Provider() domain.AgentProvider { return m.provider }
|
||||
func (m *mockCodeAgent) Execute(ctx context.Context, req *domain.AgentRequest, handler domain.AgentEventHandler) (*domain.AgentResult, error) {
|
||||
return &domain.AgentResult{}, nil
|
||||
}
|
||||
func (m *mockCodeAgent) Cancel(ctx context.Context, sessionID string) error { return nil }
|
||||
func (m *mockCodeAgent) Capabilities() domain.AgentCapabilities { return m.capabilities }
|
||||
func (m *mockCodeAgent) Available(ctx context.Context) bool { return m.available }
|
||||
|
||||
// mockAgentRegistry implements port.CodeAgentRegistry for testing.
|
||||
type mockAgentRegistry struct {
|
||||
agents map[domain.AgentProvider]*mockCodeAgent
|
||||
defaultAgent domain.AgentProvider
|
||||
setDefaultErr error
|
||||
}
|
||||
|
||||
func newMockAgentRegistry() *mockAgentRegistry {
|
||||
return &mockAgentRegistry{
|
||||
agents: make(map[domain.AgentProvider]*mockCodeAgent),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockAgentRegistry) Register(agent port.CodeAgent) {
|
||||
m.agents[agent.Provider()] = agent.(*mockCodeAgent)
|
||||
if m.defaultAgent == "" {
|
||||
m.defaultAgent = agent.Provider()
|
||||
}
|
||||
}
|
||||
|
||||
// registerAgent is a helper for tests to directly add mockCodeAgents
|
||||
func (m *mockAgentRegistry) registerAgent(agent *mockCodeAgent) {
|
||||
m.agents[agent.provider] = agent
|
||||
if m.defaultAgent == "" {
|
||||
m.defaultAgent = agent.provider
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockAgentRegistry) Get(provider domain.AgentProvider) port.CodeAgent {
|
||||
agent, ok := m.agents[provider]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return agent
|
||||
}
|
||||
|
||||
func (m *mockAgentRegistry) Default() port.CodeAgent {
|
||||
return m.Get(m.defaultAgent)
|
||||
}
|
||||
|
||||
func (m *mockAgentRegistry) DefaultProvider() domain.AgentProvider {
|
||||
return m.defaultAgent
|
||||
}
|
||||
|
||||
func (m *mockAgentRegistry) SetDefault(provider domain.AgentProvider) error {
|
||||
if m.setDefaultErr != nil {
|
||||
return m.setDefaultErr
|
||||
}
|
||||
if _, ok := m.agents[provider]; !ok {
|
||||
return fmt.Errorf("agent provider %q is not registered", provider)
|
||||
}
|
||||
m.defaultAgent = provider
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAgentRegistry) Available() []domain.AgentProvider {
|
||||
providers := make([]domain.AgentProvider, 0, len(m.agents))
|
||||
for p := range m.agents {
|
||||
providers = append(providers, p)
|
||||
}
|
||||
return providers
|
||||
}
|
||||
|
||||
func (m *mockAgentRegistry) AvailableAgents(ctx context.Context) []port.CodeAgent {
|
||||
var available []port.CodeAgent
|
||||
for _, agent := range m.agents {
|
||||
if agent.available {
|
||||
available = append(available, agent)
|
||||
}
|
||||
}
|
||||
return available
|
||||
}
|
||||
|
||||
func (m *mockAgentRegistry) Count() int {
|
||||
return len(m.agents)
|
||||
}
|
||||
|
||||
func TestAgentsHandler_List(t *testing.T) {
|
||||
registry := newMockAgentRegistry()
|
||||
registry.registerAgent(&mockCodeAgent{
|
||||
name: "Claude Code",
|
||||
provider: domain.AgentProviderClaudeCode,
|
||||
available: true,
|
||||
capabilities: domain.AgentCapabilities{
|
||||
Provider: domain.AgentProviderClaudeCode,
|
||||
SupportsSessionContinuation: true,
|
||||
SupportedModels: []string{"claude-sonnet-4-20250514"},
|
||||
DefaultModel: "claude-sonnet-4-20250514",
|
||||
},
|
||||
})
|
||||
registry.registerAgent(&mockCodeAgent{
|
||||
name: "OpenCode",
|
||||
provider: domain.AgentProviderOpenCode,
|
||||
available: false,
|
||||
capabilities: domain.AgentCapabilities{
|
||||
Provider: domain.AgentProviderOpenCode,
|
||||
SupportsSessionContinuation: true,
|
||||
SupportsModelSelection: true,
|
||||
SupportedModels: []string{"gpt-4o", "claude-sonnet-4-20250514"},
|
||||
DefaultModel: "claude-sonnet-4-20250514",
|
||||
},
|
||||
})
|
||||
|
||||
handler := NewAgentsHandler(registry)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/agents", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.List(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Unwrap from api.Response
|
||||
respBody := w.Body.Bytes()
|
||||
var apiResp struct {
|
||||
Data ListAgentsResponse `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &apiResp); err != nil {
|
||||
t.Fatalf("failed to decode api response: %v", err)
|
||||
}
|
||||
|
||||
if apiResp.Data.TotalAgents != 2 {
|
||||
t.Errorf("expected 2 agents, got %d", apiResp.Data.TotalAgents)
|
||||
}
|
||||
if apiResp.Data.AvailableCount != 1 {
|
||||
t.Errorf("expected 1 available agent, got %d", apiResp.Data.AvailableCount)
|
||||
}
|
||||
if apiResp.Data.DefaultAgent != string(domain.AgentProviderClaudeCode) {
|
||||
t.Errorf("expected default agent to be claude-code, got %s", apiResp.Data.DefaultAgent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentsHandler_List_NoRegistry(t *testing.T) {
|
||||
handler := NewAgentsHandler(nil)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/agents", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.List(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentsHandler_GetCapabilities(t *testing.T) {
|
||||
registry := newMockAgentRegistry()
|
||||
registry.registerAgent(&mockCodeAgent{
|
||||
name: "Claude Code",
|
||||
provider: domain.AgentProviderClaudeCode,
|
||||
capabilities: domain.AgentCapabilities{
|
||||
Provider: domain.AgentProviderClaudeCode,
|
||||
SupportsSessionContinuation: true,
|
||||
SupportsStreaming: true,
|
||||
SupportedModels: []string{"claude-sonnet-4-20250514"},
|
||||
DefaultModel: "claude-sonnet-4-20250514",
|
||||
MaxPromptLength: 100000,
|
||||
},
|
||||
})
|
||||
|
||||
handler := NewAgentsHandler(registry)
|
||||
|
||||
// Set up chi router for URL params
|
||||
r := chi.NewRouter()
|
||||
r.Get("/agents/{provider}", handler.GetCapabilities)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/agents/claudecode", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var apiResp struct {
|
||||
Data AgentCapabilitiesDTO `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &apiResp); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if apiResp.Data.Provider != string(domain.AgentProviderClaudeCode) {
|
||||
t.Errorf("expected provider claudecode, got %s", apiResp.Data.Provider)
|
||||
}
|
||||
if !apiResp.Data.SupportsSessionContinuation {
|
||||
t.Error("expected session continuation support")
|
||||
}
|
||||
if !apiResp.Data.SupportsStreaming {
|
||||
t.Error("expected streaming support")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentsHandler_GetCapabilities_NotFound(t *testing.T) {
|
||||
registry := newMockAgentRegistry()
|
||||
handler := NewAgentsHandler(registry)
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.Get("/agents/{provider}", handler.GetCapabilities)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/agents/unknown", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected status 404, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentsHandler_SetDefault(t *testing.T) {
|
||||
registry := newMockAgentRegistry()
|
||||
registry.registerAgent(&mockCodeAgent{
|
||||
name: "Claude Code",
|
||||
provider: domain.AgentProviderClaudeCode,
|
||||
})
|
||||
registry.registerAgent(&mockCodeAgent{
|
||||
name: "OpenCode",
|
||||
provider: domain.AgentProviderOpenCode,
|
||||
})
|
||||
|
||||
handler := NewAgentsHandler(registry)
|
||||
|
||||
body := SetDefaultRequest{Provider: string(domain.AgentProviderOpenCode)}
|
||||
bodyBytes, _ := json.Marshal(body)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/agents/default", bytes.NewReader(bodyBytes))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.SetDefault(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Verify default was changed
|
||||
if registry.defaultAgent != domain.AgentProviderOpenCode {
|
||||
t.Errorf("expected default to be opencode, got %s", registry.defaultAgent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentsHandler_SetDefault_InvalidProvider(t *testing.T) {
|
||||
registry := newMockAgentRegistry()
|
||||
registry.registerAgent(&mockCodeAgent{
|
||||
name: "Claude Code",
|
||||
provider: domain.AgentProviderClaudeCode,
|
||||
})
|
||||
|
||||
handler := NewAgentsHandler(registry)
|
||||
|
||||
body := SetDefaultRequest{Provider: "unknown"}
|
||||
bodyBytes, _ := json.Marshal(body)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/agents/default", bytes.NewReader(bodyBytes))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.SetDefault(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentsHandler_SetDefault_EmptyProvider(t *testing.T) {
|
||||
registry := newMockAgentRegistry()
|
||||
handler := NewAgentsHandler(registry)
|
||||
|
||||
body := SetDefaultRequest{Provider: ""}
|
||||
bodyBytes, _ := json.Marshal(body)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/agents/default", bytes.NewReader(bodyBytes))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.SetDefault(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentsHandler_Health(t *testing.T) {
|
||||
registry := newMockAgentRegistry()
|
||||
registry.registerAgent(&mockCodeAgent{
|
||||
name: "Claude Code",
|
||||
provider: domain.AgentProviderClaudeCode,
|
||||
available: true,
|
||||
})
|
||||
registry.registerAgent(&mockCodeAgent{
|
||||
name: "OpenCode",
|
||||
provider: domain.AgentProviderOpenCode,
|
||||
available: false,
|
||||
})
|
||||
|
||||
handler := NewAgentsHandler(registry)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/agents/health", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.Health(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var apiResp struct {
|
||||
Data AgentHealthResponse `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &apiResp); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if apiResp.Data.TotalCount != 2 {
|
||||
t.Errorf("expected 2 agents, got %d", apiResp.Data.TotalCount)
|
||||
}
|
||||
if apiResp.Data.HealthyCount != 1 {
|
||||
t.Errorf("expected 1 healthy agent, got %d", apiResp.Data.HealthyCount)
|
||||
}
|
||||
if !apiResp.Data.DefaultHealth {
|
||||
t.Error("expected default agent to be healthy")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentsHandler_Health_NoRegistry(t *testing.T) {
|
||||
handler := NewAgentsHandler(nil)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/agents/health", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.Health(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
238
internal/handlers/builds.go
Normal file
238
internal/handlers/builds.go
Normal file
@ -0,0 +1,238 @@
|
||||
// Package handlers provides HTTP handlers for the rdev API.
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/orchard9/rdev/internal/auth"
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/service"
|
||||
"github.com/orchard9/rdev/pkg/api"
|
||||
)
|
||||
|
||||
// maxRequestBodySize is the maximum allowed size for request bodies (1MB).
|
||||
const maxRequestBodySize = 1 << 20
|
||||
|
||||
// BuildsHandler handles project-scoped build endpoints.
|
||||
type BuildsHandler struct {
|
||||
buildService *service.BuildService
|
||||
}
|
||||
|
||||
// NewBuildsHandler creates a new builds handler.
|
||||
func NewBuildsHandler(buildService *service.BuildService) *BuildsHandler {
|
||||
return &BuildsHandler{
|
||||
buildService: buildService,
|
||||
}
|
||||
}
|
||||
|
||||
// Mount registers the build routes.
|
||||
func (h *BuildsHandler) Mount(r api.Router) {
|
||||
// Project-scoped build endpoints
|
||||
r.With(auth.RequireScope(auth.ScopeBuildWrite, auth.ScopeAdmin)).
|
||||
Post("/projects/{id}/builds", h.StartBuild)
|
||||
r.With(auth.RequireScope(auth.ScopeBuildRead, auth.ScopeAdmin)).
|
||||
Get("/projects/{id}/builds", h.ListBuilds)
|
||||
|
||||
// Build detail by task ID
|
||||
r.With(auth.RequireScope(auth.ScopeBuildRead, auth.ScopeAdmin)).
|
||||
Get("/builds/{taskId}", h.GetBuild)
|
||||
}
|
||||
|
||||
// StartBuildRequest is the request body for POST /projects/{id}/builds.
|
||||
type StartBuildRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Template string `json:"template,omitempty"`
|
||||
Variables map[string]string `json:"variables,omitempty"`
|
||||
AutoCommit bool `json:"auto_commit"`
|
||||
AutoPush bool `json:"auto_push"`
|
||||
CallbackURL string `json:"callback_url,omitempty"`
|
||||
}
|
||||
|
||||
// StartBuildResponse is the response for POST /projects/{id}/builds.
|
||||
type StartBuildResponse struct {
|
||||
TaskID string `json:"task_id"`
|
||||
ProjectID string `json:"project_id"`
|
||||
Status string `json:"status"`
|
||||
StatusURL string `json:"status_url"`
|
||||
}
|
||||
|
||||
// BuildAuditDTO is the data transfer object for build audit entries.
|
||||
type BuildAuditDTO struct {
|
||||
TaskID string `json:"task_id"`
|
||||
ProjectID string `json:"project_id"`
|
||||
WorkerID string `json:"worker_id,omitempty"`
|
||||
Status string `json:"status"`
|
||||
Prompt string `json:"prompt"`
|
||||
Template string `json:"template,omitempty"`
|
||||
AutoCommit bool `json:"auto_commit"`
|
||||
AutoPush bool `json:"auto_push"`
|
||||
Result *BuildResultDTO `json:"result,omitempty"`
|
||||
StartedAt string `json:"started_at"`
|
||||
CompletedAt string `json:"completed_at,omitempty"`
|
||||
}
|
||||
|
||||
// BuildResultDTO is the data transfer object for build results.
|
||||
type BuildResultDTO struct {
|
||||
Success bool `json:"success"`
|
||||
Output string `json:"output,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
CommitSHA string `json:"commit_sha,omitempty"`
|
||||
FilesChanged []string `json:"files_changed,omitempty"`
|
||||
DurationMs int64 `json:"duration_ms"`
|
||||
Artifacts map[string]string `json:"artifacts,omitempty"`
|
||||
}
|
||||
|
||||
func toBuildAuditDTO(e *domain.BuildAuditEntry) *BuildAuditDTO {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
dto := &BuildAuditDTO{
|
||||
TaskID: e.TaskID,
|
||||
ProjectID: e.ProjectID,
|
||||
WorkerID: e.WorkerID,
|
||||
Status: string(e.Status),
|
||||
Prompt: e.Spec.Prompt,
|
||||
Template: e.Spec.Template,
|
||||
AutoCommit: e.Spec.AutoCommit,
|
||||
AutoPush: e.Spec.AutoPush,
|
||||
StartedAt: e.StartedAt.Format("2006-01-02T15:04:05Z07:00"),
|
||||
}
|
||||
if e.CompletedAt != nil {
|
||||
dto.CompletedAt = e.CompletedAt.Format("2006-01-02T15:04:05Z07:00")
|
||||
}
|
||||
if e.Result != nil {
|
||||
dto.Result = &BuildResultDTO{
|
||||
Success: e.Result.Success,
|
||||
Output: e.Result.Output,
|
||||
Error: e.Result.Error,
|
||||
CommitSHA: e.Result.CommitSHA,
|
||||
FilesChanged: e.Result.FilesChanged,
|
||||
DurationMs: e.Result.DurationMs,
|
||||
Artifacts: e.Result.Artifacts,
|
||||
}
|
||||
}
|
||||
return dto
|
||||
}
|
||||
|
||||
// StartBuild enqueues a build task for a project.
|
||||
// POST /projects/{id}/builds
|
||||
func (h *BuildsHandler) StartBuild(w http.ResponseWriter, r *http.Request) {
|
||||
projectID := chi.URLParam(r, "id")
|
||||
if projectID == "" {
|
||||
api.WriteBadRequest(w, r, "project ID is required")
|
||||
return
|
||||
}
|
||||
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxRequestBodySize)
|
||||
var req StartBuildRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
api.WriteBadRequest(w, r, "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Prompt == "" {
|
||||
api.WriteBadRequest(w, r, "prompt is required")
|
||||
return
|
||||
}
|
||||
|
||||
// Validate callback URL to prevent SSRF
|
||||
if req.CallbackURL != "" {
|
||||
if err := domain.ValidateCallbackURL(req.CallbackURL); err != nil {
|
||||
api.WriteBadRequest(w, r, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
spec := domain.BuildSpec{
|
||||
Prompt: req.Prompt,
|
||||
Template: req.Template,
|
||||
Variables: req.Variables,
|
||||
AutoCommit: req.AutoCommit,
|
||||
AutoPush: req.AutoPush,
|
||||
CallbackURL: req.CallbackURL,
|
||||
}
|
||||
|
||||
taskID, err := h.buildService.StartBuild(r.Context(), projectID, spec)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrPromptRequired) {
|
||||
api.WriteBadRequest(w, r, err.Error())
|
||||
return
|
||||
}
|
||||
api.WriteInternalError(w, r, "failed to start build")
|
||||
return
|
||||
}
|
||||
|
||||
api.WriteCreated(w, r, StartBuildResponse{
|
||||
TaskID: taskID,
|
||||
ProjectID: projectID,
|
||||
Status: "pending",
|
||||
StatusURL: "/builds/" + taskID,
|
||||
})
|
||||
}
|
||||
|
||||
// ListBuilds returns build history for a project.
|
||||
// GET /projects/{id}/builds?limit=50
|
||||
func (h *BuildsHandler) ListBuilds(w http.ResponseWriter, r *http.Request) {
|
||||
projectID := chi.URLParam(r, "id")
|
||||
if projectID == "" {
|
||||
api.WriteBadRequest(w, r, "project ID is required")
|
||||
return
|
||||
}
|
||||
|
||||
limit := 50
|
||||
if limitStr := r.URL.Query().Get("limit"); limitStr != "" {
|
||||
l, err := strconv.Atoi(limitStr)
|
||||
if err != nil {
|
||||
api.WriteBadRequest(w, r, "limit must be a valid integer")
|
||||
return
|
||||
}
|
||||
if l < 1 || l > 200 {
|
||||
api.WriteBadRequest(w, r, "limit must be between 1 and 200")
|
||||
return
|
||||
}
|
||||
limit = l
|
||||
}
|
||||
|
||||
builds, err := h.buildService.ListBuilds(r.Context(), projectID, limit)
|
||||
if err != nil {
|
||||
api.WriteInternalError(w, r, "failed to list builds")
|
||||
return
|
||||
}
|
||||
|
||||
dtos := make([]*BuildAuditDTO, len(builds))
|
||||
for i, b := range builds {
|
||||
dtos[i] = toBuildAuditDTO(b)
|
||||
}
|
||||
|
||||
api.WriteSuccess(w, r, map[string]any{
|
||||
"builds": dtos,
|
||||
"project_id": projectID,
|
||||
"total": len(dtos),
|
||||
})
|
||||
}
|
||||
|
||||
// GetBuild returns the status of a specific build.
|
||||
// GET /builds/{taskId}
|
||||
func (h *BuildsHandler) GetBuild(w http.ResponseWriter, r *http.Request) {
|
||||
taskID := chi.URLParam(r, "taskId")
|
||||
if taskID == "" {
|
||||
api.WriteBadRequest(w, r, "task ID is required")
|
||||
return
|
||||
}
|
||||
|
||||
entry, err := h.buildService.GetBuildStatus(r.Context(), taskID)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrBuildNotFound) {
|
||||
api.WriteNotFound(w, r, "build not found: "+taskID)
|
||||
return
|
||||
}
|
||||
api.WriteInternalError(w, r, "failed to get build status")
|
||||
return
|
||||
}
|
||||
|
||||
api.WriteSuccess(w, r, toBuildAuditDTO(entry))
|
||||
}
|
||||
351
internal/handlers/builds_test.go
Normal file
351
internal/handlers/builds_test.go
Normal file
@ -0,0 +1,351 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/orchard9/rdev/internal/auth"
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
"github.com/orchard9/rdev/internal/service"
|
||||
)
|
||||
|
||||
// testAdminAuth is a chi middleware that injects an admin API key into the
|
||||
// request context so auth.RequireScope passes in tests.
|
||||
func testAdminAuth(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := auth.WithAPIKey(r.Context(), &domain.APIKey{
|
||||
Scopes: []domain.Scope{domain.ScopeAdmin},
|
||||
})
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// mockBuildAudit implements port.BuildAudit for testing.
|
||||
type mockBuildAudit struct {
|
||||
entries map[string]*domain.BuildAuditEntry
|
||||
err error
|
||||
}
|
||||
|
||||
func newMockBuildAudit() *mockBuildAudit {
|
||||
return &mockBuildAudit{
|
||||
entries: make(map[string]*domain.BuildAuditEntry),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockBuildAudit) Record(_ context.Context, entry *domain.BuildAuditEntry) error {
|
||||
if m.err != nil {
|
||||
return m.err
|
||||
}
|
||||
m.entries[entry.TaskID] = entry
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockBuildAudit) Update(_ context.Context, taskID string, result *domain.BuildResult) error {
|
||||
if m.err != nil {
|
||||
return m.err
|
||||
}
|
||||
entry, ok := m.entries[taskID]
|
||||
if !ok {
|
||||
return domain.ErrBuildNotFound
|
||||
}
|
||||
entry.Result = result
|
||||
if result.Success {
|
||||
entry.Status = domain.BuildStatusCompleted
|
||||
} else {
|
||||
entry.Status = domain.BuildStatusFailed
|
||||
}
|
||||
now := time.Now()
|
||||
entry.CompletedAt = &now
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockBuildAudit) Get(_ context.Context, taskID string) (*domain.BuildAuditEntry, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
entry, ok := m.entries[taskID]
|
||||
if !ok {
|
||||
return nil, domain.ErrBuildNotFound
|
||||
}
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
func (m *mockBuildAudit) List(_ context.Context, filter port.BuildAuditFilter) ([]*domain.BuildAuditEntry, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
var result []*domain.BuildAuditEntry
|
||||
for _, entry := range m.entries {
|
||||
if filter.ProjectID != "" && entry.ProjectID != filter.ProjectID {
|
||||
continue
|
||||
}
|
||||
result = append(result, entry)
|
||||
if filter.Limit > 0 && len(result) >= filter.Limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func TestBuildsHandler_StartBuild(t *testing.T) {
|
||||
queue := newMockWorkQueue()
|
||||
audit := newMockBuildAudit()
|
||||
buildService := service.NewBuildService(queue, audit, nil)
|
||||
handler := NewBuildsHandler(buildService)
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Use(testAdminAuth)
|
||||
handler.Mount(router)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
projectID string
|
||||
body StartBuildRequest
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "valid_build",
|
||||
projectID: "my-project",
|
||||
body: StartBuildRequest{
|
||||
Prompt: "Build a landing page with Next.js",
|
||||
Template: "nextjs-landing",
|
||||
AutoCommit: true,
|
||||
AutoPush: true,
|
||||
},
|
||||
wantStatus: http.StatusCreated,
|
||||
},
|
||||
{
|
||||
name: "missing_prompt",
|
||||
projectID: "my-project",
|
||||
body: StartBuildRequest{
|
||||
Template: "nextjs-landing",
|
||||
},
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "minimal_build",
|
||||
projectID: "test-project",
|
||||
body: StartBuildRequest{
|
||||
Prompt: "Add a footer component",
|
||||
},
|
||||
wantStatus: http.StatusCreated,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
body, _ := json.Marshal(tt.body)
|
||||
req := httptest.NewRequest(http.MethodPost, "/projects/"+tt.projectID+"/builds", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != tt.wantStatus {
|
||||
t.Errorf("got status %d, want %d; body: %s", rec.Code, tt.wantStatus, rec.Body.String())
|
||||
}
|
||||
|
||||
if tt.wantStatus == http.StatusCreated {
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
data, ok := resp["data"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected data to be map, got %T", resp["data"])
|
||||
}
|
||||
if data["task_id"] == nil || data["task_id"] == "" {
|
||||
t.Error("expected task_id in response")
|
||||
}
|
||||
if data["project_id"] != tt.projectID {
|
||||
t.Errorf("got project_id=%v, want %s", data["project_id"], tt.projectID)
|
||||
}
|
||||
if data["status"] != "pending" {
|
||||
t.Errorf("got status=%v, want pending", data["status"])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildsHandler_GetBuild(t *testing.T) {
|
||||
queue := newMockWorkQueue()
|
||||
audit := newMockBuildAudit()
|
||||
buildService := service.NewBuildService(queue, audit, nil)
|
||||
handler := NewBuildsHandler(buildService)
|
||||
|
||||
// Pre-populate an audit entry
|
||||
audit.entries["task-1"] = &domain.BuildAuditEntry{
|
||||
TaskID: "task-1",
|
||||
ProjectID: "my-project",
|
||||
WorkerID: "worker-1",
|
||||
Spec: domain.BuildSpec{
|
||||
Prompt: "Build landing page",
|
||||
Template: "nextjs-landing",
|
||||
},
|
||||
Status: domain.BuildStatusCompleted,
|
||||
StartedAt: time.Now().Add(-5 * time.Minute),
|
||||
Result: &domain.BuildResult{
|
||||
Success: true,
|
||||
CommitSHA: "abc123",
|
||||
DurationMs: 30000,
|
||||
},
|
||||
}
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Use(testAdminAuth)
|
||||
handler.Mount(router)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
taskID string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "existing_build",
|
||||
taskID: "task-1",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "not_found",
|
||||
taskID: "nonexistent",
|
||||
wantStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/builds/"+tt.taskID, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != tt.wantStatus {
|
||||
t.Errorf("got status %d, want %d; body: %s", rec.Code, tt.wantStatus, rec.Body.String())
|
||||
}
|
||||
|
||||
if tt.wantStatus == http.StatusOK {
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
data, ok := resp["data"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected data to be map, got %T", resp["data"])
|
||||
}
|
||||
if data["task_id"] != "task-1" {
|
||||
t.Errorf("got task_id=%v, want task-1", data["task_id"])
|
||||
}
|
||||
if data["status"] != "completed" {
|
||||
t.Errorf("got status=%v, want completed", data["status"])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildsHandler_ListBuilds(t *testing.T) {
|
||||
queue := newMockWorkQueue()
|
||||
audit := newMockBuildAudit()
|
||||
buildService := service.NewBuildService(queue, audit, nil)
|
||||
handler := NewBuildsHandler(buildService)
|
||||
|
||||
// Pre-populate audit entries
|
||||
audit.entries["task-1"] = &domain.BuildAuditEntry{
|
||||
TaskID: "task-1",
|
||||
ProjectID: "project-a",
|
||||
Status: domain.BuildStatusCompleted,
|
||||
Spec: domain.BuildSpec{Prompt: "Build page"},
|
||||
StartedAt: time.Now().Add(-10 * time.Minute),
|
||||
}
|
||||
audit.entries["task-2"] = &domain.BuildAuditEntry{
|
||||
TaskID: "task-2",
|
||||
ProjectID: "project-a",
|
||||
Status: domain.BuildStatusRunning,
|
||||
Spec: domain.BuildSpec{Prompt: "Add footer"},
|
||||
StartedAt: time.Now().Add(-5 * time.Minute),
|
||||
}
|
||||
audit.entries["task-3"] = &domain.BuildAuditEntry{
|
||||
TaskID: "task-3",
|
||||
ProjectID: "project-b",
|
||||
Status: domain.BuildStatusPending,
|
||||
Spec: domain.BuildSpec{Prompt: "Other project"},
|
||||
StartedAt: time.Now(),
|
||||
}
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Use(testAdminAuth)
|
||||
handler.Mount(router)
|
||||
|
||||
t.Run("list_builds_for_project", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/projects/project-a/builds", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("got status %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
data, ok := resp["data"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected data to be map, got %T", resp["data"])
|
||||
}
|
||||
totalF, ok := data["total"].(float64)
|
||||
if !ok {
|
||||
t.Fatalf("expected total to be float64, got %T", data["total"])
|
||||
}
|
||||
if int(totalF) != 2 {
|
||||
t.Errorf("got total=%d, want 2", int(totalF))
|
||||
}
|
||||
if data["project_id"] != "project-a" {
|
||||
t.Errorf("got project_id=%v, want project-a", data["project_id"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("list_with_limit", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/projects/project-a/builds?limit=1", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("got status %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
data, ok := resp["data"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected data to be map, got %T", resp["data"])
|
||||
}
|
||||
totalF, ok := data["total"].(float64)
|
||||
if !ok {
|
||||
t.Fatalf("expected total to be float64, got %T", data["total"])
|
||||
}
|
||||
if int(totalF) != 1 {
|
||||
t.Errorf("got total=%d, want 1 (limited)", int(totalF))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid_limit", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/projects/project-a/builds?limit=abc", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("got status %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
}
|
||||
184
internal/handlers/create_and_build.go
Normal file
184
internal/handlers/create_and_build.go
Normal file
@ -0,0 +1,184 @@
|
||||
// Package handlers provides HTTP handlers for the rdev API.
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/service"
|
||||
"github.com/orchard9/rdev/pkg/api"
|
||||
)
|
||||
|
||||
// CreateAndBuildHandler handles the combined create-project-and-build endpoint.
|
||||
type CreateAndBuildHandler struct {
|
||||
infraService *service.ProjectInfraService
|
||||
buildService *service.BuildService
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewCreateAndBuildHandler creates a new create-and-build handler.
|
||||
func NewCreateAndBuildHandler(
|
||||
infraService *service.ProjectInfraService,
|
||||
buildService *service.BuildService,
|
||||
logger *slog.Logger,
|
||||
) *CreateAndBuildHandler {
|
||||
if logger == nil {
|
||||
logger = slog.Default()
|
||||
}
|
||||
return &CreateAndBuildHandler{
|
||||
infraService: infraService,
|
||||
buildService: buildService,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Mount registers the create-and-build route.
|
||||
func (h *CreateAndBuildHandler) Mount(r api.Router) {
|
||||
r.Post("/project/create-and-build", h.CreateAndBuild)
|
||||
}
|
||||
|
||||
// CreateAndBuildRequest is the request body for POST /project/create-and-build.
|
||||
type CreateAndBuildRequest struct {
|
||||
// Project creation fields
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Private bool `json:"private,omitempty"`
|
||||
Template string `json:"template,omitempty"`
|
||||
|
||||
// Build fields
|
||||
Prompt string `json:"prompt"`
|
||||
Variables map[string]string `json:"variables,omitempty"`
|
||||
AutoCommit bool `json:"auto_commit"`
|
||||
AutoPush bool `json:"auto_push"`
|
||||
CallbackURL string `json:"callback_url,omitempty"`
|
||||
}
|
||||
|
||||
// CreateAndBuildResponse is the response for POST /project/create-and-build.
|
||||
type CreateAndBuildResponse struct {
|
||||
// Project info
|
||||
ProjectID string `json:"project_id"`
|
||||
Name string `json:"name"`
|
||||
Domain string `json:"domain"`
|
||||
URL string `json:"url"`
|
||||
|
||||
// Git info
|
||||
Git map[string]string `json:"git,omitempty"`
|
||||
|
||||
// Build info
|
||||
TaskID string `json:"task_id"`
|
||||
Status string `json:"status"`
|
||||
StatusURL string `json:"status_url"`
|
||||
}
|
||||
|
||||
// CreateAndBuild creates a project and immediately enqueues a build task.
|
||||
// POST /project/create-and-build
|
||||
func (h *CreateAndBuildHandler) CreateAndBuild(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if h.infraService == nil {
|
||||
api.WriteInternalError(w, r, "project infrastructure service not configured")
|
||||
return
|
||||
}
|
||||
if h.buildService == nil {
|
||||
api.WriteInternalError(w, r, "build service not configured")
|
||||
return
|
||||
}
|
||||
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxRequestBodySize)
|
||||
var req CreateAndBuildRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
api.WriteBadRequest(w, r, "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
api.WriteBadRequest(w, r, "name is required")
|
||||
return
|
||||
}
|
||||
if req.Prompt == "" {
|
||||
api.WriteBadRequest(w, r, "prompt is required")
|
||||
return
|
||||
}
|
||||
|
||||
// Validate callback URL to prevent SSRF
|
||||
if req.CallbackURL != "" {
|
||||
if err := domain.ValidateCallbackURL(req.CallbackURL); err != nil {
|
||||
api.WriteBadRequest(w, r, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Step 1: Create the project
|
||||
projectResult, err := h.infraService.CreateProject(ctx, service.CreateProjectRequest{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Private: req.Private,
|
||||
Template: req.Template,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrInvalidProjectName) {
|
||||
api.WriteBadRequest(w, r, err.Error())
|
||||
return
|
||||
}
|
||||
h.logger.Error("project creation failed", "error", err, "name", req.Name)
|
||||
api.WriteInternalError(w, r, "failed to create project")
|
||||
return
|
||||
}
|
||||
|
||||
// Step 2: Enqueue the build task
|
||||
spec := domain.BuildSpec{
|
||||
Prompt: req.Prompt,
|
||||
Template: req.Template,
|
||||
Variables: req.Variables,
|
||||
AutoCommit: req.AutoCommit,
|
||||
AutoPush: req.AutoPush,
|
||||
CallbackURL: req.CallbackURL,
|
||||
}
|
||||
|
||||
taskID, err := h.buildService.StartBuild(ctx, projectResult.ProjectID, spec)
|
||||
if err != nil {
|
||||
h.logger.Error("build enqueue failed after project creation",
|
||||
"error", err,
|
||||
"project", projectResult.ProjectID,
|
||||
)
|
||||
// Project was created but build failed to enqueue.
|
||||
// Return the project info with a generic error and retry URL.
|
||||
api.WriteJSON(w, r, http.StatusCreated, map[string]any{
|
||||
"project_id": projectResult.ProjectID,
|
||||
"name": projectResult.Name,
|
||||
"domain": projectResult.Domain,
|
||||
"url": projectResult.URL,
|
||||
"build_error": "project created but build failed to enqueue",
|
||||
"retry_url": "/projects/" + projectResult.ProjectID + "/builds",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
resp := CreateAndBuildResponse{
|
||||
ProjectID: projectResult.ProjectID,
|
||||
Name: projectResult.Name,
|
||||
Domain: projectResult.Domain,
|
||||
URL: projectResult.URL,
|
||||
TaskID: taskID,
|
||||
Status: "pending",
|
||||
StatusURL: "/builds/" + taskID,
|
||||
}
|
||||
|
||||
if projectResult.CloneHTTP != "" {
|
||||
resp.Git = map[string]string{
|
||||
"owner": projectResult.GitRepoOwner,
|
||||
"name": projectResult.GitRepoName,
|
||||
"clone_ssh": projectResult.CloneSSH,
|
||||
"clone_http": projectResult.CloneHTTP,
|
||||
"html_url": projectResult.HTMLURL,
|
||||
}
|
||||
}
|
||||
|
||||
api.WriteCreated(w, r, resp)
|
||||
}
|
||||
@ -2,10 +2,13 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/orchard9/rdev/internal/auth"
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
"github.com/orchard9/rdev/pkg/api"
|
||||
@ -22,6 +25,16 @@ func NewCredentialsHandler(store port.CredentialStore) *CredentialsHandler {
|
||||
return &CredentialsHandler{store: store}
|
||||
}
|
||||
|
||||
// updatedBy extracts the authenticated identity from the request context.
|
||||
// Returns the API key name for regular keys, or "superadmin" for admin key auth
|
||||
// (which has ID "admin") to preserve consistency with existing database records.
|
||||
func updatedBy(ctx context.Context) string {
|
||||
if key := auth.GetAPIKey(ctx); key != nil && key.ID != "admin" {
|
||||
return key.Name
|
||||
}
|
||||
return "superadmin"
|
||||
}
|
||||
|
||||
// Mount registers the credentials routes.
|
||||
// All routes require superadmin authentication (handled by middleware).
|
||||
func (h *CredentialsHandler) Mount(r api.Router) {
|
||||
@ -146,7 +159,7 @@ func (h *CredentialsHandler) Set(w http.ResponseWriter, r *http.Request) {
|
||||
Value: req.Value,
|
||||
Description: req.Description,
|
||||
Category: req.Category,
|
||||
UpdatedBy: "superadmin", // Could extract from auth context
|
||||
UpdatedBy: updatedBy(ctx),
|
||||
}
|
||||
|
||||
if err := h.store.Set(ctx, cred); err != nil {
|
||||
@ -191,7 +204,7 @@ func (h *CredentialsHandler) SetBatch(w http.ResponseWriter, r *http.Request) {
|
||||
Value: c.Value,
|
||||
Description: c.Description,
|
||||
Category: c.Category,
|
||||
UpdatedBy: "superadmin",
|
||||
UpdatedBy: updatedBy(ctx),
|
||||
}
|
||||
}
|
||||
|
||||
@ -224,7 +237,11 @@ func (h *CredentialsHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if err := h.store.Delete(ctx, key); err != nil {
|
||||
api.WriteNotFound(w, r, "credential not found")
|
||||
if errors.Is(err, domain.ErrCredentialNotFound) {
|
||||
api.WriteNotFound(w, r, "credential not found")
|
||||
return
|
||||
}
|
||||
api.WriteInternalError(w, r, "failed to delete credential")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
440
internal/handlers/credentials_test.go
Normal file
440
internal/handlers/credentials_test.go
Normal file
@ -0,0 +1,440 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
)
|
||||
|
||||
// mockCredentialStore implements port.CredentialStore for testing.
|
||||
type mockCredentialStore struct {
|
||||
creds map[string]domain.Credential
|
||||
err error
|
||||
}
|
||||
|
||||
func newMockCredentialStore() *mockCredentialStore {
|
||||
return &mockCredentialStore{
|
||||
creds: make(map[string]domain.Credential),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockCredentialStore) Get(_ context.Context, key string) (string, error) {
|
||||
if m.err != nil {
|
||||
return "", m.err
|
||||
}
|
||||
c, ok := m.creds[key]
|
||||
if !ok {
|
||||
return "", nil
|
||||
}
|
||||
return c.Value, nil
|
||||
}
|
||||
|
||||
func (m *mockCredentialStore) GetRequired(_ context.Context, key string) (string, error) {
|
||||
if m.err != nil {
|
||||
return "", m.err
|
||||
}
|
||||
c, ok := m.creds[key]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("credential not found: %s", key)
|
||||
}
|
||||
return c.Value, nil
|
||||
}
|
||||
|
||||
func (m *mockCredentialStore) Set(_ context.Context, cred domain.Credential) error {
|
||||
if m.err != nil {
|
||||
return m.err
|
||||
}
|
||||
cred.CreatedAt = time.Now()
|
||||
cred.UpdatedAt = time.Now()
|
||||
m.creds[cred.Key] = cred
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockCredentialStore) Delete(_ context.Context, key string) error {
|
||||
if m.err != nil {
|
||||
return m.err
|
||||
}
|
||||
if _, ok := m.creds[key]; !ok {
|
||||
return domain.ErrCredentialNotFound
|
||||
}
|
||||
delete(m.creds, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockCredentialStore) List(_ context.Context) ([]domain.Credential, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
var result []domain.Credential
|
||||
for _, c := range m.creds {
|
||||
result = append(result, c)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockCredentialStore) ListByCategory(_ context.Context, category string) ([]domain.Credential, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
var result []domain.Credential
|
||||
for _, c := range m.creds {
|
||||
if c.Category == category {
|
||||
result = append(result, c)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockCredentialStore) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
result := make(map[string]string)
|
||||
for _, k := range keys {
|
||||
if c, ok := m.creds[k]; ok {
|
||||
result[k] = c.Value
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockCredentialStore) SetMultiple(_ context.Context, creds []domain.Credential) error {
|
||||
if m.err != nil {
|
||||
return m.err
|
||||
}
|
||||
for _, c := range creds {
|
||||
c.CreatedAt = time.Now()
|
||||
c.UpdatedAt = time.Now()
|
||||
m.creds[c.Key] = c
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupCredentialsHandler() (*CredentialsHandler, *mockCredentialStore, chi.Router) {
|
||||
store := newMockCredentialStore()
|
||||
h := NewCredentialsHandler(store)
|
||||
r := chi.NewRouter()
|
||||
h.Mount(r)
|
||||
return h, store, r
|
||||
}
|
||||
|
||||
func TestCredentialsHandler_List(t *testing.T) {
|
||||
t.Run("empty list", func(t *testing.T) {
|
||||
_, _, router := setupCredentialsHandler()
|
||||
req := httptest.NewRequest("GET", "/credentials", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
data, ok := resp["data"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("response missing data array")
|
||||
}
|
||||
if len(data) != 0 {
|
||||
t.Errorf("data length = %d, want 0", len(data))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with credentials", func(t *testing.T) {
|
||||
_, store, router := setupCredentialsHandler()
|
||||
store.creds["MY_TOKEN"] = domain.Credential{
|
||||
Key: "MY_TOKEN",
|
||||
Value: "****",
|
||||
Category: "gitea",
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/credentials", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
data, ok := resp["data"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("response missing data array")
|
||||
}
|
||||
if len(data) != 1 {
|
||||
t.Errorf("data length = %d, want 1", len(data))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filter by category", func(t *testing.T) {
|
||||
_, store, router := setupCredentialsHandler()
|
||||
store.creds["GITEA_TOKEN"] = domain.Credential{
|
||||
Key: "GITEA_TOKEN", Value: "****", Category: "gitea",
|
||||
CreatedAt: time.Now(), UpdatedAt: time.Now(),
|
||||
}
|
||||
store.creds["CF_TOKEN"] = domain.Credential{
|
||||
Key: "CF_TOKEN", Value: "****", Category: "cloudflare",
|
||||
CreatedAt: time.Now(), UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/credentials?category=gitea", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
data := resp["data"].([]any)
|
||||
if len(data) != 1 {
|
||||
t.Errorf("data length = %d, want 1", len(data))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCredentialsHandler_Get(t *testing.T) {
|
||||
t.Run("existing credential", func(t *testing.T) {
|
||||
_, store, router := setupCredentialsHandler()
|
||||
store.creds["MY_TOKEN"] = domain.Credential{
|
||||
Key: "MY_TOKEN", Value: "secret123",
|
||||
CreatedAt: time.Now(), UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/credentials/MY_TOKEN", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
data := resp["data"].(map[string]any)
|
||||
if data["key"] != "MY_TOKEN" {
|
||||
t.Errorf("key = %q, want %q", data["key"], "MY_TOKEN")
|
||||
}
|
||||
if data["value"] != "secret123" {
|
||||
t.Errorf("value = %q, want %q", data["value"], "secret123")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("not found", func(t *testing.T) {
|
||||
_, _, router := setupCredentialsHandler()
|
||||
|
||||
req := httptest.NewRequest("GET", "/credentials/MISSING", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCredentialsHandler_Set(t *testing.T) {
|
||||
t.Run("valid credential", func(t *testing.T) {
|
||||
_, store, router := setupCredentialsHandler()
|
||||
|
||||
body, _ := json.Marshal(SetCredentialRequest{
|
||||
Key: "NEW_TOKEN",
|
||||
Value: "secret",
|
||||
Description: "A test token",
|
||||
Category: "gitea",
|
||||
})
|
||||
req := httptest.NewRequest("POST", "/credentials", bytes.NewReader(body))
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusCreated {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusCreated)
|
||||
}
|
||||
|
||||
if _, ok := store.creds["NEW_TOKEN"]; !ok {
|
||||
t.Error("credential not stored")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing key", func(t *testing.T) {
|
||||
_, _, router := setupCredentialsHandler()
|
||||
|
||||
body, _ := json.Marshal(SetCredentialRequest{Value: "secret"})
|
||||
req := httptest.NewRequest("POST", "/credentials", bytes.NewReader(body))
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing value", func(t *testing.T) {
|
||||
_, _, router := setupCredentialsHandler()
|
||||
|
||||
body, _ := json.Marshal(SetCredentialRequest{Key: "TOKEN"})
|
||||
req := httptest.NewRequest("POST", "/credentials", bytes.NewReader(body))
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid json", func(t *testing.T) {
|
||||
_, _, router := setupCredentialsHandler()
|
||||
|
||||
req := httptest.NewRequest("POST", "/credentials", bytes.NewReader([]byte("not json")))
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCredentialsHandler_SetBatch(t *testing.T) {
|
||||
t.Run("valid batch", func(t *testing.T) {
|
||||
_, store, router := setupCredentialsHandler()
|
||||
|
||||
body, _ := json.Marshal(SetBatchRequest{
|
||||
Credentials: []SetCredentialRequest{
|
||||
{Key: "TOKEN1", Value: "val1"},
|
||||
{Key: "TOKEN2", Value: "val2"},
|
||||
},
|
||||
})
|
||||
req := httptest.NewRequest("POST", "/credentials/batch", bytes.NewReader(body))
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusCreated {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusCreated)
|
||||
}
|
||||
|
||||
if len(store.creds) != 2 {
|
||||
t.Errorf("stored credentials = %d, want 2", len(store.creds))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty array", func(t *testing.T) {
|
||||
_, _, router := setupCredentialsHandler()
|
||||
|
||||
body, _ := json.Marshal(SetBatchRequest{Credentials: []SetCredentialRequest{}})
|
||||
req := httptest.NewRequest("POST", "/credentials/batch", bytes.NewReader(body))
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing key in batch", func(t *testing.T) {
|
||||
_, _, router := setupCredentialsHandler()
|
||||
|
||||
body, _ := json.Marshal(SetBatchRequest{
|
||||
Credentials: []SetCredentialRequest{
|
||||
{Key: "TOKEN1", Value: "val1"},
|
||||
{Key: "", Value: "val2"},
|
||||
},
|
||||
})
|
||||
req := httptest.NewRequest("POST", "/credentials/batch", bytes.NewReader(body))
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing value in batch", func(t *testing.T) {
|
||||
_, _, router := setupCredentialsHandler()
|
||||
|
||||
body, _ := json.Marshal(SetBatchRequest{
|
||||
Credentials: []SetCredentialRequest{
|
||||
{Key: "TOKEN1", Value: ""},
|
||||
},
|
||||
})
|
||||
req := httptest.NewRequest("POST", "/credentials/batch", bytes.NewReader(body))
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCredentialsHandler_Delete(t *testing.T) {
|
||||
t.Run("existing credential", func(t *testing.T) {
|
||||
_, store, router := setupCredentialsHandler()
|
||||
store.creds["TO_DELETE"] = domain.Credential{
|
||||
Key: "TO_DELETE", Value: "val",
|
||||
CreatedAt: time.Now(), UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/credentials/TO_DELETE", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
data := resp["data"].(map[string]any)
|
||||
if data["status"] != "deleted" {
|
||||
t.Errorf("status = %q, want %q", data["status"], "deleted")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("not found returns 404", func(t *testing.T) {
|
||||
_, _, router := setupCredentialsHandler()
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/credentials/NONEXISTENT", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("store error returns 500", func(t *testing.T) {
|
||||
_, store, router := setupCredentialsHandler()
|
||||
store.err = fmt.Errorf("database connection lost")
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/credentials/ANY_KEY", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusInternalServerError {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -3,31 +3,51 @@ package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
"github.com/orchard9/rdev/pkg/api"
|
||||
k8sclient "k8s.io/client-go/kubernetes"
|
||||
)
|
||||
|
||||
// ExecutorHealthChecker reports whether a background executor is running.
|
||||
type ExecutorHealthChecker interface {
|
||||
Running() bool
|
||||
WorkerID() string
|
||||
}
|
||||
|
||||
// HealthHandler handles health and readiness checks.
|
||||
type HealthHandler struct {
|
||||
serviceName string
|
||||
db *sql.DB
|
||||
k8sClient *k8sclient.Clientset
|
||||
serviceName string
|
||||
db port.DatabasePinger
|
||||
k8sChecker port.KubernetesChecker
|
||||
agentRegistry port.CodeAgentRegistry
|
||||
workExecutor ExecutorHealthChecker
|
||||
}
|
||||
|
||||
// NewHealthHandler creates a new health handler with dependencies.
|
||||
func NewHealthHandler(serviceName string, db *sql.DB, k8sClient *k8sclient.Clientset) *HealthHandler {
|
||||
func NewHealthHandler(serviceName string, db port.DatabasePinger, k8sChecker port.KubernetesChecker) *HealthHandler {
|
||||
return &HealthHandler{
|
||||
serviceName: serviceName,
|
||||
db: db,
|
||||
k8sClient: k8sClient,
|
||||
k8sChecker: k8sChecker,
|
||||
}
|
||||
}
|
||||
|
||||
// WithAgentRegistry adds a code agent registry for health monitoring.
|
||||
func (h *HealthHandler) WithAgentRegistry(registry port.CodeAgentRegistry) *HealthHandler {
|
||||
h.agentRegistry = registry
|
||||
return h
|
||||
}
|
||||
|
||||
// WithWorkExecutor adds a work executor for health monitoring.
|
||||
func (h *HealthHandler) WithWorkExecutor(executor ExecutorHealthChecker) *HealthHandler {
|
||||
h.workExecutor = executor
|
||||
return h
|
||||
}
|
||||
|
||||
// Health returns a simple liveness check.
|
||||
// This should be lightweight and only fail if the process is unhealthy.
|
||||
// GET /health
|
||||
@ -59,7 +79,7 @@ func (h *HealthHandler) Ready(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Kubernetes check
|
||||
if h.k8sClient != nil {
|
||||
if h.k8sChecker != nil {
|
||||
k8sCheck := h.checkKubernetes(ctx)
|
||||
checks["kubernetes"] = k8sCheck
|
||||
if !k8sCheck.Healthy {
|
||||
@ -67,6 +87,19 @@ func (h *HealthHandler) Ready(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// Code agent checks (informational - don't affect overall readiness)
|
||||
if h.agentRegistry != nil {
|
||||
agentChecks := h.checkCodeAgents(ctx)
|
||||
for name, check := range agentChecks {
|
||||
checks["agent:"+name] = check
|
||||
}
|
||||
}
|
||||
|
||||
// Work executor check (informational)
|
||||
if h.workExecutor != nil {
|
||||
checks["work_executor"] = h.checkWorkExecutor()
|
||||
}
|
||||
|
||||
response := ReadinessResponse{
|
||||
Status: "ready",
|
||||
Service: h.serviceName,
|
||||
@ -107,11 +140,11 @@ func (h *HealthHandler) checkDatabase(ctx context.Context) CheckResult {
|
||||
}
|
||||
|
||||
// checkKubernetes performs a Kubernetes API health check.
|
||||
func (h *HealthHandler) checkKubernetes(ctx context.Context) CheckResult {
|
||||
func (h *HealthHandler) checkKubernetes(_ context.Context) CheckResult {
|
||||
start := time.Now()
|
||||
|
||||
// Try to get server version - lightweight API call
|
||||
_, err := h.k8sClient.Discovery().ServerVersion()
|
||||
_, err := h.k8sChecker.ServerVersion()
|
||||
latency := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
@ -139,6 +172,51 @@ func (h *HealthHandler) checkKubernetes(ctx context.Context) CheckResult {
|
||||
}
|
||||
}
|
||||
|
||||
// checkCodeAgents performs health checks on all registered code agents.
|
||||
func (h *HealthHandler) checkCodeAgents(ctx context.Context) map[string]CheckResult {
|
||||
results := make(map[string]CheckResult)
|
||||
|
||||
providers := h.agentRegistry.Available()
|
||||
for _, provider := range providers {
|
||||
agent := h.agentRegistry.Get(provider)
|
||||
if agent == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
available := agent.Available(ctx)
|
||||
latency := time.Since(start)
|
||||
|
||||
msg := "available"
|
||||
if !available {
|
||||
msg = "unavailable"
|
||||
}
|
||||
|
||||
results[string(provider)] = CheckResult{
|
||||
Healthy: available,
|
||||
Message: fmt.Sprintf("%s (%s)", msg, agent.Name()),
|
||||
Latency: latency.String(),
|
||||
LastCheck: time.Now().UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// checkWorkExecutor checks whether the work executor is running.
|
||||
func (h *HealthHandler) checkWorkExecutor() CheckResult {
|
||||
running := h.workExecutor.Running()
|
||||
msg := fmt.Sprintf("worker %s: running", h.workExecutor.WorkerID())
|
||||
if !running {
|
||||
msg = fmt.Sprintf("worker %s: stopped", h.workExecutor.WorkerID())
|
||||
}
|
||||
return CheckResult{
|
||||
Healthy: running,
|
||||
Message: msg,
|
||||
LastCheck: time.Now().UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
// CheckResult represents the result of a health check.
|
||||
type CheckResult struct {
|
||||
Healthy bool `json:"healthy"`
|
||||
|
||||
@ -4,10 +4,8 @@ package handlers
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
@ -16,12 +14,13 @@ import (
|
||||
"github.com/orchard9/rdev/pkg/api"
|
||||
)
|
||||
|
||||
// InfrastructureHandler handles git, deployment, and DNS endpoints.
|
||||
// InfrastructureHandler handles git, deployment, DNS, and CI pipeline endpoints.
|
||||
type InfrastructureHandler struct {
|
||||
gitRepo port.GitRepository
|
||||
dns port.DNSProvider
|
||||
deployer port.Deployer
|
||||
projects port.ProjectRepository
|
||||
gitRepo port.GitRepository
|
||||
dns port.DNSProvider
|
||||
deployer port.Deployer
|
||||
projects port.ProjectRepository
|
||||
ciProvider port.CIProvider
|
||||
|
||||
// Config
|
||||
defaultGitOwner string
|
||||
@ -29,21 +28,9 @@ type InfrastructureHandler struct {
|
||||
clusterIP string
|
||||
}
|
||||
|
||||
// projectIDRegex validates project IDs (alphanumeric, dash, underscore only).
|
||||
var projectIDRegex = regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9_-]*$`)
|
||||
|
||||
// validateProjectID validates that a project ID is safe for use as repo/deployment name.
|
||||
func validateProjectID(id string) error {
|
||||
if id == "" {
|
||||
return errors.New("project ID cannot be empty")
|
||||
}
|
||||
if len(id) > 63 { // K8s name limit
|
||||
return errors.New("project ID too long (max 63 characters)")
|
||||
}
|
||||
if !projectIDRegex.MatchString(id) {
|
||||
return errors.New("project ID must start with a letter and contain only alphanumeric characters, dashes, or underscores")
|
||||
}
|
||||
return nil
|
||||
return domain.ValidateProjectID(id)
|
||||
}
|
||||
|
||||
// InfrastructureConfig configures the infrastructure handler.
|
||||
@ -62,6 +49,7 @@ func NewInfrastructureHandler(
|
||||
dns port.DNSProvider,
|
||||
deployer port.Deployer,
|
||||
projects port.ProjectRepository,
|
||||
ciProvider port.CIProvider,
|
||||
cfg InfrastructureConfig,
|
||||
) *InfrastructureHandler {
|
||||
return &InfrastructureHandler{
|
||||
@ -69,6 +57,7 @@ func NewInfrastructureHandler(
|
||||
dns: dns,
|
||||
deployer: deployer,
|
||||
projects: projects,
|
||||
ciProvider: ciProvider,
|
||||
defaultGitOwner: cfg.DefaultGitOwner,
|
||||
defaultDomain: cfg.DefaultDomain,
|
||||
clusterIP: cfg.ClusterIP,
|
||||
@ -90,9 +79,18 @@ func (h *InfrastructureHandler) Mount(r api.Router) {
|
||||
r.Post("/projects/{id}/deploy/scale", h.ScaleDeploy)
|
||||
r.Get("/projects/{id}/deploy/logs", h.GetDeployLogs)
|
||||
|
||||
// Domain endpoints
|
||||
// Domain endpoints (single)
|
||||
r.Post("/projects/{id}/domain", h.AddDomain)
|
||||
r.Delete("/projects/{id}/domain", h.RemoveDomain)
|
||||
|
||||
// Domain alias management (multi-domain)
|
||||
r.Get("/projects/{id}/domains", h.ListDomains)
|
||||
r.Post("/projects/{id}/domains", h.AddDomainAlias)
|
||||
r.Delete("/projects/{id}/domains/{domain}", h.RemoveDomainAlias)
|
||||
|
||||
// CI pipeline endpoints
|
||||
r.Get("/projects/{id}/pipelines", h.ListPipelines)
|
||||
r.Get("/projects/{id}/pipelines/{number}", h.GetPipeline)
|
||||
}
|
||||
|
||||
// CreateRepoRequest is the request body for POST /projects/{id}/repo.
|
||||
|
||||
@ -12,6 +12,9 @@ import (
|
||||
"github.com/orchard9/rdev/pkg/api"
|
||||
)
|
||||
|
||||
// maxReplicas is the maximum number of deployment replicas allowed.
|
||||
const maxReplicas = 10
|
||||
|
||||
// DeployRequest is the request body for POST /projects/{id}/deploy.
|
||||
type DeployRequest struct {
|
||||
Image string `json:"image"` // Container image
|
||||
@ -218,8 +221,8 @@ func (h *InfrastructureHandler) ScaleDeploy(w http.ResponseWriter, r *http.Reque
|
||||
return
|
||||
}
|
||||
|
||||
if req.Replicas < 0 || req.Replicas > 10 {
|
||||
api.WriteBadRequest(w, r, "replicas must be between 0 and 10")
|
||||
if req.Replicas < 0 || req.Replicas > maxReplicas {
|
||||
api.WriteBadRequest(w, r, fmt.Sprintf("replicas must be between 0 and %d", maxReplicas))
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
195
internal/handlers/infrastructure_domain_test.go
Normal file
195
internal/handlers/infrastructure_domain_test.go
Normal file
@ -0,0 +1,195 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
)
|
||||
|
||||
func TestInfrastructureHandler_RestartDeploy(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
_, _, _, _, router := setupInfraHandler()
|
||||
|
||||
req := httptest.NewRequest("POST", "/projects/myapp/deploy/restart", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestInfrastructureHandler_ScaleDeploy(t *testing.T) {
|
||||
t.Run("valid scale", func(t *testing.T) {
|
||||
_, _, _, _, router := setupInfraHandler()
|
||||
|
||||
body, _ := json.Marshal(ScaleRequest{Replicas: 3})
|
||||
req := httptest.NewRequest("POST", "/projects/myapp/deploy/scale", bytes.NewReader(body))
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid replicas too high", func(t *testing.T) {
|
||||
_, _, _, _, router := setupInfraHandler()
|
||||
|
||||
body, _ := json.Marshal(ScaleRequest{Replicas: 11})
|
||||
req := httptest.NewRequest("POST", "/projects/myapp/deploy/scale", bytes.NewReader(body))
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid replicas negative", func(t *testing.T) {
|
||||
_, _, _, _, router := setupInfraHandler()
|
||||
|
||||
body, _ := json.Marshal(ScaleRequest{Replicas: -1})
|
||||
req := httptest.NewRequest("POST", "/projects/myapp/deploy/scale", bytes.NewReader(body))
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestInfrastructureHandler_GetDeployLogs(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
_, _, _, deployer, router := setupInfraHandler()
|
||||
deployer.logs = "line1\nline2\nline3\n"
|
||||
|
||||
req := httptest.NewRequest("GET", "/projects/myapp/deploy/logs", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestInfrastructureHandler_AddDomain(t *testing.T) {
|
||||
t.Run("subdomain", func(t *testing.T) {
|
||||
_, _, dns, _, router := setupInfraHandler()
|
||||
|
||||
body, _ := json.Marshal(AddDomainRequest{Domain: "myapp.threesix.ai"})
|
||||
req := httptest.NewRequest("POST", "/projects/myapp/domain", bytes.NewReader(body))
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusCreated {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusCreated)
|
||||
}
|
||||
if len(dns.records) != 1 {
|
||||
t.Errorf("DNS records = %d, want 1", len(dns.records))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("external domain", func(t *testing.T) {
|
||||
_, _, dns, _, router := setupInfraHandler()
|
||||
|
||||
body, _ := json.Marshal(AddDomainRequest{Domain: "myapp.example.com"})
|
||||
req := httptest.NewRequest("POST", "/projects/myapp/domain", bytes.NewReader(body))
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusCreated {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusCreated)
|
||||
}
|
||||
// External domain should NOT create DNS record
|
||||
if len(dns.records) != 0 {
|
||||
t.Errorf("DNS records = %d, want 0 (external domain)", len(dns.records))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing domain", func(t *testing.T) {
|
||||
_, _, _, _, router := setupInfraHandler()
|
||||
|
||||
body, _ := json.Marshal(AddDomainRequest{})
|
||||
req := httptest.NewRequest("POST", "/projects/myapp/domain", bytes.NewReader(body))
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestInfrastructureHandler_RemoveDomain(t *testing.T) {
|
||||
t.Run("subdomain", func(t *testing.T) {
|
||||
_, _, dns, _, router := setupInfraHandler()
|
||||
dns.records["myapp"] = &domain.DNSRecord{ID: "rec-myapp", Name: "myapp"}
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/projects/myapp/domain?domain=myapp.threesix.ai", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing domain param", func(t *testing.T) {
|
||||
_, _, _, _, router := setupInfraHandler()
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/projects/myapp/domain", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsSubdomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
domain, base string
|
||||
want bool
|
||||
}{
|
||||
{"myapp.threesix.ai", "threesix.ai", true},
|
||||
{"deep.sub.threesix.ai", "threesix.ai", true},
|
||||
{"threesix.ai", "threesix.ai", false},
|
||||
{"myapp.example.com", "threesix.ai", false},
|
||||
{"", "threesix.ai", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.domain, func(t *testing.T) {
|
||||
got := isSubdomain(tt.domain, tt.base)
|
||||
if got != tt.want {
|
||||
t.Errorf("isSubdomain(%q, %q) = %v, want %v", tt.domain, tt.base, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSubdomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
domain, base, want string
|
||||
}{
|
||||
{"myapp.threesix.ai", "threesix.ai", "myapp"},
|
||||
{"deep.sub.threesix.ai", "threesix.ai", "deep.sub"},
|
||||
{"threesix.ai", "threesix.ai", "threesix.ai"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.domain, func(t *testing.T) {
|
||||
got := getSubdomain(tt.domain, tt.base)
|
||||
if got != tt.want {
|
||||
t.Errorf("getSubdomain(%q, %q) = %q, want %q", tt.domain, tt.base, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
260
internal/handlers/infrastructure_domains.go
Normal file
260
internal/handlers/infrastructure_domains.go
Normal file
@ -0,0 +1,260 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/pkg/api"
|
||||
)
|
||||
|
||||
// DomainAliasRequest is the request body for POST /projects/{id}/domains.
|
||||
type DomainAliasRequest struct {
|
||||
Domain string `json:"domain"` // The domain to add (e.g., "www.threesix.ai")
|
||||
Type string `json:"type,omitempty"` // "A" or "CNAME" (default: "A")
|
||||
Proxied bool `json:"proxied,omitempty"` // Cloudflare proxy (default: false)
|
||||
Content string `json:"content,omitempty"` // Target (default: cluster IP for A records)
|
||||
}
|
||||
|
||||
// DomainAliasResponse is the response for domain alias operations.
|
||||
type DomainAliasResponse struct {
|
||||
Domain string `json:"domain"`
|
||||
Type string `json:"type"`
|
||||
Content string `json:"content"`
|
||||
TTL int `json:"ttl"`
|
||||
Proxied bool `json:"proxied"`
|
||||
}
|
||||
|
||||
// ListDomains returns all DNS records associated with a project.
|
||||
// GET /projects/{id}/domains
|
||||
func (h *InfrastructureHandler) ListDomains(w http.ResponseWriter, r *http.Request) {
|
||||
projectID := chi.URLParam(r, "id")
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := validateProjectID(projectID); err != nil {
|
||||
api.WriteBadRequest(w, r, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if h.dns == nil {
|
||||
api.WriteInternalError(w, r, "DNS provider not configured")
|
||||
return
|
||||
}
|
||||
|
||||
// List all A records and find ones matching this project
|
||||
aRecords, err := h.dns.ListRecords(ctx, "A")
|
||||
if err != nil {
|
||||
api.WriteInternalError(w, r, "failed to list DNS records")
|
||||
return
|
||||
}
|
||||
|
||||
// Also list CNAME records
|
||||
cnameRecords, err := h.dns.ListRecords(ctx, "CNAME")
|
||||
if err != nil {
|
||||
api.WriteInternalError(w, r, "failed to list DNS records")
|
||||
return
|
||||
}
|
||||
|
||||
// Filter records that belong to this project:
|
||||
// - Primary: {projectID}.{defaultDomain}
|
||||
// - Aliases: any record pointing to the cluster IP or the project's primary domain
|
||||
primaryDomain := projectID + "." + h.defaultDomain
|
||||
var domains []DomainAliasResponse
|
||||
|
||||
for _, rec := range aRecords {
|
||||
name := rec.Name
|
||||
// Normalize: if name matches the project's subdomain or points to our cluster IP
|
||||
if name == primaryDomain || (rec.Content == h.clusterIP && isProjectDomain(name, projectID, h.defaultDomain)) {
|
||||
domains = append(domains, DomainAliasResponse{
|
||||
Domain: name,
|
||||
Type: rec.Type,
|
||||
Content: rec.Content,
|
||||
TTL: rec.TTL,
|
||||
Proxied: rec.Proxied,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for _, rec := range cnameRecords {
|
||||
// CNAME records pointing to the project's primary domain
|
||||
if rec.Content == primaryDomain {
|
||||
domains = append(domains, DomainAliasResponse{
|
||||
Domain: rec.Name,
|
||||
Type: rec.Type,
|
||||
Content: rec.Content,
|
||||
TTL: rec.TTL,
|
||||
Proxied: rec.Proxied,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
api.WriteSuccess(w, r, map[string]any{
|
||||
"project_id": projectID,
|
||||
"domains": domains,
|
||||
"total": len(domains),
|
||||
})
|
||||
}
|
||||
|
||||
// AddDomainAlias adds a DNS alias for a project.
|
||||
// POST /projects/{id}/domains
|
||||
func (h *InfrastructureHandler) AddDomainAlias(w http.ResponseWriter, r *http.Request) {
|
||||
projectID := chi.URLParam(r, "id")
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := validateProjectID(projectID); err != nil {
|
||||
api.WriteBadRequest(w, r, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if h.dns == nil {
|
||||
api.WriteInternalError(w, r, "DNS provider not configured")
|
||||
return
|
||||
}
|
||||
|
||||
var req DomainAliasRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
api.WriteBadRequest(w, r, "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Domain == "" {
|
||||
api.WriteBadRequest(w, r, "domain is required")
|
||||
return
|
||||
}
|
||||
|
||||
// Default record type is A
|
||||
recordType := strings.ToUpper(req.Type)
|
||||
if recordType == "" {
|
||||
recordType = "A"
|
||||
}
|
||||
if recordType != "A" && recordType != "CNAME" {
|
||||
api.WriteBadRequest(w, r, "type must be A or CNAME")
|
||||
return
|
||||
}
|
||||
|
||||
// Determine content
|
||||
content := req.Content
|
||||
if content == "" {
|
||||
switch recordType {
|
||||
case "A":
|
||||
if h.clusterIP == "" {
|
||||
api.WriteBadRequest(w, r, "cluster IP not configured and no content provided")
|
||||
return
|
||||
}
|
||||
content = h.clusterIP
|
||||
case "CNAME":
|
||||
// Default CNAME target is the project's primary domain
|
||||
content = projectID + "." + h.defaultDomain
|
||||
}
|
||||
}
|
||||
|
||||
// Determine the DNS name
|
||||
// If domain is a full FQDN under our zone, extract the subdomain for the API call
|
||||
dnsName := req.Domain
|
||||
if isSubdomain(req.Domain, h.defaultDomain) {
|
||||
dnsName = getSubdomain(req.Domain, h.defaultDomain)
|
||||
}
|
||||
|
||||
record, err := h.dns.CreateRecord(ctx, domain.DNSRecord{
|
||||
Type: recordType,
|
||||
Name: dnsName,
|
||||
Content: content,
|
||||
TTL: 1,
|
||||
Proxied: req.Proxied,
|
||||
})
|
||||
if err != nil {
|
||||
api.WriteInternalError(w, r, fmt.Sprintf("failed to create DNS record: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
note := "Domain alias configured"
|
||||
if !isSubdomain(req.Domain, h.defaultDomain) && recordType == "A" {
|
||||
note = fmt.Sprintf("External domain configured. Point your DNS to %s", h.clusterIP)
|
||||
}
|
||||
|
||||
api.WriteCreated(w, r, map[string]any{
|
||||
"project": projectID,
|
||||
"domain": record.Name,
|
||||
"type": record.Type,
|
||||
"content": record.Content,
|
||||
"status": "configured",
|
||||
"note": note,
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveDomainAlias removes a DNS alias from a project.
|
||||
// DELETE /projects/{id}/domains/{domain}
|
||||
func (h *InfrastructureHandler) RemoveDomainAlias(w http.ResponseWriter, r *http.Request) {
|
||||
projectID := chi.URLParam(r, "id")
|
||||
aliasDomain := chi.URLParam(r, "domain")
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := validateProjectID(projectID); err != nil {
|
||||
api.WriteBadRequest(w, r, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if aliasDomain == "" {
|
||||
api.WriteBadRequest(w, r, "domain path parameter is required")
|
||||
return
|
||||
}
|
||||
|
||||
if h.dns == nil {
|
||||
api.WriteInternalError(w, r, "DNS provider not configured")
|
||||
return
|
||||
}
|
||||
|
||||
// Prevent deleting the project's primary domain through this endpoint
|
||||
primaryDomain := projectID + "." + h.defaultDomain
|
||||
if aliasDomain == primaryDomain {
|
||||
api.WriteBadRequest(w, r, "cannot remove primary project domain through alias endpoint; use DELETE /project/{name} instead")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the record exists before attempting deletion
|
||||
dnsName := aliasDomain
|
||||
if isSubdomain(aliasDomain, h.defaultDomain) {
|
||||
dnsName = getSubdomain(aliasDomain, h.defaultDomain)
|
||||
}
|
||||
|
||||
// Check both A and CNAME records
|
||||
aRecord, _ := h.dns.FindRecord(ctx, "A", dnsName)
|
||||
cnameRecord, _ := h.dns.FindRecord(ctx, "CNAME", dnsName)
|
||||
|
||||
if aRecord == nil && cnameRecord == nil {
|
||||
api.WriteNotFound(w, r, fmt.Sprintf("no DNS record found for %s", aliasDomain))
|
||||
return
|
||||
}
|
||||
|
||||
// Delete whichever record(s) exist
|
||||
if aRecord != nil {
|
||||
_ = h.dns.DeleteRecordByName(ctx, "A", dnsName)
|
||||
}
|
||||
if cnameRecord != nil {
|
||||
_ = h.dns.DeleteRecordByName(ctx, "CNAME", dnsName)
|
||||
}
|
||||
|
||||
api.WriteSuccess(w, r, map[string]string{
|
||||
"project": projectID,
|
||||
"domain": aliasDomain,
|
||||
"status": "removed",
|
||||
})
|
||||
}
|
||||
|
||||
// isProjectDomain checks if a DNS name is associated with a project.
|
||||
// It matches: {projectID}.{baseDomain} or any subdomain pattern containing the project ID.
|
||||
func isProjectDomain(name, projectID, baseDomain string) bool {
|
||||
// Exact match: landing.threesix.ai
|
||||
if name == projectID+"."+baseDomain {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
246
internal/handlers/infrastructure_domains_test.go
Normal file
246
internal/handlers/infrastructure_domains_test.go
Normal file
@ -0,0 +1,246 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
)
|
||||
|
||||
func TestInfrastructureHandler_ListDomains(t *testing.T) {
|
||||
t.Run("returns matching A records", func(t *testing.T) {
|
||||
_, _, dns, _, router := setupInfraHandler()
|
||||
|
||||
// Add records — one matching the project, one unrelated
|
||||
dns.records["landing.threesix.ai"] = &domain.DNSRecord{
|
||||
ID: "rec-1", Type: "A", Name: "landing.threesix.ai",
|
||||
Content: "208.122.204.172", TTL: 1,
|
||||
}
|
||||
dns.records["other.threesix.ai"] = &domain.DNSRecord{
|
||||
ID: "rec-2", Type: "A", Name: "other.threesix.ai",
|
||||
Content: "208.122.204.172", TTL: 1,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/projects/landing/domains", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
data := resp["data"].(map[string]any)
|
||||
total := int(data["total"].(float64))
|
||||
if total != 1 {
|
||||
t.Errorf("total = %d, want 1 (only landing.threesix.ai)", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns CNAME aliases", func(t *testing.T) {
|
||||
_, _, dns, _, router := setupInfraHandler()
|
||||
|
||||
dns.records["landing.threesix.ai"] = &domain.DNSRecord{
|
||||
ID: "rec-1", Type: "A", Name: "landing.threesix.ai",
|
||||
Content: "208.122.204.172", TTL: 1,
|
||||
}
|
||||
dns.records["www.threesix.ai"] = &domain.DNSRecord{
|
||||
ID: "rec-2", Type: "CNAME", Name: "www.threesix.ai",
|
||||
Content: "landing.threesix.ai", TTL: 1,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/projects/landing/domains", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
data := resp["data"].(map[string]any)
|
||||
total := int(data["total"].(float64))
|
||||
if total != 2 {
|
||||
t.Errorf("total = %d, want 2 (A + CNAME)", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DNS not configured", func(t *testing.T) {
|
||||
h := NewInfrastructureHandler(nil, nil, nil, nil, nil, InfrastructureConfig{
|
||||
DefaultGitOwner: "threesix",
|
||||
DefaultDomain: "threesix.ai",
|
||||
ClusterIP: "208.122.204.172",
|
||||
})
|
||||
r := chi.NewRouter()
|
||||
h.Mount(r)
|
||||
|
||||
req := httptest.NewRequest("GET", "/projects/myapp/domains", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusInternalServerError {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestInfrastructureHandler_AddDomainAlias(t *testing.T) {
|
||||
t.Run("add A record alias", func(t *testing.T) {
|
||||
_, _, dns, _, router := setupInfraHandler()
|
||||
|
||||
body, _ := json.Marshal(DomainAliasRequest{Domain: "www.threesix.ai"})
|
||||
req := httptest.NewRequest("POST", "/projects/landing/domains", bytes.NewReader(body))
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusCreated {
|
||||
t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusCreated, rec.Body.String())
|
||||
}
|
||||
if len(dns.records) != 1 {
|
||||
t.Errorf("DNS records = %d, want 1", len(dns.records))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("add CNAME alias", func(t *testing.T) {
|
||||
_, _, dns, _, router := setupInfraHandler()
|
||||
|
||||
body, _ := json.Marshal(DomainAliasRequest{
|
||||
Domain: "www.threesix.ai",
|
||||
Type: "CNAME",
|
||||
})
|
||||
req := httptest.NewRequest("POST", "/projects/landing/domains", bytes.NewReader(body))
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusCreated {
|
||||
t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusCreated, rec.Body.String())
|
||||
}
|
||||
// CNAME should target landing.threesix.ai
|
||||
for _, r := range dns.records {
|
||||
if r.Type != "CNAME" {
|
||||
t.Errorf("type = %s, want CNAME", r.Type)
|
||||
}
|
||||
if r.Content != "landing.threesix.ai" {
|
||||
t.Errorf("content = %s, want landing.threesix.ai", r.Content)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid type", func(t *testing.T) {
|
||||
_, _, _, _, router := setupInfraHandler()
|
||||
|
||||
body, _ := json.Marshal(DomainAliasRequest{Domain: "www.threesix.ai", Type: "MX"})
|
||||
req := httptest.NewRequest("POST", "/projects/landing/domains", bytes.NewReader(body))
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing domain", func(t *testing.T) {
|
||||
_, _, _, _, router := setupInfraHandler()
|
||||
|
||||
body, _ := json.Marshal(DomainAliasRequest{})
|
||||
req := httptest.NewRequest("POST", "/projects/landing/domains", bytes.NewReader(body))
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DNS not configured", func(t *testing.T) {
|
||||
h := NewInfrastructureHandler(nil, nil, nil, nil, nil, InfrastructureConfig{
|
||||
DefaultGitOwner: "threesix",
|
||||
DefaultDomain: "threesix.ai",
|
||||
ClusterIP: "208.122.204.172",
|
||||
})
|
||||
r := chi.NewRouter()
|
||||
h.Mount(r)
|
||||
|
||||
body, _ := json.Marshal(DomainAliasRequest{Domain: "www.threesix.ai"})
|
||||
req := httptest.NewRequest("POST", "/projects/landing/domains", bytes.NewReader(body))
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusInternalServerError {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestInfrastructureHandler_RemoveDomainAlias(t *testing.T) {
|
||||
t.Run("removes alias", func(t *testing.T) {
|
||||
_, _, dns, _, router := setupInfraHandler()
|
||||
dns.records["www"] = &domain.DNSRecord{
|
||||
ID: "rec-www", Type: "A", Name: "www",
|
||||
Content: "208.122.204.172",
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/projects/landing/domains/www.threesix.ai", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("prevents removing primary domain", func(t *testing.T) {
|
||||
_, _, _, _, router := setupInfraHandler()
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/projects/landing/domains/landing.threesix.ai", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("not found", func(t *testing.T) {
|
||||
_, _, dns, _, router := setupInfraHandler()
|
||||
dns.err = nil // No records stored
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/projects/landing/domains/nonexistent.threesix.ai", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Errorf("status = %d, want %d; body: %s", rec.Code, http.StatusNotFound, rec.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsProjectDomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
name, projectID, baseDomain string
|
||||
want bool
|
||||
}{
|
||||
{"landing.threesix.ai", "landing", "threesix.ai", true},
|
||||
{"other.threesix.ai", "landing", "threesix.ai", false},
|
||||
{"landing.example.com", "landing", "threesix.ai", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := isProjectDomain(tt.name, tt.projectID, tt.baseDomain)
|
||||
if got != tt.want {
|
||||
t.Errorf("isProjectDomain(%q, %q, %q) = %v, want %v",
|
||||
tt.name, tt.projectID, tt.baseDomain, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
120
internal/handlers/infrastructure_pipelines.go
Normal file
120
internal/handlers/infrastructure_pipelines.go
Normal file
@ -0,0 +1,120 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/orchard9/rdev/pkg/api"
|
||||
)
|
||||
|
||||
// PipelineResponse is the JSON representation of a CI pipeline.
|
||||
type PipelineResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
Number int64 `json:"number"`
|
||||
Status string `json:"status"`
|
||||
Event string `json:"event"`
|
||||
Branch string `json:"branch"`
|
||||
Commit string `json:"commit"`
|
||||
Message string `json:"message"`
|
||||
Author string `json:"author"`
|
||||
Started string `json:"started,omitempty"`
|
||||
Finished string `json:"finished,omitempty"`
|
||||
}
|
||||
|
||||
// ListPipelines returns recent CI pipeline executions for a project.
|
||||
// GET /projects/{id}/pipelines
|
||||
func (h *InfrastructureHandler) ListPipelines(w http.ResponseWriter, r *http.Request) {
|
||||
projectID := chi.URLParam(r, "id")
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := validateProjectID(projectID); err != nil {
|
||||
api.WriteBadRequest(w, r, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if h.ciProvider == nil {
|
||||
api.WriteInternalError(w, r, "CI provider not configured")
|
||||
return
|
||||
}
|
||||
|
||||
pipelines, err := h.ciProvider.ListPipelines(ctx, h.defaultGitOwner, projectID)
|
||||
if err != nil {
|
||||
api.WriteNotFound(w, r, fmt.Sprintf("pipelines not found: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
resp := make([]PipelineResponse, len(pipelines))
|
||||
for i, p := range pipelines {
|
||||
resp[i] = PipelineResponse{
|
||||
ID: p.ID,
|
||||
Number: p.Number,
|
||||
Status: p.Status,
|
||||
Event: p.Event,
|
||||
Branch: p.Branch,
|
||||
Commit: p.Commit,
|
||||
Message: p.Message,
|
||||
Author: p.Author,
|
||||
Started: formatTime(p.Started),
|
||||
Finished: formatTime(p.Finished),
|
||||
}
|
||||
}
|
||||
|
||||
api.WriteSuccess(w, r, resp)
|
||||
}
|
||||
|
||||
// GetPipeline returns a specific CI pipeline execution for a project.
|
||||
// GET /projects/{id}/pipelines/{number}
|
||||
func (h *InfrastructureHandler) GetPipeline(w http.ResponseWriter, r *http.Request) {
|
||||
projectID := chi.URLParam(r, "id")
|
||||
numberStr := chi.URLParam(r, "number")
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := validateProjectID(projectID); err != nil {
|
||||
api.WriteBadRequest(w, r, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
number, err := strconv.ParseInt(numberStr, 10, 64)
|
||||
if err != nil {
|
||||
api.WriteBadRequest(w, r, "invalid pipeline number")
|
||||
return
|
||||
}
|
||||
|
||||
if h.ciProvider == nil {
|
||||
api.WriteInternalError(w, r, "CI provider not configured")
|
||||
return
|
||||
}
|
||||
|
||||
p, err := h.ciProvider.GetPipeline(ctx, h.defaultGitOwner, projectID, number)
|
||||
if err != nil {
|
||||
api.WriteNotFound(w, r, fmt.Sprintf("pipeline not found: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
api.WriteSuccess(w, r, PipelineResponse{
|
||||
ID: p.ID,
|
||||
Number: p.Number,
|
||||
Status: p.Status,
|
||||
Event: p.Event,
|
||||
Branch: p.Branch,
|
||||
Commit: p.Commit,
|
||||
Message: p.Message,
|
||||
Author: p.Author,
|
||||
Started: formatTime(p.Started),
|
||||
Finished: formatTime(p.Finished),
|
||||
})
|
||||
}
|
||||
|
||||
// formatTime formats a time.Time as RFC3339, returning empty string for zero time.
|
||||
func formatTime(t time.Time) string {
|
||||
if t.IsZero() {
|
||||
return ""
|
||||
}
|
||||
return t.Format(time.RFC3339)
|
||||
}
|
||||
250
internal/handlers/infrastructure_pipelines_test.go
Normal file
250
internal/handlers/infrastructure_pipelines_test.go
Normal file
@ -0,0 +1,250 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
)
|
||||
|
||||
// mockCIProvider implements port.CIProvider for testing.
|
||||
type mockCIProvider struct {
|
||||
pipelines map[string][]*domain.CIPipeline
|
||||
err error
|
||||
}
|
||||
|
||||
func newMockCIProvider() *mockCIProvider {
|
||||
return &mockCIProvider{pipelines: make(map[string][]*domain.CIPipeline)}
|
||||
}
|
||||
|
||||
func (m *mockCIProvider) ActivateRepo(context.Context, string, string, string) (*domain.CIRepo, error) {
|
||||
return nil, m.err
|
||||
}
|
||||
func (m *mockCIProvider) DeactivateRepo(context.Context, string, string) error {
|
||||
return m.err
|
||||
}
|
||||
func (m *mockCIProvider) GetRepo(context.Context, string, string) (*domain.CIRepo, error) {
|
||||
return nil, m.err
|
||||
}
|
||||
func (m *mockCIProvider) ListRepos(context.Context) ([]*domain.CIRepo, error) {
|
||||
return nil, m.err
|
||||
}
|
||||
func (m *mockCIProvider) AddSecret(context.Context, string, string, domain.CISecret) error {
|
||||
return m.err
|
||||
}
|
||||
func (m *mockCIProvider) DeleteSecret(context.Context, string, string, string) error {
|
||||
return m.err
|
||||
}
|
||||
|
||||
func (m *mockCIProvider) ListPipelines(_ context.Context, owner, repo string) ([]*domain.CIPipeline, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
key := owner + "/" + repo
|
||||
p, ok := m.pipelines[key]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("repo not found: %s", key)
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (m *mockCIProvider) GetPipeline(_ context.Context, owner, repo string, number int64) (*domain.CIPipeline, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
key := owner + "/" + repo
|
||||
for _, p := range m.pipelines[key] {
|
||||
if p.Number == number {
|
||||
return p, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("pipeline %d not found", number)
|
||||
}
|
||||
|
||||
func setupInfraHandlerWithCI(ci port.CIProvider) chi.Router {
|
||||
h := NewInfrastructureHandler(nil, nil, nil, nil, ci, InfrastructureConfig{
|
||||
DefaultGitOwner: "threesix",
|
||||
DefaultDomain: "threesix.ai",
|
||||
})
|
||||
r := chi.NewRouter()
|
||||
h.Mount(r)
|
||||
return r
|
||||
}
|
||||
|
||||
func TestInfrastructureHandler_ListPipelines(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
ci := newMockCIProvider()
|
||||
ci.pipelines["threesix/myapp"] = []*domain.CIPipeline{
|
||||
{
|
||||
ID: 100,
|
||||
Number: 1,
|
||||
Status: "success",
|
||||
Event: "push",
|
||||
Branch: "main",
|
||||
Commit: "abc123",
|
||||
Message: "initial commit",
|
||||
Author: "dev",
|
||||
Started: time.Date(2025, 1, 15, 10, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
ID: 101,
|
||||
Number: 2,
|
||||
Status: "running",
|
||||
Event: "push",
|
||||
Branch: "feature",
|
||||
Commit: "def456",
|
||||
Author: "dev",
|
||||
},
|
||||
}
|
||||
|
||||
router := setupInfraHandlerWithCI(ci)
|
||||
req := httptest.NewRequest("GET", "/projects/myapp/pipelines", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ci not configured", func(t *testing.T) {
|
||||
router := setupInfraHandlerWithCI(nil)
|
||||
req := httptest.NewRequest("GET", "/projects/myapp/pipelines", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusInternalServerError {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("repo not found", func(t *testing.T) {
|
||||
ci := newMockCIProvider()
|
||||
router := setupInfraHandlerWithCI(ci)
|
||||
|
||||
req := httptest.NewRequest("GET", "/projects/missing/pipelines", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid project id", func(t *testing.T) {
|
||||
ci := newMockCIProvider()
|
||||
router := setupInfraHandlerWithCI(ci)
|
||||
|
||||
req := httptest.NewRequest("GET", "/projects/INVALID!/pipelines", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestInfrastructureHandler_GetPipeline(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
ci := newMockCIProvider()
|
||||
ci.pipelines["threesix/myapp"] = []*domain.CIPipeline{
|
||||
{
|
||||
ID: 100,
|
||||
Number: 5,
|
||||
Status: "success",
|
||||
Event: "push",
|
||||
Branch: "main",
|
||||
Commit: "abc123",
|
||||
Message: "fix bug",
|
||||
Author: "dev",
|
||||
Started: time.Date(2025, 1, 15, 10, 0, 0, 0, time.UTC),
|
||||
Finished: time.Date(2025, 1, 15, 10, 5, 0, 0, time.UTC),
|
||||
},
|
||||
}
|
||||
|
||||
router := setupInfraHandlerWithCI(ci)
|
||||
req := httptest.NewRequest("GET", "/projects/myapp/pipelines/5", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ci not configured", func(t *testing.T) {
|
||||
router := setupInfraHandlerWithCI(nil)
|
||||
req := httptest.NewRequest("GET", "/projects/myapp/pipelines/1", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusInternalServerError {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("pipeline not found", func(t *testing.T) {
|
||||
ci := newMockCIProvider()
|
||||
ci.pipelines["threesix/myapp"] = []*domain.CIPipeline{}
|
||||
router := setupInfraHandlerWithCI(ci)
|
||||
|
||||
req := httptest.NewRequest("GET", "/projects/myapp/pipelines/999", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid pipeline number", func(t *testing.T) {
|
||||
ci := newMockCIProvider()
|
||||
router := setupInfraHandlerWithCI(ci)
|
||||
|
||||
req := httptest.NewRequest("GET", "/projects/myapp/pipelines/notanumber", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid project id", func(t *testing.T) {
|
||||
ci := newMockCIProvider()
|
||||
router := setupInfraHandlerWithCI(ci)
|
||||
|
||||
req := httptest.NewRequest("GET", "/projects/INVALID!/pipelines/1", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFormatTime(t *testing.T) {
|
||||
t.Run("non-zero time", func(t *testing.T) {
|
||||
ts := time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC)
|
||||
got := formatTime(ts)
|
||||
want := "2025-01-15T10:30:00Z"
|
||||
if got != want {
|
||||
t.Errorf("formatTime() = %q, want %q", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("zero time", func(t *testing.T) {
|
||||
got := formatTime(time.Time{})
|
||||
if got != "" {
|
||||
t.Errorf("formatTime(zero) = %q, want empty string", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
455
internal/handlers/infrastructure_test.go
Normal file
455
internal/handlers/infrastructure_test.go
Normal file
@ -0,0 +1,455 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
)
|
||||
|
||||
// mockGitRepository implements port.GitRepository for testing.
|
||||
type mockGitRepository struct {
|
||||
repos map[string]*domain.Repo
|
||||
err error
|
||||
}
|
||||
|
||||
func newMockGitRepository() *mockGitRepository {
|
||||
return &mockGitRepository{repos: make(map[string]*domain.Repo)}
|
||||
}
|
||||
|
||||
func (m *mockGitRepository) CreateRepo(_ context.Context, name, description string, private bool) (*domain.Repo, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
repo := &domain.Repo{
|
||||
ID: 1,
|
||||
Owner: "threesix",
|
||||
Name: name,
|
||||
FullName: "threesix/" + name,
|
||||
Description: description,
|
||||
Private: private,
|
||||
CloneSSH: fmt.Sprintf("git@git.threesix.ai:threesix/%s.git", name),
|
||||
CloneHTTP: fmt.Sprintf("https://git.threesix.ai/threesix/%s.git", name),
|
||||
HTMLURL: fmt.Sprintf("https://git.threesix.ai/threesix/%s", name),
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
m.repos[name] = repo
|
||||
return repo, nil
|
||||
}
|
||||
|
||||
func (m *mockGitRepository) DeleteRepo(_ context.Context, _, name string) error {
|
||||
if m.err != nil {
|
||||
return m.err
|
||||
}
|
||||
delete(m.repos, name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockGitRepository) ListRepos(_ context.Context, _ string) ([]*domain.Repo, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
var repos []*domain.Repo
|
||||
for _, r := range m.repos {
|
||||
repos = append(repos, r)
|
||||
}
|
||||
return repos, nil
|
||||
}
|
||||
|
||||
func (m *mockGitRepository) GetRepo(_ context.Context, _, name string) (*domain.Repo, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
r, ok := m.repos[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("repo not found: %s", name)
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (m *mockGitRepository) AddCollaborator(context.Context, string, string, string, string) error {
|
||||
return m.err
|
||||
}
|
||||
func (m *mockGitRepository) RemoveCollaborator(context.Context, string, string, string) error {
|
||||
return m.err
|
||||
}
|
||||
func (m *mockGitRepository) AddDeployKey(context.Context, string, string, string, string, bool) (*domain.DeployKey, error) {
|
||||
return nil, m.err
|
||||
}
|
||||
func (m *mockGitRepository) DeleteDeployKey(context.Context, string, string, int64) error {
|
||||
return m.err
|
||||
}
|
||||
func (m *mockGitRepository) CreateWebhook(context.Context, string, string, string, string, []string) (*domain.RepoWebhook, error) {
|
||||
return nil, m.err
|
||||
}
|
||||
func (m *mockGitRepository) DeleteWebhook(context.Context, string, string, int64) error {
|
||||
return m.err
|
||||
}
|
||||
|
||||
// mockDNSProvider implements port.DNSProvider for testing.
|
||||
type mockDNSProvider struct {
|
||||
records map[string]*domain.DNSRecord
|
||||
err error
|
||||
}
|
||||
|
||||
func newMockDNSProvider() *mockDNSProvider {
|
||||
return &mockDNSProvider{records: make(map[string]*domain.DNSRecord)}
|
||||
}
|
||||
|
||||
func (m *mockDNSProvider) CreateRecord(_ context.Context, record domain.DNSRecord) (*domain.DNSRecord, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
record.ID = "rec-" + record.Name
|
||||
m.records[record.Name] = &record
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
func (m *mockDNSProvider) UpdateRecord(_ context.Context, recordID string, record domain.DNSRecord) (*domain.DNSRecord, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
record.ID = recordID
|
||||
m.records[recordID] = &record
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
func (m *mockDNSProvider) DeleteRecord(_ context.Context, recordID string) error {
|
||||
if m.err != nil {
|
||||
return m.err
|
||||
}
|
||||
delete(m.records, recordID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockDNSProvider) DeleteRecordByName(_ context.Context, _, name string) error {
|
||||
if m.err != nil {
|
||||
return m.err
|
||||
}
|
||||
delete(m.records, name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockDNSProvider) GetRecord(_ context.Context, recordID string) (*domain.DNSRecord, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
r, ok := m.records[recordID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("record not found")
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (m *mockDNSProvider) ListRecords(_ context.Context, recordType string) ([]*domain.DNSRecord, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
var result []*domain.DNSRecord
|
||||
for _, r := range m.records {
|
||||
if recordType == "" || r.Type == recordType {
|
||||
result = append(result, r)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockDNSProvider) FindRecord(_ context.Context, _, name string) (*domain.DNSRecord, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
r, ok := m.records[name]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// mockDeployer implements port.Deployer for testing.
|
||||
type mockDeployer struct {
|
||||
deployments map[string]*domain.DeployStatus
|
||||
logs string
|
||||
err error
|
||||
}
|
||||
|
||||
func newMockDeployer() *mockDeployer {
|
||||
return &mockDeployer{deployments: make(map[string]*domain.DeployStatus)}
|
||||
}
|
||||
|
||||
func (m *mockDeployer) Deploy(_ context.Context, spec domain.DeploySpec) error {
|
||||
if m.err != nil {
|
||||
return m.err
|
||||
}
|
||||
m.deployments[spec.ProjectName] = &domain.DeployStatus{
|
||||
ProjectName: spec.ProjectName,
|
||||
Image: spec.Image,
|
||||
Replicas: spec.Replicas,
|
||||
ReadyReplicas: 0,
|
||||
URL: "https://" + spec.Domain,
|
||||
Status: domain.DeploymentStatusPending,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockDeployer) Undeploy(_ context.Context, projectName string) error {
|
||||
if m.err != nil {
|
||||
return m.err
|
||||
}
|
||||
delete(m.deployments, projectName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockDeployer) GetStatus(_ context.Context, projectName string) (*domain.DeployStatus, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
s, ok := m.deployments[projectName]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (m *mockDeployer) Restart(_ context.Context, _ string) error {
|
||||
return m.err
|
||||
}
|
||||
|
||||
func (m *mockDeployer) Scale(_ context.Context, projectName string, replicas int) error {
|
||||
if m.err != nil {
|
||||
return m.err
|
||||
}
|
||||
if s, ok := m.deployments[projectName]; ok {
|
||||
s.Replicas = replicas
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockDeployer) GetLogs(_ context.Context, _ string, _ int) (string, error) {
|
||||
if m.err != nil {
|
||||
return "", m.err
|
||||
}
|
||||
return m.logs, nil
|
||||
}
|
||||
|
||||
func setupInfraHandler() (*InfrastructureHandler, *mockGitRepository, *mockDNSProvider, *mockDeployer, chi.Router) {
|
||||
git := newMockGitRepository()
|
||||
dns := newMockDNSProvider()
|
||||
deployer := newMockDeployer()
|
||||
h := NewInfrastructureHandler(git, dns, deployer, nil, nil, InfrastructureConfig{
|
||||
DefaultGitOwner: "threesix",
|
||||
DefaultDomain: "threesix.ai",
|
||||
ClusterIP: "208.122.204.172",
|
||||
})
|
||||
r := chi.NewRouter()
|
||||
h.Mount(r)
|
||||
return h, git, dns, deployer, r
|
||||
}
|
||||
|
||||
func TestInfrastructureHandler_CreateRepo(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
_, git, _, _, router := setupInfraHandler()
|
||||
|
||||
body, _ := json.Marshal(CreateRepoRequest{Description: "Test repo", Private: true})
|
||||
req := httptest.NewRequest("POST", "/projects/myapp/repo", bytes.NewReader(body))
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusCreated {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusCreated)
|
||||
}
|
||||
if _, ok := git.repos["myapp"]; !ok {
|
||||
t.Error("repo not created")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid project id", func(t *testing.T) {
|
||||
_, _, _, _, router := setupInfraHandler()
|
||||
|
||||
req := httptest.NewRequest("POST", "/projects/INVALID_NAME!/repo", bytes.NewReader([]byte("{}")))
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty body allowed", func(t *testing.T) {
|
||||
_, _, _, _, router := setupInfraHandler()
|
||||
|
||||
req := httptest.NewRequest("POST", "/projects/myapp/repo", bytes.NewReader([]byte("")))
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
// Should succeed with empty body (EOF is allowed)
|
||||
if rec.Code != http.StatusCreated {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusCreated)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("git not configured", func(t *testing.T) {
|
||||
h := NewInfrastructureHandler(nil, nil, nil, nil, nil, InfrastructureConfig{})
|
||||
r := chi.NewRouter()
|
||||
h.Mount(r)
|
||||
|
||||
req := httptest.NewRequest("POST", "/projects/myapp/repo", bytes.NewReader([]byte("{}")))
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusInternalServerError {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestInfrastructureHandler_GetRepo(t *testing.T) {
|
||||
t.Run("found", func(t *testing.T) {
|
||||
_, git, _, _, router := setupInfraHandler()
|
||||
git.repos["myapp"] = &domain.Repo{
|
||||
ID: 1, Owner: "threesix", Name: "myapp", FullName: "threesix/myapp",
|
||||
CloneSSH: "git@git.threesix.ai:threesix/myapp.git",
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/projects/myapp/repo", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("not found", func(t *testing.T) {
|
||||
_, _, _, _, router := setupInfraHandler()
|
||||
|
||||
req := httptest.NewRequest("GET", "/projects/missing/repo", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestInfrastructureHandler_DeleteRepo(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
_, git, _, _, router := setupInfraHandler()
|
||||
git.repos["myapp"] = &domain.Repo{ID: 1, Name: "myapp"}
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/projects/myapp/repo", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestInfrastructureHandler_Deploy(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
_, _, _, deployer, router := setupInfraHandler()
|
||||
|
||||
body, _ := json.Marshal(DeployRequest{
|
||||
Image: "registry.threesix.ai/myapp:latest",
|
||||
Port: 8080,
|
||||
Replicas: 2,
|
||||
})
|
||||
req := httptest.NewRequest("POST", "/projects/myapp/deploy", bytes.NewReader(body))
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusCreated {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusCreated)
|
||||
}
|
||||
if _, ok := deployer.deployments["myapp"]; !ok {
|
||||
t.Error("deployment not created")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing image", func(t *testing.T) {
|
||||
_, _, _, _, router := setupInfraHandler()
|
||||
|
||||
body, _ := json.Marshal(DeployRequest{Port: 8080})
|
||||
req := httptest.NewRequest("POST", "/projects/myapp/deploy", bytes.NewReader(body))
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("deployer not configured", func(t *testing.T) {
|
||||
h := NewInfrastructureHandler(nil, nil, nil, nil, nil, InfrastructureConfig{})
|
||||
r := chi.NewRouter()
|
||||
h.Mount(r)
|
||||
|
||||
body, _ := json.Marshal(DeployRequest{Image: "myimage:latest"})
|
||||
req := httptest.NewRequest("POST", "/projects/myapp/deploy", bytes.NewReader(body))
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusInternalServerError {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestInfrastructureHandler_GetDeployStatus(t *testing.T) {
|
||||
t.Run("found", func(t *testing.T) {
|
||||
_, _, _, deployer, router := setupInfraHandler()
|
||||
deployer.deployments["myapp"] = &domain.DeployStatus{
|
||||
ProjectName: "myapp",
|
||||
Image: "myimage:latest",
|
||||
Status: domain.DeploymentStatusRunning,
|
||||
Replicas: 2,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/projects/myapp/deploy/status", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("not found", func(t *testing.T) {
|
||||
_, _, _, _, router := setupInfraHandler()
|
||||
|
||||
req := httptest.NewRequest("GET", "/projects/missing/deploy/status", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestInfrastructureHandler_Undeploy(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
_, _, _, deployer, router := setupInfraHandler()
|
||||
deployer.deployments["myapp"] = &domain.DeployStatus{ProjectName: "myapp"}
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/projects/myapp/deploy", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -2,6 +2,7 @@ package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
@ -66,7 +67,7 @@ type CreateKeyResponse struct {
|
||||
// apiKeyToResponse converts an APIKey to a JSON response.
|
||||
func apiKeyToResponse(k *auth.APIKey) KeyResponse {
|
||||
resp := KeyResponse{
|
||||
ID: k.ID,
|
||||
ID: string(k.ID),
|
||||
Name: k.Name,
|
||||
KeyPrefix: k.KeyPrefix,
|
||||
Scopes: auth.ScopesToStrings(k.Scopes),
|
||||
@ -76,7 +77,10 @@ func apiKeyToResponse(k *auth.APIKey) KeyResponse {
|
||||
}
|
||||
|
||||
if k.ProjectIDs != nil {
|
||||
resp.ProjectIDs = k.ProjectIDs
|
||||
resp.ProjectIDs = make([]string, len(k.ProjectIDs))
|
||||
for i, pid := range k.ProjectIDs {
|
||||
resp.ProjectIDs[i] = string(pid)
|
||||
}
|
||||
}
|
||||
|
||||
if k.AllowedIPs != nil {
|
||||
@ -159,8 +163,8 @@ func (h *KeysHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Get creator from authenticated key
|
||||
creator := "admin"
|
||||
if apiKey := auth.GetAPIKey(r.Context()); apiKey != nil && apiKey.ID != "admin" {
|
||||
creator = apiKey.ID
|
||||
if apiKey := auth.GetAPIKey(r.Context()); apiKey != nil && string(apiKey.ID) != "admin" {
|
||||
creator = string(apiKey.ID)
|
||||
}
|
||||
|
||||
result, err := h.authService.Create(r.Context(), auth.CreateKeyRequest{
|
||||
@ -206,7 +210,7 @@ func (h *KeysHandler) Get(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
key, err := h.authService.Get(r.Context(), id)
|
||||
if err != nil {
|
||||
if err == auth.ErrKeyNotFound {
|
||||
if errors.Is(err, auth.ErrKeyNotFound) {
|
||||
api.WriteNotFound(w, r, "Key not found")
|
||||
return
|
||||
}
|
||||
@ -223,7 +227,7 @@ func (h *KeysHandler) Revoke(w http.ResponseWriter, r *http.Request) {
|
||||
id := chi.URLParam(r, "id")
|
||||
|
||||
if err := h.authService.Revoke(r.Context(), id); err != nil {
|
||||
if err == auth.ErrKeyNotFound {
|
||||
if errors.Is(err, auth.ErrKeyNotFound) {
|
||||
api.WriteNotFound(w, r, "Key not found")
|
||||
return
|
||||
}
|
||||
|
||||
@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
@ -11,7 +12,10 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/orchard9/rdev/internal/adapter/postgres"
|
||||
"github.com/orchard9/rdev/internal/auth"
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/service"
|
||||
"github.com/orchard9/rdev/internal/testutil"
|
||||
)
|
||||
|
||||
@ -22,7 +26,9 @@ func setupKeysHandler(t *testing.T) (*KeysHandler, chi.Router, *auth.Service) {
|
||||
db := testutil.TestDB(t)
|
||||
t.Cleanup(func() { testutil.CleanupTestKeys(t, db) })
|
||||
|
||||
authService := auth.NewService(db, "test-admin-key")
|
||||
apiKeyRepo := postgres.NewAPIKeyRepository(db)
|
||||
apiKeySvc := service.NewAPIKeyService(apiKeyRepo, "test-admin-key")
|
||||
authService := auth.NewService(apiKeySvc, "test-admin-key")
|
||||
handler := NewKeysHandler(authService)
|
||||
|
||||
router := chi.NewRouter()
|
||||
@ -204,7 +210,7 @@ func TestKeysHandler_Get(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("existing key", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/keys/"+result.Key.ID, nil)
|
||||
req := httptest.NewRequest("GET", "/keys/"+string(result.Key.ID), nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
@ -249,7 +255,7 @@ func TestKeysHandler_Revoke(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("revoke existing key", func(t *testing.T) {
|
||||
req := httptest.NewRequest("DELETE", "/keys/"+result.Key.ID, nil)
|
||||
req := httptest.NewRequest("DELETE", "/keys/"+string(result.Key.ID), nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
@ -268,7 +274,7 @@ func TestKeysHandler_Revoke(t *testing.T) {
|
||||
|
||||
// Verify the key is actually revoked
|
||||
_, err := authService.Validate(context.Background(), result.Secret)
|
||||
if err != auth.ErrKeyRevoked {
|
||||
if !errors.Is(err, auth.ErrKeyRevoked) {
|
||||
t.Errorf("Key should be revoked, got err = %v", err)
|
||||
}
|
||||
})
|
||||
@ -308,11 +314,11 @@ func TestApiKeyToResponse(t *testing.T) {
|
||||
future := now.Add(24 * time.Hour)
|
||||
|
||||
key := &auth.APIKey{
|
||||
ID: "test-id",
|
||||
ID: domain.APIKeyID("test-id"),
|
||||
Name: "test-name",
|
||||
KeyPrefix: "rdev_sk_abc",
|
||||
Scopes: []auth.Scope{auth.ScopeProjectsRead, auth.ScopeProjectsExecute},
|
||||
ProjectIDs: []string{"proj-a"},
|
||||
ProjectIDs: []domain.ProjectID{"proj-a"},
|
||||
CreatedAt: now,
|
||||
ExpiresAt: &future,
|
||||
LastUsedAt: &now,
|
||||
|
||||
@ -18,12 +18,17 @@ import (
|
||||
// ProjectManagementHandler handles project lifecycle operations.
|
||||
type ProjectManagementHandler struct {
|
||||
infraService *service.ProjectInfraService
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewProjectManagementHandler creates a new project management handler.
|
||||
func NewProjectManagementHandler(infraService *service.ProjectInfraService) *ProjectManagementHandler {
|
||||
func NewProjectManagementHandler(infraService *service.ProjectInfraService, logger *slog.Logger) *ProjectManagementHandler {
|
||||
if logger == nil {
|
||||
logger = slog.Default()
|
||||
}
|
||||
return &ProjectManagementHandler{
|
||||
infraService: infraService,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
@ -84,7 +89,7 @@ func (h *ProjectManagementHandler) Create(w http.ResponseWriter, r *http.Request
|
||||
return
|
||||
}
|
||||
// Log internal errors but return generic message to client
|
||||
slog.Error("project creation failed", "error", err, "name", req.Name)
|
||||
h.logger.Error("project creation failed", "error", err, "name", req.Name)
|
||||
api.WriteInternalError(w, r, "failed to create project")
|
||||
return
|
||||
}
|
||||
@ -119,7 +124,7 @@ func (h *ProjectManagementHandler) List(w http.ResponseWriter, r *http.Request)
|
||||
|
||||
projects, err := h.infraService.ListProjects(ctx)
|
||||
if err != nil {
|
||||
slog.Error("failed to list projects", "error", err)
|
||||
h.logger.Error("failed to list projects", "error", err)
|
||||
api.WriteInternalError(w, r, "failed to list projects")
|
||||
return
|
||||
}
|
||||
@ -169,7 +174,7 @@ func (h *ProjectManagementHandler) Status(w http.ResponseWriter, r *http.Request
|
||||
api.WriteNotFound(w, r, "project not found")
|
||||
return
|
||||
}
|
||||
slog.Error("failed to get project status", "error", err, "name", name)
|
||||
h.logger.Error("failed to get project status", "error", err, "name", name)
|
||||
api.WriteInternalError(w, r, "failed to get project status")
|
||||
return
|
||||
}
|
||||
@ -216,7 +221,7 @@ func (h *ProjectManagementHandler) Delete(w http.ResponseWriter, r *http.Request
|
||||
api.WriteNotFound(w, r, "project not found")
|
||||
return
|
||||
}
|
||||
slog.Error("failed to delete project", "error", err, "name", name)
|
||||
h.logger.Error("failed to delete project", "error", err, "name", name)
|
||||
api.WriteInternalError(w, r, "failed to delete project")
|
||||
return
|
||||
}
|
||||
@ -240,7 +245,7 @@ func (h *ProjectManagementHandler) ListTemplates(w http.ResponseWriter, r *http.
|
||||
|
||||
templates, err := h.infraService.ListTemplates(ctx)
|
||||
if err != nil {
|
||||
slog.Error("failed to list templates", "error", err)
|
||||
h.logger.Error("failed to list templates", "error", err)
|
||||
api.WriteInternalError(w, r, "failed to list templates")
|
||||
return
|
||||
}
|
||||
@ -277,7 +282,7 @@ func (h *ProjectManagementHandler) GetTemplate(w http.ResponseWriter, r *http.Re
|
||||
api.WriteNotFound(w, r, "template not found")
|
||||
return
|
||||
}
|
||||
slog.Error("failed to get template", "error", err, "name", name)
|
||||
h.logger.Error("failed to get template", "error", err, "name", name)
|
||||
api.WriteInternalError(w, r, "failed to get template")
|
||||
return
|
||||
}
|
||||
|
||||
85
internal/handlers/project_management_test.go
Normal file
85
internal/handlers/project_management_test.go
Normal file
@ -0,0 +1,85 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
func TestProjectManagementHandler_NilService(t *testing.T) {
|
||||
h := NewProjectManagementHandler(nil, slog.Default())
|
||||
r := chi.NewRouter()
|
||||
h.Mount(r)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
body string
|
||||
}{
|
||||
{"create", "POST", "/project", `{"name":"test"}`},
|
||||
{"list", "GET", "/project", ""},
|
||||
{"status", "GET", "/project/test", ""},
|
||||
{"delete", "DELETE", "/project/test", ""},
|
||||
{"list templates", "GET", "/templates", ""},
|
||||
{"get template", "GET", "/templates/default", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var req *http.Request
|
||||
if tt.body != "" {
|
||||
req = httptest.NewRequest(tt.method, tt.path, bytes.NewReader([]byte(tt.body)))
|
||||
} else {
|
||||
req = httptest.NewRequest(tt.method, tt.path, nil)
|
||||
}
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusInternalServerError {
|
||||
t.Errorf("%s: status = %d, want %d", tt.name, rec.Code, http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProjectManagementHandler_CreateValidation(t *testing.T) {
|
||||
// With nil service, the handler returns 500 before reaching validation.
|
||||
// This tests that the nil check takes precedence.
|
||||
h := NewProjectManagementHandler(nil, slog.Default())
|
||||
r := chi.NewRouter()
|
||||
h.Mount(r)
|
||||
|
||||
t.Run("nil service returns 500 even with missing name", func(t *testing.T) {
|
||||
body, _ := json.Marshal(CreateRequest{Name: ""})
|
||||
req := httptest.NewRequest("POST", "/project", bytes.NewReader(body))
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusInternalServerError {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nil service returns 500 even with invalid json", func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/project", bytes.NewReader([]byte("not json")))
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusInternalServerError {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewProjectManagementHandler_NilLogger(t *testing.T) {
|
||||
h := NewProjectManagementHandler(nil, nil)
|
||||
if h.logger == nil {
|
||||
t.Error("logger should default to slog.Default() when nil")
|
||||
}
|
||||
}
|
||||
@ -67,7 +67,7 @@ func getAuditContext(r *http.Request) *service.AuditContext {
|
||||
}
|
||||
|
||||
return &service.AuditContext{
|
||||
APIKeyID: apiKey.ID,
|
||||
APIKeyID: string(apiKey.ID),
|
||||
ClientIP: getClientIP(r),
|
||||
UserAgent: r.UserAgent(),
|
||||
}
|
||||
|
||||
98
internal/handlers/projects_commands_test.go
Normal file
98
internal/handlers/projects_commands_test.go
Normal file
@ -0,0 +1,98 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
func TestProjectsHandler_RunClaude_InvalidJSON(t *testing.T) {
|
||||
h := &ProjectsHandler{streams: newStreamManager()}
|
||||
r := chi.NewRouter()
|
||||
h.Mount(r)
|
||||
|
||||
req := httptest.NewRequest("POST", "/projects/myapp/claude", bytes.NewReader([]byte("not json")))
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProjectsHandler_RunClaude_NoServiceConfigured(t *testing.T) {
|
||||
h := &ProjectsHandler{streams: newStreamManager()}
|
||||
r := chi.NewRouter()
|
||||
h.Mount(r)
|
||||
|
||||
body, _ := json.Marshal(ClaudeRequest{Prompt: "hello"})
|
||||
req := httptest.NewRequest("POST", "/projects/myapp/claude", bytes.NewReader(body))
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusInternalServerError {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProjectsHandler_RunShell_InvalidJSON(t *testing.T) {
|
||||
h := &ProjectsHandler{streams: newStreamManager()}
|
||||
r := chi.NewRouter()
|
||||
h.Mount(r)
|
||||
|
||||
req := httptest.NewRequest("POST", "/projects/myapp/shell", bytes.NewReader([]byte("not json")))
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProjectsHandler_RunShell_NoServiceConfigured(t *testing.T) {
|
||||
h := &ProjectsHandler{streams: newStreamManager()}
|
||||
r := chi.NewRouter()
|
||||
h.Mount(r)
|
||||
|
||||
body, _ := json.Marshal(ShellRequest{Command: "ls"})
|
||||
req := httptest.NewRequest("POST", "/projects/myapp/shell", bytes.NewReader(body))
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusInternalServerError {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProjectsHandler_RunGit_InvalidJSON(t *testing.T) {
|
||||
h := &ProjectsHandler{streams: newStreamManager()}
|
||||
r := chi.NewRouter()
|
||||
h.Mount(r)
|
||||
|
||||
req := httptest.NewRequest("POST", "/projects/myapp/git", bytes.NewReader([]byte("not json")))
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProjectsHandler_RunGit_NoServiceConfigured(t *testing.T) {
|
||||
h := &ProjectsHandler{streams: newStreamManager()}
|
||||
r := chi.NewRouter()
|
||||
h.Mount(r)
|
||||
|
||||
body, _ := json.Marshal(GitRequest{Args: []string{"status"}})
|
||||
req := httptest.NewRequest("POST", "/projects/myapp/git", bytes.NewReader(body))
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusInternalServerError {
|
||||
t.Errorf("status = %d, want %d", rec.Code, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
187
internal/handlers/projects_stream_test.go
Normal file
187
internal/handlers/projects_stream_test.go
Normal file
@ -0,0 +1,187 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// recorderFlusher wraps httptest.ResponseRecorder to satisfy http.Flusher.
|
||||
type recorderFlusher struct {
|
||||
*httptest.ResponseRecorder
|
||||
}
|
||||
|
||||
func (rf *recorderFlusher) Flush() {}
|
||||
|
||||
func newRecorderFlusher() *recorderFlusher {
|
||||
return &recorderFlusher{httptest.NewRecorder()}
|
||||
}
|
||||
|
||||
func TestStreamManager_SubscribeAndSend(t *testing.T) {
|
||||
sm := newStreamManager()
|
||||
|
||||
ch := sm.Subscribe("stream-1")
|
||||
defer sm.Unsubscribe("stream-1", ch)
|
||||
|
||||
// Send event
|
||||
sm.Send("stream-1", "output", map[string]any{"line": "hello"})
|
||||
|
||||
select {
|
||||
case evt := <-ch:
|
||||
if evt.Type != "output" {
|
||||
t.Errorf("event type = %q, want %q", evt.Type, "output")
|
||||
}
|
||||
if evt.Data["line"] != "hello" {
|
||||
t.Errorf("event data = %v, want line=hello", evt.Data)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for event")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamManager_MultipleSubscribers(t *testing.T) {
|
||||
sm := newStreamManager()
|
||||
|
||||
ch1 := sm.Subscribe("stream-1")
|
||||
ch2 := sm.Subscribe("stream-1")
|
||||
defer sm.Unsubscribe("stream-1", ch1)
|
||||
defer sm.Unsubscribe("stream-1", ch2)
|
||||
|
||||
sm.Send("stream-1", "test", map[string]any{"value": 1})
|
||||
|
||||
// Both should receive
|
||||
for i, ch := range []chan streamEvent{ch1, ch2} {
|
||||
select {
|
||||
case evt := <-ch:
|
||||
if evt.Type != "test" {
|
||||
t.Errorf("subscriber %d: event type = %q, want %q", i, evt.Type, "test")
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("subscriber %d: timed out", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamManager_Close(t *testing.T) {
|
||||
sm := newStreamManager()
|
||||
|
||||
ch := sm.Subscribe("stream-1")
|
||||
|
||||
sm.Close("stream-1")
|
||||
|
||||
// Channel should be closed
|
||||
_, ok := <-ch
|
||||
if ok {
|
||||
t.Error("channel should be closed after Close()")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamManager_SendToNonexistentStream(t *testing.T) {
|
||||
sm := newStreamManager()
|
||||
// Should not panic
|
||||
sm.Send("nonexistent", "test", map[string]any{})
|
||||
}
|
||||
|
||||
func TestStreamManager_Unsubscribe(t *testing.T) {
|
||||
sm := newStreamManager()
|
||||
|
||||
ch1 := sm.Subscribe("stream-1")
|
||||
ch2 := sm.Subscribe("stream-1")
|
||||
|
||||
sm.Unsubscribe("stream-1", ch1)
|
||||
|
||||
// ch1 should be closed
|
||||
_, ok := <-ch1
|
||||
if ok {
|
||||
t.Error("ch1 should be closed after Unsubscribe")
|
||||
}
|
||||
|
||||
// ch2 should still receive
|
||||
sm.Send("stream-1", "test", map[string]any{})
|
||||
select {
|
||||
case evt := <-ch2:
|
||||
if evt.Type != "test" {
|
||||
t.Errorf("event type = %q, want %q", evt.Type, "test")
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("ch2 timed out")
|
||||
}
|
||||
|
||||
sm.Unsubscribe("stream-1", ch2)
|
||||
}
|
||||
|
||||
func TestWriteSSE(t *testing.T) {
|
||||
rf := newRecorderFlusher()
|
||||
|
||||
writeSSE(rf.ResponseRecorder, rf, "output", map[string]any{"line": "hello"})
|
||||
|
||||
body := rf.Body.String()
|
||||
if !strings.Contains(body, "event: output\n") {
|
||||
t.Errorf("missing event line in SSE output: %s", body)
|
||||
}
|
||||
if !strings.Contains(body, "data: ") {
|
||||
t.Errorf("missing data line in SSE output: %s", body)
|
||||
}
|
||||
// Should not have id line
|
||||
if strings.Contains(body, "id: ") {
|
||||
t.Errorf("should not have id line without ID: %s", body)
|
||||
}
|
||||
|
||||
// Verify data is valid JSON
|
||||
lines := strings.Split(body, "\n")
|
||||
for _, line := range lines {
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
jsonStr := strings.TrimPrefix(line, "data: ")
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal([]byte(jsonStr), &parsed); err != nil {
|
||||
t.Errorf("data is not valid JSON: %v", err)
|
||||
}
|
||||
if parsed["line"] != "hello" {
|
||||
t.Errorf("data[line] = %v, want hello", parsed["line"])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteSSEWithID(t *testing.T) {
|
||||
rf := newRecorderFlusher()
|
||||
|
||||
writeSSEWithID(rf.ResponseRecorder, rf, "evt-123", "complete", map[string]any{"exit_code": 0})
|
||||
|
||||
body := rf.Body.String()
|
||||
if !strings.Contains(body, "id: evt-123\n") {
|
||||
t.Errorf("missing id line in SSE output: %s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: complete\n") {
|
||||
t.Errorf("missing event line in SSE output: %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamManager_FullChannel(t *testing.T) {
|
||||
sm := newStreamManager()
|
||||
|
||||
ch := sm.Subscribe("stream-1")
|
||||
|
||||
// Fill the channel (buffer size is 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
sm.Send("stream-1", "test", map[string]any{"i": i})
|
||||
}
|
||||
|
||||
// Next send should not block (dropped)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
sm.Send("stream-1", "test", map[string]any{"i": 100})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Good - did not block
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Send blocked on full channel")
|
||||
}
|
||||
|
||||
sm.Unsubscribe("stream-1", ch)
|
||||
}
|
||||
@ -138,7 +138,7 @@ func (h *QueueHandler) Enqueue(w http.ResponseWriter, r *http.Request) {
|
||||
// Get API key ID for audit trail
|
||||
var apiKeyID string
|
||||
if apiKey := auth.GetAPIKey(r.Context()); apiKey != nil {
|
||||
apiKeyID = apiKey.ID
|
||||
apiKeyID = string(apiKey.ID)
|
||||
}
|
||||
|
||||
// Create queued command
|
||||
|
||||
@ -8,13 +8,13 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
"github.com/orchard9/rdev/internal/validate"
|
||||
"github.com/orchard9/rdev/pkg/api"
|
||||
)
|
||||
|
||||
@ -115,8 +115,7 @@ func (h *WebhookHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
api.WriteBadRequest(w, r, "url is required")
|
||||
return
|
||||
}
|
||||
parsedURL, err := url.Parse(req.URL)
|
||||
if err != nil || (parsedURL.Scheme != "http" && parsedURL.Scheme != "https") {
|
||||
if err := validate.HTTPURL(req.URL, "url"); err != nil {
|
||||
api.WriteBadRequest(w, r, "url must be a valid HTTP or HTTPS URL")
|
||||
return
|
||||
}
|
||||
@ -288,8 +287,7 @@ func (h *WebhookHandler) Update(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Update fields
|
||||
if req.URL != "" {
|
||||
parsedURL, err := url.Parse(req.URL)
|
||||
if err != nil || (parsedURL.Scheme != "http" && parsedURL.Scheme != "https") {
|
||||
if err := validate.HTTPURL(req.URL, "url"); err != nil {
|
||||
api.WriteBadRequest(w, r, "url must be a valid HTTP or HTTPS URL")
|
||||
return
|
||||
}
|
||||
|
||||
@ -6,13 +6,12 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
"github.com/orchard9/rdev/internal/service"
|
||||
"github.com/orchard9/rdev/internal/validate"
|
||||
"github.com/orchard9/rdev/pkg/api"
|
||||
)
|
||||
|
||||
@ -88,19 +87,16 @@ func (h *WorkHandler) Enqueue(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Validate task type
|
||||
taskType := port.WorkTaskType(req.TaskType)
|
||||
taskType := domain.WorkTaskType(req.TaskType)
|
||||
if !taskType.IsValid() {
|
||||
api.WriteBadRequest(w, r, "task_type must be one of: build, test, deploy, custom")
|
||||
return
|
||||
}
|
||||
|
||||
// Validate callback URL if provided
|
||||
if req.CallbackURL != "" {
|
||||
parsedURL, err := url.Parse(req.CallbackURL)
|
||||
if err != nil || (parsedURL.Scheme != "http" && parsedURL.Scheme != "https") {
|
||||
api.WriteBadRequest(w, r, "callback_url must be a valid HTTP/HTTPS URL")
|
||||
return
|
||||
}
|
||||
if err := validate.HTTPURL(req.CallbackURL, "callback_url"); err != nil {
|
||||
api.WriteBadRequest(w, r, "callback_url must be a valid HTTP/HTTPS URL")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.workService.EnqueueTask(r.Context(), service.EnqueueTaskRequest{
|
||||
@ -157,8 +153,8 @@ type WorkResultDTO struct {
|
||||
Artifacts map[string]string `json:"artifacts,omitempty"`
|
||||
}
|
||||
|
||||
// toWorkTaskDTO converts a port.WorkTask to a WorkTaskDTO.
|
||||
func toWorkTaskDTO(t *port.WorkTask) *WorkTaskDTO {
|
||||
// toWorkTaskDTO converts a domain.WorkTask to a WorkTaskDTO.
|
||||
func toWorkTaskDTO(t *domain.WorkTask) *WorkTaskDTO {
|
||||
if t == nil {
|
||||
return nil
|
||||
}
|
||||
@ -273,7 +269,7 @@ func (h *WorkHandler) Complete(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
result := &port.WorkResult{
|
||||
result := &domain.WorkResult{
|
||||
Output: req.Output,
|
||||
Artifacts: req.Artifacts,
|
||||
}
|
||||
@ -358,9 +354,9 @@ func (h *WorkHandler) ListByProject(w http.ResponseWriter, r *http.Request) {
|
||||
projectID := chi.URLParam(r, "projectId")
|
||||
|
||||
// Parse and validate optional status filter
|
||||
var status *port.WorkTaskStatus
|
||||
var status *domain.WorkTaskStatus
|
||||
if s := r.URL.Query().Get("status"); s != "" {
|
||||
st := port.WorkTaskStatus(s)
|
||||
st := domain.WorkTaskStatus(s)
|
||||
if !st.IsValid() {
|
||||
api.WriteBadRequest(w, r, "invalid status filter: must be pending, running, completed, failed, or cancelled")
|
||||
return
|
||||
@ -369,7 +365,7 @@ func (h *WorkHandler) ListByProject(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Parse pagination options
|
||||
opts := port.DefaultWorkListOptions()
|
||||
opts := domain.DefaultWorkListOptions()
|
||||
if limitStr := r.URL.Query().Get("limit"); limitStr != "" {
|
||||
limit, err := strconv.Atoi(limitStr)
|
||||
if err != nil {
|
||||
|
||||
381
internal/handlers/work_lifecycle_test.go
Normal file
381
internal/handlers/work_lifecycle_test.go
Normal file
@ -0,0 +1,381 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/service"
|
||||
)
|
||||
|
||||
func TestWorkHandler_Fail(t *testing.T) {
|
||||
mockQueue := newMockWorkQueue()
|
||||
workService := service.NewWorkService(mockQueue, service.WorkServiceConfig{})
|
||||
handler := NewWorkHandler(workService)
|
||||
|
||||
// Pre-populate a running task
|
||||
mockQueue.tasks["task-1"] = &domain.WorkTask{
|
||||
ID: "task-1",
|
||||
ProjectID: "test-project",
|
||||
Type: domain.WorkTaskTypeBuild,
|
||||
Status: domain.WorkTaskStatusRunning,
|
||||
WorkerID: "worker-1",
|
||||
MaxRetries: 3,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
router := chi.NewRouter()
|
||||
handler.Mount(router)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
taskID string
|
||||
body FailWorkRequest
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "valid_fail",
|
||||
taskID: "task-1",
|
||||
body: FailWorkRequest{Error: "Build failed: npm error"},
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "missing_error",
|
||||
taskID: "task-1",
|
||||
body: FailWorkRequest{},
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "task_not_found",
|
||||
taskID: "nonexistent",
|
||||
body: FailWorkRequest{Error: "Failed"},
|
||||
wantStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
body, _ := json.Marshal(tt.body)
|
||||
req := httptest.NewRequest(http.MethodPost, "/work/"+tt.taskID+"/fail", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != tt.wantStatus {
|
||||
t.Errorf("got status %d, want %d; body: %s", rec.Code, tt.wantStatus, rec.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorkHandler_Cancel(t *testing.T) {
|
||||
mockQueue := newMockWorkQueue()
|
||||
workService := service.NewWorkService(mockQueue, service.WorkServiceConfig{})
|
||||
handler := NewWorkHandler(workService)
|
||||
|
||||
// Pre-populate tasks
|
||||
mockQueue.tasks["pending-task"] = &domain.WorkTask{
|
||||
ID: "pending-task",
|
||||
ProjectID: "test-project",
|
||||
Type: domain.WorkTaskTypeBuild,
|
||||
Status: domain.WorkTaskStatusPending,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
mockQueue.tasks["running-task"] = &domain.WorkTask{
|
||||
ID: "running-task",
|
||||
ProjectID: "test-project",
|
||||
Type: domain.WorkTaskTypeBuild,
|
||||
Status: domain.WorkTaskStatusRunning,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
router := chi.NewRouter()
|
||||
handler.Mount(router)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
taskID string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "cancel_pending_task",
|
||||
taskID: "pending-task",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "cancel_running_task_fails",
|
||||
taskID: "running-task",
|
||||
wantStatus: http.StatusNotFound, // Can only cancel pending tasks
|
||||
},
|
||||
{
|
||||
name: "task_not_found",
|
||||
taskID: "nonexistent",
|
||||
wantStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/work/"+tt.taskID+"/cancel", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != tt.wantStatus {
|
||||
t.Errorf("got status %d, want %d; body: %s", rec.Code, tt.wantStatus, rec.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorkHandler_GetTask(t *testing.T) {
|
||||
mockQueue := newMockWorkQueue()
|
||||
workService := service.NewWorkService(mockQueue, service.WorkServiceConfig{})
|
||||
handler := NewWorkHandler(workService)
|
||||
|
||||
// Pre-populate a task
|
||||
mockQueue.tasks["task-1"] = &domain.WorkTask{
|
||||
ID: "task-1",
|
||||
ProjectID: "test-project",
|
||||
Type: domain.WorkTaskTypeBuild,
|
||||
Status: domain.WorkTaskStatusRunning,
|
||||
Spec: map[string]any{
|
||||
"prompt": "Build it",
|
||||
},
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
router := chi.NewRouter()
|
||||
handler.Mount(router)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
taskID string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "get_existing_task",
|
||||
taskID: "task-1",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "task_not_found",
|
||||
taskID: "nonexistent",
|
||||
wantStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/work/"+tt.taskID, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != tt.wantStatus {
|
||||
t.Errorf("got status %d, want %d; body: %s", rec.Code, tt.wantStatus, rec.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorkHandler_ListByProject(t *testing.T) {
|
||||
mockQueue := newMockWorkQueue()
|
||||
workService := service.NewWorkService(mockQueue, service.WorkServiceConfig{})
|
||||
handler := NewWorkHandler(workService)
|
||||
|
||||
// Pre-populate tasks
|
||||
mockQueue.tasks["task-1"] = &domain.WorkTask{
|
||||
ID: "task-1",
|
||||
ProjectID: "project-a",
|
||||
Type: domain.WorkTaskTypeBuild,
|
||||
Status: domain.WorkTaskStatusPending,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
mockQueue.tasks["task-2"] = &domain.WorkTask{
|
||||
ID: "task-2",
|
||||
ProjectID: "project-a",
|
||||
Type: domain.WorkTaskTypeTest,
|
||||
Status: domain.WorkTaskStatusCompleted,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
mockQueue.tasks["task-3"] = &domain.WorkTask{
|
||||
ID: "task-3",
|
||||
ProjectID: "project-b",
|
||||
Type: domain.WorkTaskTypeDeploy,
|
||||
Status: domain.WorkTaskStatusRunning,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
router := chi.NewRouter()
|
||||
handler.Mount(router)
|
||||
|
||||
t.Run("list_all_for_project", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/work/projects/project-a", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("got status %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
data := resp["data"].(map[string]any)
|
||||
total := int(data["total"].(float64))
|
||||
if total != 2 {
|
||||
t.Errorf("got %d tasks, want 2", total)
|
||||
}
|
||||
// Verify pagination metadata is present
|
||||
if _, ok := data["limit"]; !ok {
|
||||
t.Error("expected limit in response")
|
||||
}
|
||||
if _, ok := data["offset"]; !ok {
|
||||
t.Error("expected offset in response")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("list_with_status_filter", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/work/projects/project-a?status=pending", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("got status %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
data := resp["data"].(map[string]any)
|
||||
total := int(data["total"].(float64))
|
||||
if total != 1 {
|
||||
t.Errorf("got %d tasks, want 1", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("list_with_pagination", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/work/projects/project-a?limit=1&offset=0", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("got status %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
data := resp["data"].(map[string]any)
|
||||
|
||||
// Total should reflect all matching tasks
|
||||
total := int(data["total"].(float64))
|
||||
if total != 2 {
|
||||
t.Errorf("got total=%d, want 2", total)
|
||||
}
|
||||
|
||||
// But tasks returned should be limited
|
||||
tasks := data["tasks"].([]any)
|
||||
if len(tasks) != 1 {
|
||||
t.Errorf("got %d tasks returned, want 1", len(tasks))
|
||||
}
|
||||
|
||||
// Verify limit/offset are reflected
|
||||
if int(data["limit"].(float64)) != 1 {
|
||||
t.Errorf("got limit=%v, want 1", data["limit"])
|
||||
}
|
||||
if int(data["offset"].(float64)) != 0 {
|
||||
t.Errorf("got offset=%v, want 0", data["offset"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid_limit", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/work/projects/project-a?limit=invalid", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("got status %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid_offset", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/work/projects/project-a?offset=invalid", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("got status %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid_status_filter", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/work/projects/project-a?status=invalid", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("got status %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWorkHandler_Stats(t *testing.T) {
|
||||
mockQueue := newMockWorkQueue()
|
||||
workService := service.NewWorkService(mockQueue, service.WorkServiceConfig{})
|
||||
handler := NewWorkHandler(workService)
|
||||
|
||||
// Pre-populate tasks with various statuses
|
||||
mockQueue.tasks["task-1"] = &domain.WorkTask{ID: "task-1", Status: domain.WorkTaskStatusPending}
|
||||
mockQueue.tasks["task-2"] = &domain.WorkTask{ID: "task-2", Status: domain.WorkTaskStatusPending}
|
||||
mockQueue.tasks["task-3"] = &domain.WorkTask{ID: "task-3", Status: domain.WorkTaskStatusRunning}
|
||||
mockQueue.tasks["task-4"] = &domain.WorkTask{ID: "task-4", Status: domain.WorkTaskStatusCompleted}
|
||||
mockQueue.tasks["task-5"] = &domain.WorkTask{ID: "task-5", Status: domain.WorkTaskStatusFailed}
|
||||
|
||||
router := chi.NewRouter()
|
||||
handler.Mount(router)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/work/stats", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("got status %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
data := resp["data"].(map[string]any)
|
||||
|
||||
if int(data["pending"].(float64)) != 2 {
|
||||
t.Errorf("got pending=%v, want 2", data["pending"])
|
||||
}
|
||||
if int(data["running"].(float64)) != 1 {
|
||||
t.Errorf("got running=%v, want 1", data["running"])
|
||||
}
|
||||
if int(data["completed"].(float64)) != 1 {
|
||||
t.Errorf("got completed=%v, want 1", data["completed"])
|
||||
}
|
||||
if int(data["failed"].(float64)) != 1 {
|
||||
t.Errorf("got failed=%v, want 1", data["failed"])
|
||||
}
|
||||
}
|
||||
@ -11,41 +11,40 @@ import (
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
"github.com/orchard9/rdev/internal/service"
|
||||
)
|
||||
|
||||
// mockWorkQueue implements port.WorkQueue for testing.
|
||||
type mockWorkQueue struct {
|
||||
tasks map[string]*port.WorkTask
|
||||
tasks map[string]*domain.WorkTask
|
||||
err error
|
||||
}
|
||||
|
||||
func newMockWorkQueue() *mockWorkQueue {
|
||||
return &mockWorkQueue{
|
||||
tasks: make(map[string]*port.WorkTask),
|
||||
tasks: make(map[string]*domain.WorkTask),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockWorkQueue) Enqueue(ctx context.Context, task *port.WorkTask) (string, error) {
|
||||
func (m *mockWorkQueue) Enqueue(ctx context.Context, task *domain.WorkTask) (string, error) {
|
||||
if m.err != nil {
|
||||
return "", m.err
|
||||
}
|
||||
id := "task-123"
|
||||
task.ID = id
|
||||
task.Status = port.WorkTaskStatusPending
|
||||
task.Status = domain.WorkTaskStatusPending
|
||||
task.CreatedAt = time.Now()
|
||||
m.tasks[id] = task
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (m *mockWorkQueue) Dequeue(ctx context.Context, workerID string) (*port.WorkTask, error) {
|
||||
func (m *mockWorkQueue) Dequeue(ctx context.Context, workerID string) (*domain.WorkTask, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
for _, task := range m.tasks {
|
||||
if task.Status == port.WorkTaskStatusPending {
|
||||
task.Status = port.WorkTaskStatusRunning
|
||||
if task.Status == domain.WorkTaskStatusPending {
|
||||
task.Status = domain.WorkTaskStatusRunning
|
||||
task.WorkerID = workerID
|
||||
now := time.Now()
|
||||
task.StartedAt = &now
|
||||
@ -55,7 +54,7 @@ func (m *mockWorkQueue) Dequeue(ctx context.Context, workerID string) (*port.Wor
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockWorkQueue) Complete(ctx context.Context, taskID string, result *port.WorkResult) error {
|
||||
func (m *mockWorkQueue) Complete(ctx context.Context, taskID string, result *domain.WorkResult) error {
|
||||
if m.err != nil {
|
||||
return m.err
|
||||
}
|
||||
@ -63,7 +62,7 @@ func (m *mockWorkQueue) Complete(ctx context.Context, taskID string, result *por
|
||||
if !ok {
|
||||
return domain.ErrWorkTaskNotFound
|
||||
}
|
||||
task.Status = port.WorkTaskStatusCompleted
|
||||
task.Status = domain.WorkTaskStatusCompleted
|
||||
task.Result = result
|
||||
now := time.Now()
|
||||
task.CompletedAt = &now
|
||||
@ -79,11 +78,11 @@ func (m *mockWorkQueue) Fail(ctx context.Context, taskID string, errMsg string)
|
||||
return domain.ErrWorkTaskNotFound
|
||||
}
|
||||
if task.RetryCount < task.MaxRetries {
|
||||
task.Status = port.WorkTaskStatusPending
|
||||
task.Status = domain.WorkTaskStatusPending
|
||||
task.RetryCount++
|
||||
task.Error = errMsg
|
||||
} else {
|
||||
task.Status = port.WorkTaskStatusFailed
|
||||
task.Status = domain.WorkTaskStatusFailed
|
||||
task.Error = errMsg
|
||||
now := time.Now()
|
||||
task.CompletedAt = &now
|
||||
@ -99,16 +98,16 @@ func (m *mockWorkQueue) Cancel(ctx context.Context, taskID string) error {
|
||||
if !ok {
|
||||
return domain.ErrWorkTaskNotFound
|
||||
}
|
||||
if task.Status != port.WorkTaskStatusPending {
|
||||
if task.Status != domain.WorkTaskStatusPending {
|
||||
return domain.ErrWorkTaskNotFound
|
||||
}
|
||||
task.Status = port.WorkTaskStatusCancelled
|
||||
task.Status = domain.WorkTaskStatusCancelled
|
||||
now := time.Now()
|
||||
task.CompletedAt = &now
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockWorkQueue) GetTask(ctx context.Context, taskID string) (*port.WorkTask, error) {
|
||||
func (m *mockWorkQueue) GetTask(ctx context.Context, taskID string) (*domain.WorkTask, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
@ -119,13 +118,13 @@ func (m *mockWorkQueue) GetTask(ctx context.Context, taskID string) (*port.WorkT
|
||||
return task, nil
|
||||
}
|
||||
|
||||
func (m *mockWorkQueue) ListByProject(ctx context.Context, projectID string, status *port.WorkTaskStatus, opts port.WorkListOptions) (*port.WorkListResult, error) {
|
||||
func (m *mockWorkQueue) ListByProject(ctx context.Context, projectID string, status *domain.WorkTaskStatus, opts domain.WorkListOptions) (*domain.WorkListResult, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
opts.Normalize()
|
||||
|
||||
var tasks []*port.WorkTask
|
||||
var tasks []*domain.WorkTask
|
||||
for _, task := range m.tasks {
|
||||
if task.ProjectID == projectID {
|
||||
if status == nil || task.Status == *status {
|
||||
@ -146,7 +145,7 @@ func (m *mockWorkQueue) ListByProject(ctx context.Context, projectID string, sta
|
||||
tasks = tasks[opts.Offset:end]
|
||||
}
|
||||
|
||||
return &port.WorkListResult{
|
||||
return &domain.WorkListResult{
|
||||
Tasks: tasks,
|
||||
Total: total,
|
||||
Limit: opts.Limit,
|
||||
@ -154,22 +153,22 @@ func (m *mockWorkQueue) ListByProject(ctx context.Context, projectID string, sta
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockWorkQueue) GetStats(ctx context.Context) (*port.WorkQueueStats, error) {
|
||||
func (m *mockWorkQueue) GetStats(ctx context.Context) (*domain.WorkQueueStats, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
stats := &port.WorkQueueStats{}
|
||||
stats := &domain.WorkQueueStats{}
|
||||
for _, task := range m.tasks {
|
||||
switch task.Status {
|
||||
case port.WorkTaskStatusPending:
|
||||
case domain.WorkTaskStatusPending:
|
||||
stats.Pending++
|
||||
case port.WorkTaskStatusRunning:
|
||||
case domain.WorkTaskStatusRunning:
|
||||
stats.Running++
|
||||
case port.WorkTaskStatusCompleted:
|
||||
case domain.WorkTaskStatusCompleted:
|
||||
stats.Completed++
|
||||
case port.WorkTaskStatusFailed:
|
||||
case domain.WorkTaskStatusFailed:
|
||||
stats.Failed++
|
||||
case port.WorkTaskStatusCancelled:
|
||||
case domain.WorkTaskStatusCancelled:
|
||||
stats.Cancelled++
|
||||
}
|
||||
}
|
||||
@ -292,11 +291,11 @@ func TestWorkHandler_Dequeue(t *testing.T) {
|
||||
handler := NewWorkHandler(workService)
|
||||
|
||||
// Pre-populate a pending task
|
||||
mockQueue.tasks["task-1"] = &port.WorkTask{
|
||||
mockQueue.tasks["task-1"] = &domain.WorkTask{
|
||||
ID: "task-1",
|
||||
ProjectID: "test-project",
|
||||
Type: port.WorkTaskTypeBuild,
|
||||
Status: port.WorkTaskStatusPending,
|
||||
Type: domain.WorkTaskTypeBuild,
|
||||
Status: domain.WorkTaskStatusPending,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
@ -359,11 +358,11 @@ func TestWorkHandler_Complete(t *testing.T) {
|
||||
handler := NewWorkHandler(workService)
|
||||
|
||||
// Pre-populate a running task
|
||||
mockQueue.tasks["task-1"] = &port.WorkTask{
|
||||
mockQueue.tasks["task-1"] = &domain.WorkTask{
|
||||
ID: "task-1",
|
||||
ProjectID: "test-project",
|
||||
Type: port.WorkTaskTypeBuild,
|
||||
Status: port.WorkTaskStatusRunning,
|
||||
Type: domain.WorkTaskTypeBuild,
|
||||
Status: domain.WorkTaskStatusRunning,
|
||||
WorkerID: "worker-1",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
@ -411,370 +410,3 @@ func TestWorkHandler_Complete(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorkHandler_Fail(t *testing.T) {
|
||||
mockQueue := newMockWorkQueue()
|
||||
workService := service.NewWorkService(mockQueue, service.WorkServiceConfig{})
|
||||
handler := NewWorkHandler(workService)
|
||||
|
||||
// Pre-populate a running task
|
||||
mockQueue.tasks["task-1"] = &port.WorkTask{
|
||||
ID: "task-1",
|
||||
ProjectID: "test-project",
|
||||
Type: port.WorkTaskTypeBuild,
|
||||
Status: port.WorkTaskStatusRunning,
|
||||
WorkerID: "worker-1",
|
||||
MaxRetries: 3,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
router := chi.NewRouter()
|
||||
handler.Mount(router)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
taskID string
|
||||
body FailWorkRequest
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "valid_fail",
|
||||
taskID: "task-1",
|
||||
body: FailWorkRequest{Error: "Build failed: npm error"},
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "missing_error",
|
||||
taskID: "task-1",
|
||||
body: FailWorkRequest{},
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "task_not_found",
|
||||
taskID: "nonexistent",
|
||||
body: FailWorkRequest{Error: "Failed"},
|
||||
wantStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
body, _ := json.Marshal(tt.body)
|
||||
req := httptest.NewRequest(http.MethodPost, "/work/"+tt.taskID+"/fail", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != tt.wantStatus {
|
||||
t.Errorf("got status %d, want %d; body: %s", rec.Code, tt.wantStatus, rec.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorkHandler_Cancel(t *testing.T) {
|
||||
mockQueue := newMockWorkQueue()
|
||||
workService := service.NewWorkService(mockQueue, service.WorkServiceConfig{})
|
||||
handler := NewWorkHandler(workService)
|
||||
|
||||
// Pre-populate tasks
|
||||
mockQueue.tasks["pending-task"] = &port.WorkTask{
|
||||
ID: "pending-task",
|
||||
ProjectID: "test-project",
|
||||
Type: port.WorkTaskTypeBuild,
|
||||
Status: port.WorkTaskStatusPending,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
mockQueue.tasks["running-task"] = &port.WorkTask{
|
||||
ID: "running-task",
|
||||
ProjectID: "test-project",
|
||||
Type: port.WorkTaskTypeBuild,
|
||||
Status: port.WorkTaskStatusRunning,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
router := chi.NewRouter()
|
||||
handler.Mount(router)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
taskID string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "cancel_pending_task",
|
||||
taskID: "pending-task",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "cancel_running_task_fails",
|
||||
taskID: "running-task",
|
||||
wantStatus: http.StatusNotFound, // Can only cancel pending tasks
|
||||
},
|
||||
{
|
||||
name: "task_not_found",
|
||||
taskID: "nonexistent",
|
||||
wantStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/work/"+tt.taskID+"/cancel", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != tt.wantStatus {
|
||||
t.Errorf("got status %d, want %d; body: %s", rec.Code, tt.wantStatus, rec.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorkHandler_GetTask(t *testing.T) {
|
||||
mockQueue := newMockWorkQueue()
|
||||
workService := service.NewWorkService(mockQueue, service.WorkServiceConfig{})
|
||||
handler := NewWorkHandler(workService)
|
||||
|
||||
// Pre-populate a task
|
||||
mockQueue.tasks["task-1"] = &port.WorkTask{
|
||||
ID: "task-1",
|
||||
ProjectID: "test-project",
|
||||
Type: port.WorkTaskTypeBuild,
|
||||
Status: port.WorkTaskStatusRunning,
|
||||
Spec: map[string]any{
|
||||
"prompt": "Build it",
|
||||
},
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
router := chi.NewRouter()
|
||||
handler.Mount(router)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
taskID string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "get_existing_task",
|
||||
taskID: "task-1",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "task_not_found",
|
||||
taskID: "nonexistent",
|
||||
wantStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/work/"+tt.taskID, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != tt.wantStatus {
|
||||
t.Errorf("got status %d, want %d; body: %s", rec.Code, tt.wantStatus, rec.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorkHandler_ListByProject(t *testing.T) {
|
||||
mockQueue := newMockWorkQueue()
|
||||
workService := service.NewWorkService(mockQueue, service.WorkServiceConfig{})
|
||||
handler := NewWorkHandler(workService)
|
||||
|
||||
// Pre-populate tasks
|
||||
mockQueue.tasks["task-1"] = &port.WorkTask{
|
||||
ID: "task-1",
|
||||
ProjectID: "project-a",
|
||||
Type: port.WorkTaskTypeBuild,
|
||||
Status: port.WorkTaskStatusPending,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
mockQueue.tasks["task-2"] = &port.WorkTask{
|
||||
ID: "task-2",
|
||||
ProjectID: "project-a",
|
||||
Type: port.WorkTaskTypeTest,
|
||||
Status: port.WorkTaskStatusCompleted,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
mockQueue.tasks["task-3"] = &port.WorkTask{
|
||||
ID: "task-3",
|
||||
ProjectID: "project-b",
|
||||
Type: port.WorkTaskTypeDeploy,
|
||||
Status: port.WorkTaskStatusRunning,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
router := chi.NewRouter()
|
||||
handler.Mount(router)
|
||||
|
||||
t.Run("list_all_for_project", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/work/projects/project-a", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("got status %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
data := resp["data"].(map[string]any)
|
||||
total := int(data["total"].(float64))
|
||||
if total != 2 {
|
||||
t.Errorf("got %d tasks, want 2", total)
|
||||
}
|
||||
// Verify pagination metadata is present
|
||||
if _, ok := data["limit"]; !ok {
|
||||
t.Error("expected limit in response")
|
||||
}
|
||||
if _, ok := data["offset"]; !ok {
|
||||
t.Error("expected offset in response")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("list_with_status_filter", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/work/projects/project-a?status=pending", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("got status %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
data := resp["data"].(map[string]any)
|
||||
total := int(data["total"].(float64))
|
||||
if total != 1 {
|
||||
t.Errorf("got %d tasks, want 1", total)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("list_with_pagination", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/work/projects/project-a?limit=1&offset=0", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("got status %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
data := resp["data"].(map[string]any)
|
||||
|
||||
// Total should reflect all matching tasks
|
||||
total := int(data["total"].(float64))
|
||||
if total != 2 {
|
||||
t.Errorf("got total=%d, want 2", total)
|
||||
}
|
||||
|
||||
// But tasks returned should be limited
|
||||
tasks := data["tasks"].([]any)
|
||||
if len(tasks) != 1 {
|
||||
t.Errorf("got %d tasks returned, want 1", len(tasks))
|
||||
}
|
||||
|
||||
// Verify limit/offset are reflected
|
||||
if int(data["limit"].(float64)) != 1 {
|
||||
t.Errorf("got limit=%v, want 1", data["limit"])
|
||||
}
|
||||
if int(data["offset"].(float64)) != 0 {
|
||||
t.Errorf("got offset=%v, want 0", data["offset"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid_limit", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/work/projects/project-a?limit=invalid", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("got status %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid_offset", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/work/projects/project-a?offset=invalid", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("got status %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid_status_filter", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/work/projects/project-a?status=invalid", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("got status %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWorkHandler_Stats(t *testing.T) {
|
||||
mockQueue := newMockWorkQueue()
|
||||
workService := service.NewWorkService(mockQueue, service.WorkServiceConfig{})
|
||||
handler := NewWorkHandler(workService)
|
||||
|
||||
// Pre-populate tasks with various statuses
|
||||
mockQueue.tasks["task-1"] = &port.WorkTask{ID: "task-1", Status: port.WorkTaskStatusPending}
|
||||
mockQueue.tasks["task-2"] = &port.WorkTask{ID: "task-2", Status: port.WorkTaskStatusPending}
|
||||
mockQueue.tasks["task-3"] = &port.WorkTask{ID: "task-3", Status: port.WorkTaskStatusRunning}
|
||||
mockQueue.tasks["task-4"] = &port.WorkTask{ID: "task-4", Status: port.WorkTaskStatusCompleted}
|
||||
mockQueue.tasks["task-5"] = &port.WorkTask{ID: "task-5", Status: port.WorkTaskStatusFailed}
|
||||
|
||||
router := chi.NewRouter()
|
||||
handler.Mount(router)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/work/stats", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("got status %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
data := resp["data"].(map[string]any)
|
||||
|
||||
if int(data["pending"].(float64)) != 2 {
|
||||
t.Errorf("got pending=%v, want 2", data["pending"])
|
||||
}
|
||||
if int(data["running"].(float64)) != 1 {
|
||||
t.Errorf("got running=%v, want 1", data["running"])
|
||||
}
|
||||
if int(data["completed"].(float64)) != 1 {
|
||||
t.Errorf("got completed=%v, want 1", data["completed"])
|
||||
}
|
||||
if int(data["failed"].(float64)) != 1 {
|
||||
t.Errorf("got failed=%v, want 1", data["failed"])
|
||||
}
|
||||
}
|
||||
|
||||
162
internal/handlers/workers.go
Normal file
162
internal/handlers/workers.go
Normal file
@ -0,0 +1,162 @@
|
||||
// Package handlers provides HTTP handlers for the rdev API.
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/orchard9/rdev/internal/auth"
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
"github.com/orchard9/rdev/internal/service"
|
||||
"github.com/orchard9/rdev/pkg/api"
|
||||
)
|
||||
|
||||
// WorkersHandler handles worker pool management endpoints.
|
||||
type WorkersHandler struct {
|
||||
workerService *service.WorkerService
|
||||
}
|
||||
|
||||
// NewWorkersHandler creates a new workers handler.
|
||||
func NewWorkersHandler(workerService *service.WorkerService) *WorkersHandler {
|
||||
return &WorkersHandler{
|
||||
workerService: workerService,
|
||||
}
|
||||
}
|
||||
|
||||
// Mount registers the worker pool routes.
|
||||
func (h *WorkersHandler) Mount(r api.Router) {
|
||||
r.Route("/workers", func(r chi.Router) {
|
||||
r.With(auth.RequireScope(auth.ScopeWorkersRead, auth.ScopeAdmin)).Get("/", h.List)
|
||||
r.With(auth.RequireScope(auth.ScopeWorkersRead, auth.ScopeAdmin)).Get("/{workerId}", h.Get)
|
||||
r.With(auth.RequireScope(auth.ScopeWorkersWrite, auth.ScopeAdmin)).Post("/{workerId}/drain", h.Drain)
|
||||
})
|
||||
}
|
||||
|
||||
// WorkerDTO is the data transfer object for workers.
|
||||
type WorkerDTO struct {
|
||||
ID string `json:"id"`
|
||||
Hostname string `json:"hostname"`
|
||||
Status string `json:"status"`
|
||||
CurrentTask string `json:"current_task,omitempty"`
|
||||
Capabilities []string `json:"capabilities,omitempty"`
|
||||
RegisteredAt string `json:"registered_at"`
|
||||
LastHeartbeat string `json:"last_heartbeat"`
|
||||
Version string `json:"version,omitempty"`
|
||||
}
|
||||
|
||||
func toWorkerDTO(w *domain.Worker) *WorkerDTO {
|
||||
if w == nil {
|
||||
return nil
|
||||
}
|
||||
return &WorkerDTO{
|
||||
ID: w.ID,
|
||||
Hostname: w.Hostname,
|
||||
Status: string(w.Status),
|
||||
CurrentTask: w.CurrentTask,
|
||||
Capabilities: w.Capabilities,
|
||||
RegisteredAt: w.RegisteredAt.Format("2006-01-02T15:04:05Z07:00"),
|
||||
LastHeartbeat: w.LastHeartbeat.Format("2006-01-02T15:04:05Z07:00"),
|
||||
Version: w.Version,
|
||||
}
|
||||
}
|
||||
|
||||
// List returns all workers with optional status filter.
|
||||
// GET /workers?status=idle
|
||||
func (h *WorkersHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
filter := port.WorkerFilter{}
|
||||
|
||||
if s := r.URL.Query().Get("status"); s != "" {
|
||||
st := domain.WorkerStatus(s)
|
||||
if !st.IsValid() {
|
||||
api.WriteBadRequest(w, r, "invalid status: must be idle, busy, draining, or offline")
|
||||
return
|
||||
}
|
||||
filter.Status = &st
|
||||
}
|
||||
|
||||
workers, err := h.workerService.ListWorkers(r.Context(), filter)
|
||||
if err != nil {
|
||||
api.WriteInternalError(w, r, "failed to list workers")
|
||||
return
|
||||
}
|
||||
|
||||
dtos := make([]*WorkerDTO, len(workers))
|
||||
for i, wkr := range workers {
|
||||
dtos[i] = toWorkerDTO(wkr)
|
||||
}
|
||||
|
||||
// Compute summary counts
|
||||
idle, busy, draining, offline := 0, 0, 0, 0
|
||||
for _, wkr := range workers {
|
||||
switch wkr.Status {
|
||||
case domain.WorkerStatusIdle:
|
||||
idle++
|
||||
case domain.WorkerStatusBusy:
|
||||
busy++
|
||||
case domain.WorkerStatusDraining:
|
||||
draining++
|
||||
case domain.WorkerStatusOffline:
|
||||
offline++
|
||||
}
|
||||
}
|
||||
|
||||
api.WriteSuccess(w, r, map[string]any{
|
||||
"workers": dtos,
|
||||
"total": len(dtos),
|
||||
"summary": map[string]int{
|
||||
"idle": idle,
|
||||
"busy": busy,
|
||||
"draining": draining,
|
||||
"offline": offline,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Get returns a specific worker by ID.
|
||||
// GET /workers/{workerId}
|
||||
func (h *WorkersHandler) Get(w http.ResponseWriter, r *http.Request) {
|
||||
workerID := chi.URLParam(r, "workerId")
|
||||
if workerID == "" {
|
||||
api.WriteBadRequest(w, r, "worker ID is required")
|
||||
return
|
||||
}
|
||||
|
||||
worker, err := h.workerService.GetWorker(r.Context(), workerID)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrWorkerNotFound) {
|
||||
api.WriteNotFound(w, r, "worker not found: "+workerID)
|
||||
return
|
||||
}
|
||||
api.WriteInternalError(w, r, "failed to get worker")
|
||||
return
|
||||
}
|
||||
|
||||
api.WriteSuccess(w, r, toWorkerDTO(worker))
|
||||
}
|
||||
|
||||
// Drain sets a worker to draining status.
|
||||
// POST /workers/{workerId}/drain
|
||||
func (h *WorkersHandler) Drain(w http.ResponseWriter, r *http.Request) {
|
||||
workerID := chi.URLParam(r, "workerId")
|
||||
if workerID == "" {
|
||||
api.WriteBadRequest(w, r, "worker ID is required")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.workerService.DrainWorker(r.Context(), workerID); err != nil {
|
||||
if errors.Is(err, domain.ErrWorkerNotFound) {
|
||||
api.WriteNotFound(w, r, "worker not found: "+workerID)
|
||||
return
|
||||
}
|
||||
api.WriteInternalError(w, r, "failed to drain worker")
|
||||
return
|
||||
}
|
||||
|
||||
api.WriteSuccess(w, r, map[string]any{
|
||||
"worker_id": workerID,
|
||||
"status": "draining",
|
||||
"message": "worker will finish current task then stop accepting new work",
|
||||
})
|
||||
}
|
||||
304
internal/handlers/workers_test.go
Normal file
304
internal/handlers/workers_test.go
Normal file
@ -0,0 +1,304 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
"github.com/orchard9/rdev/internal/service"
|
||||
)
|
||||
|
||||
// mockWorkerRegistry implements port.WorkerRegistry for testing.
|
||||
type mockWorkerRegistry struct {
|
||||
workers map[string]*domain.Worker
|
||||
err error
|
||||
}
|
||||
|
||||
func newMockWorkerRegistry() *mockWorkerRegistry {
|
||||
return &mockWorkerRegistry{
|
||||
workers: make(map[string]*domain.Worker),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockWorkerRegistry) Register(_ context.Context, w *domain.Worker) error {
|
||||
if m.err != nil {
|
||||
return m.err
|
||||
}
|
||||
m.workers[w.ID] = w
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockWorkerRegistry) Heartbeat(_ context.Context, workerID string) error {
|
||||
if m.err != nil {
|
||||
return m.err
|
||||
}
|
||||
w, ok := m.workers[workerID]
|
||||
if !ok {
|
||||
return domain.ErrWorkerNotFound
|
||||
}
|
||||
w.LastHeartbeat = time.Now()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockWorkerRegistry) UpdateStatus(_ context.Context, workerID string, status domain.WorkerStatus, taskID string) error {
|
||||
if m.err != nil {
|
||||
return m.err
|
||||
}
|
||||
w, ok := m.workers[workerID]
|
||||
if !ok {
|
||||
return domain.ErrWorkerNotFound
|
||||
}
|
||||
w.Status = status
|
||||
w.CurrentTask = taskID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockWorkerRegistry) Deregister(_ context.Context, workerID string) error {
|
||||
if m.err != nil {
|
||||
return m.err
|
||||
}
|
||||
delete(m.workers, workerID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockWorkerRegistry) Get(_ context.Context, workerID string) (*domain.Worker, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
w, ok := m.workers[workerID]
|
||||
if !ok {
|
||||
return nil, domain.ErrWorkerNotFound
|
||||
}
|
||||
return w, nil
|
||||
}
|
||||
|
||||
func (m *mockWorkerRegistry) List(_ context.Context, filter port.WorkerFilter) ([]*domain.Worker, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
var result []*domain.Worker
|
||||
for _, w := range m.workers {
|
||||
if filter.Status != nil && w.Status != *filter.Status {
|
||||
continue
|
||||
}
|
||||
result = append(result, w)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockWorkerRegistry) MarkStaleOffline(_ context.Context, _ time.Duration) (int, error) {
|
||||
return 0, m.err
|
||||
}
|
||||
|
||||
func TestWorkersHandler_List(t *testing.T) {
|
||||
registry := newMockWorkerRegistry()
|
||||
queue := newMockWorkQueue()
|
||||
workerService := service.NewWorkerService(registry, queue, nil)
|
||||
handler := NewWorkersHandler(workerService)
|
||||
|
||||
// Populate workers
|
||||
registry.workers["worker-1"] = &domain.Worker{
|
||||
ID: "worker-1",
|
||||
Hostname: "host-1",
|
||||
Status: domain.WorkerStatusIdle,
|
||||
Capabilities: []string{"build"},
|
||||
RegisteredAt: time.Now(),
|
||||
LastHeartbeat: time.Now(),
|
||||
Version: "1.0.0",
|
||||
}
|
||||
registry.workers["worker-2"] = &domain.Worker{
|
||||
ID: "worker-2",
|
||||
Hostname: "host-2",
|
||||
Status: domain.WorkerStatusBusy,
|
||||
CurrentTask: "task-abc",
|
||||
Capabilities: []string{"build", "deploy"},
|
||||
RegisteredAt: time.Now(),
|
||||
LastHeartbeat: time.Now(),
|
||||
Version: "1.0.0",
|
||||
}
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Use(testAdminAuth)
|
||||
handler.Mount(router)
|
||||
|
||||
t.Run("list_all_workers", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/workers", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("got status %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
data, ok := resp["data"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected data to be map, got %T", resp["data"])
|
||||
}
|
||||
totalF, ok := data["total"].(float64)
|
||||
if !ok {
|
||||
t.Fatalf("expected total to be float64, got %T", data["total"])
|
||||
}
|
||||
if int(totalF) != 2 {
|
||||
t.Errorf("got total=%d, want 2", int(totalF))
|
||||
}
|
||||
|
||||
summary, ok := data["summary"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected summary to be map, got %T", data["summary"])
|
||||
}
|
||||
if idleF, ok := summary["idle"].(float64); !ok || int(idleF) != 1 {
|
||||
t.Errorf("got idle=%v, want 1", summary["idle"])
|
||||
}
|
||||
if busyF, ok := summary["busy"].(float64); !ok || int(busyF) != 1 {
|
||||
t.Errorf("got busy=%v, want 1", summary["busy"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filter_by_status", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/workers?status=idle", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("got status %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
data, ok := resp["data"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected data to be map, got %T", resp["data"])
|
||||
}
|
||||
totalF, ok := data["total"].(float64)
|
||||
if !ok {
|
||||
t.Fatalf("expected total to be float64, got %T", data["total"])
|
||||
}
|
||||
if int(totalF) != 1 {
|
||||
t.Errorf("got total=%d, want 1", int(totalF))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid_status_filter", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/workers?status=invalid", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("got status %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWorkersHandler_Get(t *testing.T) {
|
||||
registry := newMockWorkerRegistry()
|
||||
queue := newMockWorkQueue()
|
||||
workerService := service.NewWorkerService(registry, queue, nil)
|
||||
handler := NewWorkersHandler(workerService)
|
||||
|
||||
registry.workers["worker-1"] = &domain.Worker{
|
||||
ID: "worker-1",
|
||||
Hostname: "host-1",
|
||||
Status: domain.WorkerStatusIdle,
|
||||
RegisteredAt: time.Now(),
|
||||
LastHeartbeat: time.Now(),
|
||||
}
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Use(testAdminAuth)
|
||||
handler.Mount(router)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
workerID string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "existing_worker",
|
||||
workerID: "worker-1",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "not_found",
|
||||
workerID: "nonexistent",
|
||||
wantStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/workers/"+tt.workerID, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != tt.wantStatus {
|
||||
t.Errorf("got status %d, want %d; body: %s", rec.Code, tt.wantStatus, rec.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorkersHandler_Drain(t *testing.T) {
|
||||
registry := newMockWorkerRegistry()
|
||||
queue := newMockWorkQueue()
|
||||
workerService := service.NewWorkerService(registry, queue, nil)
|
||||
handler := NewWorkersHandler(workerService)
|
||||
|
||||
registry.workers["worker-1"] = &domain.Worker{
|
||||
ID: "worker-1",
|
||||
Hostname: "host-1",
|
||||
Status: domain.WorkerStatusBusy,
|
||||
CurrentTask: "task-abc",
|
||||
RegisteredAt: time.Now(),
|
||||
LastHeartbeat: time.Now(),
|
||||
}
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Use(testAdminAuth)
|
||||
handler.Mount(router)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
workerID string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "drain_existing_worker",
|
||||
workerID: "worker-1",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "drain_nonexistent_worker",
|
||||
workerID: "nonexistent",
|
||||
wantStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/workers/"+tt.workerID+"/drain", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != tt.wantStatus {
|
||||
t.Errorf("got status %d, want %d; body: %s", rec.Code, tt.wantStatus, rec.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Verify the worker was actually set to draining
|
||||
if registry.workers["worker-1"].Status != domain.WorkerStatusDraining {
|
||||
t.Errorf("expected worker status to be draining, got %s", registry.workers["worker-1"].Status)
|
||||
}
|
||||
}
|
||||
@ -25,6 +25,57 @@ var (
|
||||
Buckets: prometheus.ExponentialBuckets(0.1, 2, 15), // 0.1s to ~27min
|
||||
}, []string{"project", "type"})
|
||||
|
||||
// Code Agents
|
||||
agentRequestsTotal = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "rdev_agent_requests_total",
|
||||
Help: "Total number of code agent requests",
|
||||
}, []string{"provider", "status"})
|
||||
|
||||
agentRequestDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Name: "rdev_agent_request_duration_seconds",
|
||||
Help: "Duration of code agent requests in seconds",
|
||||
Buckets: prometheus.ExponentialBuckets(0.1, 2, 15), // 0.1s to ~27min
|
||||
}, []string{"provider"})
|
||||
|
||||
agentToolUse = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "rdev_agent_tool_use_total",
|
||||
Help: "Total number of tool invocations by code agents",
|
||||
}, []string{"provider", "tool"})
|
||||
|
||||
agentAvailability = promauto.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Name: "rdev_agent_available",
|
||||
Help: "Whether the code agent is available (1) or not (0)",
|
||||
}, []string{"provider"})
|
||||
|
||||
// Worker Pool
|
||||
workersTotal = promauto.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Name: "rdev_workers_total",
|
||||
Help: "Number of registered workers by status",
|
||||
}, []string{"status"})
|
||||
|
||||
workerHeartbeatAge = promauto.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Name: "rdev_worker_heartbeat_age_seconds",
|
||||
Help: "Age of the most recent worker heartbeat in seconds",
|
||||
}, []string{"worker_id"})
|
||||
|
||||
// Builds
|
||||
buildsTotal = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "rdev_builds_total",
|
||||
Help: "Total number of build tasks by status",
|
||||
}, []string{"project", "status"})
|
||||
|
||||
buildDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Name: "rdev_build_duration_seconds",
|
||||
Help: "Duration of build executions in seconds",
|
||||
Buckets: prometheus.ExponentialBuckets(1, 2, 12), // 1s to ~34min
|
||||
}, []string{"project"})
|
||||
|
||||
// Work Queue
|
||||
workQueueDepth = promauto.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Name: "rdev_work_queue_depth",
|
||||
Help: "Number of tasks in the work queue by status",
|
||||
}, []string{"status"})
|
||||
|
||||
// Streams
|
||||
activeStreams = promauto.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Name: "rdev_active_streams",
|
||||
@ -81,6 +132,49 @@ func RecordAuthFailure(reason string) {
|
||||
authFailures.WithLabelValues(reason).Inc()
|
||||
}
|
||||
|
||||
// RecordAgentRequest records a code agent request execution.
|
||||
func RecordAgentRequest(provider, status string, durationMs int64) {
|
||||
agentRequestsTotal.WithLabelValues(provider, status).Inc()
|
||||
agentRequestDuration.WithLabelValues(provider).Observe(float64(durationMs) / 1000.0)
|
||||
}
|
||||
|
||||
// RecordAgentToolUse records a tool invocation by a code agent.
|
||||
func RecordAgentToolUse(provider, tool string) {
|
||||
agentToolUse.WithLabelValues(provider, tool).Inc()
|
||||
}
|
||||
|
||||
// SetAgentAvailability sets the availability status of a code agent.
|
||||
func SetAgentAvailability(provider string, available bool) {
|
||||
val := 0.0
|
||||
if available {
|
||||
val = 1.0
|
||||
}
|
||||
agentAvailability.WithLabelValues(provider).Set(val)
|
||||
}
|
||||
|
||||
// SetWorkerCount sets the number of workers for a given status.
|
||||
func SetWorkerCount(status string, count int) {
|
||||
workersTotal.WithLabelValues(status).Set(float64(count))
|
||||
}
|
||||
|
||||
// RecordWorkerHeartbeat sets the age of a worker's most recent heartbeat.
|
||||
func RecordWorkerHeartbeat(workerID string, ageSeconds float64) {
|
||||
workerHeartbeatAge.WithLabelValues(workerID).Set(ageSeconds)
|
||||
}
|
||||
|
||||
// RecordBuild records a build task completion.
|
||||
func RecordBuild(project, status string, durationMs int64) {
|
||||
buildsTotal.WithLabelValues(project, status).Inc()
|
||||
if durationMs > 0 {
|
||||
buildDuration.WithLabelValues(project).Observe(float64(durationMs) / 1000.0)
|
||||
}
|
||||
}
|
||||
|
||||
// SetWorkQueueDepth sets the current depth of the work queue for a status.
|
||||
func SetWorkQueueDepth(status string, count int64) {
|
||||
workQueueDepth.WithLabelValues(status).Set(float64(count))
|
||||
}
|
||||
|
||||
// Handler returns the Prometheus HTTP handler.
|
||||
func Handler() http.Handler {
|
||||
return promhttp.Handler()
|
||||
@ -124,6 +218,10 @@ var pathNormalizers = []struct {
|
||||
}{
|
||||
// /keys/uuid -> /keys/{id}
|
||||
{regexp.MustCompile(`^/keys/[^/]+$`), "/keys/{id}"},
|
||||
// /workers/{id}/... -> /workers/{id}/...
|
||||
{regexp.MustCompile(`^/workers/[^/]+(/.*)?$`), "/workers/{id}$1"},
|
||||
// /builds/{id} -> /builds/{id}
|
||||
{regexp.MustCompile(`^/builds/[^/]+$`), "/builds/{id}"},
|
||||
// /projects/{id}/claude-config/{type}/{name} -> /projects/{id}/claude-config/{type}/{name}
|
||||
{regexp.MustCompile(`^/projects/[^/]+/claude-config/(commands|skills|agents)/[^/]+$`), "/projects/{id}/claude-config/$1/{name}"},
|
||||
// /projects/{id}/... (any sub-path) - must be last as it's most general
|
||||
|
||||
@ -63,7 +63,7 @@ func RateLimitMiddleware(cfg RateLimitConfig) func(http.Handler) http.Handler {
|
||||
}
|
||||
|
||||
// Skip rate limiting for admin keys
|
||||
if apiKey.ID == "admin" {
|
||||
if string(apiKey.ID) == "admin" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
@ -71,7 +71,7 @@ func RateLimitMiddleware(cfg RateLimitConfig) func(http.Handler) http.Handler {
|
||||
// Check rate limit and record atomically to prevent race conditions
|
||||
// RecordRequest is called first to ensure the count is incremented before
|
||||
// we check, preventing burst bypass under high concurrency
|
||||
if err := cfg.Limiter.RecordRequest(r.Context(), apiKey.ID); err != nil {
|
||||
if err := cfg.Limiter.RecordRequest(r.Context(), string(apiKey.ID)); err != nil {
|
||||
logger.Error("failed to record rate limit request", "error", err, "key_id", apiKey.ID)
|
||||
// On error, allow the request (fail open)
|
||||
next.ServeHTTP(w, r)
|
||||
@ -79,7 +79,7 @@ func RateLimitMiddleware(cfg RateLimitConfig) func(http.Handler) http.Handler {
|
||||
}
|
||||
|
||||
// Now check the limit (which includes the just-recorded request)
|
||||
result, err := cfg.Limiter.CheckLimit(r.Context(), apiKey.ID)
|
||||
result, err := cfg.Limiter.CheckLimit(r.Context(), string(apiKey.ID))
|
||||
if err != nil {
|
||||
logger.Error("failed to check rate limit", "error", err, "key_id", apiKey.ID)
|
||||
// On error, allow the request (fail open)
|
||||
|
||||
46
internal/port/build_audit.go
Normal file
46
internal/port/build_audit.go
Normal file
@ -0,0 +1,46 @@
|
||||
// Package port defines interface contracts for external adapters.
|
||||
package port
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
)
|
||||
|
||||
// BuildAudit records build history for observability and debugging.
|
||||
// Every build that passes through the system gets an audit entry,
|
||||
// providing a complete history of what was requested, who executed it,
|
||||
// and what the outcome was.
|
||||
type BuildAudit interface {
|
||||
// Record creates a new audit entry when a build starts.
|
||||
Record(ctx context.Context, entry *domain.BuildAuditEntry) error
|
||||
|
||||
// Update modifies an existing entry when a build completes.
|
||||
Update(ctx context.Context, taskID string, result *domain.BuildResult) error
|
||||
|
||||
// Get retrieves a specific audit entry by task ID.
|
||||
// Returns ErrBuildNotFound if the entry does not exist.
|
||||
Get(ctx context.Context, taskID string) (*domain.BuildAuditEntry, error)
|
||||
|
||||
// List returns audit entries matching the filter.
|
||||
List(ctx context.Context, filter BuildAuditFilter) ([]*domain.BuildAuditEntry, error)
|
||||
}
|
||||
|
||||
// BuildAuditFilter specifies criteria for listing audit entries.
|
||||
type BuildAuditFilter struct {
|
||||
// ProjectID filters entries by project.
|
||||
ProjectID string
|
||||
|
||||
// WorkerID filters entries by worker.
|
||||
WorkerID string
|
||||
|
||||
// Status filters entries by build status. Nil means all statuses.
|
||||
Status *domain.BuildStatus
|
||||
|
||||
// Since filters entries created after this time.
|
||||
Since time.Time
|
||||
|
||||
// Limit is the maximum number of entries to return.
|
||||
Limit int
|
||||
}
|
||||
@ -29,4 +29,10 @@ type CIProvider interface {
|
||||
|
||||
// DeleteSecret removes a secret from a repository.
|
||||
DeleteSecret(ctx context.Context, owner, repo, secretName string) error
|
||||
|
||||
// ListPipelines returns recent CI pipeline executions for a repository.
|
||||
ListPipelines(ctx context.Context, owner, repo string) ([]*domain.CIPipeline, error)
|
||||
|
||||
// GetPipeline returns a specific pipeline execution by number.
|
||||
GetPipeline(ctx context.Context, owner, repo string, number int64) (*domain.CIPipeline, error)
|
||||
}
|
||||
|
||||
@ -49,6 +49,10 @@ type CodeAgentRegistry interface {
|
||||
// Returns nil if no agents are registered.
|
||||
Default() CodeAgent
|
||||
|
||||
// DefaultProvider returns the current default provider.
|
||||
// Returns empty string if no agents are registered.
|
||||
DefaultProvider() domain.AgentProvider
|
||||
|
||||
// SetDefault sets which provider should be used as the default.
|
||||
// Returns error if the provider is not registered.
|
||||
SetDefault(provider domain.AgentProvider) error
|
||||
@ -58,4 +62,7 @@ type CodeAgentRegistry interface {
|
||||
|
||||
// AvailableAgents returns all registered agents that are currently available.
|
||||
AvailableAgents(ctx context.Context) []CodeAgent
|
||||
|
||||
// Count returns the number of registered agents.
|
||||
Count() int
|
||||
}
|
||||
|
||||
@ -13,13 +13,15 @@ type CredentialStore interface {
|
||||
// Get retrieves a credential by key. Returns empty string if not found.
|
||||
Get(ctx context.Context, key string) (string, error)
|
||||
|
||||
// GetRequired retrieves a credential by key. Returns error if not found.
|
||||
// GetRequired retrieves a credential by key.
|
||||
// Returns domain.ErrCredentialNotFound if the key does not exist.
|
||||
GetRequired(ctx context.Context, key string) (string, error)
|
||||
|
||||
// Set stores or updates a credential.
|
||||
Set(ctx context.Context, cred domain.Credential) error
|
||||
|
||||
// Delete removes a credential by key.
|
||||
// Returns domain.ErrCredentialNotFound if the key does not exist.
|
||||
Delete(ctx context.Context, key string) error
|
||||
|
||||
// List returns all credentials (with values masked).
|
||||
|
||||
15
internal/port/health.go
Normal file
15
internal/port/health.go
Normal file
@ -0,0 +1,15 @@
|
||||
package port
|
||||
|
||||
import "context"
|
||||
|
||||
// DatabasePinger checks database connectivity.
|
||||
// *sql.DB satisfies this interface.
|
||||
type DatabasePinger interface {
|
||||
PingContext(ctx context.Context) error
|
||||
}
|
||||
|
||||
// KubernetesChecker checks Kubernetes API connectivity.
|
||||
type KubernetesChecker interface {
|
||||
// ServerVersion returns the server version string, or an error if unreachable.
|
||||
ServerVersion() (string, error)
|
||||
}
|
||||
@ -4,6 +4,8 @@ package port
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
)
|
||||
|
||||
// WorkQueue defines operations for the worker pool task queue.
|
||||
@ -12,15 +14,15 @@ import (
|
||||
type WorkQueue interface {
|
||||
// Enqueue adds a task to the queue.
|
||||
// Returns the task ID.
|
||||
Enqueue(ctx context.Context, task *WorkTask) (string, error)
|
||||
Enqueue(ctx context.Context, task *domain.WorkTask) (string, error)
|
||||
|
||||
// Dequeue atomically claims the next available task for a worker.
|
||||
// Uses FOR UPDATE SKIP LOCKED for concurrent worker safety.
|
||||
// Returns nil if no tasks are available.
|
||||
Dequeue(ctx context.Context, workerID string) (*WorkTask, error)
|
||||
Dequeue(ctx context.Context, workerID string) (*domain.WorkTask, error)
|
||||
|
||||
// Complete marks a task as successfully completed with results.
|
||||
Complete(ctx context.Context, taskID string, result *WorkResult) error
|
||||
Complete(ctx context.Context, taskID string, result *domain.WorkResult) error
|
||||
|
||||
// Fail marks a task as failed with an error message.
|
||||
// If retry_count < max_retries, the task will be re-queued as pending.
|
||||
@ -31,13 +33,13 @@ type WorkQueue interface {
|
||||
Cancel(ctx context.Context, taskID string) error
|
||||
|
||||
// GetTask retrieves a task by ID.
|
||||
GetTask(ctx context.Context, taskID string) (*WorkTask, error)
|
||||
GetTask(ctx context.Context, taskID string) (*domain.WorkTask, error)
|
||||
|
||||
// ListByProject returns tasks for a project with optional status filter and pagination.
|
||||
ListByProject(ctx context.Context, projectID string, status *WorkTaskStatus, opts WorkListOptions) (*WorkListResult, error)
|
||||
ListByProject(ctx context.Context, projectID string, status *domain.WorkTaskStatus, opts domain.WorkListOptions) (*domain.WorkListResult, error)
|
||||
|
||||
// GetStats returns queue statistics.
|
||||
GetStats(ctx context.Context) (*WorkQueueStats, error)
|
||||
GetStats(ctx context.Context) (*domain.WorkQueueStats, error)
|
||||
|
||||
// CleanupOld removes completed/failed/cancelled tasks older than the specified duration.
|
||||
CleanupOld(ctx context.Context, olderThan time.Duration) (int64, error)
|
||||
@ -46,170 +48,3 @@ type WorkQueue interface {
|
||||
// This handles workers that crashed without reporting completion.
|
||||
RequeueStale(ctx context.Context, timeout time.Duration) (int64, error)
|
||||
}
|
||||
|
||||
// WorkTaskStatus represents the status of a work task.
|
||||
type WorkTaskStatus string
|
||||
|
||||
const (
|
||||
WorkTaskStatusPending WorkTaskStatus = "pending"
|
||||
WorkTaskStatusRunning WorkTaskStatus = "running"
|
||||
WorkTaskStatusCompleted WorkTaskStatus = "completed"
|
||||
WorkTaskStatusFailed WorkTaskStatus = "failed"
|
||||
WorkTaskStatusCancelled WorkTaskStatus = "cancelled"
|
||||
)
|
||||
|
||||
// IsValid returns true if the status is a known valid status.
|
||||
func (s WorkTaskStatus) IsValid() bool {
|
||||
switch s {
|
||||
case WorkTaskStatusPending, WorkTaskStatusRunning, WorkTaskStatusCompleted,
|
||||
WorkTaskStatusFailed, WorkTaskStatusCancelled:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// WorkTaskType represents the type of work task.
|
||||
type WorkTaskType string
|
||||
|
||||
const (
|
||||
WorkTaskTypeBuild WorkTaskType = "build"
|
||||
WorkTaskTypeTest WorkTaskType = "test"
|
||||
WorkTaskTypeDeploy WorkTaskType = "deploy"
|
||||
WorkTaskTypeCustom WorkTaskType = "custom"
|
||||
)
|
||||
|
||||
// IsValid returns true if the task type is a known valid type.
|
||||
func (t WorkTaskType) IsValid() bool {
|
||||
switch t {
|
||||
case WorkTaskTypeBuild, WorkTaskTypeTest, WorkTaskTypeDeploy, WorkTaskTypeCustom:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// WorkTask represents a task in the work queue.
|
||||
type WorkTask struct {
|
||||
// ID is the unique task identifier.
|
||||
ID string
|
||||
|
||||
// ProjectID is the project this task belongs to.
|
||||
ProjectID string
|
||||
|
||||
// Type is the task type (build, test, deploy, custom).
|
||||
Type WorkTaskType
|
||||
|
||||
// Spec contains task-specific parameters.
|
||||
// For build tasks: template, prompt, variables, auto_deploy, git_url
|
||||
// For test tasks: test_command, git_url
|
||||
// For deploy tasks: image, replicas, env
|
||||
Spec map[string]any
|
||||
|
||||
// Status is the current task status.
|
||||
Status WorkTaskStatus
|
||||
|
||||
// Priority determines execution order (higher = more urgent).
|
||||
Priority int
|
||||
|
||||
// WorkerID is the ID of the worker that claimed this task.
|
||||
WorkerID string
|
||||
|
||||
// CallbackURL is the webhook URL for completion notification.
|
||||
CallbackURL string
|
||||
|
||||
// CreatedAt is when the task was created.
|
||||
CreatedAt time.Time
|
||||
|
||||
// StartedAt is when a worker started executing the task.
|
||||
StartedAt *time.Time
|
||||
|
||||
// CompletedAt is when the task finished (success or failure).
|
||||
CompletedAt *time.Time
|
||||
|
||||
// Result contains the task output (if completed).
|
||||
Result *WorkResult
|
||||
|
||||
// Error contains the error message (if failed).
|
||||
Error string
|
||||
|
||||
// RetryCount is the number of retry attempts.
|
||||
RetryCount int
|
||||
|
||||
// MaxRetries is the maximum allowed retry attempts.
|
||||
MaxRetries int
|
||||
}
|
||||
|
||||
// WorkResult contains the result of a completed task.
|
||||
type WorkResult struct {
|
||||
// Output is the main output from task execution.
|
||||
Output string `json:"output,omitempty"`
|
||||
|
||||
// Artifacts contains named artifacts from the task.
|
||||
// For build tasks: commit_sha, deploy_url, etc.
|
||||
Artifacts map[string]string `json:"artifacts,omitempty"`
|
||||
}
|
||||
|
||||
// WorkQueueStats contains queue statistics.
|
||||
type WorkQueueStats struct {
|
||||
// Pending is the count of pending tasks.
|
||||
Pending int64 `json:"pending"`
|
||||
|
||||
// Running is the count of running tasks.
|
||||
Running int64 `json:"running"`
|
||||
|
||||
// Completed is the count of completed tasks (last 24h).
|
||||
Completed int64 `json:"completed"`
|
||||
|
||||
// Failed is the count of failed tasks (last 24h).
|
||||
Failed int64 `json:"failed"`
|
||||
|
||||
// Cancelled is the count of cancelled tasks (last 24h).
|
||||
Cancelled int64 `json:"cancelled"`
|
||||
|
||||
// OldestPending is the age of the oldest pending task.
|
||||
OldestPending *time.Duration `json:"oldest_pending,omitempty"`
|
||||
}
|
||||
|
||||
// WorkListOptions contains pagination options for listing tasks.
|
||||
type WorkListOptions struct {
|
||||
// Limit is the maximum number of tasks to return (default: 50, max: 100).
|
||||
Limit int
|
||||
|
||||
// Offset is the number of tasks to skip (for pagination).
|
||||
Offset int
|
||||
}
|
||||
|
||||
// DefaultWorkListOptions returns options with default values.
|
||||
func DefaultWorkListOptions() WorkListOptions {
|
||||
return WorkListOptions{
|
||||
Limit: 50,
|
||||
Offset: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize applies defaults and limits to the options.
|
||||
func (o *WorkListOptions) Normalize() {
|
||||
if o.Limit <= 0 {
|
||||
o.Limit = 50
|
||||
}
|
||||
if o.Limit > 100 {
|
||||
o.Limit = 100
|
||||
}
|
||||
if o.Offset < 0 {
|
||||
o.Offset = 0
|
||||
}
|
||||
}
|
||||
|
||||
// WorkListResult contains paginated task results.
|
||||
type WorkListResult struct {
|
||||
// Tasks is the list of tasks.
|
||||
Tasks []*WorkTask
|
||||
|
||||
// Total is the total count of matching tasks (for pagination metadata).
|
||||
Total int64
|
||||
|
||||
// Limit is the limit that was applied.
|
||||
Limit int
|
||||
|
||||
// Offset is the offset that was applied.
|
||||
Offset int
|
||||
}
|
||||
|
||||
51
internal/port/worker_registry.go
Normal file
51
internal/port/worker_registry.go
Normal file
@ -0,0 +1,51 @@
|
||||
// Package port defines interface contracts for external adapters.
|
||||
package port
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
)
|
||||
|
||||
// WorkerRegistry manages the lifecycle of workers in the pool.
|
||||
// It handles registration, heartbeats, status updates, and health monitoring.
|
||||
type WorkerRegistry interface {
|
||||
// Register adds a worker to the pool.
|
||||
// If a worker with the same ID already exists, it is re-registered.
|
||||
Register(ctx context.Context, worker *domain.Worker) error
|
||||
|
||||
// Heartbeat updates the worker's last_heartbeat timestamp.
|
||||
// Returns ErrWorkerNotFound if the worker does not exist or is offline.
|
||||
Heartbeat(ctx context.Context, workerID string) error
|
||||
|
||||
// UpdateStatus changes a worker's status and optionally assigns a task.
|
||||
// Pass empty taskID to clear the current task assignment.
|
||||
UpdateStatus(ctx context.Context, workerID string, status domain.WorkerStatus, taskID string) error
|
||||
|
||||
// Deregister removes a worker from the pool.
|
||||
Deregister(ctx context.Context, workerID string) error
|
||||
|
||||
// Get retrieves a specific worker by ID.
|
||||
// Returns ErrWorkerNotFound if the worker does not exist.
|
||||
Get(ctx context.Context, workerID string) (*domain.Worker, error)
|
||||
|
||||
// List returns all workers matching the filter.
|
||||
List(ctx context.Context, filter WorkerFilter) ([]*domain.Worker, error)
|
||||
|
||||
// MarkStaleOffline marks workers without a recent heartbeat as offline.
|
||||
// Returns the number of workers marked offline.
|
||||
MarkStaleOffline(ctx context.Context, threshold time.Duration) (int, error)
|
||||
}
|
||||
|
||||
// WorkerFilter specifies criteria for listing workers.
|
||||
type WorkerFilter struct {
|
||||
// Status filters workers by status. Nil means all statuses.
|
||||
Status *domain.WorkerStatus
|
||||
|
||||
// HasCapability filters workers that have a specific capability.
|
||||
HasCapability string
|
||||
|
||||
// Limit is the maximum number of workers to return. Zero means no limit.
|
||||
Limit int
|
||||
}
|
||||
@ -158,7 +158,7 @@ func (l *Limiter) getKey(r *http.Request) string {
|
||||
// Default: use API key ID from context
|
||||
// This requires the auth middleware to run first
|
||||
if apiKey := auth.GetAPIKey(r.Context()); apiKey != nil {
|
||||
return apiKey.ID
|
||||
return string(apiKey.ID)
|
||||
}
|
||||
|
||||
// Fallback: use client IP
|
||||
@ -258,7 +258,7 @@ func itoa(i int) string {
|
||||
func KeyFromAPIKey() func(*http.Request) string {
|
||||
return func(r *http.Request) string {
|
||||
if apiKey := auth.GetAPIKey(r.Context()); apiKey != nil {
|
||||
return apiKey.ID
|
||||
return string(apiKey.ID)
|
||||
}
|
||||
return getClientIP(r)
|
||||
}
|
||||
|
||||
@ -31,6 +31,7 @@ type CreateKeyRequest struct {
|
||||
Name string
|
||||
Scopes []domain.Scope
|
||||
ProjectIDs []domain.ProjectID
|
||||
AllowedIPs []string // CIDR notation; nil = no restriction
|
||||
ExpiresIn time.Duration
|
||||
CreatedBy string
|
||||
}
|
||||
@ -43,14 +44,26 @@ type CreateKeyResult struct {
|
||||
|
||||
// Create generates a new API key.
|
||||
func (s *APIKeyService) Create(ctx context.Context, req CreateKeyRequest) (*CreateKeyResult, error) {
|
||||
// Generate secret
|
||||
secret, err := generateSecret()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate secret: %w", err)
|
||||
// Generate key components using auth-compatible format:
|
||||
// identifier: 4 random bytes → 8 hex chars
|
||||
// random: 16 random bytes → 32 hex chars
|
||||
// Full key: rdev_sk_<identifier>_<random>
|
||||
idBytes := make([]byte, 4)
|
||||
if _, err := rand.Read(idBytes); err != nil {
|
||||
return nil, fmt.Errorf("generate identifier: %w", err)
|
||||
}
|
||||
identifier := hex.EncodeToString(idBytes)
|
||||
|
||||
// Hash the secret
|
||||
keyHash := hashKey(secret)
|
||||
randomBytes := make([]byte, 16)
|
||||
if _, err := rand.Read(randomBytes); err != nil {
|
||||
return nil, fmt.Errorf("generate random: %w", err)
|
||||
}
|
||||
random := hex.EncodeToString(randomBytes)
|
||||
|
||||
fullKey := fmt.Sprintf("rdev_sk_%s_%s", identifier, random)
|
||||
|
||||
// Hash the full key (what the user receives and sends back for auth)
|
||||
keyHash := hashKey(fullKey)
|
||||
|
||||
// Calculate expiration
|
||||
var expiresAt *time.Time
|
||||
@ -62,9 +75,10 @@ func (s *APIKeyService) Create(ctx context.Context, req CreateKeyRequest) (*Crea
|
||||
// Create key
|
||||
key := &domain.APIKey{
|
||||
Name: req.Name,
|
||||
KeyPrefix: secret[:8],
|
||||
KeyPrefix: identifier,
|
||||
Scopes: req.Scopes,
|
||||
ProjectIDs: req.ProjectIDs,
|
||||
AllowedIPs: req.AllowedIPs,
|
||||
ExpiresAt: expiresAt,
|
||||
CreatedBy: req.CreatedBy,
|
||||
}
|
||||
@ -75,7 +89,7 @@ func (s *APIKeyService) Create(ctx context.Context, req CreateKeyRequest) (*Crea
|
||||
|
||||
return &CreateKeyResult{
|
||||
Key: key,
|
||||
Secret: formatSecret(key.KeyPrefix, secret),
|
||||
Secret: fullKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -105,6 +119,42 @@ func (s *APIKeyService) UpdateLastUsed(ctx context.Context, id domain.APIKeyID)
|
||||
return s.repo.UpdateLastUsed(ctx, id)
|
||||
}
|
||||
|
||||
// Validate checks a raw API key and returns the associated APIKey if valid.
|
||||
// It checks for admin key, looks up by hash, and verifies the key is active.
|
||||
// On success it asynchronously updates the last-used timestamp.
|
||||
func (s *APIKeyService) Validate(ctx context.Context, rawKey string) (*domain.APIKey, error) {
|
||||
// Check admin key first
|
||||
if s.adminKey != "" && rawKey == s.adminKey {
|
||||
return &domain.APIKey{
|
||||
ID: "admin",
|
||||
Name: "Super Admin",
|
||||
KeyPrefix: "admin",
|
||||
Scopes: []domain.Scope{domain.ScopeAdmin},
|
||||
}, nil
|
||||
}
|
||||
|
||||
keyHash := hashKey(rawKey)
|
||||
|
||||
apiKey, err := s.repo.GetByHash(ctx, keyHash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if apiKey.IsRevoked() {
|
||||
return nil, domain.ErrKeyRevoked
|
||||
}
|
||||
if apiKey.IsExpired() {
|
||||
return nil, domain.ErrKeyExpired
|
||||
}
|
||||
|
||||
// Update last_used_at asynchronously
|
||||
go func() {
|
||||
_ = s.repo.UpdateLastUsed(context.Background(), apiKey.ID)
|
||||
}()
|
||||
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
// ValidateAdminKey checks if the provided key matches the admin key.
|
||||
func (s *APIKeyService) ValidateAdminKey(key string) bool {
|
||||
return s.adminKey != "" && key == s.adminKey
|
||||
@ -115,26 +165,12 @@ func (s *APIKeyService) AdminKey() string {
|
||||
return s.adminKey
|
||||
}
|
||||
|
||||
// generateSecret creates a cryptographically secure random key.
|
||||
func generateSecret() (string, error) {
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// hashKey creates a SHA-256 hash of a key.
|
||||
func hashKey(key string) string {
|
||||
hash := sha256.Sum256([]byte(key))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// formatSecret creates the full secret string with prefix.
|
||||
func formatSecret(prefix, secret string) string {
|
||||
return fmt.Sprintf("rdev_sk_%s_%s", prefix, secret[8:])
|
||||
}
|
||||
|
||||
// ParseExpiration converts a duration string to time.Duration.
|
||||
// Supported formats: "30d", "60d", "90d", "1y", "never" (or empty)
|
||||
func ParseExpiration(s string) (time.Duration, error) {
|
||||
|
||||
@ -319,28 +319,6 @@ func TestParseExpiration(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSecret(t *testing.T) {
|
||||
secrets := make(map[string]bool)
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
secret, err := generateSecret()
|
||||
if err != nil {
|
||||
t.Fatalf("generateSecret() error = %v", err)
|
||||
}
|
||||
|
||||
// Should be 64 hex characters (32 bytes)
|
||||
if len(secret) != 64 {
|
||||
t.Errorf("Secret length = %d, want 64", len(secret))
|
||||
}
|
||||
|
||||
// Should be unique
|
||||
if secrets[secret] {
|
||||
t.Errorf("Duplicate secret generated: %q", secret)
|
||||
}
|
||||
secrets[secret] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashKey(t *testing.T) {
|
||||
// Same input should produce same hash
|
||||
hash1 := hashKey("test-key")
|
||||
@ -360,12 +338,3 @@ func TestHashKey(t *testing.T) {
|
||||
t.Errorf("Hash length = %d, want 64", len(hash1))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatSecret(t *testing.T) {
|
||||
result := formatSecret("abcd1234", "abcd12345678rest")
|
||||
expected := "rdev_sk_abcd1234_5678rest"
|
||||
|
||||
if result != expected {
|
||||
t.Errorf("formatSecret() = %q, want %q", result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
132
internal/service/build_service.go
Normal file
132
internal/service/build_service.go
Normal file
@ -0,0 +1,132 @@
|
||||
// Package service provides business logic services.
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
)
|
||||
|
||||
// BuildService orchestrates build task submission and tracking.
|
||||
// It coordinates between the work queue (execution) and build audit (history).
|
||||
type BuildService struct {
|
||||
queue port.WorkQueue
|
||||
audit port.BuildAudit
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewBuildService creates a new build service.
|
||||
func NewBuildService(
|
||||
queue port.WorkQueue,
|
||||
audit port.BuildAudit,
|
||||
logger *slog.Logger,
|
||||
) *BuildService {
|
||||
if logger == nil {
|
||||
logger = slog.Default()
|
||||
}
|
||||
return &BuildService{
|
||||
queue: queue,
|
||||
audit: audit,
|
||||
logger: logger.With("service", "build"),
|
||||
}
|
||||
}
|
||||
|
||||
// StartBuild enqueues a build task and creates an audit entry.
|
||||
// Returns the task ID for status tracking.
|
||||
func (s *BuildService) StartBuild(ctx context.Context, projectID string, spec domain.BuildSpec) (string, error) {
|
||||
if err := spec.Validate(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if projectID == "" {
|
||||
return "", fmt.Errorf("project_id is required")
|
||||
}
|
||||
|
||||
// Build work task spec from build spec
|
||||
taskSpec := map[string]any{
|
||||
"prompt": spec.Prompt,
|
||||
"auto_commit": spec.AutoCommit,
|
||||
"auto_push": spec.AutoPush,
|
||||
}
|
||||
if spec.Template != "" {
|
||||
taskSpec["template"] = spec.Template
|
||||
}
|
||||
if len(spec.Variables) > 0 {
|
||||
taskSpec["variables"] = spec.Variables
|
||||
}
|
||||
|
||||
// Create work task
|
||||
task := &domain.WorkTask{
|
||||
ProjectID: projectID,
|
||||
Type: domain.WorkTaskTypeBuild,
|
||||
Spec: taskSpec,
|
||||
CallbackURL: spec.CallbackURL,
|
||||
MaxRetries: 3,
|
||||
}
|
||||
|
||||
// Enqueue to work queue
|
||||
taskID, err := s.queue.Enqueue(ctx, task)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("enqueue build task: %w", err)
|
||||
}
|
||||
|
||||
// Create audit entry (non-critical - don't fail the build if audit fails)
|
||||
auditEntry := &domain.BuildAuditEntry{
|
||||
TaskID: taskID,
|
||||
ProjectID: projectID,
|
||||
Spec: spec,
|
||||
Status: domain.BuildStatusPending,
|
||||
StartedAt: time.Now(),
|
||||
}
|
||||
if err := s.audit.Record(ctx, auditEntry); err != nil {
|
||||
s.logger.Warn("failed to record audit entry",
|
||||
"task_id", taskID,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
|
||||
s.logger.Info("build enqueued",
|
||||
"task_id", taskID,
|
||||
"project_id", projectID,
|
||||
"template", spec.Template,
|
||||
"auto_push", spec.AutoPush,
|
||||
)
|
||||
|
||||
return taskID, nil
|
||||
}
|
||||
|
||||
// GetBuildStatus returns the current status of a build.
|
||||
func (s *BuildService) GetBuildStatus(ctx context.Context, taskID string) (*domain.BuildAuditEntry, error) {
|
||||
return s.audit.Get(ctx, taskID)
|
||||
}
|
||||
|
||||
// ListBuilds returns build history for a project.
|
||||
func (s *BuildService) ListBuilds(ctx context.Context, projectID string, limit int) ([]*domain.BuildAuditEntry, error) {
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
return s.audit.List(ctx, port.BuildAuditFilter{
|
||||
ProjectID: projectID,
|
||||
Limit: limit,
|
||||
})
|
||||
}
|
||||
|
||||
// CompleteBuild updates the audit entry when a build finishes.
|
||||
// Called by the work queue processor on task completion.
|
||||
func (s *BuildService) CompleteBuild(ctx context.Context, taskID string, result *domain.BuildResult) error {
|
||||
if err := s.audit.Update(ctx, taskID, result); err != nil {
|
||||
return fmt.Errorf("update audit: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info("build completed",
|
||||
"task_id", taskID,
|
||||
"success", result.Success,
|
||||
"duration_ms", result.DurationMs,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
215
internal/service/build_service_test.go
Normal file
215
internal/service/build_service_test.go
Normal file
@ -0,0 +1,215 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
)
|
||||
|
||||
func TestBuildService_StartBuild(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("enqueues build successfully", func(t *testing.T) {
|
||||
queue := newMockWorkQueue()
|
||||
audit := newMockBuildAudit()
|
||||
svc := NewBuildService(queue, audit, nil)
|
||||
|
||||
taskID, err := svc.StartBuild(ctx, "project-1", domain.BuildSpec{
|
||||
Prompt: "Build a landing page",
|
||||
Template: "nextjs",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("StartBuild() error = %v", err)
|
||||
}
|
||||
if taskID == "" {
|
||||
t.Error("expected non-empty task ID")
|
||||
}
|
||||
|
||||
// Verify task was enqueued
|
||||
if len(queue.tasks) != 1 {
|
||||
t.Errorf("expected 1 task in queue, got %d", len(queue.tasks))
|
||||
}
|
||||
task := queue.tasks[taskID]
|
||||
if task.ProjectID != "project-1" {
|
||||
t.Errorf("got project_id %q, want %q", task.ProjectID, "project-1")
|
||||
}
|
||||
if task.Type != domain.WorkTaskTypeBuild {
|
||||
t.Errorf("got type %q, want %q", task.Type, domain.WorkTaskTypeBuild)
|
||||
}
|
||||
|
||||
// Verify audit was recorded
|
||||
if len(audit.entries) != 1 {
|
||||
t.Errorf("expected 1 audit entry, got %d", len(audit.entries))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("validates prompt required", func(t *testing.T) {
|
||||
queue := newMockWorkQueue()
|
||||
audit := newMockBuildAudit()
|
||||
svc := NewBuildService(queue, audit, nil)
|
||||
|
||||
_, err := svc.StartBuild(ctx, "project-1", domain.BuildSpec{})
|
||||
if err == nil {
|
||||
t.Error("expected error for empty prompt")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("validates project ID required", func(t *testing.T) {
|
||||
queue := newMockWorkQueue()
|
||||
audit := newMockBuildAudit()
|
||||
svc := NewBuildService(queue, audit, nil)
|
||||
|
||||
_, err := svc.StartBuild(ctx, "", domain.BuildSpec{Prompt: "Build"})
|
||||
if err == nil {
|
||||
t.Error("expected error for empty project ID")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("includes variables in spec", func(t *testing.T) {
|
||||
queue := newMockWorkQueue()
|
||||
audit := newMockBuildAudit()
|
||||
svc := NewBuildService(queue, audit, nil)
|
||||
|
||||
taskID, err := svc.StartBuild(ctx, "project-1", domain.BuildSpec{
|
||||
Prompt: "Build",
|
||||
Variables: map[string]string{
|
||||
"name": "My App",
|
||||
"color": "blue",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("StartBuild() error = %v", err)
|
||||
}
|
||||
|
||||
task := queue.tasks[taskID]
|
||||
vars, ok := task.Spec["variables"].(map[string]string)
|
||||
if !ok {
|
||||
t.Fatal("expected variables in task spec")
|
||||
}
|
||||
if vars["name"] != "My App" {
|
||||
t.Errorf("got variable name %q, want %q", vars["name"], "My App")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("continues if audit fails", func(t *testing.T) {
|
||||
queue := newMockWorkQueue()
|
||||
audit := newMockBuildAudit()
|
||||
audit.err = fmt.Errorf("db connection failed")
|
||||
svc := NewBuildService(queue, audit, nil)
|
||||
|
||||
taskID, err := svc.StartBuild(ctx, "project-1", domain.BuildSpec{
|
||||
Prompt: "Build",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("StartBuild() should succeed even if audit fails, got error = %v", err)
|
||||
}
|
||||
if taskID == "" {
|
||||
t.Error("expected non-empty task ID")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildService_GetBuildStatus(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("returns existing entry", func(t *testing.T) {
|
||||
audit := newMockBuildAudit()
|
||||
audit.entries["task-1"] = &domain.BuildAuditEntry{
|
||||
TaskID: "task-1",
|
||||
ProjectID: "project-1",
|
||||
Status: domain.BuildStatusRunning,
|
||||
}
|
||||
svc := NewBuildService(newMockWorkQueue(), audit, nil)
|
||||
|
||||
entry, err := svc.GetBuildStatus(ctx, "task-1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetBuildStatus() error = %v", err)
|
||||
}
|
||||
if entry.Status != domain.BuildStatusRunning {
|
||||
t.Errorf("got status %q, want %q", entry.Status, domain.BuildStatusRunning)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns error for nonexistent entry", func(t *testing.T) {
|
||||
audit := newMockBuildAudit()
|
||||
svc := NewBuildService(newMockWorkQueue(), audit, nil)
|
||||
|
||||
_, err := svc.GetBuildStatus(ctx, "nonexistent")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent entry")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildService_ListBuilds(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
audit := newMockBuildAudit()
|
||||
audit.entries["task-1"] = &domain.BuildAuditEntry{
|
||||
TaskID: "task-1", ProjectID: "project-a", Status: domain.BuildStatusCompleted,
|
||||
}
|
||||
audit.entries["task-2"] = &domain.BuildAuditEntry{
|
||||
TaskID: "task-2", ProjectID: "project-a", Status: domain.BuildStatusFailed,
|
||||
}
|
||||
audit.entries["task-3"] = &domain.BuildAuditEntry{
|
||||
TaskID: "task-3", ProjectID: "project-b", Status: domain.BuildStatusPending,
|
||||
}
|
||||
|
||||
svc := NewBuildService(newMockWorkQueue(), audit, nil)
|
||||
|
||||
t.Run("lists builds for project", func(t *testing.T) {
|
||||
entries, err := svc.ListBuilds(ctx, "project-a", 50)
|
||||
if err != nil {
|
||||
t.Fatalf("ListBuilds() error = %v", err)
|
||||
}
|
||||
if len(entries) != 2 {
|
||||
t.Errorf("got %d entries, want 2", len(entries))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses default limit", func(t *testing.T) {
|
||||
entries, err := svc.ListBuilds(ctx, "project-a", 0)
|
||||
if err != nil {
|
||||
t.Fatalf("ListBuilds() error = %v", err)
|
||||
}
|
||||
if len(entries) != 2 {
|
||||
t.Errorf("got %d entries, want 2", len(entries))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildService_CompleteBuild(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("updates audit on completion", func(t *testing.T) {
|
||||
audit := newMockBuildAudit()
|
||||
audit.entries["task-1"] = &domain.BuildAuditEntry{
|
||||
TaskID: "task-1",
|
||||
ProjectID: "project-1",
|
||||
Status: domain.BuildStatusRunning,
|
||||
}
|
||||
svc := NewBuildService(newMockWorkQueue(), audit, nil)
|
||||
|
||||
err := svc.CompleteBuild(ctx, "task-1", &domain.BuildResult{
|
||||
Success: true,
|
||||
CommitSHA: "abc123",
|
||||
DurationMs: 5000,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CompleteBuild() error = %v", err)
|
||||
}
|
||||
|
||||
entry := audit.entries["task-1"]
|
||||
if entry.Status != domain.BuildStatusCompleted {
|
||||
t.Errorf("got status %q, want %q", entry.Status, domain.BuildStatusCompleted)
|
||||
}
|
||||
if entry.Result == nil {
|
||||
t.Fatal("expected result to be set")
|
||||
}
|
||||
if !entry.Result.Success {
|
||||
t.Error("expected result.Success = true")
|
||||
}
|
||||
})
|
||||
}
|
||||
256
internal/service/mock_test.go
Normal file
256
internal/service/mock_test.go
Normal file
@ -0,0 +1,256 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
)
|
||||
|
||||
// mockWorkQueue implements port.WorkQueue for service tests.
|
||||
// Configure tasks and err fields to control behavior.
|
||||
type mockWorkQueue struct {
|
||||
tasks map[string]*domain.WorkTask
|
||||
err error
|
||||
}
|
||||
|
||||
func newMockWorkQueue() *mockWorkQueue {
|
||||
return &mockWorkQueue{tasks: make(map[string]*domain.WorkTask)}
|
||||
}
|
||||
|
||||
func (m *mockWorkQueue) Enqueue(ctx context.Context, task *domain.WorkTask) (string, error) {
|
||||
if m.err != nil {
|
||||
return "", m.err
|
||||
}
|
||||
id := fmt.Sprintf("task-%d", len(m.tasks)+1)
|
||||
task.ID = id
|
||||
task.Status = domain.WorkTaskStatusPending
|
||||
task.CreatedAt = time.Now()
|
||||
m.tasks[id] = task
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (m *mockWorkQueue) Dequeue(ctx context.Context, workerID string) (*domain.WorkTask, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
for _, task := range m.tasks {
|
||||
if task.Status == domain.WorkTaskStatusPending {
|
||||
task.Status = domain.WorkTaskStatusRunning
|
||||
task.WorkerID = workerID
|
||||
now := time.Now()
|
||||
task.StartedAt = &now
|
||||
return task, nil
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockWorkQueue) Complete(ctx context.Context, taskID string, result *domain.WorkResult) error {
|
||||
if m.err != nil {
|
||||
return m.err
|
||||
}
|
||||
task, ok := m.tasks[taskID]
|
||||
if !ok {
|
||||
return domain.ErrWorkTaskNotFound
|
||||
}
|
||||
task.Status = domain.WorkTaskStatusCompleted
|
||||
task.Result = result
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockWorkQueue) Fail(ctx context.Context, taskID string, errMsg string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockWorkQueue) Cancel(ctx context.Context, taskID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockWorkQueue) GetTask(ctx context.Context, taskID string) (*domain.WorkTask, error) {
|
||||
task, ok := m.tasks[taskID]
|
||||
if !ok {
|
||||
return nil, domain.ErrWorkTaskNotFound
|
||||
}
|
||||
return task, nil
|
||||
}
|
||||
|
||||
func (m *mockWorkQueue) ListByProject(ctx context.Context, projectID string, status *domain.WorkTaskStatus, opts domain.WorkListOptions) (*domain.WorkListResult, error) {
|
||||
return &domain.WorkListResult{}, nil
|
||||
}
|
||||
|
||||
func (m *mockWorkQueue) GetStats(ctx context.Context) (*domain.WorkQueueStats, error) {
|
||||
return &domain.WorkQueueStats{}, nil
|
||||
}
|
||||
|
||||
func (m *mockWorkQueue) CleanupOld(ctx context.Context, olderThan time.Duration) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockWorkQueue) RequeueStale(ctx context.Context, timeout time.Duration) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// mockBuildAudit implements port.BuildAudit for service tests.
|
||||
type mockBuildAudit struct {
|
||||
entries map[string]*domain.BuildAuditEntry
|
||||
err error
|
||||
}
|
||||
|
||||
func newMockBuildAudit() *mockBuildAudit {
|
||||
return &mockBuildAudit{entries: make(map[string]*domain.BuildAuditEntry)}
|
||||
}
|
||||
|
||||
func (m *mockBuildAudit) Record(ctx context.Context, entry *domain.BuildAuditEntry) error {
|
||||
if m.err != nil {
|
||||
return m.err
|
||||
}
|
||||
m.entries[entry.TaskID] = entry
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockBuildAudit) Update(ctx context.Context, taskID string, result *domain.BuildResult) error {
|
||||
if m.err != nil {
|
||||
return m.err
|
||||
}
|
||||
entry, ok := m.entries[taskID]
|
||||
if !ok {
|
||||
return domain.ErrBuildNotFound
|
||||
}
|
||||
entry.Result = result
|
||||
if result.Success {
|
||||
entry.Status = domain.BuildStatusCompleted
|
||||
} else {
|
||||
entry.Status = domain.BuildStatusFailed
|
||||
}
|
||||
now := time.Now()
|
||||
entry.CompletedAt = &now
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockBuildAudit) Get(ctx context.Context, taskID string) (*domain.BuildAuditEntry, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
entry, ok := m.entries[taskID]
|
||||
if !ok {
|
||||
return nil, domain.ErrBuildNotFound
|
||||
}
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
func (m *mockBuildAudit) List(ctx context.Context, filter port.BuildAuditFilter) ([]*domain.BuildAuditEntry, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
var result []*domain.BuildAuditEntry
|
||||
for _, entry := range m.entries {
|
||||
if filter.ProjectID != "" && entry.ProjectID != filter.ProjectID {
|
||||
continue
|
||||
}
|
||||
result = append(result, entry)
|
||||
}
|
||||
if filter.Limit > 0 && len(result) > filter.Limit {
|
||||
result = result[:filter.Limit]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// mockWorkerRegistry implements port.WorkerRegistry for service tests.
|
||||
type mockWorkerRegistry struct {
|
||||
workers map[string]*domain.Worker
|
||||
err error
|
||||
}
|
||||
|
||||
func newMockWorkerRegistry() *mockWorkerRegistry {
|
||||
return &mockWorkerRegistry{workers: make(map[string]*domain.Worker)}
|
||||
}
|
||||
|
||||
func (m *mockWorkerRegistry) Register(ctx context.Context, worker *domain.Worker) error {
|
||||
if m.err != nil {
|
||||
return m.err
|
||||
}
|
||||
m.workers[worker.ID] = worker
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockWorkerRegistry) Heartbeat(ctx context.Context, workerID string) error {
|
||||
if m.err != nil {
|
||||
return m.err
|
||||
}
|
||||
w, ok := m.workers[workerID]
|
||||
if !ok {
|
||||
return domain.ErrWorkerNotFound
|
||||
}
|
||||
w.LastHeartbeat = time.Now()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockWorkerRegistry) UpdateStatus(ctx context.Context, workerID string, status domain.WorkerStatus, taskID string) error {
|
||||
if m.err != nil {
|
||||
return m.err
|
||||
}
|
||||
w, ok := m.workers[workerID]
|
||||
if !ok {
|
||||
return domain.ErrWorkerNotFound
|
||||
}
|
||||
w.Status = status
|
||||
w.CurrentTask = taskID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockWorkerRegistry) Deregister(ctx context.Context, workerID string) error {
|
||||
if m.err != nil {
|
||||
return m.err
|
||||
}
|
||||
if _, ok := m.workers[workerID]; !ok {
|
||||
return domain.ErrWorkerNotFound
|
||||
}
|
||||
delete(m.workers, workerID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockWorkerRegistry) Get(ctx context.Context, workerID string) (*domain.Worker, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
w, ok := m.workers[workerID]
|
||||
if !ok {
|
||||
return nil, domain.ErrWorkerNotFound
|
||||
}
|
||||
return w, nil
|
||||
}
|
||||
|
||||
func (m *mockWorkerRegistry) List(ctx context.Context, filter port.WorkerFilter) ([]*domain.Worker, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
var result []*domain.Worker
|
||||
for _, w := range m.workers {
|
||||
if filter.Status != nil && w.Status != *filter.Status {
|
||||
continue
|
||||
}
|
||||
result = append(result, w)
|
||||
}
|
||||
if filter.Limit > 0 && len(result) > filter.Limit {
|
||||
result = result[:filter.Limit]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockWorkerRegistry) MarkStaleOffline(ctx context.Context, threshold time.Duration) (int, error) {
|
||||
if m.err != nil {
|
||||
return 0, m.err
|
||||
}
|
||||
count := 0
|
||||
for _, w := range m.workers {
|
||||
if w.Status != domain.WorkerStatusOffline && time.Since(w.LastHeartbeat) > threshold {
|
||||
w.Status = domain.WorkerStatusOffline
|
||||
w.CurrentTask = ""
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
@ -4,49 +4,19 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
)
|
||||
|
||||
// projectNameRegex validates project names for DNS and K8s compatibility.
|
||||
// Must be lowercase, start with a letter, contain only letters, numbers, and dashes.
|
||||
var projectNameRegex = regexp.MustCompile(`^[a-z][a-z0-9-]*$`)
|
||||
|
||||
// reservedProjectNames are names that cannot be used for projects.
|
||||
var reservedProjectNames = map[string]bool{
|
||||
"www": true,
|
||||
"api": true,
|
||||
"git": true,
|
||||
"ci": true,
|
||||
"registry": true,
|
||||
"admin": true,
|
||||
"root": true,
|
||||
"rdev": true,
|
||||
"pantheon": true,
|
||||
}
|
||||
|
||||
// ValidateProjectName validates that a project name is safe for use as
|
||||
// a DNS subdomain, K8s resource name, and git repository name.
|
||||
// Delegates to domain.ValidateProjectName for centralized validation.
|
||||
func ValidateProjectName(name string) error {
|
||||
if name == "" {
|
||||
return errors.New("project name cannot be empty")
|
||||
}
|
||||
if len(name) > 63 {
|
||||
return errors.New("project name too long (max 63 characters)")
|
||||
}
|
||||
if !projectNameRegex.MatchString(name) {
|
||||
return errors.New("project name must be lowercase, start with a letter, and contain only letters, numbers, and dashes")
|
||||
}
|
||||
if reservedProjectNames[name] {
|
||||
return fmt.Errorf("'%s' is a reserved name", name)
|
||||
}
|
||||
return nil
|
||||
return domain.ValidateProjectName(name)
|
||||
}
|
||||
|
||||
// ProjectInfraService orchestrates project infrastructure operations.
|
||||
@ -136,7 +106,7 @@ type CreateProjectResult struct {
|
||||
func (s *ProjectInfraService) CreateProject(ctx context.Context, req CreateProjectRequest) (*CreateProjectResult, error) {
|
||||
// Validate project name first
|
||||
if err := ValidateProjectName(req.Name); err != nil {
|
||||
return nil, fmt.Errorf("%w: %v", domain.ErrInvalidProjectName, err)
|
||||
return nil, fmt.Errorf("%w: %w", domain.ErrInvalidProjectName, err)
|
||||
}
|
||||
|
||||
s.logger.Info("creating project", "name", req.Name)
|
||||
|
||||
@ -142,12 +142,12 @@ func (s *ProjectService) ExecuteClaude(ctx context.Context, req ExecuteClaudeReq
|
||||
return nil, fmt.Errorf("%w: prompt is required", domain.ErrInvalidCommand)
|
||||
}
|
||||
if err := sanitize.ClaudePrompt(req.Prompt); err != nil {
|
||||
return nil, fmt.Errorf("%w: %v", domain.ErrCommandSanitization, err)
|
||||
return nil, fmt.Errorf("%w: %w", domain.ErrCommandSanitization, err)
|
||||
}
|
||||
|
||||
// Validate stream ID
|
||||
if err := sanitize.StreamID(req.StreamID); err != nil {
|
||||
return nil, fmt.Errorf("%w: %v", domain.ErrInvalidCommand, err)
|
||||
return nil, fmt.Errorf("%w: %w", domain.ErrInvalidCommand, err)
|
||||
}
|
||||
|
||||
// Generate command ID
|
||||
|
||||
@ -65,6 +65,10 @@ func (s *ProjectService) executeAgentCommand(agent port.CodeAgent, req *domain.A
|
||||
"tool": event.ToolName,
|
||||
"input": event.ToolInput,
|
||||
}
|
||||
// Record tool use metric
|
||||
if event.ToolName != "" {
|
||||
metrics.RecordAgentToolUse(string(agent.Provider()), event.ToolName)
|
||||
}
|
||||
case domain.AgentEventToolResult:
|
||||
eventType = "tool_result"
|
||||
data = map[string]any{
|
||||
@ -112,6 +116,9 @@ func (s *ProjectService) executeAgentCommand(agent port.CodeAgent, req *domain.A
|
||||
}
|
||||
metrics.RecordCommand(string(cmd.ProjectID), string(cmd.Type), status, result.DurationMs)
|
||||
|
||||
// Record agent-specific metrics
|
||||
metrics.RecordAgentRequest(string(agent.Provider()), status, result.DurationMs)
|
||||
|
||||
// Log audit completion if audit logger is configured
|
||||
if s.auditLogger != nil {
|
||||
var auditStatus domain.AuditStatus
|
||||
|
||||
@ -40,12 +40,12 @@ func (s *ProjectService) ExecuteShell(ctx context.Context, req ExecuteShellReque
|
||||
return nil, fmt.Errorf("%w: command is required", domain.ErrInvalidCommand)
|
||||
}
|
||||
if err := sanitize.ShellCommand(req.Command); err != nil {
|
||||
return nil, fmt.Errorf("%w: %v", domain.ErrCommandSanitization, err)
|
||||
return nil, fmt.Errorf("%w: %w", domain.ErrCommandSanitization, err)
|
||||
}
|
||||
|
||||
// Validate stream ID
|
||||
if err := sanitize.StreamID(req.StreamID); err != nil {
|
||||
return nil, fmt.Errorf("%w: %v", domain.ErrInvalidCommand, err)
|
||||
return nil, fmt.Errorf("%w: %w", domain.ErrInvalidCommand, err)
|
||||
}
|
||||
|
||||
// Generate command ID
|
||||
@ -120,12 +120,12 @@ func (s *ProjectService) ExecuteGit(ctx context.Context, req ExecuteGitRequest)
|
||||
return nil, fmt.Errorf("%w: args is required", domain.ErrInvalidCommand)
|
||||
}
|
||||
if err := sanitize.GitArgs(req.Args); err != nil {
|
||||
return nil, fmt.Errorf("%w: %v", domain.ErrCommandSanitization, err)
|
||||
return nil, fmt.Errorf("%w: %w", domain.ErrCommandSanitization, err)
|
||||
}
|
||||
|
||||
// Validate stream ID
|
||||
if err := sanitize.StreamID(req.StreamID); err != nil {
|
||||
return nil, fmt.Errorf("%w: %v", domain.ErrInvalidCommand, err)
|
||||
return nil, fmt.Errorf("%w: %w", domain.ErrInvalidCommand, err)
|
||||
}
|
||||
|
||||
// Generate command ID
|
||||
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
"github.com/orchard9/rdev/internal/webhook"
|
||||
)
|
||||
@ -57,7 +58,7 @@ func (s *WorkService) EnqueueTask(ctx context.Context, req EnqueueTaskRequest) (
|
||||
maxRetries = 3
|
||||
}
|
||||
|
||||
task := &port.WorkTask{
|
||||
task := &domain.WorkTask{
|
||||
ProjectID: req.ProjectID,
|
||||
Type: req.Type,
|
||||
Spec: req.Spec,
|
||||
@ -85,7 +86,7 @@ func (s *WorkService) EnqueueTask(ctx context.Context, req EnqueueTaskRequest) (
|
||||
}
|
||||
|
||||
// DequeueTask claims the next available task for a worker.
|
||||
func (s *WorkService) DequeueTask(ctx context.Context, workerID string) (*port.WorkTask, error) {
|
||||
func (s *WorkService) DequeueTask(ctx context.Context, workerID string) (*domain.WorkTask, error) {
|
||||
if workerID == "" {
|
||||
return nil, fmt.Errorf("worker_id is required")
|
||||
}
|
||||
@ -108,7 +109,7 @@ func (s *WorkService) DequeueTask(ctx context.Context, workerID string) (*port.W
|
||||
}
|
||||
|
||||
// CompleteTask marks a task as successfully completed.
|
||||
func (s *WorkService) CompleteTask(ctx context.Context, taskID string, result *port.WorkResult) error {
|
||||
func (s *WorkService) CompleteTask(ctx context.Context, taskID string, result *domain.WorkResult) error {
|
||||
// Get task for callback URL before completing
|
||||
task, err := s.queue.GetTask(ctx, taskID)
|
||||
if err != nil {
|
||||
@ -147,7 +148,7 @@ func (s *WorkService) FailTask(ctx context.Context, taskID string, errMsg string
|
||||
|
||||
// Check if it was requeued or permanently failed
|
||||
updatedTask, _ := s.queue.GetTask(ctx, taskID)
|
||||
if updatedTask != nil && updatedTask.Status == port.WorkTaskStatusFailed {
|
||||
if updatedTask != nil && updatedTask.Status == domain.WorkTaskStatusFailed {
|
||||
s.logger.Warn("task failed permanently",
|
||||
"task_id", taskID,
|
||||
"project", task.ProjectID,
|
||||
@ -199,22 +200,22 @@ func (s *WorkService) CancelTask(ctx context.Context, taskID string) error {
|
||||
}
|
||||
|
||||
// GetTask retrieves a task by ID.
|
||||
func (s *WorkService) GetTask(ctx context.Context, taskID string) (*port.WorkTask, error) {
|
||||
func (s *WorkService) GetTask(ctx context.Context, taskID string) (*domain.WorkTask, error) {
|
||||
return s.queue.GetTask(ctx, taskID)
|
||||
}
|
||||
|
||||
// ListByProject returns tasks for a project with pagination.
|
||||
func (s *WorkService) ListByProject(ctx context.Context, projectID string, status *port.WorkTaskStatus, opts port.WorkListOptions) (*port.WorkListResult, error) {
|
||||
func (s *WorkService) ListByProject(ctx context.Context, projectID string, status *domain.WorkTaskStatus, opts domain.WorkListOptions) (*domain.WorkListResult, error) {
|
||||
return s.queue.ListByProject(ctx, projectID, status, opts)
|
||||
}
|
||||
|
||||
// GetStats returns queue statistics.
|
||||
func (s *WorkService) GetStats(ctx context.Context) (*port.WorkQueueStats, error) {
|
||||
func (s *WorkService) GetStats(ctx context.Context) (*domain.WorkQueueStats, error) {
|
||||
return s.queue.GetStats(ctx)
|
||||
}
|
||||
|
||||
// notifyCallback sends a webhook notification for task status changes.
|
||||
func (s *WorkService) notifyCallback(task *port.WorkTask, status string, result *port.WorkResult, errMsg string) {
|
||||
func (s *WorkService) notifyCallback(task *domain.WorkTask, status string, result *domain.WorkResult, errMsg string) {
|
||||
if s.webhookDispatcher == nil || task.CallbackURL == "" {
|
||||
return
|
||||
}
|
||||
@ -251,7 +252,7 @@ type EnqueueTaskRequest struct {
|
||||
ProjectID string `json:"project_id"`
|
||||
|
||||
// Type is the task type (build, test, deploy, custom).
|
||||
Type port.WorkTaskType `json:"task_type"`
|
||||
Type domain.WorkTaskType `json:"task_type"`
|
||||
|
||||
// Spec contains task-specific parameters.
|
||||
Spec map[string]any `json:"task_spec"`
|
||||
|
||||
215
internal/service/work_service_test.go
Normal file
215
internal/service/work_service_test.go
Normal file
@ -0,0 +1,215 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
)
|
||||
|
||||
func newTestWorkService() (*WorkService, *mockWorkQueue) {
|
||||
q := newMockWorkQueue()
|
||||
svc := NewWorkService(q, WorkServiceConfig{})
|
||||
return svc, q
|
||||
}
|
||||
|
||||
func TestWorkService_EnqueueTask(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
svc, q := newTestWorkService()
|
||||
|
||||
result, err := svc.EnqueueTask(context.Background(), EnqueueTaskRequest{
|
||||
ProjectID: "myapp",
|
||||
Type: domain.WorkTaskTypeBuild,
|
||||
Priority: 1,
|
||||
Spec: map[string]any{"branch": "main"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result.TaskID == "" {
|
||||
t.Error("task ID should not be empty")
|
||||
}
|
||||
if result.StatusURL == "" {
|
||||
t.Error("status URL should not be empty")
|
||||
}
|
||||
if len(q.tasks) != 1 {
|
||||
t.Errorf("tasks in queue = %d, want 1", len(q.tasks))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("default max retries", func(t *testing.T) {
|
||||
svc, q := newTestWorkService()
|
||||
|
||||
_, err := svc.EnqueueTask(context.Background(), EnqueueTaskRequest{
|
||||
ProjectID: "myapp",
|
||||
Type: domain.WorkTaskTypeBuild,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
for _, task := range q.tasks {
|
||||
if task.MaxRetries != 3 {
|
||||
t.Errorf("max retries = %d, want 3 (default)", task.MaxRetries)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing project id", func(t *testing.T) {
|
||||
svc, _ := newTestWorkService()
|
||||
|
||||
_, err := svc.EnqueueTask(context.Background(), EnqueueTaskRequest{
|
||||
Type: domain.WorkTaskTypeBuild,
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected error for missing project_id")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing type", func(t *testing.T) {
|
||||
svc, _ := newTestWorkService()
|
||||
|
||||
_, err := svc.EnqueueTask(context.Background(), EnqueueTaskRequest{
|
||||
ProjectID: "myapp",
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected error for missing type")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWorkService_DequeueTask(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
svc, q := newTestWorkService()
|
||||
|
||||
// Enqueue a task first
|
||||
q.tasks["task-1"] = &domain.WorkTask{
|
||||
ID: "task-1",
|
||||
ProjectID: "myapp",
|
||||
Type: domain.WorkTaskTypeBuild,
|
||||
Status: domain.WorkTaskStatusPending,
|
||||
}
|
||||
|
||||
task, err := svc.DequeueTask(context.Background(), "worker-1")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if task == nil {
|
||||
t.Fatal("expected a task, got nil")
|
||||
}
|
||||
if task.ID != "task-1" {
|
||||
t.Errorf("task ID = %q, want %q", task.ID, "task-1")
|
||||
}
|
||||
if task.Status != domain.WorkTaskStatusRunning {
|
||||
t.Errorf("task status = %q, want %q", task.Status, domain.WorkTaskStatusRunning)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty queue", func(t *testing.T) {
|
||||
svc, _ := newTestWorkService()
|
||||
|
||||
task, err := svc.DequeueTask(context.Background(), "worker-1")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if task != nil {
|
||||
t.Error("expected nil task for empty queue")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing worker id", func(t *testing.T) {
|
||||
svc, _ := newTestWorkService()
|
||||
|
||||
_, err := svc.DequeueTask(context.Background(), "")
|
||||
if err == nil {
|
||||
t.Error("expected error for missing worker_id")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWorkService_CompleteTask(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
svc, q := newTestWorkService()
|
||||
|
||||
q.tasks["task-1"] = &domain.WorkTask{
|
||||
ID: "task-1",
|
||||
ProjectID: "myapp",
|
||||
Type: domain.WorkTaskTypeBuild,
|
||||
Status: domain.WorkTaskStatusRunning,
|
||||
}
|
||||
|
||||
result := &domain.WorkResult{Output: "ok"}
|
||||
err := svc.CompleteTask(context.Background(), "task-1", result)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
task := q.tasks["task-1"]
|
||||
if task.Status != domain.WorkTaskStatusCompleted {
|
||||
t.Errorf("status = %q, want %q", task.Status, domain.WorkTaskStatusCompleted)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("task not found", func(t *testing.T) {
|
||||
svc, _ := newTestWorkService()
|
||||
|
||||
err := svc.CompleteTask(context.Background(), "nonexistent", nil)
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent task")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWorkService_GetTask(t *testing.T) {
|
||||
t.Run("found", func(t *testing.T) {
|
||||
svc, q := newTestWorkService()
|
||||
|
||||
q.tasks["task-1"] = &domain.WorkTask{
|
||||
ID: "task-1",
|
||||
ProjectID: "myapp",
|
||||
}
|
||||
|
||||
task, err := svc.GetTask(context.Background(), "task-1")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if task.ID != "task-1" {
|
||||
t.Errorf("task ID = %q, want %q", task.ID, "task-1")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("not found", func(t *testing.T) {
|
||||
svc, _ := newTestWorkService()
|
||||
|
||||
_, err := svc.GetTask(context.Background(), "missing")
|
||||
if err == nil {
|
||||
t.Error("expected error for missing task")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWorkService_GetStats(t *testing.T) {
|
||||
svc, _ := newTestWorkService()
|
||||
|
||||
stats, err := svc.GetStats(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if stats == nil {
|
||||
t.Error("stats should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorkService_CancelTask(t *testing.T) {
|
||||
svc, q := newTestWorkService()
|
||||
|
||||
q.tasks["task-1"] = &domain.WorkTask{
|
||||
ID: "task-1",
|
||||
ProjectID: "myapp",
|
||||
Status: domain.WorkTaskStatusPending,
|
||||
}
|
||||
|
||||
err := svc.CancelTask(context.Background(), "task-1")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
231
internal/service/worker_service.go
Normal file
231
internal/service/worker_service.go
Normal file
@ -0,0 +1,231 @@
|
||||
// Package service provides business logic services.
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultHeartbeatInterval is how often the health checker runs.
|
||||
DefaultHeartbeatInterval = 30 * time.Second
|
||||
|
||||
// DefaultStaleThreshold is how long since last heartbeat before marking offline.
|
||||
DefaultStaleThreshold = 90 * time.Second
|
||||
)
|
||||
|
||||
// WorkerService manages worker lifecycle and task assignment.
|
||||
// It coordinates between the worker registry (pool management) and
|
||||
// the work queue (task execution).
|
||||
type WorkerService struct {
|
||||
registry port.WorkerRegistry
|
||||
queue port.WorkQueue
|
||||
audit port.BuildAudit
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewWorkerService creates a new worker service.
|
||||
func NewWorkerService(
|
||||
registry port.WorkerRegistry,
|
||||
queue port.WorkQueue,
|
||||
logger *slog.Logger,
|
||||
) *WorkerService {
|
||||
if logger == nil {
|
||||
logger = slog.Default()
|
||||
}
|
||||
return &WorkerService{
|
||||
registry: registry,
|
||||
queue: queue,
|
||||
logger: logger.With("service", "worker"),
|
||||
}
|
||||
}
|
||||
|
||||
// WithBuildAudit adds a build audit for recording task assignments.
|
||||
func (s *WorkerService) WithBuildAudit(audit port.BuildAudit) *WorkerService {
|
||||
s.audit = audit
|
||||
return s
|
||||
}
|
||||
|
||||
// Register adds a worker to the pool.
|
||||
func (s *WorkerService) Register(ctx context.Context, worker *domain.Worker) error {
|
||||
if err := worker.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
worker.RegisteredAt = time.Now()
|
||||
worker.LastHeartbeat = time.Now()
|
||||
worker.Status = domain.WorkerStatusIdle
|
||||
|
||||
if err := s.registry.Register(ctx, worker); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.logger.Info("worker registered",
|
||||
"worker_id", worker.ID,
|
||||
"hostname", worker.Hostname,
|
||||
"version", worker.Version,
|
||||
"capabilities", worker.Capabilities,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Heartbeat updates worker liveness.
|
||||
func (s *WorkerService) Heartbeat(ctx context.Context, workerID string) error {
|
||||
return s.registry.Heartbeat(ctx, workerID)
|
||||
}
|
||||
|
||||
// Deregister removes a worker from the pool.
|
||||
func (s *WorkerService) Deregister(ctx context.Context, workerID string) error {
|
||||
if err := s.registry.Deregister(ctx, workerID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.logger.Info("worker deregistered", "worker_id", workerID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetWorker retrieves a specific worker.
|
||||
func (s *WorkerService) GetWorker(ctx context.Context, workerID string) (*domain.Worker, error) {
|
||||
return s.registry.Get(ctx, workerID)
|
||||
}
|
||||
|
||||
// ListWorkers returns all workers matching the optional filter.
|
||||
func (s *WorkerService) ListWorkers(ctx context.Context, filter port.WorkerFilter) ([]*domain.Worker, error) {
|
||||
return s.registry.List(ctx, filter)
|
||||
}
|
||||
|
||||
// ClaimTask atomically dequeues a task and marks worker as busy.
|
||||
func (s *WorkerService) ClaimTask(ctx context.Context, workerID string) (*domain.WorkTask, error) {
|
||||
task, err := s.queue.Dequeue(ctx, workerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if task == nil {
|
||||
return nil, nil // No tasks available
|
||||
}
|
||||
|
||||
// Mark worker as busy with the claimed task
|
||||
if err := s.registry.UpdateStatus(ctx, workerID, domain.WorkerStatusBusy, task.ID); err != nil {
|
||||
s.logger.Warn("failed to update worker status after claim",
|
||||
"worker_id", workerID,
|
||||
"task_id", task.ID,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
|
||||
// Update audit entry if available
|
||||
if s.audit != nil {
|
||||
entry, _ := s.audit.Get(ctx, task.ID)
|
||||
if entry != nil {
|
||||
entry.WorkerID = workerID
|
||||
entry.Status = domain.BuildStatusRunning
|
||||
}
|
||||
}
|
||||
|
||||
s.logger.Info("task claimed",
|
||||
"task_id", task.ID,
|
||||
"worker_id", workerID,
|
||||
"project_id", task.ProjectID,
|
||||
"type", task.Type,
|
||||
)
|
||||
|
||||
return task, nil
|
||||
}
|
||||
|
||||
// CompleteTask marks a task as complete and returns worker to idle.
|
||||
func (s *WorkerService) CompleteTask(ctx context.Context, workerID, taskID string, result *domain.BuildResult) error {
|
||||
if result == nil {
|
||||
result = &domain.BuildResult{}
|
||||
}
|
||||
|
||||
// Convert domain build result to work result
|
||||
bwr := result.ToWorkResult()
|
||||
workResult := &domain.WorkResult{
|
||||
Output: bwr.Output,
|
||||
Artifacts: bwr.Artifacts,
|
||||
}
|
||||
|
||||
// Update audit record (non-critical)
|
||||
if s.audit != nil {
|
||||
if err := s.audit.Update(ctx, taskID, result); err != nil {
|
||||
s.logger.Warn("failed to update audit",
|
||||
"task_id", taskID,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Complete in queue
|
||||
if err := s.queue.Complete(ctx, taskID, workResult); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Return worker to idle
|
||||
if err := s.registry.UpdateStatus(ctx, workerID, domain.WorkerStatusIdle, ""); err != nil {
|
||||
s.logger.Warn("failed to return worker to idle",
|
||||
"worker_id", workerID,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
|
||||
s.logger.Info("task completed",
|
||||
"task_id", taskID,
|
||||
"worker_id", workerID,
|
||||
"success", result.Success,
|
||||
"duration_ms", result.DurationMs,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DrainWorker sets a worker to draining status so it finishes current work
|
||||
// but doesn't accept new tasks.
|
||||
func (s *WorkerService) DrainWorker(ctx context.Context, workerID string) error {
|
||||
worker, err := s.registry.Get(ctx, workerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.registry.UpdateStatus(ctx, workerID, domain.WorkerStatusDraining, worker.CurrentTask); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.logger.Info("worker draining",
|
||||
"worker_id", workerID,
|
||||
"current_task", worker.CurrentTask,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartHealthChecker runs a background goroutine that marks stale workers offline.
|
||||
// It returns when the context is cancelled.
|
||||
func (s *WorkerService) StartHealthChecker(ctx context.Context) {
|
||||
ticker := time.NewTicker(DefaultHeartbeatInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
s.logger.Info("worker health checker started",
|
||||
"interval", DefaultHeartbeatInterval,
|
||||
"stale_threshold", DefaultStaleThreshold,
|
||||
)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
s.logger.Info("worker health checker stopped")
|
||||
return
|
||||
case <-ticker.C:
|
||||
count, err := s.registry.MarkStaleOffline(ctx, DefaultStaleThreshold)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to mark stale workers", "error", err)
|
||||
} else if count > 0 {
|
||||
s.logger.Warn("marked workers offline", "count", count)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
328
internal/service/worker_service_test.go
Normal file
328
internal/service/worker_service_test.go
Normal file
@ -0,0 +1,328 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
)
|
||||
|
||||
func TestWorkerService_Register(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("registers valid worker", func(t *testing.T) {
|
||||
registry := newMockWorkerRegistry()
|
||||
queue := newMockWorkQueue()
|
||||
svc := NewWorkerService(registry, queue, nil)
|
||||
|
||||
err := svc.Register(ctx, &domain.Worker{
|
||||
ID: "worker-1",
|
||||
Hostname: "host-1",
|
||||
Capabilities: []string{"build"},
|
||||
Version: "1.0.0",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Register() error = %v", err)
|
||||
}
|
||||
|
||||
w := registry.workers["worker-1"]
|
||||
if w == nil {
|
||||
t.Fatal("worker not found in registry")
|
||||
}
|
||||
if w.Status != domain.WorkerStatusIdle {
|
||||
t.Errorf("got status %q, want %q", w.Status, domain.WorkerStatusIdle)
|
||||
}
|
||||
if w.RegisteredAt.IsZero() {
|
||||
t.Error("expected registered_at to be set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("validates worker", func(t *testing.T) {
|
||||
registry := newMockWorkerRegistry()
|
||||
queue := newMockWorkQueue()
|
||||
svc := NewWorkerService(registry, queue, nil)
|
||||
|
||||
err := svc.Register(ctx, &domain.Worker{})
|
||||
if err == nil {
|
||||
t.Error("expected validation error for empty worker")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWorkerService_Heartbeat(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("updates heartbeat", func(t *testing.T) {
|
||||
registry := newMockWorkerRegistry()
|
||||
registry.workers["worker-1"] = &domain.Worker{
|
||||
ID: "worker-1",
|
||||
Hostname: "host-1",
|
||||
Status: domain.WorkerStatusIdle,
|
||||
LastHeartbeat: time.Now().Add(-30 * time.Second),
|
||||
}
|
||||
|
||||
svc := NewWorkerService(registry, newMockWorkQueue(), nil)
|
||||
|
||||
err := svc.Heartbeat(ctx, "worker-1")
|
||||
if err != nil {
|
||||
t.Fatalf("Heartbeat() error = %v", err)
|
||||
}
|
||||
|
||||
// Heartbeat should be recent
|
||||
w := registry.workers["worker-1"]
|
||||
if time.Since(w.LastHeartbeat) > time.Second {
|
||||
t.Error("expected heartbeat to be updated to now")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns error for nonexistent worker", func(t *testing.T) {
|
||||
registry := newMockWorkerRegistry()
|
||||
svc := NewWorkerService(registry, newMockWorkQueue(), nil)
|
||||
|
||||
err := svc.Heartbeat(ctx, "nonexistent")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent worker")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWorkerService_Deregister(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("deregisters worker", func(t *testing.T) {
|
||||
registry := newMockWorkerRegistry()
|
||||
registry.workers["worker-1"] = &domain.Worker{
|
||||
ID: "worker-1",
|
||||
Hostname: "host-1",
|
||||
Status: domain.WorkerStatusIdle,
|
||||
}
|
||||
|
||||
svc := NewWorkerService(registry, newMockWorkQueue(), nil)
|
||||
|
||||
err := svc.Deregister(ctx, "worker-1")
|
||||
if err != nil {
|
||||
t.Fatalf("Deregister() error = %v", err)
|
||||
}
|
||||
|
||||
if _, ok := registry.workers["worker-1"]; ok {
|
||||
t.Error("worker should be removed from registry")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWorkerService_ClaimTask(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("claims task and marks worker busy", func(t *testing.T) {
|
||||
registry := newMockWorkerRegistry()
|
||||
registry.workers["worker-1"] = &domain.Worker{
|
||||
ID: "worker-1",
|
||||
Hostname: "host-1",
|
||||
Status: domain.WorkerStatusIdle,
|
||||
}
|
||||
|
||||
queue := newMockWorkQueue()
|
||||
queue.tasks["task-1"] = &domain.WorkTask{
|
||||
ID: "task-1",
|
||||
ProjectID: "project-1",
|
||||
Type: domain.WorkTaskTypeBuild,
|
||||
Status: domain.WorkTaskStatusPending,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
svc := NewWorkerService(registry, queue, nil)
|
||||
|
||||
task, err := svc.ClaimTask(ctx, "worker-1")
|
||||
if err != nil {
|
||||
t.Fatalf("ClaimTask() error = %v", err)
|
||||
}
|
||||
if task == nil {
|
||||
t.Fatal("expected task to be returned")
|
||||
}
|
||||
if task.ID != "task-1" {
|
||||
t.Errorf("got task ID %q, want %q", task.ID, "task-1")
|
||||
}
|
||||
|
||||
// Worker should be busy with the task
|
||||
w := registry.workers["worker-1"]
|
||||
if w.Status != domain.WorkerStatusBusy {
|
||||
t.Errorf("got status %q, want %q", w.Status, domain.WorkerStatusBusy)
|
||||
}
|
||||
if w.CurrentTask != "task-1" {
|
||||
t.Errorf("got current_task %q, want %q", w.CurrentTask, "task-1")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns nil when no tasks available", func(t *testing.T) {
|
||||
registry := newMockWorkerRegistry()
|
||||
registry.workers["worker-1"] = &domain.Worker{
|
||||
ID: "worker-1",
|
||||
Hostname: "host-1",
|
||||
Status: domain.WorkerStatusIdle,
|
||||
}
|
||||
|
||||
queue := newMockWorkQueue()
|
||||
svc := NewWorkerService(registry, queue, nil)
|
||||
|
||||
task, err := svc.ClaimTask(ctx, "worker-1")
|
||||
if err != nil {
|
||||
t.Fatalf("ClaimTask() error = %v", err)
|
||||
}
|
||||
if task != nil {
|
||||
t.Error("expected nil task when queue is empty")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWorkerService_CompleteTask(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("completes task and returns worker to idle", func(t *testing.T) {
|
||||
registry := newMockWorkerRegistry()
|
||||
registry.workers["worker-1"] = &domain.Worker{
|
||||
ID: "worker-1",
|
||||
Hostname: "host-1",
|
||||
Status: domain.WorkerStatusBusy,
|
||||
CurrentTask: "task-1",
|
||||
}
|
||||
|
||||
queue := newMockWorkQueue()
|
||||
queue.tasks["task-1"] = &domain.WorkTask{
|
||||
ID: "task-1",
|
||||
ProjectID: "project-1",
|
||||
Type: domain.WorkTaskTypeBuild,
|
||||
Status: domain.WorkTaskStatusRunning,
|
||||
WorkerID: "worker-1",
|
||||
}
|
||||
|
||||
svc := NewWorkerService(registry, queue, nil)
|
||||
|
||||
err := svc.CompleteTask(ctx, "worker-1", "task-1", &domain.BuildResult{
|
||||
Success: true,
|
||||
CommitSHA: "abc123",
|
||||
DurationMs: 5000,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CompleteTask() error = %v", err)
|
||||
}
|
||||
|
||||
// Task should be completed
|
||||
task := queue.tasks["task-1"]
|
||||
if task.Status != domain.WorkTaskStatusCompleted {
|
||||
t.Errorf("got task status %q, want %q", task.Status, domain.WorkTaskStatusCompleted)
|
||||
}
|
||||
|
||||
// Worker should be idle
|
||||
w := registry.workers["worker-1"]
|
||||
if w.Status != domain.WorkerStatusIdle {
|
||||
t.Errorf("got worker status %q, want %q", w.Status, domain.WorkerStatusIdle)
|
||||
}
|
||||
if w.CurrentTask != "" {
|
||||
t.Errorf("got current_task %q, want empty", w.CurrentTask)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles nil result", func(t *testing.T) {
|
||||
registry := newMockWorkerRegistry()
|
||||
registry.workers["worker-1"] = &domain.Worker{
|
||||
ID: "worker-1",
|
||||
Hostname: "host-1",
|
||||
Status: domain.WorkerStatusBusy,
|
||||
CurrentTask: "task-1",
|
||||
}
|
||||
|
||||
queue := newMockWorkQueue()
|
||||
queue.tasks["task-1"] = &domain.WorkTask{
|
||||
ID: "task-1",
|
||||
Status: domain.WorkTaskStatusRunning,
|
||||
WorkerID: "worker-1",
|
||||
}
|
||||
|
||||
svc := NewWorkerService(registry, queue, nil)
|
||||
|
||||
err := svc.CompleteTask(ctx, "worker-1", "task-1", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("CompleteTask(nil result) error = %v", err)
|
||||
}
|
||||
|
||||
// Worker should be idle
|
||||
w := registry.workers["worker-1"]
|
||||
if w.Status != domain.WorkerStatusIdle {
|
||||
t.Errorf("got worker status %q, want %q", w.Status, domain.WorkerStatusIdle)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWorkerService_ListWorkers(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
registry := newMockWorkerRegistry()
|
||||
registry.workers["worker-1"] = &domain.Worker{ID: "worker-1", Status: domain.WorkerStatusIdle}
|
||||
registry.workers["worker-2"] = &domain.Worker{ID: "worker-2", Status: domain.WorkerStatusBusy}
|
||||
registry.workers["worker-3"] = &domain.Worker{ID: "worker-3", Status: domain.WorkerStatusIdle}
|
||||
|
||||
svc := NewWorkerService(registry, newMockWorkQueue(), nil)
|
||||
|
||||
t.Run("lists all workers", func(t *testing.T) {
|
||||
workers, err := svc.ListWorkers(ctx, port.WorkerFilter{})
|
||||
if err != nil {
|
||||
t.Fatalf("ListWorkers() error = %v", err)
|
||||
}
|
||||
if len(workers) != 3 {
|
||||
t.Errorf("got %d workers, want 3", len(workers))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filters by status", func(t *testing.T) {
|
||||
idle := domain.WorkerStatusIdle
|
||||
workers, err := svc.ListWorkers(ctx, port.WorkerFilter{Status: &idle})
|
||||
if err != nil {
|
||||
t.Fatalf("ListWorkers() error = %v", err)
|
||||
}
|
||||
if len(workers) != 2 {
|
||||
t.Errorf("got %d idle workers, want 2", len(workers))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWorkerService_DrainWorker(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("drains worker", func(t *testing.T) {
|
||||
registry := newMockWorkerRegistry()
|
||||
registry.workers["worker-1"] = &domain.Worker{
|
||||
ID: "worker-1",
|
||||
Hostname: "host-1",
|
||||
Status: domain.WorkerStatusBusy,
|
||||
CurrentTask: "task-1",
|
||||
}
|
||||
|
||||
svc := NewWorkerService(registry, newMockWorkQueue(), nil)
|
||||
|
||||
err := svc.DrainWorker(ctx, "worker-1")
|
||||
if err != nil {
|
||||
t.Fatalf("DrainWorker() error = %v", err)
|
||||
}
|
||||
|
||||
w := registry.workers["worker-1"]
|
||||
if w.Status != domain.WorkerStatusDraining {
|
||||
t.Errorf("got status %q, want %q", w.Status, domain.WorkerStatusDraining)
|
||||
}
|
||||
// Should preserve current task
|
||||
if w.CurrentTask != "task-1" {
|
||||
t.Errorf("got current_task %q, want %q", w.CurrentTask, "task-1")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns error for nonexistent worker", func(t *testing.T) {
|
||||
registry := newMockWorkerRegistry()
|
||||
svc := NewWorkerService(registry, newMockWorkQueue(), nil)
|
||||
|
||||
err := svc.DrainWorker(ctx, "nonexistent")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent worker")
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -13,7 +13,7 @@ package telemetry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
@ -111,7 +111,7 @@ func New(ctx context.Context, cfg Config) (*Telemetry, error) {
|
||||
|
||||
exporter, err := otlptracegrpc.New(ctx, opts...)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to create OTLP exporter: " + err.Error())
|
||||
return nil, fmt.Errorf("failed to create OTLP exporter: %w", err)
|
||||
}
|
||||
|
||||
// Create resource with service information
|
||||
@ -179,7 +179,7 @@ func (t *Telemetry) Shutdown(ctx context.Context) error {
|
||||
}
|
||||
|
||||
if err := t.tracerProvider.Shutdown(ctx); err != nil {
|
||||
return errors.New("telemetry shutdown failed: " + err.Error())
|
||||
return fmt.Errorf("telemetry shutdown failed: %w", err)
|
||||
}
|
||||
|
||||
t.logger.Info("telemetry shutdown complete")
|
||||
|
||||
@ -5,6 +5,7 @@ package validate
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
@ -239,6 +240,22 @@ var (
|
||||
|
||||
// --- Convenience validators for common patterns ---
|
||||
|
||||
// HTTPURL validates that a string is a valid HTTP or HTTPS URL.
|
||||
// Returns nil for empty strings (use Required for that check).
|
||||
func HTTPURL(value, field string) error {
|
||||
if value == "" {
|
||||
return nil
|
||||
}
|
||||
parsed, err := url.Parse(value)
|
||||
if err != nil || (parsed.Scheme != "http" && parsed.Scheme != "https") || parsed.Host == "" {
|
||||
return ValidationError{
|
||||
Field: field,
|
||||
Message: "must be a valid HTTP or HTTPS URL",
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Name validates a name field (alphanumeric with dashes/underscores, 1-64 chars).
|
||||
// This matches the existing isValidName pattern in claude_config.go.
|
||||
func Name(value, field string) error {
|
||||
|
||||
196
internal/worker/build_executor.go
Normal file
196
internal/worker/build_executor.go
Normal file
@ -0,0 +1,196 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
)
|
||||
|
||||
// BuildExecutor handles WorkTaskTypeBuild tasks.
|
||||
// It translates BuildSpec fields from the work task's Spec map into an
|
||||
// AgentRequest, executes via a CodeAgent, and returns a BuildResult.
|
||||
type BuildExecutor struct {
|
||||
agentRegistry port.CodeAgentRegistry
|
||||
gitOps *GitOperations
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewBuildExecutor creates a new build executor.
|
||||
func NewBuildExecutor(
|
||||
agentRegistry port.CodeAgentRegistry,
|
||||
gitOps *GitOperations,
|
||||
logger *slog.Logger,
|
||||
) *BuildExecutor {
|
||||
if logger == nil {
|
||||
logger = slog.Default()
|
||||
}
|
||||
return &BuildExecutor{
|
||||
agentRegistry: agentRegistry,
|
||||
gitOps: gitOps,
|
||||
logger: logger.With("component", "build-executor"),
|
||||
}
|
||||
}
|
||||
|
||||
// Execute runs a build task by translating its spec into an agent call.
|
||||
func (b *BuildExecutor) Execute(ctx context.Context, task *domain.WorkTask) *domain.BuildResult {
|
||||
start := time.Now()
|
||||
|
||||
spec, err := b.parseSpec(task.Spec)
|
||||
if err != nil {
|
||||
return &domain.BuildResult{
|
||||
Success: false,
|
||||
Error: fmt.Sprintf("invalid build spec: %v", err),
|
||||
DurationMs: time.Since(start).Milliseconds(),
|
||||
}
|
||||
}
|
||||
|
||||
// Determine working directory
|
||||
workDir := "/workspace"
|
||||
|
||||
// Clone repo if git URL is provided in the spec
|
||||
gitURL, _ := task.Spec["git_url"].(string)
|
||||
if gitURL != "" && b.gitOps != nil {
|
||||
cloneDir, cleanup, err := b.gitOps.CloneToTemp(ctx, gitURL)
|
||||
if err != nil {
|
||||
return &domain.BuildResult{
|
||||
Success: false,
|
||||
Error: fmt.Sprintf("git clone failed: %v", err),
|
||||
DurationMs: time.Since(start).Milliseconds(),
|
||||
}
|
||||
}
|
||||
defer cleanup()
|
||||
workDir = cloneDir
|
||||
}
|
||||
|
||||
// Get a code agent
|
||||
agent := b.agentRegistry.Default()
|
||||
if agent == nil {
|
||||
return &domain.BuildResult{
|
||||
Success: false,
|
||||
Error: "no code agent available",
|
||||
DurationMs: time.Since(start).Milliseconds(),
|
||||
}
|
||||
}
|
||||
|
||||
// Build the agent request
|
||||
agentReq := &domain.AgentRequest{
|
||||
Prompt: spec.Prompt,
|
||||
ProjectID: domain.ProjectID(task.ProjectID),
|
||||
WorkingDir: workDir,
|
||||
Timeout: 10 * time.Minute,
|
||||
}
|
||||
|
||||
// Collect output with a size cap to prevent OOM on verbose builds.
|
||||
const maxOutputSize = 1 << 20 // 1MB
|
||||
var outputBuilder strings.Builder
|
||||
|
||||
b.logger.Info("executing build via agent",
|
||||
"task_id", task.ID,
|
||||
"project_id", task.ProjectID,
|
||||
"agent", agent.Name(),
|
||||
"work_dir", workDir,
|
||||
)
|
||||
|
||||
// Execute the agent
|
||||
agentResult, err := agent.Execute(ctx, agentReq, func(event domain.AgentEvent) {
|
||||
if event.Type == domain.AgentEventOutput || event.Type == domain.AgentEventError {
|
||||
if outputBuilder.Len() >= maxOutputSize {
|
||||
return // Output cap reached, discard further output
|
||||
}
|
||||
if outputBuilder.Len() > 0 {
|
||||
outputBuilder.WriteString("\n")
|
||||
}
|
||||
remaining := maxOutputSize - outputBuilder.Len()
|
||||
if len(event.Content) > remaining {
|
||||
outputBuilder.WriteString(event.Content[:remaining])
|
||||
outputBuilder.WriteString("\n... [output truncated at 1MB]")
|
||||
} else {
|
||||
outputBuilder.WriteString(event.Content)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return &domain.BuildResult{
|
||||
Success: false,
|
||||
Error: fmt.Sprintf("agent execution failed: %v", err),
|
||||
Output: outputBuilder.String(),
|
||||
DurationMs: time.Since(start).Milliseconds(),
|
||||
}
|
||||
}
|
||||
|
||||
result := &domain.BuildResult{
|
||||
Success: agentResult.Success(),
|
||||
Output: outputBuilder.String(),
|
||||
DurationMs: time.Since(start).Milliseconds(),
|
||||
}
|
||||
|
||||
if !agentResult.Success() {
|
||||
errMsg := "agent returned non-zero exit code"
|
||||
if agentResult.Error != nil {
|
||||
errMsg = agentResult.Error.Error()
|
||||
}
|
||||
result.Error = errMsg
|
||||
}
|
||||
|
||||
// Handle git commit/push if requested
|
||||
if result.Success && b.gitOps != nil && gitURL != "" {
|
||||
if spec.AutoCommit {
|
||||
commitMsg := fmt.Sprintf("build: %s", truncate(spec.Prompt, 72))
|
||||
sha, filesChanged, err := b.gitOps.CommitAndPush(ctx, workDir, commitMsg, spec.AutoPush)
|
||||
if err != nil {
|
||||
b.logger.Warn("git commit/push failed",
|
||||
"task_id", task.ID,
|
||||
"error", err,
|
||||
)
|
||||
result.Success = false
|
||||
result.Error = fmt.Sprintf("build succeeded but git operations failed: %v", err)
|
||||
} else {
|
||||
result.CommitSHA = sha
|
||||
result.FilesChanged = filesChanged
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// parsedBuildSpec holds typed fields extracted from the task spec map.
|
||||
type parsedBuildSpec struct {
|
||||
Prompt string
|
||||
AutoCommit bool
|
||||
AutoPush bool
|
||||
}
|
||||
|
||||
// parseSpec extracts typed BuildSpec fields from the generic map[string]any.
|
||||
func (b *BuildExecutor) parseSpec(spec map[string]any) (*parsedBuildSpec, error) {
|
||||
prompt, _ := spec["prompt"].(string)
|
||||
if prompt == "" {
|
||||
return nil, fmt.Errorf("prompt is required")
|
||||
}
|
||||
|
||||
autoCommit, _ := spec["auto_commit"].(bool)
|
||||
autoPush, _ := spec["auto_push"].(bool)
|
||||
|
||||
return &parsedBuildSpec{
|
||||
Prompt: prompt,
|
||||
AutoCommit: autoCommit,
|
||||
AutoPush: autoPush,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// truncate shortens a string to maxLen, adding "..." if truncated.
|
||||
func truncate(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
if maxLen <= 3 {
|
||||
return s[:maxLen]
|
||||
}
|
||||
return s[:maxLen-3] + "..."
|
||||
}
|
||||
233
internal/worker/git_operations.go
Normal file
233
internal/worker/git_operations.go
Normal file
@ -0,0 +1,233 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// GitOperations provides git clone, commit, and push functionality
|
||||
// for the build executor. It uses os/exec to run git commands.
|
||||
type GitOperations struct {
|
||||
giteaToken string
|
||||
gitUser string
|
||||
gitEmail string
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// GitOperationsConfig configures git operations.
|
||||
type GitOperationsConfig struct {
|
||||
// GiteaToken is the token for HTTPS clone/push authentication.
|
||||
GiteaToken string
|
||||
|
||||
// GitUser is the git commit author name.
|
||||
GitUser string
|
||||
|
||||
// GitEmail is the git commit author email.
|
||||
GitEmail string
|
||||
|
||||
Logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewGitOperations creates a new git operations helper.
|
||||
func NewGitOperations(cfg GitOperationsConfig) *GitOperations {
|
||||
if cfg.GitUser == "" {
|
||||
cfg.GitUser = "rdev-worker"
|
||||
}
|
||||
if cfg.GitEmail == "" {
|
||||
cfg.GitEmail = "worker@threesix.ai"
|
||||
}
|
||||
if cfg.Logger == nil {
|
||||
cfg.Logger = slog.Default()
|
||||
}
|
||||
return &GitOperations{
|
||||
giteaToken: cfg.GiteaToken,
|
||||
gitUser: cfg.GitUser,
|
||||
gitEmail: cfg.GitEmail,
|
||||
logger: cfg.Logger.With("component", "git-ops"),
|
||||
}
|
||||
}
|
||||
|
||||
// CloneToTemp clones a repository to a temporary directory.
|
||||
// Returns the clone directory and a cleanup function.
|
||||
func (g *GitOperations) CloneToTemp(ctx context.Context, gitURL string) (string, func(), error) {
|
||||
tmpDir, err := os.MkdirTemp("", "rdev-build-*")
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("create temp dir: %w", err)
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
if err := os.RemoveAll(tmpDir); err != nil {
|
||||
g.logger.Warn("failed to cleanup temp dir", "dir", tmpDir, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Inject token into clone URL for authentication
|
||||
authURL := g.injectToken(gitURL)
|
||||
|
||||
if err := g.runGit(ctx, tmpDir, "clone", authURL, "."); err != nil {
|
||||
cleanup()
|
||||
return "", nil, fmt.Errorf("git clone: %w", err)
|
||||
}
|
||||
|
||||
// Configure git user for commits
|
||||
if err := g.runGit(ctx, tmpDir, "config", "user.name", g.gitUser); err != nil {
|
||||
cleanup()
|
||||
return "", nil, fmt.Errorf("git config user.name: %w", err)
|
||||
}
|
||||
if err := g.runGit(ctx, tmpDir, "config", "user.email", g.gitEmail); err != nil {
|
||||
cleanup()
|
||||
return "", nil, fmt.Errorf("git config user.email: %w", err)
|
||||
}
|
||||
|
||||
g.logger.Info("cloned repository", "url", gitURL, "dir", tmpDir)
|
||||
return tmpDir, cleanup, nil
|
||||
}
|
||||
|
||||
// CommitAndPush stages all changes, commits, and optionally pushes.
|
||||
// Returns the commit SHA and list of changed files.
|
||||
func (g *GitOperations) CommitAndPush(ctx context.Context, dir, message string, push bool) (string, []string, error) {
|
||||
// Stage all changes
|
||||
if err := g.runGit(ctx, dir, "add", "-A"); err != nil {
|
||||
return "", nil, fmt.Errorf("git add: %w", err)
|
||||
}
|
||||
|
||||
// Check if there are changes to commit
|
||||
status, err := g.runGitOutput(ctx, dir, "status", "--porcelain")
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("git status: %w", err)
|
||||
}
|
||||
if strings.TrimSpace(status) == "" {
|
||||
g.logger.Info("no changes to commit", "dir", dir)
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
// Get list of changed files
|
||||
diffOutput, err := g.runGitOutput(ctx, dir, "diff", "--cached", "--name-only")
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("git diff: %w", err)
|
||||
}
|
||||
var filesChanged []string
|
||||
for _, f := range strings.Split(strings.TrimSpace(diffOutput), "\n") {
|
||||
if f != "" {
|
||||
filesChanged = append(filesChanged, f)
|
||||
}
|
||||
}
|
||||
|
||||
// Commit
|
||||
if err := g.runGit(ctx, dir, "commit", "-m", message); err != nil {
|
||||
return "", nil, fmt.Errorf("git commit: %w", err)
|
||||
}
|
||||
|
||||
// Get commit SHA
|
||||
sha, err := g.runGitOutput(ctx, dir, "rev-parse", "HEAD")
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("git rev-parse: %w", err)
|
||||
}
|
||||
sha = strings.TrimSpace(sha)
|
||||
|
||||
g.logger.Info("committed changes",
|
||||
"sha", sha,
|
||||
"files", len(filesChanged),
|
||||
)
|
||||
|
||||
// Push if requested
|
||||
if push {
|
||||
if err := g.runGit(ctx, dir, "push"); err != nil {
|
||||
return sha, filesChanged, fmt.Errorf("git push: %w", err)
|
||||
}
|
||||
g.logger.Info("pushed changes", "sha", sha)
|
||||
}
|
||||
|
||||
return sha, filesChanged, nil
|
||||
}
|
||||
|
||||
// injectToken adds the Gitea token to an HTTPS git URL for authentication.
|
||||
// Converts "https://git.example.com/org/repo.git" to
|
||||
// "https://token@git.example.com/org/repo.git".
|
||||
func (g *GitOperations) injectToken(gitURL string) string {
|
||||
if g.giteaToken == "" {
|
||||
return gitURL
|
||||
}
|
||||
// Handle https:// URLs
|
||||
if strings.HasPrefix(gitURL, "https://") {
|
||||
return "https://" + g.giteaToken + "@" + gitURL[len("https://"):]
|
||||
}
|
||||
if strings.HasPrefix(gitURL, "http://") {
|
||||
return "http://" + g.giteaToken + "@" + gitURL[len("http://"):]
|
||||
}
|
||||
return gitURL
|
||||
}
|
||||
|
||||
// gitEnv returns a minimal environment for git subprocesses.
|
||||
// Only PATH and HOME are inherited; all other host env vars are excluded
|
||||
// to prevent credential or config leakage.
|
||||
func gitEnv() []string {
|
||||
env := []string{"GIT_TERMINAL_PROMPT=0"}
|
||||
for _, key := range []string{"PATH", "HOME"} {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
env = append(env, key+"="+v)
|
||||
}
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
// runGit executes a git command in the given directory.
|
||||
func (g *GitOperations) runGit(ctx context.Context, dir string, args ...string) error {
|
||||
cmd := exec.CommandContext(ctx, "git", args...)
|
||||
cmd.Dir = dir
|
||||
cmd.Env = gitEnv()
|
||||
|
||||
var stderr bytes.Buffer
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
// Redact token from error messages
|
||||
errMsg := g.redactToken(stderr.String())
|
||||
return fmt.Errorf("%s: %s", err, errMsg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// runGitOutput executes a git command and returns its stdout.
|
||||
func (g *GitOperations) runGitOutput(ctx context.Context, dir string, args ...string) (string, error) {
|
||||
cmd := exec.CommandContext(ctx, "git", args...)
|
||||
cmd.Dir = dir
|
||||
cmd.Env = gitEnv()
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
errMsg := g.redactToken(stderr.String())
|
||||
return "", fmt.Errorf("%s: %s", err, errMsg)
|
||||
}
|
||||
return stdout.String(), nil
|
||||
}
|
||||
|
||||
// redactToken removes the Gitea token from log/error output.
|
||||
func (g *GitOperations) redactToken(s string) string {
|
||||
if g.giteaToken == "" {
|
||||
return s
|
||||
}
|
||||
return strings.ReplaceAll(s, g.giteaToken, "[REDACTED]")
|
||||
}
|
||||
|
||||
// EnsureGitDir verifies that the given path is a valid git repository.
|
||||
func (g *GitOperations) EnsureGitDir(dir string) error {
|
||||
gitDir := filepath.Join(dir, ".git")
|
||||
info, err := os.Stat(gitDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("not a git repository: %w", err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return fmt.Errorf("not a git repository: .git is not a directory")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
415
internal/worker/git_operations_test.go
Normal file
415
internal/worker/git_operations_test.go
Normal file
@ -0,0 +1,415 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func testGitOps(token string) *GitOperations {
|
||||
return NewGitOperations(GitOperationsConfig{
|
||||
GiteaToken: token,
|
||||
GitUser: "test-user",
|
||||
GitEmail: "test@example.com",
|
||||
Logger: slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelWarn})),
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewGitOperations_Defaults(t *testing.T) {
|
||||
g := NewGitOperations(GitOperationsConfig{})
|
||||
if g.gitUser != "rdev-worker" {
|
||||
t.Errorf("expected default gitUser 'rdev-worker', got %q", g.gitUser)
|
||||
}
|
||||
if g.gitEmail != "worker@threesix.ai" {
|
||||
t.Errorf("expected default gitEmail 'worker@threesix.ai', got %q", g.gitEmail)
|
||||
}
|
||||
if g.logger == nil {
|
||||
t.Error("expected non-nil logger")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewGitOperations_CustomValues(t *testing.T) {
|
||||
g := NewGitOperations(GitOperationsConfig{
|
||||
GiteaToken: "my-token",
|
||||
GitUser: "custom-user",
|
||||
GitEmail: "custom@example.com",
|
||||
})
|
||||
if g.giteaToken != "my-token" {
|
||||
t.Errorf("expected token 'my-token', got %q", g.giteaToken)
|
||||
}
|
||||
if g.gitUser != "custom-user" {
|
||||
t.Errorf("expected gitUser 'custom-user', got %q", g.gitUser)
|
||||
}
|
||||
if g.gitEmail != "custom@example.com" {
|
||||
t.Errorf("expected gitEmail 'custom@example.com', got %q", g.gitEmail)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInjectToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
url string
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "https URL with token",
|
||||
token: "ghp_abc123",
|
||||
url: "https://git.example.com/org/repo.git",
|
||||
expect: "https://ghp_abc123@git.example.com/org/repo.git",
|
||||
},
|
||||
{
|
||||
name: "http URL with token",
|
||||
token: "ghp_abc123",
|
||||
url: "http://git.example.com/org/repo.git",
|
||||
expect: "http://ghp_abc123@git.example.com/org/repo.git",
|
||||
},
|
||||
{
|
||||
name: "no token",
|
||||
token: "",
|
||||
url: "https://git.example.com/org/repo.git",
|
||||
expect: "https://git.example.com/org/repo.git",
|
||||
},
|
||||
{
|
||||
name: "ssh URL unchanged",
|
||||
token: "ghp_abc123",
|
||||
url: "git@git.example.com:org/repo.git",
|
||||
expect: "git@git.example.com:org/repo.git",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
g := testGitOps(tt.token)
|
||||
got := g.injectToken(tt.url)
|
||||
if got != tt.expect {
|
||||
t.Errorf("injectToken(%q) = %q, want %q", tt.url, got, tt.expect)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedactToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
input string
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "redacts token from message",
|
||||
token: "secret123",
|
||||
input: "fatal: Authentication failed for 'https://secret123@git.example.com/repo.git'",
|
||||
expect: "fatal: Authentication failed for 'https://[REDACTED]@git.example.com/repo.git'",
|
||||
},
|
||||
{
|
||||
name: "no token to redact",
|
||||
token: "",
|
||||
input: "fatal: repository not found",
|
||||
expect: "fatal: repository not found",
|
||||
},
|
||||
{
|
||||
name: "token not present in message",
|
||||
token: "secret123",
|
||||
input: "fatal: repository not found",
|
||||
expect: "fatal: repository not found",
|
||||
},
|
||||
{
|
||||
name: "multiple occurrences",
|
||||
token: "tok",
|
||||
input: "tok appears twice: tok",
|
||||
expect: "[REDACTED] appears twice: [REDACTED]",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
g := testGitOps(tt.token)
|
||||
got := g.redactToken(tt.input)
|
||||
if got != tt.expect {
|
||||
t.Errorf("redactToken(%q) = %q, want %q", tt.input, got, tt.expect)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureGitDir(t *testing.T) {
|
||||
g := testGitOps("")
|
||||
|
||||
t.Run("valid git directory", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := os.MkdirAll(filepath.Join(dir, ".git"), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := g.EnsureGitDir(dir); err != nil {
|
||||
t.Errorf("expected no error for valid git dir, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no .git directory", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
err := g.EnsureGitDir(dir)
|
||||
if err == nil {
|
||||
t.Error("expected error for non-git directory")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run(".git is a file not directory", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := os.WriteFile(filepath.Join(dir, ".git"), []byte("gitdir: .."), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err := g.EnsureGitDir(dir)
|
||||
if err == nil {
|
||||
t.Error("expected error when .git is a file")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestCommitAndPush_NoChanges tests that CommitAndPush returns nil when
|
||||
// there are no staged changes in the repository.
|
||||
func TestCommitAndPush_NoChanges(t *testing.T) {
|
||||
g := testGitOps("")
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a real git repo with an initial commit
|
||||
dir := t.TempDir()
|
||||
if err := g.runGit(ctx, dir, "init"); err != nil {
|
||||
t.Fatal("git init:", err)
|
||||
}
|
||||
if err := g.runGit(ctx, dir, "config", "user.name", "test"); err != nil {
|
||||
t.Fatal("git config user.name:", err)
|
||||
}
|
||||
if err := g.runGit(ctx, dir, "config", "user.email", "test@test.com"); err != nil {
|
||||
t.Fatal("git config user.email:", err)
|
||||
}
|
||||
// Create initial commit so HEAD exists
|
||||
if err := os.WriteFile(filepath.Join(dir, "README.md"), []byte("init"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := g.runGit(ctx, dir, "add", "-A"); err != nil {
|
||||
t.Fatal("git add:", err)
|
||||
}
|
||||
if err := g.runGit(ctx, dir, "commit", "-m", "initial"); err != nil {
|
||||
t.Fatal("git commit:", err)
|
||||
}
|
||||
|
||||
// No new changes — should return empty with no error
|
||||
sha, files, err := g.CommitAndPush(ctx, dir, "no changes", false)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, got: %v", err)
|
||||
}
|
||||
if sha != "" {
|
||||
t.Errorf("expected empty SHA, got: %q", sha)
|
||||
}
|
||||
if len(files) != 0 {
|
||||
t.Errorf("expected no files, got: %v", files)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCommitAndPush_WithChanges tests that CommitAndPush correctly stages,
|
||||
// commits, and returns SHA and changed file list.
|
||||
func TestCommitAndPush_WithChanges(t *testing.T) {
|
||||
g := testGitOps("")
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a real git repo
|
||||
dir := t.TempDir()
|
||||
if err := g.runGit(ctx, dir, "init"); err != nil {
|
||||
t.Fatal("git init:", err)
|
||||
}
|
||||
if err := g.runGit(ctx, dir, "config", "user.name", "test"); err != nil {
|
||||
t.Fatal("git config user.name:", err)
|
||||
}
|
||||
if err := g.runGit(ctx, dir, "config", "user.email", "test@test.com"); err != nil {
|
||||
t.Fatal("git config user.email:", err)
|
||||
}
|
||||
// Initial commit
|
||||
if err := os.WriteFile(filepath.Join(dir, "README.md"), []byte("init"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := g.runGit(ctx, dir, "add", "-A"); err != nil {
|
||||
t.Fatal("git add:", err)
|
||||
}
|
||||
if err := g.runGit(ctx, dir, "commit", "-m", "initial"); err != nil {
|
||||
t.Fatal("git commit:", err)
|
||||
}
|
||||
|
||||
// Create new files to commit
|
||||
if err := os.WriteFile(filepath.Join(dir, "main.go"), []byte("package main"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, "go.mod"), []byte("module test"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// CommitAndPush without push (no remote)
|
||||
sha, files, err := g.CommitAndPush(ctx, dir, "add go files", false)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, got: %v", err)
|
||||
}
|
||||
if sha == "" {
|
||||
t.Error("expected non-empty SHA")
|
||||
}
|
||||
if len(sha) < 7 {
|
||||
t.Errorf("expected SHA to be at least 7 chars, got: %q", sha)
|
||||
}
|
||||
if len(files) != 2 {
|
||||
t.Errorf("expected 2 changed files, got %d: %v", len(files), files)
|
||||
}
|
||||
|
||||
// Verify the files are in the list
|
||||
fileSet := make(map[string]bool)
|
||||
for _, f := range files {
|
||||
fileSet[f] = true
|
||||
}
|
||||
if !fileSet["main.go"] {
|
||||
t.Error("expected main.go in changed files")
|
||||
}
|
||||
if !fileSet["go.mod"] {
|
||||
t.Error("expected go.mod in changed files")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCommitAndPush_PushWithoutRemote tests that push fails gracefully
|
||||
// when there's no remote configured.
|
||||
func TestCommitAndPush_PushWithoutRemote(t *testing.T) {
|
||||
g := testGitOps("")
|
||||
ctx := context.Background()
|
||||
|
||||
dir := t.TempDir()
|
||||
if err := g.runGit(ctx, dir, "init"); err != nil {
|
||||
t.Fatal("git init:", err)
|
||||
}
|
||||
if err := g.runGit(ctx, dir, "config", "user.name", "test"); err != nil {
|
||||
t.Fatal("git config:", err)
|
||||
}
|
||||
if err := g.runGit(ctx, dir, "config", "user.email", "test@test.com"); err != nil {
|
||||
t.Fatal("git config:", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, "file.txt"), []byte("init"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := g.runGit(ctx, dir, "add", "-A"); err != nil {
|
||||
t.Fatal("git add:", err)
|
||||
}
|
||||
if err := g.runGit(ctx, dir, "commit", "-m", "initial"); err != nil {
|
||||
t.Fatal("git commit:", err)
|
||||
}
|
||||
|
||||
// Add a new file
|
||||
if err := os.WriteFile(filepath.Join(dir, "new.txt"), []byte("new"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Push should fail (no remote) but commit succeeds — SHA is returned
|
||||
sha, files, err := g.CommitAndPush(ctx, dir, "test push", true)
|
||||
if err == nil {
|
||||
t.Error("expected push error when no remote configured")
|
||||
}
|
||||
// Even though push failed, commit should have succeeded
|
||||
if sha == "" {
|
||||
t.Error("expected SHA from successful commit before push failure")
|
||||
}
|
||||
if len(files) != 1 || files[0] != "new.txt" {
|
||||
t.Errorf("expected [new.txt], got: %v", files)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCloneToTemp_InvalidURL tests that CloneToTemp fails on a bad URL.
|
||||
func TestCloneToTemp_InvalidURL(t *testing.T) {
|
||||
g := testGitOps("")
|
||||
ctx := context.Background()
|
||||
|
||||
_, _, err := g.CloneToTemp(ctx, "https://invalid.example.com/no-such-repo.git")
|
||||
if err == nil {
|
||||
t.Error("expected error cloning invalid URL")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCloneToTemp_LocalRepo tests cloning a local bare repository.
|
||||
func TestCloneToTemp_LocalRepo(t *testing.T) {
|
||||
g := testGitOps("")
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a bare repo to clone from
|
||||
bareDir := t.TempDir()
|
||||
if err := g.runGit(ctx, bareDir, "init", "--bare"); err != nil {
|
||||
t.Fatal("git init --bare:", err)
|
||||
}
|
||||
|
||||
// Create a source repo and push to the bare repo
|
||||
srcDir := t.TempDir()
|
||||
if err := g.runGit(ctx, srcDir, "init"); err != nil {
|
||||
t.Fatal("git init:", err)
|
||||
}
|
||||
if err := g.runGit(ctx, srcDir, "config", "user.name", "test"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := g.runGit(ctx, srcDir, "config", "user.email", "test@test.com"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(srcDir, "hello.txt"), []byte("hello"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := g.runGit(ctx, srcDir, "add", "-A"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := g.runGit(ctx, srcDir, "commit", "-m", "initial"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := g.runGit(ctx, srcDir, "remote", "add", "origin", bareDir); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := g.runGit(ctx, srcDir, "push", "origin", "master"); err != nil {
|
||||
// Some git versions use "main" as default branch
|
||||
if err2 := g.runGit(ctx, srcDir, "push", "origin", "main"); err2 != nil {
|
||||
t.Fatalf("push failed for both master and main: master=%v, main=%v", err, err2)
|
||||
}
|
||||
}
|
||||
|
||||
// Clone the bare repo using file:// protocol
|
||||
cloneDir, cleanup, err := g.CloneToTemp(ctx, "file://"+bareDir)
|
||||
if err != nil {
|
||||
t.Fatalf("CloneToTemp failed: %v", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
// Verify the cloned file exists
|
||||
content, err := os.ReadFile(filepath.Join(cloneDir, "hello.txt"))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read cloned file: %v", err)
|
||||
}
|
||||
if string(content) != "hello" {
|
||||
t.Errorf("expected file content 'hello', got %q", string(content))
|
||||
}
|
||||
|
||||
// Verify .git dir exists
|
||||
if err := g.EnsureGitDir(cloneDir); err != nil {
|
||||
t.Errorf("cloned dir should be a git repo: %v", err)
|
||||
}
|
||||
|
||||
// Verify git config was set
|
||||
userName, err := g.runGitOutput(ctx, cloneDir, "config", "user.name")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get user.name: %v", err)
|
||||
}
|
||||
if got := strings.TrimSpace(userName); got != "test-user" {
|
||||
t.Errorf("expected user.name 'test-user', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunGit_ContextCancellation(t *testing.T) {
|
||||
g := testGitOps("")
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
dir := t.TempDir()
|
||||
err := g.runGit(ctx, dir, "status")
|
||||
if err == nil {
|
||||
t.Error("expected error when context is cancelled")
|
||||
}
|
||||
}
|
||||
322
internal/worker/mock_test.go
Normal file
322
internal/worker/mock_test.go
Normal file
@ -0,0 +1,322 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
"github.com/orchard9/rdev/internal/service"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// Mock implementations for worker package tests
|
||||
// =============================================================================
|
||||
|
||||
type mockWorkQueue struct {
|
||||
mu sync.Mutex
|
||||
tasks map[string]*domain.WorkTask
|
||||
err error
|
||||
}
|
||||
|
||||
func newMockWorkQueue() *mockWorkQueue {
|
||||
return &mockWorkQueue{tasks: make(map[string]*domain.WorkTask)}
|
||||
}
|
||||
|
||||
func (m *mockWorkQueue) Enqueue(_ context.Context, task *domain.WorkTask) (string, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.err != nil {
|
||||
return "", m.err
|
||||
}
|
||||
id := fmt.Sprintf("task-%d", len(m.tasks)+1)
|
||||
task.ID = id
|
||||
task.Status = domain.WorkTaskStatusPending
|
||||
task.CreatedAt = time.Now()
|
||||
m.tasks[id] = task
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (m *mockWorkQueue) Dequeue(_ context.Context, workerID string) (*domain.WorkTask, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
for _, task := range m.tasks {
|
||||
if task.Status == domain.WorkTaskStatusPending {
|
||||
task.Status = domain.WorkTaskStatusRunning
|
||||
task.WorkerID = workerID
|
||||
now := time.Now()
|
||||
task.StartedAt = &now
|
||||
return task, nil
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockWorkQueue) Complete(_ context.Context, taskID string, result *domain.WorkResult) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
task, ok := m.tasks[taskID]
|
||||
if !ok {
|
||||
return domain.ErrWorkTaskNotFound
|
||||
}
|
||||
task.Status = domain.WorkTaskStatusCompleted
|
||||
task.Result = result
|
||||
now := time.Now()
|
||||
task.CompletedAt = &now
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockWorkQueue) Fail(_ context.Context, taskID string, errMsg string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
task, ok := m.tasks[taskID]
|
||||
if !ok {
|
||||
return domain.ErrWorkTaskNotFound
|
||||
}
|
||||
task.RetryCount++
|
||||
if task.RetryCount >= task.MaxRetries {
|
||||
task.Status = domain.WorkTaskStatusFailed
|
||||
task.Error = errMsg
|
||||
} else {
|
||||
task.Status = domain.WorkTaskStatusPending
|
||||
task.WorkerID = ""
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockWorkQueue) Cancel(_ context.Context, taskID string) error { return nil }
|
||||
func (m *mockWorkQueue) GetTask(_ context.Context, taskID string) (*domain.WorkTask, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
task, ok := m.tasks[taskID]
|
||||
if !ok {
|
||||
return nil, domain.ErrWorkTaskNotFound
|
||||
}
|
||||
return task, nil
|
||||
}
|
||||
func (m *mockWorkQueue) ListByProject(_ context.Context, _ string, _ *domain.WorkTaskStatus, _ domain.WorkListOptions) (*domain.WorkListResult, error) {
|
||||
return &domain.WorkListResult{}, nil
|
||||
}
|
||||
func (m *mockWorkQueue) GetStats(_ context.Context) (*domain.WorkQueueStats, error) {
|
||||
return &domain.WorkQueueStats{}, nil
|
||||
}
|
||||
func (m *mockWorkQueue) CleanupOld(_ context.Context, _ time.Duration) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (m *mockWorkQueue) RequeueStale(_ context.Context, _ time.Duration) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
type mockWorkerRegistry struct {
|
||||
mu sync.Mutex
|
||||
workers map[string]*domain.Worker
|
||||
err error
|
||||
}
|
||||
|
||||
func newMockWorkerRegistry() *mockWorkerRegistry {
|
||||
return &mockWorkerRegistry{workers: make(map[string]*domain.Worker)}
|
||||
}
|
||||
|
||||
func (m *mockWorkerRegistry) Register(_ context.Context, worker *domain.Worker) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.err != nil {
|
||||
return m.err
|
||||
}
|
||||
m.workers[worker.ID] = worker
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockWorkerRegistry) Heartbeat(_ context.Context, workerID string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
w, ok := m.workers[workerID]
|
||||
if !ok {
|
||||
return domain.ErrWorkerNotFound
|
||||
}
|
||||
w.LastHeartbeat = time.Now()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockWorkerRegistry) UpdateStatus(_ context.Context, workerID string, status domain.WorkerStatus, taskID string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
w, ok := m.workers[workerID]
|
||||
if !ok {
|
||||
return domain.ErrWorkerNotFound
|
||||
}
|
||||
w.Status = status
|
||||
w.CurrentTask = taskID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockWorkerRegistry) Deregister(_ context.Context, workerID string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.workers, workerID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockWorkerRegistry) Get(_ context.Context, workerID string) (*domain.Worker, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
w, ok := m.workers[workerID]
|
||||
if !ok {
|
||||
return nil, domain.ErrWorkerNotFound
|
||||
}
|
||||
return w, nil
|
||||
}
|
||||
|
||||
func (m *mockWorkerRegistry) List(_ context.Context, filter port.WorkerFilter) ([]*domain.Worker, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
var result []*domain.Worker
|
||||
for _, w := range m.workers {
|
||||
if filter.Status != nil && w.Status != *filter.Status {
|
||||
continue
|
||||
}
|
||||
result = append(result, w)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockWorkerRegistry) MarkStaleOffline(_ context.Context, _ time.Duration) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
type mockBuildAudit struct {
|
||||
mu sync.Mutex
|
||||
entries map[string]*domain.BuildAuditEntry
|
||||
}
|
||||
|
||||
func newMockBuildAudit() *mockBuildAudit {
|
||||
return &mockBuildAudit{entries: make(map[string]*domain.BuildAuditEntry)}
|
||||
}
|
||||
|
||||
func (m *mockBuildAudit) Record(_ context.Context, entry *domain.BuildAuditEntry) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.entries[entry.TaskID] = entry
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockBuildAudit) Update(_ context.Context, taskID string, result *domain.BuildResult) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
entry, ok := m.entries[taskID]
|
||||
if !ok {
|
||||
return domain.ErrBuildNotFound
|
||||
}
|
||||
entry.Result = result
|
||||
if result.Success {
|
||||
entry.Status = domain.BuildStatusCompleted
|
||||
} else {
|
||||
entry.Status = domain.BuildStatusFailed
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockBuildAudit) Get(_ context.Context, taskID string) (*domain.BuildAuditEntry, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
entry, ok := m.entries[taskID]
|
||||
if !ok {
|
||||
return nil, domain.ErrBuildNotFound
|
||||
}
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
func (m *mockBuildAudit) List(_ context.Context, _ port.BuildAuditFilter) ([]*domain.BuildAuditEntry, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type mockCodeAgent struct {
|
||||
result *domain.AgentResult
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockCodeAgent) Name() string { return "mock-agent" }
|
||||
func (m *mockCodeAgent) Provider() domain.AgentProvider { return "mock" }
|
||||
func (m *mockCodeAgent) Cancel(_ context.Context, _ string) error { return nil }
|
||||
func (m *mockCodeAgent) Capabilities() domain.AgentCapabilities {
|
||||
return domain.AgentCapabilities{Provider: "mock"}
|
||||
}
|
||||
func (m *mockCodeAgent) Available(_ context.Context) bool { return true }
|
||||
func (m *mockCodeAgent) Execute(_ context.Context, req *domain.AgentRequest, handler domain.AgentEventHandler) (*domain.AgentResult, error) {
|
||||
if handler != nil {
|
||||
handler(domain.AgentEvent{
|
||||
Type: domain.AgentEventOutput,
|
||||
Content: "mock output for: " + req.Prompt,
|
||||
})
|
||||
}
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.result, nil
|
||||
}
|
||||
|
||||
type mockCodeAgentRegistry struct {
|
||||
agent port.CodeAgent
|
||||
}
|
||||
|
||||
func (m *mockCodeAgentRegistry) Register(agent port.CodeAgent) { m.agent = agent }
|
||||
func (m *mockCodeAgentRegistry) Get(_ domain.AgentProvider) port.CodeAgent { return m.agent }
|
||||
func (m *mockCodeAgentRegistry) Default() port.CodeAgent { return m.agent }
|
||||
func (m *mockCodeAgentRegistry) DefaultProvider() domain.AgentProvider { return "mock" }
|
||||
func (m *mockCodeAgentRegistry) SetDefault(_ domain.AgentProvider) error { return nil }
|
||||
func (m *mockCodeAgentRegistry) Available() []domain.AgentProvider {
|
||||
return []domain.AgentProvider{"mock"}
|
||||
}
|
||||
func (m *mockCodeAgentRegistry) AvailableAgents(_ context.Context) []port.CodeAgent {
|
||||
return []port.CodeAgent{m.agent}
|
||||
}
|
||||
func (m *mockCodeAgentRegistry) Count() int { return 1 }
|
||||
|
||||
// =============================================================================
|
||||
// Helper to build test dependencies
|
||||
// =============================================================================
|
||||
|
||||
type testDeps struct {
|
||||
queue *mockWorkQueue
|
||||
registry *mockWorkerRegistry
|
||||
audit *mockBuildAudit
|
||||
agent *mockCodeAgent
|
||||
|
||||
workerSvc *service.WorkerService
|
||||
workSvc *service.WorkService
|
||||
buildExec *BuildExecutor
|
||||
}
|
||||
|
||||
func newTestDeps() *testDeps {
|
||||
queue := newMockWorkQueue()
|
||||
registry := newMockWorkerRegistry()
|
||||
audit := newMockBuildAudit()
|
||||
agent := &mockCodeAgent{
|
||||
result: &domain.AgentResult{
|
||||
ExitCode: 0,
|
||||
DurationMs: 1000,
|
||||
},
|
||||
}
|
||||
agentRegistry := &mockCodeAgentRegistry{agent: agent}
|
||||
|
||||
workerSvc := service.NewWorkerService(registry, queue, nil).
|
||||
WithBuildAudit(audit)
|
||||
workSvc := service.NewWorkService(queue, service.WorkServiceConfig{})
|
||||
|
||||
buildExec := NewBuildExecutor(agentRegistry, nil, nil)
|
||||
|
||||
return &testDeps{
|
||||
queue: queue,
|
||||
registry: registry,
|
||||
audit: audit,
|
||||
agent: agent,
|
||||
workerSvc: workerSvc,
|
||||
workSvc: workSvc,
|
||||
buildExec: buildExec,
|
||||
}
|
||||
}
|
||||
241
internal/worker/queue_maintenance.go
Normal file
241
internal/worker/queue_maintenance.go
Normal file
@ -0,0 +1,241 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/orchard9/rdev/internal/metrics"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
)
|
||||
|
||||
// QueueMaintenance runs periodic maintenance tasks on the work queue
|
||||
// and worker registry: stale task recovery, stale worker marking,
|
||||
// old task cleanup, and queue metrics refresh.
|
||||
type QueueMaintenance struct {
|
||||
queue port.WorkQueue
|
||||
registry port.WorkerRegistry
|
||||
logger *slog.Logger
|
||||
|
||||
// Intervals
|
||||
staleTaskTimeout time.Duration
|
||||
staleWorkerTimeout time.Duration
|
||||
cleanupAge time.Duration
|
||||
maintenancePeriod time.Duration
|
||||
metricsPeriod time.Duration
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// QueueMaintenanceConfig holds configuration for queue maintenance.
|
||||
type QueueMaintenanceConfig struct {
|
||||
// StaleTaskTimeout is how long a running task can be silent before requeue.
|
||||
// Default: 30 minutes.
|
||||
StaleTaskTimeout time.Duration
|
||||
|
||||
// StaleWorkerTimeout is how long without heartbeat before marking offline.
|
||||
// Default: 2 minutes.
|
||||
StaleWorkerTimeout time.Duration
|
||||
|
||||
// CleanupAge is the minimum age for completed/failed/cancelled tasks to be cleaned up.
|
||||
// Default: 7 days.
|
||||
CleanupAge time.Duration
|
||||
|
||||
// MaintenancePeriod is how often to run maintenance tasks.
|
||||
// Default: 1 minute.
|
||||
MaintenancePeriod time.Duration
|
||||
|
||||
// MetricsPeriod is how often to refresh queue depth metrics.
|
||||
// Default: 15 seconds.
|
||||
MetricsPeriod time.Duration
|
||||
|
||||
Logger *slog.Logger
|
||||
}
|
||||
|
||||
// DefaultQueueMaintenanceConfig returns sensible defaults.
|
||||
func DefaultQueueMaintenanceConfig() *QueueMaintenanceConfig {
|
||||
return &QueueMaintenanceConfig{
|
||||
StaleTaskTimeout: 30 * time.Minute,
|
||||
StaleWorkerTimeout: 2 * time.Minute,
|
||||
CleanupAge: 7 * 24 * time.Hour,
|
||||
MaintenancePeriod: 1 * time.Minute,
|
||||
MetricsPeriod: 15 * time.Second,
|
||||
Logger: slog.Default(),
|
||||
}
|
||||
}
|
||||
|
||||
// NewQueueMaintenance creates a new queue maintenance worker.
|
||||
func NewQueueMaintenance(
|
||||
queue port.WorkQueue,
|
||||
registry port.WorkerRegistry,
|
||||
cfg *QueueMaintenanceConfig,
|
||||
) *QueueMaintenance {
|
||||
if cfg == nil {
|
||||
cfg = DefaultQueueMaintenanceConfig()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
return &QueueMaintenance{
|
||||
queue: queue,
|
||||
registry: registry,
|
||||
logger: cfg.Logger.With("component", "queue-maintenance"),
|
||||
staleTaskTimeout: cfg.StaleTaskTimeout,
|
||||
staleWorkerTimeout: cfg.StaleWorkerTimeout,
|
||||
cleanupAge: cfg.CleanupAge,
|
||||
maintenancePeriod: cfg.MaintenancePeriod,
|
||||
metricsPeriod: cfg.MetricsPeriod,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins the maintenance and metrics loops.
|
||||
func (m *QueueMaintenance) Start() {
|
||||
m.logger.Info("queue maintenance started",
|
||||
"maintenance_period", m.maintenancePeriod,
|
||||
"metrics_period", m.metricsPeriod,
|
||||
"stale_task_timeout", m.staleTaskTimeout,
|
||||
"stale_worker_timeout", m.staleWorkerTimeout,
|
||||
"cleanup_age", m.cleanupAge,
|
||||
)
|
||||
|
||||
m.wg.Add(2)
|
||||
go m.maintenanceLoop()
|
||||
go m.metricsLoop()
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the maintenance worker.
|
||||
func (m *QueueMaintenance) Stop() {
|
||||
m.logger.Info("queue maintenance stopping")
|
||||
m.cancel()
|
||||
m.wg.Wait()
|
||||
m.logger.Info("queue maintenance stopped")
|
||||
}
|
||||
|
||||
// maintenanceLoop runs periodic maintenance: stale recovery, worker health, cleanup.
|
||||
func (m *QueueMaintenance) maintenanceLoop() {
|
||||
defer m.wg.Done()
|
||||
|
||||
// Run immediately on start
|
||||
m.runMaintenance()
|
||||
|
||||
ticker := time.NewTicker(m.maintenancePeriod)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
m.runMaintenance()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// metricsLoop refreshes queue depth metrics on a faster cadence.
|
||||
func (m *QueueMaintenance) metricsLoop() {
|
||||
defer m.wg.Done()
|
||||
|
||||
// Run immediately on start
|
||||
m.refreshMetrics()
|
||||
|
||||
ticker := time.NewTicker(m.metricsPeriod)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
m.refreshMetrics()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// runMaintenance executes all maintenance tasks.
|
||||
func (m *QueueMaintenance) runMaintenance() {
|
||||
ctx, cancel := context.WithTimeout(m.ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
m.requeueStaleTasks(ctx)
|
||||
m.markStaleWorkers(ctx)
|
||||
m.cleanupOldTasks(ctx)
|
||||
}
|
||||
|
||||
// requeueStaleTasks requeues tasks that have been running too long
|
||||
// (the worker likely crashed without reporting).
|
||||
func (m *QueueMaintenance) requeueStaleTasks(ctx context.Context) {
|
||||
count, err := m.queue.RequeueStale(ctx, m.staleTaskTimeout)
|
||||
if err != nil {
|
||||
m.logger.Warn("failed to requeue stale tasks", "error", err)
|
||||
return
|
||||
}
|
||||
if count > 0 {
|
||||
m.logger.Info("requeued stale tasks", "count", count)
|
||||
}
|
||||
}
|
||||
|
||||
// markStaleWorkers marks workers without recent heartbeats as offline.
|
||||
func (m *QueueMaintenance) markStaleWorkers(ctx context.Context) {
|
||||
count, err := m.registry.MarkStaleOffline(ctx, m.staleWorkerTimeout)
|
||||
if err != nil {
|
||||
m.logger.Warn("failed to mark stale workers offline", "error", err)
|
||||
return
|
||||
}
|
||||
if count > 0 {
|
||||
m.logger.Info("marked stale workers offline", "count", count)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupOldTasks removes completed/failed/cancelled tasks older than cleanup age.
|
||||
func (m *QueueMaintenance) cleanupOldTasks(ctx context.Context) {
|
||||
count, err := m.queue.CleanupOld(ctx, m.cleanupAge)
|
||||
if err != nil {
|
||||
m.logger.Warn("failed to cleanup old tasks", "error", err)
|
||||
return
|
||||
}
|
||||
if count > 0 {
|
||||
m.logger.Info("cleaned up old tasks", "count", count)
|
||||
}
|
||||
}
|
||||
|
||||
// refreshMetrics fetches queue stats and updates Prometheus gauges.
|
||||
func (m *QueueMaintenance) refreshMetrics() {
|
||||
ctx, cancel := context.WithTimeout(m.ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
stats, err := m.queue.GetStats(ctx)
|
||||
if err != nil {
|
||||
m.logger.Warn("failed to get queue stats for metrics", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
metrics.SetWorkQueueDepth("pending", stats.Pending)
|
||||
metrics.SetWorkQueueDepth("running", stats.Running)
|
||||
metrics.SetWorkQueueDepth("completed", stats.Completed)
|
||||
metrics.SetWorkQueueDepth("failed", stats.Failed)
|
||||
metrics.SetWorkQueueDepth("cancelled", stats.Cancelled)
|
||||
|
||||
// Worker counts
|
||||
workers, err := m.registry.List(ctx, port.WorkerFilter{})
|
||||
if err != nil {
|
||||
m.logger.Warn("failed to list workers for metrics", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
counts := map[string]int{
|
||||
"idle": 0, "busy": 0, "draining": 0, "offline": 0,
|
||||
}
|
||||
for _, w := range workers {
|
||||
counts[string(w.Status)]++
|
||||
age := time.Since(w.LastHeartbeat).Seconds()
|
||||
metrics.RecordWorkerHeartbeat(w.ID, age)
|
||||
}
|
||||
for status, count := range counts {
|
||||
metrics.SetWorkerCount(status, count)
|
||||
}
|
||||
}
|
||||
297
internal/worker/queue_maintenance_test.go
Normal file
297
internal/worker/queue_maintenance_test.go
Normal file
@ -0,0 +1,297 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
)
|
||||
|
||||
// mockMaintenanceQueue implements port.WorkQueue for maintenance tests.
|
||||
type mockMaintenanceQueue struct {
|
||||
mu sync.Mutex
|
||||
requeueCalls int
|
||||
cleanupCalls int
|
||||
statsCalls int
|
||||
requeueCount int64
|
||||
cleanupCount int64
|
||||
stats *domain.WorkQueueStats
|
||||
err error
|
||||
}
|
||||
|
||||
func newMockMaintenanceQueue() *mockMaintenanceQueue {
|
||||
return &mockMaintenanceQueue{
|
||||
stats: &domain.WorkQueueStats{
|
||||
Pending: 5,
|
||||
Running: 2,
|
||||
Completed: 100,
|
||||
Failed: 3,
|
||||
Cancelled: 1,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockMaintenanceQueue) Enqueue(_ context.Context, _ *domain.WorkTask) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (m *mockMaintenanceQueue) Dequeue(_ context.Context, _ string) (*domain.WorkTask, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockMaintenanceQueue) Complete(_ context.Context, _ string, _ *domain.WorkResult) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockMaintenanceQueue) Fail(_ context.Context, _ string, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockMaintenanceQueue) Cancel(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockMaintenanceQueue) GetTask(_ context.Context, _ string) (*domain.WorkTask, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockMaintenanceQueue) ListByProject(_ context.Context, _ string, _ *domain.WorkTaskStatus, _ domain.WorkListOptions) (*domain.WorkListResult, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockMaintenanceQueue) GetStats(_ context.Context) (*domain.WorkQueueStats, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.statsCalls++
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.stats, nil
|
||||
}
|
||||
|
||||
func (m *mockMaintenanceQueue) CleanupOld(_ context.Context, _ time.Duration) (int64, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.cleanupCalls++
|
||||
if m.err != nil {
|
||||
return 0, m.err
|
||||
}
|
||||
return m.cleanupCount, nil
|
||||
}
|
||||
|
||||
func (m *mockMaintenanceQueue) RequeueStale(_ context.Context, _ time.Duration) (int64, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.requeueCalls++
|
||||
if m.err != nil {
|
||||
return 0, m.err
|
||||
}
|
||||
return m.requeueCount, nil
|
||||
}
|
||||
|
||||
// mockMaintenanceRegistry implements port.WorkerRegistry for maintenance tests.
|
||||
type mockMaintenanceRegistry struct {
|
||||
mu sync.Mutex
|
||||
markStaleCalls int
|
||||
markStaleCount int
|
||||
workers []*domain.Worker
|
||||
err error
|
||||
}
|
||||
|
||||
func newMockMaintenanceRegistry() *mockMaintenanceRegistry {
|
||||
return &mockMaintenanceRegistry{
|
||||
workers: []*domain.Worker{
|
||||
{
|
||||
ID: "worker-1",
|
||||
Status: domain.WorkerStatusIdle,
|
||||
LastHeartbeat: time.Now(),
|
||||
},
|
||||
{
|
||||
ID: "worker-2",
|
||||
Status: domain.WorkerStatusBusy,
|
||||
LastHeartbeat: time.Now().Add(-5 * time.Minute),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockMaintenanceRegistry) Register(_ context.Context, _ *domain.Worker) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockMaintenanceRegistry) Heartbeat(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockMaintenanceRegistry) UpdateStatus(_ context.Context, _ string, _ domain.WorkerStatus, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockMaintenanceRegistry) Deregister(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockMaintenanceRegistry) Get(_ context.Context, _ string) (*domain.Worker, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockMaintenanceRegistry) List(_ context.Context, _ port.WorkerFilter) ([]*domain.Worker, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.workers, nil
|
||||
}
|
||||
|
||||
func (m *mockMaintenanceRegistry) MarkStaleOffline(_ context.Context, _ time.Duration) (int, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.markStaleCalls++
|
||||
if m.err != nil {
|
||||
return 0, m.err
|
||||
}
|
||||
return m.markStaleCount, nil
|
||||
}
|
||||
|
||||
func TestQueueMaintenance_DefaultConfig(t *testing.T) {
|
||||
cfg := DefaultQueueMaintenanceConfig()
|
||||
|
||||
if cfg.StaleTaskTimeout != 30*time.Minute {
|
||||
t.Errorf("got StaleTaskTimeout=%v, want 30m", cfg.StaleTaskTimeout)
|
||||
}
|
||||
if cfg.StaleWorkerTimeout != 2*time.Minute {
|
||||
t.Errorf("got StaleWorkerTimeout=%v, want 2m", cfg.StaleWorkerTimeout)
|
||||
}
|
||||
if cfg.CleanupAge != 7*24*time.Hour {
|
||||
t.Errorf("got CleanupAge=%v, want 7d", cfg.CleanupAge)
|
||||
}
|
||||
if cfg.MaintenancePeriod != 1*time.Minute {
|
||||
t.Errorf("got MaintenancePeriod=%v, want 1m", cfg.MaintenancePeriod)
|
||||
}
|
||||
if cfg.MetricsPeriod != 15*time.Second {
|
||||
t.Errorf("got MetricsPeriod=%v, want 15s", cfg.MetricsPeriod)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueueMaintenance_RunMaintenance(t *testing.T) {
|
||||
queue := newMockMaintenanceQueue()
|
||||
queue.requeueCount = 2
|
||||
queue.cleanupCount = 5
|
||||
|
||||
registry := newMockMaintenanceRegistry()
|
||||
registry.markStaleCount = 1
|
||||
|
||||
cfg := &QueueMaintenanceConfig{
|
||||
StaleTaskTimeout: 30 * time.Minute,
|
||||
StaleWorkerTimeout: 2 * time.Minute,
|
||||
CleanupAge: 7 * 24 * time.Hour,
|
||||
MaintenancePeriod: 1 * time.Hour, // won't fire in test
|
||||
MetricsPeriod: 1 * time.Hour, // won't fire in test
|
||||
Logger: slog.Default(),
|
||||
}
|
||||
|
||||
m := NewQueueMaintenance(queue, registry, cfg)
|
||||
|
||||
// Run maintenance directly
|
||||
m.runMaintenance()
|
||||
|
||||
queue.mu.Lock()
|
||||
defer queue.mu.Unlock()
|
||||
registry.mu.Lock()
|
||||
defer registry.mu.Unlock()
|
||||
|
||||
if queue.requeueCalls != 1 {
|
||||
t.Errorf("got requeueCalls=%d, want 1", queue.requeueCalls)
|
||||
}
|
||||
if queue.cleanupCalls != 1 {
|
||||
t.Errorf("got cleanupCalls=%d, want 1", queue.cleanupCalls)
|
||||
}
|
||||
if registry.markStaleCalls != 1 {
|
||||
t.Errorf("got markStaleCalls=%d, want 1", registry.markStaleCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueueMaintenance_RefreshMetrics(t *testing.T) {
|
||||
queue := newMockMaintenanceQueue()
|
||||
registry := newMockMaintenanceRegistry()
|
||||
|
||||
cfg := &QueueMaintenanceConfig{
|
||||
StaleTaskTimeout: 30 * time.Minute,
|
||||
StaleWorkerTimeout: 2 * time.Minute,
|
||||
CleanupAge: 7 * 24 * time.Hour,
|
||||
MaintenancePeriod: 1 * time.Hour,
|
||||
MetricsPeriod: 1 * time.Hour,
|
||||
Logger: slog.Default(),
|
||||
}
|
||||
|
||||
m := NewQueueMaintenance(queue, registry, cfg)
|
||||
|
||||
// Run metrics refresh directly
|
||||
m.refreshMetrics()
|
||||
|
||||
queue.mu.Lock()
|
||||
if queue.statsCalls != 1 {
|
||||
t.Errorf("got statsCalls=%d, want 1", queue.statsCalls)
|
||||
}
|
||||
queue.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestQueueMaintenance_StartStop(t *testing.T) {
|
||||
queue := newMockMaintenanceQueue()
|
||||
registry := newMockMaintenanceRegistry()
|
||||
|
||||
cfg := &QueueMaintenanceConfig{
|
||||
StaleTaskTimeout: 30 * time.Minute,
|
||||
StaleWorkerTimeout: 2 * time.Minute,
|
||||
CleanupAge: 7 * 24 * time.Hour,
|
||||
MaintenancePeriod: 50 * time.Millisecond,
|
||||
MetricsPeriod: 50 * time.Millisecond,
|
||||
Logger: slog.Default(),
|
||||
}
|
||||
|
||||
m := NewQueueMaintenance(queue, registry, cfg)
|
||||
m.Start()
|
||||
|
||||
// Poll until maintenance has run at least once (runs immediately on start)
|
||||
deadline := time.After(2 * time.Second)
|
||||
for {
|
||||
queue.mu.Lock()
|
||||
rCalls := queue.requeueCalls
|
||||
sCalls := queue.statsCalls
|
||||
queue.mu.Unlock()
|
||||
registry.mu.Lock()
|
||||
mCalls := registry.markStaleCalls
|
||||
registry.mu.Unlock()
|
||||
|
||||
if rCalls >= 1 && sCalls >= 1 && mCalls >= 1 {
|
||||
break
|
||||
}
|
||||
|
||||
select {
|
||||
case <-deadline:
|
||||
t.Fatalf("timed out waiting for maintenance to run: requeue=%d stats=%d markStale=%d", rCalls, sCalls, mCalls)
|
||||
default:
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
m.Stop()
|
||||
}
|
||||
|
||||
func TestQueueMaintenance_NilConfig(t *testing.T) {
|
||||
queue := newMockMaintenanceQueue()
|
||||
registry := newMockMaintenanceRegistry()
|
||||
|
||||
m := NewQueueMaintenance(queue, registry, nil)
|
||||
if m.staleTaskTimeout != 30*time.Minute {
|
||||
t.Errorf("expected default stale task timeout, got %v", m.staleTaskTimeout)
|
||||
}
|
||||
if m.metricsPeriod != 15*time.Second {
|
||||
t.Errorf("expected default metrics period, got %v", m.metricsPeriod)
|
||||
}
|
||||
}
|
||||
289
internal/worker/work_executor.go
Normal file
289
internal/worker/work_executor.go
Normal file
@ -0,0 +1,289 @@
|
||||
// Package worker provides background workers for async task processing.
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/metrics"
|
||||
"github.com/orchard9/rdev/internal/service"
|
||||
)
|
||||
|
||||
// WorkExecutor is a background daemon that polls the work queue for tasks
|
||||
// and executes them via task-type-specific handlers. It self-registers as
|
||||
// a worker, sends heartbeats, and reports results.
|
||||
type WorkExecutor struct {
|
||||
workerSvc *service.WorkerService
|
||||
workSvc *service.WorkService
|
||||
buildExec *BuildExecutor
|
||||
logger *slog.Logger
|
||||
|
||||
workerID string
|
||||
hostname string
|
||||
version string
|
||||
capabilities []string
|
||||
pollPeriod time.Duration
|
||||
hbPeriod time.Duration
|
||||
taskTimeout time.Duration
|
||||
|
||||
started int32 // atomic flag to prevent double-start
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// WorkExecutorConfig holds configuration for the work executor.
|
||||
type WorkExecutorConfig struct {
|
||||
// WorkerID uniquely identifies this executor instance.
|
||||
// Defaults to HOSTNAME env var or "rdev-worker-0".
|
||||
WorkerID string
|
||||
|
||||
// Version reported to the worker registry.
|
||||
Version string
|
||||
|
||||
// Capabilities reported to the worker registry.
|
||||
Capabilities []string
|
||||
|
||||
// PollPeriod is how often to check for new tasks.
|
||||
PollPeriod time.Duration
|
||||
|
||||
// HeartbeatPeriod is how often to send heartbeats.
|
||||
HeartbeatPeriod time.Duration
|
||||
|
||||
// TaskTimeout is the maximum time a single task may run.
|
||||
// Default: 15 minutes.
|
||||
TaskTimeout time.Duration
|
||||
|
||||
Logger *slog.Logger
|
||||
}
|
||||
|
||||
// DefaultWorkExecutorConfig returns sensible defaults.
|
||||
func DefaultWorkExecutorConfig() *WorkExecutorConfig {
|
||||
workerID := os.Getenv("HOSTNAME")
|
||||
if workerID == "" {
|
||||
workerID = "rdev-worker-0"
|
||||
}
|
||||
return &WorkExecutorConfig{
|
||||
WorkerID: workerID,
|
||||
Capabilities: []string{"build"},
|
||||
PollPeriod: 5 * time.Second,
|
||||
HeartbeatPeriod: 30 * time.Second,
|
||||
TaskTimeout: 15 * time.Minute,
|
||||
Logger: slog.Default(),
|
||||
}
|
||||
}
|
||||
|
||||
// NewWorkExecutor creates a new work executor daemon.
|
||||
func NewWorkExecutor(
|
||||
workerSvc *service.WorkerService,
|
||||
workSvc *service.WorkService,
|
||||
buildExec *BuildExecutor,
|
||||
cfg *WorkExecutorConfig,
|
||||
) *WorkExecutor {
|
||||
if cfg == nil {
|
||||
cfg = DefaultWorkExecutorConfig()
|
||||
}
|
||||
|
||||
hostname, _ := os.Hostname()
|
||||
if hostname == "" {
|
||||
hostname = cfg.WorkerID
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
capabilities := cfg.Capabilities
|
||||
if len(capabilities) == 0 {
|
||||
capabilities = []string{"build"}
|
||||
}
|
||||
|
||||
taskTimeout := cfg.TaskTimeout
|
||||
if taskTimeout == 0 {
|
||||
taskTimeout = 15 * time.Minute
|
||||
}
|
||||
|
||||
return &WorkExecutor{
|
||||
workerSvc: workerSvc,
|
||||
workSvc: workSvc,
|
||||
buildExec: buildExec,
|
||||
logger: cfg.Logger.With("component", "work-executor"),
|
||||
workerID: cfg.WorkerID,
|
||||
hostname: hostname,
|
||||
version: cfg.Version,
|
||||
capabilities: capabilities,
|
||||
pollPeriod: cfg.PollPeriod,
|
||||
hbPeriod: cfg.HeartbeatPeriod,
|
||||
taskTimeout: taskTimeout,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// Start registers the worker and begins the poll and heartbeat loops.
|
||||
func (e *WorkExecutor) Start() error {
|
||||
if !atomic.CompareAndSwapInt32(&e.started, 0, 1) {
|
||||
return fmt.Errorf("executor already started")
|
||||
}
|
||||
|
||||
// Register this worker in the pool
|
||||
worker := &domain.Worker{
|
||||
ID: e.workerID,
|
||||
Hostname: e.hostname,
|
||||
Capabilities: e.capabilities,
|
||||
Version: e.version,
|
||||
}
|
||||
if err := e.workerSvc.Register(e.ctx, worker); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
e.logger.Info("work executor started",
|
||||
"worker_id", e.workerID,
|
||||
"poll_period", e.pollPeriod,
|
||||
"heartbeat_period", e.hbPeriod,
|
||||
)
|
||||
|
||||
// Start heartbeat loop
|
||||
e.wg.Add(1)
|
||||
go e.heartbeatLoop()
|
||||
|
||||
// Start poll loop
|
||||
e.wg.Add(1)
|
||||
go e.pollLoop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the executor.
|
||||
func (e *WorkExecutor) Stop() {
|
||||
e.logger.Info("work executor stopping", "worker_id", e.workerID)
|
||||
e.cancel()
|
||||
e.wg.Wait()
|
||||
|
||||
// Deregister (best-effort, context is cancelled so use a fresh one)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := e.workerSvc.Deregister(ctx, e.workerID); err != nil {
|
||||
e.logger.Warn("failed to deregister worker", "error", err)
|
||||
}
|
||||
|
||||
e.logger.Info("work executor stopped", "worker_id", e.workerID)
|
||||
}
|
||||
|
||||
// WorkerID returns the executor's worker ID.
|
||||
func (e *WorkExecutor) WorkerID() string {
|
||||
return e.workerID
|
||||
}
|
||||
|
||||
// Running returns true if the executor context has not been cancelled.
|
||||
func (e *WorkExecutor) Running() bool {
|
||||
return e.ctx.Err() == nil
|
||||
}
|
||||
|
||||
// heartbeatLoop sends periodic heartbeats to the worker registry.
|
||||
func (e *WorkExecutor) heartbeatLoop() {
|
||||
defer e.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(e.hbPeriod)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-e.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := e.workerSvc.Heartbeat(e.ctx, e.workerID); err != nil {
|
||||
e.logger.Warn("heartbeat failed", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// pollLoop checks for available tasks on a ticker.
|
||||
func (e *WorkExecutor) pollLoop() {
|
||||
defer e.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(e.pollPeriod)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-e.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
e.tryClaimAndExecute()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// tryClaimAndExecute attempts to claim a task and execute it.
|
||||
func (e *WorkExecutor) tryClaimAndExecute() {
|
||||
task, err := e.workerSvc.ClaimTask(e.ctx, e.workerID)
|
||||
if err != nil {
|
||||
e.logger.Warn("failed to claim task", "error", err)
|
||||
return
|
||||
}
|
||||
if task == nil {
|
||||
return // No tasks available
|
||||
}
|
||||
|
||||
e.logger.Info("executing task",
|
||||
"task_id", task.ID,
|
||||
"project_id", task.ProjectID,
|
||||
"type", task.Type,
|
||||
)
|
||||
|
||||
taskCtx, taskCancel := context.WithTimeout(e.ctx, e.taskTimeout)
|
||||
defer taskCancel()
|
||||
|
||||
result := e.executeTask(taskCtx, task)
|
||||
|
||||
// Record build metrics
|
||||
status := "success"
|
||||
if !result.Success {
|
||||
status = "failed"
|
||||
}
|
||||
metrics.RecordBuild(task.ProjectID, status, result.DurationMs)
|
||||
|
||||
if result.Success {
|
||||
if err := e.workerSvc.CompleteTask(e.ctx, e.workerID, task.ID, result); err != nil {
|
||||
e.logger.Error("failed to complete task",
|
||||
"task_id", task.ID,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
// Fail the task through work service (handles retry logic)
|
||||
errMsg := result.Error
|
||||
if errMsg == "" {
|
||||
errMsg = "execution failed"
|
||||
}
|
||||
if err := e.workSvc.FailTask(e.ctx, task.ID, errMsg); err != nil {
|
||||
e.logger.Error("failed to record task failure",
|
||||
"task_id", task.ID,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
// Return worker to idle regardless
|
||||
if err := e.workerSvc.Heartbeat(e.ctx, e.workerID); err != nil {
|
||||
e.logger.Warn("failed to heartbeat after failure", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// executeTask routes a task to the appropriate handler based on its type.
|
||||
func (e *WorkExecutor) executeTask(ctx context.Context, task *domain.WorkTask) *domain.BuildResult {
|
||||
switch task.Type {
|
||||
case domain.WorkTaskTypeBuild:
|
||||
return e.buildExec.Execute(ctx, task)
|
||||
default:
|
||||
return &domain.BuildResult{
|
||||
Success: false,
|
||||
Error: "unsupported task type: " + string(task.Type),
|
||||
}
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user