package auth import ( "context" "errors" "net/http" "strings" "github.com/orchard9/rdev/pkg/api" ) // Header for API key authentication. const HeaderAPIKey = "X-API-Key" // Context keys. type contextKey string const ( contextKeyAPIKey contextKey = "api_key" ) // GetAPIKey retrieves the authenticated API key from the request context. func GetAPIKey(ctx context.Context) *APIKey { key, _ := ctx.Value(contextKeyAPIKey).(*APIKey) return key } // Middleware creates an authentication middleware. func Middleware(svc *Service) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Skip auth for health endpoints if r.URL.Path == "/health" || r.URL.Path == "/ready" { next.ServeHTTP(w, r) return } // Skip auth for docs if r.URL.Path == "/docs" || r.URL.Path == "/openapi.json" { next.ServeHTTP(w, r) return } // Get key from header key := r.Header.Get(HeaderAPIKey) if key == "" { // Also check Authorization: Bearer auth := r.Header.Get("Authorization") if strings.HasPrefix(auth, "Bearer ") { key = strings.TrimPrefix(auth, "Bearer ") } } if key == "" { api.WriteError(w, r, http.StatusUnauthorized, "UNAUTHORIZED", "Missing API key") return } // Validate key apiKey, err := svc.Validate(r.Context(), key) if err != nil { if errors.Is(err, ErrKeyNotFound) { api.WriteError(w, r, http.StatusUnauthorized, "UNAUTHORIZED", "Invalid API key") return } if errors.Is(err, ErrKeyRevoked) { api.WriteError(w, r, http.StatusUnauthorized, "KEY_REVOKED", "API key has been revoked") return } if errors.Is(err, ErrKeyExpired) { api.WriteError(w, r, http.StatusUnauthorized, "KEY_EXPIRED", "API key has expired") return } api.WriteError(w, r, http.StatusInternalServerError, "AUTH_ERROR", "Authentication failed") return } // Add key to context ctx := context.WithValue(r.Context(), contextKeyAPIKey, apiKey) next.ServeHTTP(w, r.WithContext(ctx)) }) } } // RequireScope creates a middleware that checks for required scopes. func RequireScope(required ...Scope) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { apiKey := GetAPIKey(r.Context()) if apiKey == nil { api.WriteError(w, r, http.StatusUnauthorized, "UNAUTHORIZED", "Not authenticated") return } if !HasAnyScope(apiKey.Scopes, required...) { api.WriteError(w, r, http.StatusForbidden, "FORBIDDEN", "Insufficient permissions. Required: "+scopesToString(required)) return } next.ServeHTTP(w, r) }) } } // RequireProjectAccess creates a middleware that checks project access. // projectIDParam is the URL parameter name containing the project ID. func RequireProjectAccess(projectIDParam string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { apiKey := GetAPIKey(r.Context()) if apiKey == nil { api.WriteError(w, r, http.StatusUnauthorized, "UNAUTHORIZED", "Not authenticated") return } // Admin has access to everything if HasScope(apiKey.Scopes, ScopeAdmin) { next.ServeHTTP(w, r) return } // Get project ID from URL // Using chi's URLParam would require importing chi here // Instead, we'll extract from path in the handler // This middleware just validates the key has project restrictions next.ServeHTTP(w, r) }) } } func scopesToString(scopes []Scope) string { ss := make([]string, len(scopes)) for i, s := range scopes { ss[i] = string(s) } return strings.Join(ss, ", ") }