rdev/internal/auth/middleware.go
jordan 4f01015132
All checks were successful
ci/woodpecker/push/woodpecker Pipeline was successful
feat: implement project access enforcement and management API
- Fix no-op RequireProjectAccess middleware to enforce project_ids
- Apply project access middleware to all project-scoped routes
- Filter GET /projects by allowed project IDs for restricted keys
- Add GET /me endpoint with key identity, scopes, and project access info
- Add PATCH /keys/{id} for partial key updates (name, scopes, project_ids, allowed_ips, expires_in)
- Add GET/POST/DELETE /projects/{id}/access for project-centric access management
- Auto-grant creating key access when using POST /project/create-and-build
- Accept grant_to_key_ids in create-and-build to grant multiple keys on project creation
- Move newProvisionerWithDeps test helper from production code to test file

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-02-21 15:38:37 -07:00

196 lines
5.1 KiB
Go

package auth
import (
"context"
"errors"
"net/http"
"strings"
"github.com/go-chi/chi/v5"
"github.com/orchard9/rdev/internal/domain"
"github.com/orchard9/rdev/pkg/api"
)
// getClientIP extracts the client IP from the request.
func getClientIP(r *http.Request) string {
// Check X-Forwarded-For header (set by proxies/load balancers)
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// Take the first IP in the chain
for i := 0; i < len(xff); i++ {
if xff[i] == ',' {
return strings.TrimSpace(xff[:i])
}
}
return strings.TrimSpace(xff)
}
// Check X-Real-IP header
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return strings.TrimSpace(xri)
}
// Fall back to RemoteAddr
// RemoteAddr is "IP:port", so strip the port
addr := r.RemoteAddr
// Handle IPv6 addresses like "[::1]:8080"
if strings.HasPrefix(addr, "[") {
if idx := strings.LastIndex(addr, "]:"); idx != -1 {
return addr[1:idx]
}
return strings.Trim(addr, "[]")
}
// Handle IPv4 addresses like "192.168.1.1:8080"
if idx := strings.LastIndex(addr, ":"); idx != -1 {
return addr[:idx]
}
return addr
}
// Header for API key authentication.
const HeaderAPIKey = "X-API-Key"
// Context keys.
type contextKey string
const (
contextKeyAPIKey contextKey = "api_key"
)
// GetAPIKey retrieves the authenticated API key from the request context.
func GetAPIKey(ctx context.Context) *APIKey {
key, _ := ctx.Value(contextKeyAPIKey).(*APIKey)
return key
}
// WithAPIKey returns a context with the given API key set.
// This is primarily useful for testing.
func WithAPIKey(ctx context.Context, apiKey *APIKey) context.Context {
return context.WithValue(ctx, contextKeyAPIKey, apiKey)
}
// Middleware creates an authentication middleware.
func Middleware(svc *Service) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip auth for health endpoints
if r.URL.Path == "/health" || r.URL.Path == "/ready" {
next.ServeHTTP(w, r)
return
}
// Skip auth for docs
if r.URL.Path == "/docs" || r.URL.Path == "/openapi.json" {
next.ServeHTTP(w, r)
return
}
// Skip auth for metrics
if r.URL.Path == "/metrics" {
next.ServeHTTP(w, r)
return
}
// Get key from header
key := r.Header.Get(HeaderAPIKey)
if key == "" {
// Also check Authorization: Bearer
auth := r.Header.Get("Authorization")
if strings.HasPrefix(auth, "Bearer ") {
key = strings.TrimPrefix(auth, "Bearer ")
}
}
if key == "" {
api.WriteUnauthorized(w, r, "Missing API key")
return
}
// Validate key
apiKey, err := svc.Validate(r.Context(), key)
if err != nil {
if errors.Is(err, ErrKeyNotFound) {
api.WriteUnauthorized(w, r, "Invalid API key")
return
}
if errors.Is(err, ErrKeyRevoked) {
api.WriteError(w, r, http.StatusUnauthorized, "KEY_REVOKED", "API key has been revoked")
return
}
if errors.Is(err, ErrKeyExpired) {
api.WriteError(w, r, http.StatusUnauthorized, "KEY_EXPIRED", "API key has expired")
return
}
api.WriteError(w, r, http.StatusInternalServerError, "AUTH_ERROR", "Authentication failed")
return
}
// Check IP allowlist
clientIP := getClientIP(r)
if !apiKey.IsIPAllowed(clientIP) {
api.WriteError(w, r, http.StatusForbidden, "IP_NOT_ALLOWED", "IP address not allowed for this API key")
return
}
// Add key to context
ctx := context.WithValue(r.Context(), contextKeyAPIKey, apiKey)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// RequireScope creates a middleware that checks for required scopes.
func RequireScope(required ...Scope) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
apiKey := GetAPIKey(r.Context())
if apiKey == nil {
api.WriteUnauthorized(w, r, "Not authenticated")
return
}
if !HasAnyScope(apiKey.Scopes, required...) {
api.WriteForbidden(w, r, "Insufficient permissions. Required: "+scopesToString(required))
return
}
next.ServeHTTP(w, r)
})
}
}
// RequireProjectAccess creates a middleware that checks project access.
// projectIDParam is the URL parameter name containing the project ID.
func RequireProjectAccess(projectIDParam string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
apiKey := GetAPIKey(r.Context())
if apiKey == nil {
api.WriteUnauthorized(w, r, "Not authenticated")
return
}
// Admin has access to everything
if HasScope(apiKey.Scopes, ScopeAdmin) {
next.ServeHTTP(w, r)
return
}
projectID := domain.ProjectID(chi.URLParam(r, projectIDParam))
if !apiKey.HasProjectAccess(projectID) {
api.WriteForbidden(w, r, "Access denied to this project")
return
}
next.ServeHTTP(w, r)
})
}
}
func scopesToString(scopes []Scope) string {
ss := make([]string, len(scopes))
for i, s := range scopes {
ss[i] = string(s)
}
return strings.Join(ss, ", ")
}