//! Custom axum extractors for the StemeDB API. use axum::{ async_trait, extract::FromRequestParts, http::{request::Parts, StatusCode}, response::{IntoResponse, Response}, }; use serde::de::DeserializeOwned; use std::fmt; /// Rejection type for QsQuery extraction failures. #[derive(Debug)] pub struct QsQueryRejection { message: String, } impl fmt::Display for QsQueryRejection { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "Failed to deserialize query string: {}", self.message) } } impl std::error::Error for QsQueryRejection {} impl IntoResponse for QsQueryRejection { fn into_response(self) -> Response { (StatusCode::BAD_REQUEST, self.message).into_response() } } /// Query string extractor that supports bracket notation (e.g., `?sources[]=value1&sources[]=value2`). /// /// This extractor uses `serde_qs` in **non-strict mode** to properly handle /// array parameters with bracket notation (both literal `[]` and URL-encoded `%5B%5D`), /// which is the standard format used by JavaScript's URLSearchParams and web browsers. /// /// # When to Use QsQuery vs Query /// /// **Use `QsQuery` when:** /// - Your request DTO contains `Vec` or `Option>` fields /// - The endpoint is called by the dashboard or JavaScript clients /// - You need bracket notation support: `?filters[]=a&filters[]=b` /// /// **Use standard `axum::extract::Query` when:** /// - All query parameters are scalars (String, usize, bool, `Option`, etc.) /// - No array/vector parameters needed /// - Simpler and lighter weight for non-array cases /// /// # Example /// /// ```rust,ignore /// use stemedb_api::extractors::QsQuery; /// use serde::Deserialize; /// /// #[derive(Deserialize)] /// struct MyRequest { /// sources: Option>, // Array parameter /// limit: usize, // Scalar parameter /// } /// /// // ✅ Correct - QsQuery handles both array and scalar params /// async fn handler(QsQuery(params): QsQuery) { /// // Dashboard sends: ?sources[]=rfc&sources[]=community&limit=10 /// // params.sources = Some(vec!["rfc", "community"]) /// // params.limit = 10 /// } /// /// // ❌ Wrong - standard Query can't parse bracket notation /// async fn wrong_handler(Query(params): Query) { /// // Dashboard sends: ?sources[]=rfc&sources[]=community /// // Result: params.sources = None (silently fails!) /// } /// ``` /// /// # Dashboard Compatibility /// /// The StemeDB Dashboard uses JavaScript's `URLSearchParams.append()` which generates /// bracket notation for arrays: /// /// ```javascript /// // Dashboard code /// params.sources.forEach(s => searchParams.append("sources[]", s)); /// // Generates: ?sources[]=rfc&sources[]=owasp&sources[]=community /// ``` /// /// If you use standard `Query` for array parameters, the dashboard filters will appear /// to work but silently fail (returning all results instead of filtered results). #[derive(Debug, Clone, Copy, Default)] pub struct QsQuery(pub T); #[async_trait] impl FromRequestParts for QsQuery where T: DeserializeOwned, S: Send + Sync, { type Rejection = QsQueryRejection; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { let query = parts.uri.query().unwrap_or_default(); // Use non-strict mode to accept both encoded (%5B%5D) and literal ([]) brackets. // Browsers URL-encode brackets, so sources[] becomes sources%5B%5D in the query string. let config = serde_qs::Config::new(5, false); let value = config .deserialize_str(query) .map_err(|err| QsQueryRejection { message: err.to_string() })?; Ok(QsQuery(value)) } } impl std::ops::Deref for QsQuery { type Target = T; fn deref(&self) -> &Self::Target { &self.0 } } impl std::ops::DerefMut for QsQuery { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 } } #[cfg(test)] mod tests { use super::*; use axum::http::{Request, Uri}; use serde::Deserialize; #[derive(Debug, Deserialize, PartialEq)] struct TestParams { sources: Option>, limit: Option, } #[tokio::test] async fn test_bracket_notation() { let uri: Uri = "http://example.com?sources[]=rfc&sources[]=community&limit=10".parse().unwrap(); let mut parts = Request::builder().uri(uri).body(()).unwrap().into_parts().0; let QsQuery(params): QsQuery = QsQuery::from_request_parts(&mut parts, &()).await.unwrap(); assert_eq!( params, TestParams { sources: Some(vec!["rfc".to_string(), "community".to_string()]), limit: Some(10), } ); } #[tokio::test] async fn test_no_brackets() { let uri: Uri = "http://example.com?limit=5".parse().unwrap(); let mut parts = Request::builder().uri(uri).body(()).unwrap().into_parts().0; let QsQuery(params): QsQuery = QsQuery::from_request_parts(&mut parts, &()).await.unwrap(); assert_eq!(params, TestParams { sources: None, limit: Some(5) }); } #[tokio::test] async fn test_empty_query() { let uri: Uri = "http://example.com".parse().unwrap(); let mut parts = Request::builder().uri(uri).body(()).unwrap().into_parts().0; let QsQuery(params): QsQuery = QsQuery::from_request_parts(&mut parts, &()).await.unwrap(); assert_eq!(params, TestParams { sources: None, limit: None }); } #[tokio::test] async fn test_encoded_brackets() { // Test URL-encoded brackets (%5B = '[', %5D = ']') // This is what browsers send when using URLSearchParams let uri: Uri = "http://example.com?sources%5B%5D=rfc&sources%5B%5D=owasp&sources%5B%5D=community&limit=100" .parse() .unwrap(); let mut parts = Request::builder().uri(uri).body(()).unwrap().into_parts().0; let QsQuery(params): QsQuery = QsQuery::from_request_parts(&mut parts, &()).await.unwrap(); assert_eq!( params, TestParams { sources: Some(vec![ "rfc".to_string(), "owasp".to_string(), "community".to_string() ]), limit: Some(100), } ); } }