use std::collections::HashMap; use std::net::SocketAddr; use std::path::PathBuf; use std::sync::Arc; use axum::extract::{Query, State}; use axum::http::StatusCode; use axum::response::IntoResponse; use axum::routing::{get, post}; use axum::{Json, Router}; use iknowyou_engine::{ AuxMemory, FeedbackAction, FeedbackEvent, IkyEngine, NoopAuxMemory, PersonalizationItem, RetrievedItem, }; use serde::{Deserialize, Serialize}; use tidaldb::session::SessionHandle; #[derive(Clone)] struct AppState { engine: Arc, sessions: Arc>>, } #[derive(Debug, Serialize)] struct ErrorResponse { error: String, } #[derive(Debug, Deserialize)] struct UpsertUserRequest { user_id: u64, #[serde(default)] metadata: HashMap, } #[derive(Debug, Deserialize)] struct UpsertItemRequest { item_id: u64, creator_id: u64, title: String, #[serde(default = "default_message_category")] category: String, } fn default_message_category() -> String { "message".to_string() } #[derive(Debug, Deserialize)] struct FeedbackRequest { user_id: u64, item_id: u64, creator_id: Option, action: FeedbackAction, } #[derive(Debug, Deserialize)] struct RetrieveQuery { user_id: u64, #[serde(default = "default_limit")] limit: usize, } fn default_limit() -> usize { 20 } #[derive(Debug, Serialize)] struct RetrieveResponse { items: Vec, } #[derive(Debug, Deserialize)] struct StartSessionRequest { conversation_id: String, user_id: u64, #[serde(default = "default_agent_id")] agent_id: String, } fn default_agent_id() -> String { "aeries".to_string() } #[derive(Debug, Deserialize)] struct SessionSignalRequest { conversation_id: String, signal_type: String, item_id: u64, #[serde(default = "default_weight")] weight: f64, annotation: Option, } fn default_weight() -> f64 { 1.0 } #[derive(Debug, Deserialize)] struct CloseSessionRequest { conversation_id: String, } #[derive(Debug, Deserialize)] struct ObservationRequest { person_id: u64, observation: String, } #[derive(Debug, Serialize)] struct OkResponse { ok: bool, } #[derive(Debug, Serialize)] struct StartSessionResponse { ok: bool, session_id: String, } #[tokio::main] async fn main() -> Result<(), Box> { let data_dir = std::env::var("IKY_ENGINE_DATA_DIR") .map(PathBuf::from) .unwrap_or_else(|_| std::env::temp_dir().join("iknowyou_engine_data")); let aux: Arc = build_aux_memory()?; let engine = Arc::new( IkyEngine::builder() .data_dir(&data_dir) .with_aux_memory(aux) .open()?, ); let state = AppState { engine, sessions: Arc::new(tokio::sync::Mutex::new(HashMap::new())), }; let app = Router::new() .route("/healthz", get(healthz)) .route("/v1/users/upsert", post(upsert_user)) .route("/v1/items/upsert", post(upsert_item)) .route("/v1/feedback", post(record_feedback)) .route("/v1/retrieve", get(retrieve_for_user)) .route("/v1/sessions/start", post(start_session)) .route("/v1/sessions/signal", post(session_signal)) .route("/v1/sessions/close", post(close_session)) .route("/v1/aux/observation", post(aux_observation)) .with_state(state); let bind_addr = std::env::var("IKY_ENGINE_BIND") .unwrap_or_else(|_| "127.0.0.1:7777".to_string()) .parse::()?; let listener = tokio::net::TcpListener::bind(bind_addr).await?; println!("iknowyou-engine server listening on {bind_addr}"); println!("data_dir: {}", data_dir.display()); axum::serve(listener, app).await?; Ok(()) } fn build_aux_memory() -> Result, Box> { #[cfg(feature = "synap-aux")] { let base = std::env::var("SYNAP_URL").ok(); let key = std::env::var("SYNAP_API_KEY").ok(); if let (Some(base), Some(key)) = (base, key) && !base.is_empty() && !key.is_empty() { let aux = iknowyou_engine::SynapAuxMemory::new(base, key)?; return Ok(Arc::new(aux)); } } Ok(Arc::new(NoopAuxMemory)) } async fn healthz() -> Json { Json(OkResponse { ok: true }) } async fn upsert_user( State(state): State, Json(req): Json, ) -> Result, (StatusCode, Json)> { state .engine .upsert_user(req.user_id, &req.metadata) .map_err(internal_error)?; Ok(Json(OkResponse { ok: true })) } async fn upsert_item( State(state): State, Json(req): Json, ) -> Result, (StatusCode, Json)> { let item = PersonalizationItem { item_id: req.item_id, creator_id: req.creator_id, title: req.title, category: req.category, embedding: None, }; state.engine.upsert_item(&item).map_err(internal_error)?; Ok(Json(OkResponse { ok: true })) } async fn record_feedback( State(state): State, Json(req): Json, ) -> Result, (StatusCode, Json)> { let event = FeedbackEvent::now(req.user_id, req.item_id, req.creator_id, req.action); state .engine .record_feedback(event) .map_err(internal_error)?; Ok(Json(OkResponse { ok: true })) } async fn retrieve_for_user( State(state): State, Query(query): Query, ) -> Result, (StatusCode, Json)> { let items = state .engine .retrieve_for_user_items(query.user_id, query.limit) .map_err(internal_error)?; Ok(Json(RetrieveResponse { items })) } async fn start_session( State(state): State, Json(req): Json, ) -> Result, (StatusCode, Json)> { let mut sessions = state.sessions.lock().await; if let Some(handle) = sessions.get(&req.conversation_id) { return Ok(Json(StartSessionResponse { ok: true, session_id: handle.id.to_string(), })); } let handle = state .engine .start_session(req.user_id, &req.agent_id, HashMap::new()) .map_err(internal_error)?; let session_id = handle.id.to_string(); sessions.insert(req.conversation_id, handle); Ok(Json(StartSessionResponse { ok: true, session_id, })) } async fn session_signal( State(state): State, Json(req): Json, ) -> Result, (StatusCode, Json)> { let sessions = state.sessions.lock().await; let handle = sessions.get(&req.conversation_id).ok_or_else(|| { ( StatusCode::NOT_FOUND, Json(ErrorResponse { error: "session not found".to_string(), }), ) })?; state .engine .session_signal( handle, &req.signal_type, req.item_id, req.weight, req.annotation, ) .map_err(internal_error)?; Ok(Json(OkResponse { ok: true })) } async fn close_session( State(state): State, Json(req): Json, ) -> Result, (StatusCode, Json)> { let mut sessions = state.sessions.lock().await; let handle = sessions.remove(&req.conversation_id).ok_or_else(|| { ( StatusCode::NOT_FOUND, Json(ErrorResponse { error: "session not found".to_string(), }), ) })?; state.engine.close_session(handle).map_err(internal_error)?; Ok(Json(OkResponse { ok: true })) } async fn aux_observation( State(state): State, Json(req): Json, ) -> Result, (StatusCode, Json)> { state .engine .remember_aux_observation(req.person_id, &req.observation) .map_err(internal_error)?; Ok(Json(OkResponse { ok: true })) } fn internal_error(err: E) -> (StatusCode, Json) { ( StatusCode::INTERNAL_SERVER_ERROR, Json(ErrorResponse { error: err.to_string(), }), ) } impl IntoResponse for ErrorResponse { fn into_response(self) -> axum::response::Response { (StatusCode::INTERNAL_SERVER_ERROR, Json(self)).into_response() } }