//! API Key authentication middleware (P4.2 Authentication). //! //! This middleware enforces API key authentication for protected endpoints. //! By default, only `/v1/admin/*` endpoints require authentication, but this //! can be configured to require authentication for all endpoints. //! //! # Request Flow //! //! 1. Check if path is public (health, swagger, etc.) - skip auth //! 2. Check if path requires auth based on config //! 3. Extract `X-API-Key` header //! 4. Validate key against store (exists, enabled, not expired) //! 5. Check role permissions for method + path //! 6. Check per-key rate limit //! 7. Inject `ApiKeyExtension` for downstream handlers //! 8. Return 401/403/429 on failure //! //! # Headers //! //! | Header | Direction | Description | //! |--------|-----------|-------------| //! | `X-API-Key` | Request | The API key (e.g., `steme_live_abc123...`) | //! | `X-Rate-Limit-Remaining` | Response | Remaining requests in current window | //! | `X-Rate-Limit-Reset` | Response | Unix timestamp when window resets | use axum::{ body::Body, http::{Method, Request, Response, StatusCode}, response::IntoResponse, Json, }; use futures::future::BoxFuture; use serde::Serialize; use std::sync::Arc; use std::task::{Context, Poll}; use stemedb_storage::{ApiKeyRecord, ApiKeyRole, ApiKeyStore, DEFAULT_API_KEY_RATE_LIMIT}; use tower::{Layer, Service}; use tracing::{debug, info, warn}; /// Header name for API key. pub const API_KEY_HEADER: &str = "x-api-key"; /// Header name for rate limit remaining. pub const RATE_LIMIT_REMAINING_HEADER: &str = "x-rate-limit-remaining"; /// Header name for rate limit reset timestamp. pub const RATE_LIMIT_RESET_HEADER: &str = "x-rate-limit-reset"; /// Configuration for API key authentication. #[derive(Debug, Clone)] pub struct ApiKeyAuthConfig { /// Master switch: when false, all endpoints are open (local dev mode). /// When true, authentication is enforced per the rules below. pub enabled: bool, /// Require API key for all endpoints (not just admin). pub require_for_all: bool, /// Paths that never require authentication. pub public_paths: Vec, } impl Default for ApiKeyAuthConfig { fn default() -> Self { Self { enabled: false, // Open mode by default (local dev) require_for_all: false, public_paths: vec![ "/health".to_string(), "/v1/health".to_string(), "/swagger-ui".to_string(), "/api-docs".to_string(), "/metrics".to_string(), ], } } } /// Request extension containing authenticated API key info. /// /// This is injected into the request extensions after successful authentication, /// allowing downstream handlers to access key information. #[derive(Debug, Clone)] pub struct ApiKeyExtension { /// BLAKE3 hash of the key. pub key_hash: [u8; 32], /// Access role for this key. pub role: ApiKeyRole, /// Human-readable label for this key. pub label: String, } /// Error response for authentication failures. #[derive(Debug, Serialize)] struct AuthError { error: String, code: String, } /// Tower Layer for API key authentication. #[derive(Clone)] pub struct ApiKeyAuthLayer { api_key_store: Arc, config: ApiKeyAuthConfig, } impl ApiKeyAuthLayer { /// Create a new ApiKeyAuthLayer with default configuration. /// /// Default: Only `/v1/admin/*` endpoints require authentication. pub fn new(api_key_store: Arc) -> Self { Self { api_key_store, config: ApiKeyAuthConfig::default() } } /// Create a new ApiKeyAuthLayer with custom configuration. pub fn with_config(api_key_store: Arc, config: ApiKeyAuthConfig) -> Self { Self { api_key_store, config } } /// Configure to require API key for all endpoints. pub fn require_for_all(mut self) -> Self { self.config.require_for_all = true; self } /// Add a public path that doesn't require authentication. pub fn public_path(mut self, path: impl Into) -> Self { self.config.public_paths.push(path.into()); self } } impl Layer for ApiKeyAuthLayer where A: Clone, { type Service = ApiKeyAuthService; fn layer(&self, inner: S) -> Self::Service { ApiKeyAuthService { inner, api_key_store: Arc::clone(&self.api_key_store), config: self.config.clone(), } } } /// Tower Service for API key authentication. #[derive(Clone)] pub struct ApiKeyAuthService { inner: S, api_key_store: Arc, config: ApiKeyAuthConfig, } impl ApiKeyAuthService { /// Check if a path is public (never requires auth). fn is_public_path(&self, path: &str) -> bool { self.config.public_paths.iter().any(|p| path.starts_with(p)) } /// Check if a path requires authentication. fn requires_auth(&self, path: &str) -> bool { // Open mode: no auth required for any endpoint (local dev) if !self.config.enabled { return false; } if self.is_public_path(path) { return false; } if self.config.require_for_all { return true; } // By default, only admin endpoints require auth path.starts_with("/v1/admin") } /// Check if the role can access the given method + path. fn check_permission(role: ApiKeyRole, method: &Method, path: &str) -> bool { // Admin endpoints require Admin role if path.starts_with("/v1/admin") { return role.can_admin(); } // Write endpoints require WriteAgent or Admin role if method == Method::POST || method == Method::PUT || method == Method::DELETE { // These specific paths are write operations if path.starts_with("/v1/assert") || path.starts_with("/v1/vote") || path.starts_with("/v1/supersede") { return role.can_write(); } } // Read operations are allowed for all authenticated users true } /// Extract API key from request headers. fn extract_api_key(req: &Request) -> Option { req.headers().get(API_KEY_HEADER).and_then(|v| v.to_str().ok()).map(|s| s.to_string()) } /// Hash an API key using BLAKE3. fn hash_api_key(raw_key: &str) -> [u8; 32] { *blake3::hash(raw_key.as_bytes()).as_bytes() } } impl Service> for ApiKeyAuthService where S: Service, Response = Response> + Clone + Send + 'static, S::Future: Send, A: ApiKeyStore + 'static, { type Response = Response; type Error = S::Error; type Future = BoxFuture<'static, Result>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, mut req: Request) -> Self::Future { let path = req.uri().path().to_string(); let method = req.method().clone(); let api_key_store = Arc::clone(&self.api_key_store); // Check if auth is required let requires_auth = self.requires_auth(&path); // Clone the inner service for the async block let mut inner = self.inner.clone(); Box::pin(async move { // Skip auth for public paths or when not required if !requires_auth { debug!(path = %path, "Skipping API key auth for path"); return inner.call(req).await; } // Extract API key let raw_key = match Self::extract_api_key(&req) { Some(key) => key, None => { warn!(path = %path, "Missing API key"); let error = AuthError { error: "Missing API key".to_string(), code: "UNAUTHORIZED".to_string(), }; return Ok((StatusCode::UNAUTHORIZED, Json(error)).into_response()); } }; // Validate key format (basic check) if !raw_key.starts_with("steme_") { warn!(path = %path, "Invalid API key format"); let error = AuthError { error: "Invalid API key format".to_string(), code: "UNAUTHORIZED".to_string(), }; return Ok((StatusCode::UNAUTHORIZED, Json(error)).into_response()); } // Hash the key let key_hash = Self::hash_api_key(&raw_key); // Get current timestamp let now = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .map(|d| d.as_secs()) .unwrap_or(0); // Validate key against store let record: ApiKeyRecord = match api_key_store.validate_key(&key_hash, now).await { Ok(Some(r)) => r, Ok(None) => { warn!(path = %path, key_hash = %hex::encode(&key_hash[..8]), "Invalid or expired API key"); let error = AuthError { error: "Invalid or expired API key".to_string(), code: "UNAUTHORIZED".to_string(), }; return Ok((StatusCode::UNAUTHORIZED, Json(error)).into_response()); } Err(e) => { warn!(path = %path, error = %e, "API key validation failed"); // Fail closed for security let error = AuthError { error: "Authentication service unavailable".to_string(), code: "UNAUTHORIZED".to_string(), }; return Ok((StatusCode::UNAUTHORIZED, Json(error)).into_response()); } }; // Check role permissions if !Self::check_permission(record.role, &method, &path) { warn!( path = %path, method = %method, role = %record.role, label = %record.label, "Insufficient permissions" ); let error = AuthError { error: format!( "Insufficient permissions. Role '{}' cannot access {} {}", record.role, method, path ), code: "FORBIDDEN".to_string(), }; return Ok((StatusCode::FORBIDDEN, Json(error)).into_response()); } // Check per-key rate limit let rate_limit = record.rate_limit.unwrap_or(DEFAULT_API_KEY_RATE_LIMIT); let rate_result = match api_key_store.check_rate_limit(&key_hash, rate_limit, now).await { Ok(r) => r, Err(e) => { warn!(error = %e, "Rate limit check failed, allowing request"); // Fail open for rate limiting (availability over strictness) stemedb_storage::RateLimitResult { allowed: true, remaining: rate_limit, limit: rate_limit, reset_at: now + 3600, } } }; if !rate_result.allowed { warn!( path = %path, label = %record.label, "API key rate limited" ); let error = AuthError { error: format!( "Rate limit exceeded. Limit: {} requests/hour. Resets at {}", rate_result.limit, rate_result.reset_at ), code: "RATE_LIMITED".to_string(), }; let mut response = (StatusCode::TOO_MANY_REQUESTS, Json(error)).into_response(); // Add rate limit headers let headers = response.headers_mut(); if let Ok(v) = rate_result.remaining.to_string().parse() { headers.insert(RATE_LIMIT_REMAINING_HEADER, v); } if let Ok(v) = rate_result.reset_at.to_string().parse() { headers.insert(RATE_LIMIT_RESET_HEADER, v); } return Ok(response); } // Update last_used_at (fire and forget - don't block on this) let touch_store = Arc::clone(&api_key_store); let touch_hash = key_hash; tokio::spawn(async move { if let Err(e) = touch_store.touch_key(&touch_hash, now).await { debug!(error = %e, "Failed to update API key last_used_at"); } }); // Inject extension for downstream handlers let extension = ApiKeyExtension { key_hash, role: record.role, label: record.label.clone() }; req.extensions_mut().insert(extension); info!( path = %path, label = %record.label, role = %record.role, "API key authenticated" ); // Call inner service let mut response = inner.call(req).await?; // Add rate limit headers to response let headers = response.headers_mut(); if let Ok(v) = rate_result.remaining.to_string().parse() { headers.insert(RATE_LIMIT_REMAINING_HEADER, v); } if let Ok(v) = rate_result.reset_at.to_string().parse() { headers.insert(RATE_LIMIT_RESET_HEADER, v); } Ok(response) }) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_is_public_path() { let service = ApiKeyAuthService::<(), ()> { inner: (), api_key_store: Arc::new(()), config: ApiKeyAuthConfig { enabled: true, ..Default::default() }, }; assert!(service.is_public_path("/health")); assert!(service.is_public_path("/v1/health")); assert!(service.is_public_path("/swagger-ui/index.html")); assert!(service.is_public_path("/api-docs/openapi.json")); assert!(!service.is_public_path("/v1/admin/api-keys")); assert!(!service.is_public_path("/v1/assert")); } #[test] fn test_requires_auth_disabled() { let service = ApiKeyAuthService::<(), ()> { inner: (), api_key_store: Arc::new(()), config: ApiKeyAuthConfig::default(), // enabled: false }; // Everything is open when auth is disabled assert!(!service.requires_auth("/health")); assert!(!service.requires_auth("/swagger-ui")); assert!(!service.requires_auth("/v1/admin/api-keys")); assert!(!service.requires_auth("/v1/admin/quarantine")); assert!(!service.requires_auth("/v1/assert")); assert!(!service.requires_auth("/v1/query")); } #[test] fn test_requires_auth_enabled() { let service = ApiKeyAuthService::<(), ()> { inner: (), api_key_store: Arc::new(()), config: ApiKeyAuthConfig { enabled: true, ..Default::default() }, }; // Public paths don't require auth assert!(!service.requires_auth("/health")); assert!(!service.requires_auth("/swagger-ui")); // Admin paths require auth assert!(service.requires_auth("/v1/admin/api-keys")); assert!(service.requires_auth("/v1/admin/quarantine")); // Non-admin paths don't require auth by default assert!(!service.requires_auth("/v1/assert")); assert!(!service.requires_auth("/v1/query")); } #[test] fn test_requires_auth_all() { let service = ApiKeyAuthService::<(), ()> { inner: (), api_key_store: Arc::new(()), config: ApiKeyAuthConfig { enabled: true, require_for_all: true, ..Default::default() }, }; // Public paths still don't require auth assert!(!service.requires_auth("/health")); // All other paths require auth assert!(service.requires_auth("/v1/admin/api-keys")); assert!(service.requires_auth("/v1/assert")); assert!(service.requires_auth("/v1/query")); } #[test] fn test_check_permission_admin() { // Admin can do everything assert!(ApiKeyAuthService::<(), ()>::check_permission( ApiKeyRole::Admin, &Method::GET, "/v1/admin/api-keys" )); assert!(ApiKeyAuthService::<(), ()>::check_permission( ApiKeyRole::Admin, &Method::POST, "/v1/assert" )); assert!(ApiKeyAuthService::<(), ()>::check_permission( ApiKeyRole::Admin, &Method::GET, "/v1/query" )); } #[test] fn test_check_permission_write_agent() { // WriteAgent can write but not admin assert!(!ApiKeyAuthService::<(), ()>::check_permission( ApiKeyRole::WriteAgent, &Method::GET, "/v1/admin/api-keys" )); assert!(ApiKeyAuthService::<(), ()>::check_permission( ApiKeyRole::WriteAgent, &Method::POST, "/v1/assert" )); assert!(ApiKeyAuthService::<(), ()>::check_permission( ApiKeyRole::WriteAgent, &Method::GET, "/v1/query" )); } #[test] fn test_check_permission_read_only() { // ReadOnly can only read assert!(!ApiKeyAuthService::<(), ()>::check_permission( ApiKeyRole::ReadOnly, &Method::GET, "/v1/admin/api-keys" )); assert!(!ApiKeyAuthService::<(), ()>::check_permission( ApiKeyRole::ReadOnly, &Method::POST, "/v1/assert" )); assert!(ApiKeyAuthService::<(), ()>::check_permission( ApiKeyRole::ReadOnly, &Method::GET, "/v1/query" )); } #[test] fn test_hash_api_key() { let key = "steme_test_abcdef123456"; let hash = ApiKeyAuthService::<(), ()>::hash_api_key(key); // Hash should be deterministic let hash2 = ApiKeyAuthService::<(), ()>::hash_api_key(key); assert_eq!(hash, hash2); // Different keys should have different hashes let other_key = "steme_test_different"; let other_hash = ApiKeyAuthService::<(), ()>::hash_api_key(other_key); assert_ne!(hash, other_hash); } }