package auth import ( "context" "errors" "net/http" "strings" "github.com/orchard9/rdev/pkg/api" ) // getClientIP extracts the client IP from the request. func getClientIP(r *http.Request) string { // Check X-Forwarded-For header (set by proxies/load balancers) if xff := r.Header.Get("X-Forwarded-For"); xff != "" { // Take the first IP in the chain for i := 0; i < len(xff); i++ { if xff[i] == ',' { return strings.TrimSpace(xff[:i]) } } return strings.TrimSpace(xff) } // Check X-Real-IP header if xri := r.Header.Get("X-Real-IP"); xri != "" { return strings.TrimSpace(xri) } // Fall back to RemoteAddr // RemoteAddr is "IP:port", so strip the port addr := r.RemoteAddr // Handle IPv6 addresses like "[::1]:8080" if strings.HasPrefix(addr, "[") { if idx := strings.LastIndex(addr, "]:"); idx != -1 { return addr[1:idx] } return strings.Trim(addr, "[]") } // Handle IPv4 addresses like "192.168.1.1:8080" if idx := strings.LastIndex(addr, ":"); idx != -1 { return addr[:idx] } return addr } // Header for API key authentication. const HeaderAPIKey = "X-API-Key" // Context keys. type contextKey string const ( contextKeyAPIKey contextKey = "api_key" ) // GetAPIKey retrieves the authenticated API key from the request context. func GetAPIKey(ctx context.Context) *APIKey { key, _ := ctx.Value(contextKeyAPIKey).(*APIKey) return key } // WithAPIKey returns a context with the given API key set. // This is primarily useful for testing. func WithAPIKey(ctx context.Context, apiKey *APIKey) context.Context { return context.WithValue(ctx, contextKeyAPIKey, apiKey) } // Middleware creates an authentication middleware. func Middleware(svc *Service) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Skip auth for health endpoints if r.URL.Path == "/health" || r.URL.Path == "/ready" { next.ServeHTTP(w, r) return } // Skip auth for docs if r.URL.Path == "/docs" || r.URL.Path == "/openapi.json" { next.ServeHTTP(w, r) return } // Skip auth for metrics if r.URL.Path == "/metrics" { next.ServeHTTP(w, r) return } // Get key from header key := r.Header.Get(HeaderAPIKey) if key == "" { // Also check Authorization: Bearer auth := r.Header.Get("Authorization") if strings.HasPrefix(auth, "Bearer ") { key = strings.TrimPrefix(auth, "Bearer ") } } if key == "" { api.WriteError(w, r, http.StatusUnauthorized, "UNAUTHORIZED", "Missing API key") return } // Validate key apiKey, err := svc.Validate(r.Context(), key) if err != nil { if errors.Is(err, ErrKeyNotFound) { api.WriteError(w, r, http.StatusUnauthorized, "UNAUTHORIZED", "Invalid API key") return } if errors.Is(err, ErrKeyRevoked) { api.WriteError(w, r, http.StatusUnauthorized, "KEY_REVOKED", "API key has been revoked") return } if errors.Is(err, ErrKeyExpired) { api.WriteError(w, r, http.StatusUnauthorized, "KEY_EXPIRED", "API key has expired") return } api.WriteError(w, r, http.StatusInternalServerError, "AUTH_ERROR", "Authentication failed") return } // Check IP allowlist clientIP := getClientIP(r) if !apiKey.IsIPAllowed(clientIP) { api.WriteError(w, r, http.StatusForbidden, "IP_NOT_ALLOWED", "IP address not allowed for this API key") return } // Add key to context ctx := context.WithValue(r.Context(), contextKeyAPIKey, apiKey) next.ServeHTTP(w, r.WithContext(ctx)) }) } } // RequireScope creates a middleware that checks for required scopes. func RequireScope(required ...Scope) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { apiKey := GetAPIKey(r.Context()) if apiKey == nil { api.WriteError(w, r, http.StatusUnauthorized, "UNAUTHORIZED", "Not authenticated") return } if !HasAnyScope(apiKey.Scopes, required...) { api.WriteError(w, r, http.StatusForbidden, "FORBIDDEN", "Insufficient permissions. Required: "+scopesToString(required)) return } next.ServeHTTP(w, r) }) } } // RequireProjectAccess creates a middleware that checks project access. // projectIDParam is the URL parameter name containing the project ID. func RequireProjectAccess(projectIDParam string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { apiKey := GetAPIKey(r.Context()) if apiKey == nil { api.WriteError(w, r, http.StatusUnauthorized, "UNAUTHORIZED", "Not authenticated") return } // Admin has access to everything if HasScope(apiKey.Scopes, ScopeAdmin) { next.ServeHTTP(w, r) return } // Get project ID from URL // Using chi's URLParam would require importing chi here // Instead, we'll extract from path in the handler // This middleware just validates the key has project restrictions next.ServeHTTP(w, r) }) } } func scopesToString(scopes []Scope) string { ss := make([]string, len(scopes)) for i, s := range scopes { ss[i] = string(s) } return strings.Join(ss, ", ") }