diff --git a/tidal/src/db/session_restore.rs b/tidal/src/db/session_restore.rs index 04844de..d18824d 100644 --- a/tidal/src/db/session_restore.rs +++ b/tidal/src/db/session_restore.rs @@ -4,9 +4,9 @@ use std::collections::HashMap; use std::sync::Arc; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; -use crate::schema::Timestamp; +use crate::schema::{EntityId, Timestamp}; use crate::session::{self as session_mod, AgentId, SessionId, SessionState}; -use crate::storage::{Tag, parse_key}; +use crate::storage::{Tag, encode_key, parse_key}; use super::TidalDb; @@ -52,7 +52,7 @@ impl TidalDb { } } b"start" => { - if let Some((session_id, _user_id, _started_at_ns)) = + if let Some((session_id, _user_id, _started_at_ns, _metadata)) = session_mod::deserialize_start_record(&value) { tracing::warn!( @@ -120,6 +120,19 @@ impl TidalDb { let closed = Arc::new(AtomicBool::new(false)); + // Rehydrate metadata from the persisted start record in storage. + // The WAL Start event does not carry metadata; the storage start record does. + let metadata = self + .storage + .as_ref() + .and_then(|storage| { + let key = encode_key(EntityId::new(session_id), Tag::Session, b"start"); + storage.items_engine().get(&key).ok().flatten() + }) + .and_then(|bytes| session_mod::deserialize_start_record(&bytes)) + .map(|(_, _, _, meta)| meta) + .unwrap_or_default(); + let state = Arc::new(SessionState { id: SessionId::from_raw(session_id), user_id, @@ -128,7 +141,7 @@ impl TidalDb { // We lost the exact monotonic Instant -- approximate with "now". started_at: std::time::Instant::now(), started_at_ns, - metadata: HashMap::new(), // metadata not persisted in the session journal + metadata, signals: dashmap::DashMap::new(), signaled_entities: dashmap::DashMap::new(), annotations: std::sync::Mutex::new(Vec::new()), @@ -138,9 +151,9 @@ impl TidalDb { closed, }); - // Replay signals into the restored session state. + // Replay signals and annotations into the restored session state. if let Some(signals) = session_signals.get(&session_id) { - for (entity_id, weight, ts_ns, signal_name, _annotation) in signals { + for (entity_id, weight, ts_ns, signal_name, annotation) in signals { let lambda = schema .signal(signal_name) .and_then(|def| def.decay().lambda()) @@ -155,6 +168,14 @@ impl TidalDb { state.signaled_entities.insert(*entity_id, ()); state.signals_written.fetch_add(1, Ordering::Relaxed); + + // Replay annotation if present. + if let Some(ann) = annotation + && let Ok(mut anns) = state.annotations.lock() + && anns.len() < session_mod::MAX_ANNOTATIONS + { + anns.push((*ts_ns, ann.clone())); + } } } diff --git a/tidal/src/session/serde.rs b/tidal/src/session/serde/mod.rs similarity index 89% rename from tidal/src/session/serde.rs rename to tidal/src/session/serde/mod.rs index 252e2f1..0aebe6c 100644 --- a/tidal/src/session/serde.rs +++ b/tidal/src/session/serde/mod.rs @@ -6,10 +6,13 @@ //! a snapshot when the session closes. //! - **Audit log** — standalone serialisation used for the separate audit keyspace. +mod start_record; + +pub use start_record::{deserialize_start_record, serialize_start_record}; + use super::audit::AuditEntry; use super::signal_state::SignalSnapEntry; use super::snapshot::SessionSnapshot; -use super::state::SessionState; use super::types::SessionId; /// Format version byte for snapshot serialization. @@ -275,51 +278,6 @@ pub fn deserialize_snapshot(bytes: &[u8]) -> Option { }) } -// ── Start record ────────────────────────────────────────────────────────────── - -/// Serialize a compact session start record. -/// -/// Written to storage on `start_session`; deleted when the session is closed -/// and replaced by a snapshot record. -/// -/// Format: `[session_id: 8][user_id: 8][started_at_ns: 8][metadata_count: u32][...kv...]` -#[must_use] -#[allow(clippy::cast_possible_truncation)] -pub fn serialize_start_record(state: &SessionState) -> Vec { - let mut buf = Vec::new(); - buf.extend_from_slice(&state.id.as_u64().to_le_bytes()); - buf.extend_from_slice(&state.user_id.to_le_bytes()); - buf.extend_from_slice(&state.started_at_ns.to_le_bytes()); - buf.extend_from_slice(&(state.metadata.len() as u32).to_le_bytes()); - for (k, v) in &state.metadata { - buf.extend_from_slice(&(k.len() as u32).to_le_bytes()); - buf.extend_from_slice(k.as_bytes()); - buf.extend_from_slice(&(v.len() as u32).to_le_bytes()); - buf.extend_from_slice(v.as_bytes()); - } - buf -} - -/// Deserialize a session start record (reads `session_id`, `user_id`, `started_at_ns`). -/// -/// Returns `None` if the bytes are malformed. -#[must_use] -pub fn deserialize_start_record(bytes: &[u8]) -> Option<(SessionId, u64, u64)> { - if bytes.len() < 24 { - return None; - } - let session_id = u64::from_le_bytes([ - bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7], - ]); - let user_id = u64::from_le_bytes([ - bytes[8], bytes[9], bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15], - ]); - let started_at_ns = u64::from_le_bytes([ - bytes[16], bytes[17], bytes[18], bytes[19], bytes[20], bytes[21], bytes[22], bytes[23], - ]); - Some((SessionId::from_raw(session_id), user_id, started_at_ns)) -} - // ── Audit log ───────────────────────────────────────────────────────────────── /// Serialize an audit log to bytes for storage archival. diff --git a/tidal/src/session/serde/start_record.rs b/tidal/src/session/serde/start_record.rs new file mode 100644 index 0000000..0f4ced5 --- /dev/null +++ b/tidal/src/session/serde/start_record.rs @@ -0,0 +1,202 @@ +//! Binary encoding and decoding for session start records. +//! +//! A start record is written to storage on `start_session` and deleted (replaced +//! by a snapshot) when the session closes. On restart, surviving start records +//! indicate sessions that were active at crash time; their metadata is +//! rehydrated from this record during WAL replay. + +use super::super::state::SessionState; +use super::super::types::SessionId; + +/// Serialize a compact session start record. +/// +/// Written to storage on `start_session`; deleted when the session is closed +/// and replaced by a snapshot record. +/// +/// Format: `[session_id: 8][user_id: 8][started_at_ns: 8][metadata_count: u32][...kv...]` +#[must_use] +#[allow(clippy::cast_possible_truncation)] +pub fn serialize_start_record(state: &SessionState) -> Vec { + let mut buf = Vec::new(); + buf.extend_from_slice(&state.id.as_u64().to_le_bytes()); + buf.extend_from_slice(&state.user_id.to_le_bytes()); + buf.extend_from_slice(&state.started_at_ns.to_le_bytes()); + buf.extend_from_slice(&(state.metadata.len() as u32).to_le_bytes()); + for (k, v) in &state.metadata { + buf.extend_from_slice(&(k.len() as u32).to_le_bytes()); + buf.extend_from_slice(k.as_bytes()); + buf.extend_from_slice(&(v.len() as u32).to_le_bytes()); + buf.extend_from_slice(v.as_bytes()); + } + buf +} + +/// Deserialize a session start record. +/// +/// Returns `(session_id, user_id, started_at_ns, metadata)`. +/// Older records written before metadata support was added return an empty map +/// rather than `None` — partial reads are treated as empty metadata, not errors. +/// +/// Returns `None` only if the mandatory first 24 bytes are missing. +#[must_use] +#[allow(clippy::cast_possible_truncation)] +pub fn deserialize_start_record( + bytes: &[u8], +) -> Option<( + SessionId, + u64, + u64, + std::collections::HashMap, +)> { + if bytes.len() < 24 { + return None; + } + let session_id = u64::from_le_bytes([ + bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7], + ]); + let user_id = u64::from_le_bytes([ + bytes[8], bytes[9], bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15], + ]); + let started_at_ns = u64::from_le_bytes([ + bytes[16], bytes[17], bytes[18], bytes[19], bytes[20], bytes[21], bytes[22], bytes[23], + ]); + + // Metadata section: [count: u32][key_len: u32][key bytes][val_len: u32][val bytes]... + // Records written before metadata support stop at byte 24 — return empty map. + let mut pos = 24usize; + let read_u32 = |p: &mut usize| -> Option { + if *p + 4 > bytes.len() { + return None; + } + let v = u32::from_le_bytes([bytes[*p], bytes[*p + 1], bytes[*p + 2], bytes[*p + 3]]); + *p += 4; + Some(v) + }; + let metadata = 'parse: { + let Some(count) = read_u32(&mut pos) else { + break 'parse std::collections::HashMap::new(); + }; + let mut map = std::collections::HashMap::with_capacity(count as usize); + for _ in 0..count { + let Some(key_len) = read_u32(&mut pos) else { + break 'parse map; + }; + let key_len = key_len as usize; + if pos + key_len > bytes.len() { + break 'parse map; + } + let key = String::from_utf8_lossy(&bytes[pos..pos + key_len]).to_string(); + pos += key_len; + + let Some(val_len) = read_u32(&mut pos) else { + break 'parse map; + }; + let val_len = val_len as usize; + if pos + val_len > bytes.len() { + break 'parse map; + } + let val = String::from_utf8_lossy(&bytes[pos..pos + val_len]).to_string(); + pos += val_len; + + map.insert(key, val); + } + map + }; + + Some(( + SessionId::from_raw(session_id), + user_id, + started_at_ns, + metadata, + )) +} + +// ── Unit tests ──────────────────────────────────────────────────────────────── + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use std::collections::HashMap; + use std::sync::Arc; + use std::sync::atomic::{AtomicBool, AtomicU64}; + + use super::{deserialize_start_record, serialize_start_record}; + use crate::session::audit::AuditLog; + use crate::session::state::SessionState; + use crate::session::types::{AgentId, SessionId}; + + fn make_state(metadata: HashMap) -> SessionState { + let closed = Arc::new(AtomicBool::new(false)); + SessionState { + id: SessionId(42), + user_id: 7, + agent_id: AgentId::new("agent-a").unwrap(), + policy_name: "default".to_string(), + started_at: std::time::Instant::now(), + started_at_ns: 123_456_789, + metadata, + signals: dashmap::DashMap::new(), + signaled_entities: dashmap::DashMap::new(), + annotations: std::sync::Mutex::new(Vec::new()), + signals_written: AtomicU64::new(0), + signals_rejected: AtomicU64::new(0), + audit_log: std::sync::Mutex::new(AuditLog::new()), + closed, + } + } + + #[test] + fn start_record_roundtrip_with_metadata() { + let mut metadata = HashMap::new(); + metadata.insert("tool".to_string(), "planner".to_string()); + metadata.insert("context".to_string(), "daily-feed".to_string()); + + let state = make_state(metadata); + let bytes = serialize_start_record(&state); + let (session_id, user_id, started_at_ns, metadata) = + deserialize_start_record(&bytes).unwrap(); + + assert_eq!(session_id, SessionId(42)); + assert_eq!(user_id, 7); + assert_eq!(started_at_ns, 123_456_789); + assert_eq!(metadata.get("tool").map(String::as_str), Some("planner")); + assert_eq!( + metadata.get("context").map(String::as_str), + Some("daily-feed") + ); + } + + #[test] + fn start_record_roundtrip_empty_metadata() { + let state = make_state(HashMap::new()); + let bytes = serialize_start_record(&state); + let (session_id, user_id, started_at_ns, metadata) = + deserialize_start_record(&bytes).unwrap(); + + assert_eq!(session_id, SessionId(42)); + assert_eq!(user_id, 7); + assert_eq!(started_at_ns, 123_456_789); + assert!(metadata.is_empty()); + } + + #[test] + fn start_record_backward_compat_no_metadata() { + // A record with only the 24-byte header (pre-metadata format) must parse + // without error and return an empty metadata map. + let mut buf = Vec::new(); + buf.extend_from_slice(&42u64.to_le_bytes()); // session_id + buf.extend_from_slice(&7u64.to_le_bytes()); // user_id + buf.extend_from_slice(&999u64.to_le_bytes()); // started_at_ns + // No metadata bytes — simulates a record written before this fix. + + let (session_id, user_id, started_at_ns, metadata) = + deserialize_start_record(&buf).unwrap(); + assert_eq!(session_id, SessionId(42)); + assert_eq!(user_id, 7); + assert_eq!(started_at_ns, 999); + assert!( + metadata.is_empty(), + "backward-compat: metadata should be empty" + ); + } +} diff --git a/tidal/tests/session_durability.rs b/tidal/tests/session_durability.rs index 80bdd37..7493acf 100644 --- a/tidal/tests/session_durability.rs +++ b/tidal/tests/session_durability.rs @@ -441,7 +441,165 @@ fn active_session_state_restored_after_crash() { } } -// ── Test 7: WAL replay preserves signal counts exactly ─────────────────────── +// ── Test 6: Session metadata survives crash recovery ───────────────────────── + +/// Verifies that `start_session` metadata (e.g. `{"tool": "planner"}`) is +/// rehydrated from the storage start record during WAL replay, so that +/// `session_snapshot().metadata` is populated after a crash. +#[test] +fn metadata_survives_crash() { + let dir = tempfile::tempdir().unwrap(); + let schema = test_schema(); + + let session_id; + { + let db = TidalDb::builder() + .with_data_dir(dir.path()) + .with_schema(schema.clone()) + .open() + .unwrap(); + + let mut meta = HashMap::new(); + meta.insert("tool".to_string(), "planner".to_string()); + meta.insert("context".to_string(), "daily-feed".to_string()); + + let handle = db + .start_session(10, "agent-meta", "default_policy", meta) + .unwrap(); + session_id = handle.id; + + // Write a signal so the WAL has a SessionSignal record too. + let mut item_meta = HashMap::new(); + item_meta.insert("title".to_string(), "item-1".to_string()); + db.write_item_with_metadata(EntityId::new(1), &item_meta) + .unwrap(); + let ts = Timestamp::now(); + db.session_signal(&handle, "reward", EntityId::new(1), 1.0, ts, None) + .unwrap(); + + // Simulate crash: drop without close_session(). + drop(db); + } + + // Reopen — metadata must be present in the restored active session. + { + let db = TidalDb::builder() + .with_data_dir(dir.path()) + .with_schema(schema) + .open() + .unwrap(); + + let active = db.active_sessions(); + assert!( + active.iter().any(|info| info.id == session_id), + "session must be restored as active after crash" + ); + + let snap = db.session_snapshot(session_id).unwrap(); + assert_eq!( + snap.metadata.get("tool").map(String::as_str), + Some("planner"), + "tool metadata must survive crash recovery" + ); + assert_eq!( + snap.metadata.get("context").map(String::as_str), + Some("daily-feed"), + "context metadata must survive crash recovery" + ); + + db.close().unwrap(); + } +} + +// ── Test 7: Preference annotations survive crash recovery ──────────────────── + +/// Verifies that session signal annotations ("more jazz today") are replayed +/// from the WAL during crash recovery, so that `SessionContext::keywords` is +/// populated for FOR SESSION ranking after a restart. +#[test] +fn annotations_survive_crash() { + let dir = tempfile::tempdir().unwrap(); + let schema = test_schema(); + + let session_id; + { + let db = TidalDb::builder() + .with_data_dir(dir.path()) + .with_schema(schema.clone()) + .open() + .unwrap(); + + let handle = db + .start_session(11, "agent-ann", "default_policy", HashMap::new()) + .unwrap(); + session_id = handle.id; + + let mut item_meta = HashMap::new(); + item_meta.insert("genre".to_string(), "jazz".to_string()); + db.write_item_with_metadata(EntityId::new(1), &item_meta) + .unwrap(); + let ts = Timestamp::now(); + + // Write two annotated signals. + db.session_signal( + &handle, + "reward", + EntityId::new(1), + 1.0, + ts, + Some("more jazz today".to_string()), + ) + .unwrap(); + db.session_signal( + &handle, + "view", + EntityId::new(1), + 0.5, + ts, + Some("acoustic vibes".to_string()), + ) + .unwrap(); + + // Simulate crash. + drop(db); + } + + // Reopen — annotations must be present so FOR SESSION ranking can use them. + { + let db = TidalDb::builder() + .with_data_dir(dir.path()) + .with_schema(schema) + .open() + .unwrap(); + + let active = db.active_sessions(); + assert!( + active.iter().any(|info| info.id == session_id), + "session must be restored as active after crash" + ); + + let snap = db.session_snapshot(session_id).unwrap(); + assert_eq!( + snap.annotations.len(), + 2, + "both annotations must be replayed from WAL" + ); + + let texts: Vec<&str> = snap.annotations.iter().map(|(_, s)| s.as_str()).collect(); + assert!( + texts.contains(&"more jazz today"), + "jazz annotation must survive crash recovery" + ); + assert!( + texts.contains(&"acoustic vibes"), + "acoustic annotation must survive crash recovery" + ); + + db.close().unwrap(); + } +} + +// ── Test 8: WAL replay preserves signal counts exactly ─────────────────────── /// Property-like correctness test: write exactly K signals of one type into /// an active session, "crash" (drop without close_session), reopen, and