fix: restore session metadata and annotations after crash recovery
deserialize_start_record now returns the full (session_id, user_id,
started_at_ns, metadata) tuple — the metadata bytes were already written
by serialize_start_record but silently discarded on read.
restore_session_wal_events looks up the persisted start record in storage
for each open session and uses the deserialized metadata instead of
HashMap::new(), so fields like {"tool":"planner"} survive a crash.
The signal replay loop no longer discards _annotation — annotations are
now pushed into state.annotations during WAL replay, restoring preference
hints like "more jazz today" so FOR SESSION ranking works post-restart.
Two new integration tests in session_durability.rs verify both fixes
against a real persistent store with simulated crash (drop without
close_session). session/serde.rs split into serde/mod.rs + serde/start_record.rs
to satisfy the 600-line limit.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
192c473f55
commit
fd95dfc2be
@ -4,9 +4,9 @@ use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
|
||||
use crate::schema::Timestamp;
|
||||
use crate::schema::{EntityId, Timestamp};
|
||||
use crate::session::{self as session_mod, AgentId, SessionId, SessionState};
|
||||
use crate::storage::{Tag, parse_key};
|
||||
use crate::storage::{Tag, encode_key, parse_key};
|
||||
|
||||
use super::TidalDb;
|
||||
|
||||
@ -52,7 +52,7 @@ impl TidalDb {
|
||||
}
|
||||
}
|
||||
b"start" => {
|
||||
if let Some((session_id, _user_id, _started_at_ns)) =
|
||||
if let Some((session_id, _user_id, _started_at_ns, _metadata)) =
|
||||
session_mod::deserialize_start_record(&value)
|
||||
{
|
||||
tracing::warn!(
|
||||
@ -120,6 +120,19 @@ impl TidalDb {
|
||||
|
||||
let closed = Arc::new(AtomicBool::new(false));
|
||||
|
||||
// Rehydrate metadata from the persisted start record in storage.
|
||||
// The WAL Start event does not carry metadata; the storage start record does.
|
||||
let metadata = self
|
||||
.storage
|
||||
.as_ref()
|
||||
.and_then(|storage| {
|
||||
let key = encode_key(EntityId::new(session_id), Tag::Session, b"start");
|
||||
storage.items_engine().get(&key).ok().flatten()
|
||||
})
|
||||
.and_then(|bytes| session_mod::deserialize_start_record(&bytes))
|
||||
.map(|(_, _, _, meta)| meta)
|
||||
.unwrap_or_default();
|
||||
|
||||
let state = Arc::new(SessionState {
|
||||
id: SessionId::from_raw(session_id),
|
||||
user_id,
|
||||
@ -128,7 +141,7 @@ impl TidalDb {
|
||||
// We lost the exact monotonic Instant -- approximate with "now".
|
||||
started_at: std::time::Instant::now(),
|
||||
started_at_ns,
|
||||
metadata: HashMap::new(), // metadata not persisted in the session journal
|
||||
metadata,
|
||||
signals: dashmap::DashMap::new(),
|
||||
signaled_entities: dashmap::DashMap::new(),
|
||||
annotations: std::sync::Mutex::new(Vec::new()),
|
||||
@ -138,9 +151,9 @@ impl TidalDb {
|
||||
closed,
|
||||
});
|
||||
|
||||
// Replay signals into the restored session state.
|
||||
// Replay signals and annotations into the restored session state.
|
||||
if let Some(signals) = session_signals.get(&session_id) {
|
||||
for (entity_id, weight, ts_ns, signal_name, _annotation) in signals {
|
||||
for (entity_id, weight, ts_ns, signal_name, annotation) in signals {
|
||||
let lambda = schema
|
||||
.signal(signal_name)
|
||||
.and_then(|def| def.decay().lambda())
|
||||
@ -155,6 +168,14 @@ impl TidalDb {
|
||||
|
||||
state.signaled_entities.insert(*entity_id, ());
|
||||
state.signals_written.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
// Replay annotation if present.
|
||||
if let Some(ann) = annotation
|
||||
&& let Ok(mut anns) = state.annotations.lock()
|
||||
&& anns.len() < session_mod::MAX_ANNOTATIONS
|
||||
{
|
||||
anns.push((*ts_ns, ann.clone()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -6,10 +6,13 @@
|
||||
//! a snapshot when the session closes.
|
||||
//! - **Audit log** — standalone serialisation used for the separate audit keyspace.
|
||||
|
||||
mod start_record;
|
||||
|
||||
pub use start_record::{deserialize_start_record, serialize_start_record};
|
||||
|
||||
use super::audit::AuditEntry;
|
||||
use super::signal_state::SignalSnapEntry;
|
||||
use super::snapshot::SessionSnapshot;
|
||||
use super::state::SessionState;
|
||||
use super::types::SessionId;
|
||||
|
||||
/// Format version byte for snapshot serialization.
|
||||
@ -275,51 +278,6 @@ pub fn deserialize_snapshot(bytes: &[u8]) -> Option<SessionSnapshot> {
|
||||
})
|
||||
}
|
||||
|
||||
// ── Start record ──────────────────────────────────────────────────────────────
|
||||
|
||||
/// Serialize a compact session start record.
|
||||
///
|
||||
/// Written to storage on `start_session`; deleted when the session is closed
|
||||
/// and replaced by a snapshot record.
|
||||
///
|
||||
/// Format: `[session_id: 8][user_id: 8][started_at_ns: 8][metadata_count: u32][...kv...]`
|
||||
#[must_use]
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
pub fn serialize_start_record(state: &SessionState) -> Vec<u8> {
|
||||
let mut buf = Vec::new();
|
||||
buf.extend_from_slice(&state.id.as_u64().to_le_bytes());
|
||||
buf.extend_from_slice(&state.user_id.to_le_bytes());
|
||||
buf.extend_from_slice(&state.started_at_ns.to_le_bytes());
|
||||
buf.extend_from_slice(&(state.metadata.len() as u32).to_le_bytes());
|
||||
for (k, v) in &state.metadata {
|
||||
buf.extend_from_slice(&(k.len() as u32).to_le_bytes());
|
||||
buf.extend_from_slice(k.as_bytes());
|
||||
buf.extend_from_slice(&(v.len() as u32).to_le_bytes());
|
||||
buf.extend_from_slice(v.as_bytes());
|
||||
}
|
||||
buf
|
||||
}
|
||||
|
||||
/// Deserialize a session start record (reads `session_id`, `user_id`, `started_at_ns`).
|
||||
///
|
||||
/// Returns `None` if the bytes are malformed.
|
||||
#[must_use]
|
||||
pub fn deserialize_start_record(bytes: &[u8]) -> Option<(SessionId, u64, u64)> {
|
||||
if bytes.len() < 24 {
|
||||
return None;
|
||||
}
|
||||
let session_id = u64::from_le_bytes([
|
||||
bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
|
||||
]);
|
||||
let user_id = u64::from_le_bytes([
|
||||
bytes[8], bytes[9], bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15],
|
||||
]);
|
||||
let started_at_ns = u64::from_le_bytes([
|
||||
bytes[16], bytes[17], bytes[18], bytes[19], bytes[20], bytes[21], bytes[22], bytes[23],
|
||||
]);
|
||||
Some((SessionId::from_raw(session_id), user_id, started_at_ns))
|
||||
}
|
||||
|
||||
// ── Audit log ─────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Serialize an audit log to bytes for storage archival.
|
||||
202
tidal/src/session/serde/start_record.rs
Normal file
202
tidal/src/session/serde/start_record.rs
Normal file
@ -0,0 +1,202 @@
|
||||
//! Binary encoding and decoding for session start records.
|
||||
//!
|
||||
//! A start record is written to storage on `start_session` and deleted (replaced
|
||||
//! by a snapshot) when the session closes. On restart, surviving start records
|
||||
//! indicate sessions that were active at crash time; their metadata is
|
||||
//! rehydrated from this record during WAL replay.
|
||||
|
||||
use super::super::state::SessionState;
|
||||
use super::super::types::SessionId;
|
||||
|
||||
/// Serialize a compact session start record.
|
||||
///
|
||||
/// Written to storage on `start_session`; deleted when the session is closed
|
||||
/// and replaced by a snapshot record.
|
||||
///
|
||||
/// Format: `[session_id: 8][user_id: 8][started_at_ns: 8][metadata_count: u32][...kv...]`
|
||||
#[must_use]
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
pub fn serialize_start_record(state: &SessionState) -> Vec<u8> {
|
||||
let mut buf = Vec::new();
|
||||
buf.extend_from_slice(&state.id.as_u64().to_le_bytes());
|
||||
buf.extend_from_slice(&state.user_id.to_le_bytes());
|
||||
buf.extend_from_slice(&state.started_at_ns.to_le_bytes());
|
||||
buf.extend_from_slice(&(state.metadata.len() as u32).to_le_bytes());
|
||||
for (k, v) in &state.metadata {
|
||||
buf.extend_from_slice(&(k.len() as u32).to_le_bytes());
|
||||
buf.extend_from_slice(k.as_bytes());
|
||||
buf.extend_from_slice(&(v.len() as u32).to_le_bytes());
|
||||
buf.extend_from_slice(v.as_bytes());
|
||||
}
|
||||
buf
|
||||
}
|
||||
|
||||
/// Deserialize a session start record.
|
||||
///
|
||||
/// Returns `(session_id, user_id, started_at_ns, metadata)`.
|
||||
/// Older records written before metadata support was added return an empty map
|
||||
/// rather than `None` — partial reads are treated as empty metadata, not errors.
|
||||
///
|
||||
/// Returns `None` only if the mandatory first 24 bytes are missing.
|
||||
#[must_use]
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
pub fn deserialize_start_record(
|
||||
bytes: &[u8],
|
||||
) -> Option<(
|
||||
SessionId,
|
||||
u64,
|
||||
u64,
|
||||
std::collections::HashMap<String, String>,
|
||||
)> {
|
||||
if bytes.len() < 24 {
|
||||
return None;
|
||||
}
|
||||
let session_id = u64::from_le_bytes([
|
||||
bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
|
||||
]);
|
||||
let user_id = u64::from_le_bytes([
|
||||
bytes[8], bytes[9], bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15],
|
||||
]);
|
||||
let started_at_ns = u64::from_le_bytes([
|
||||
bytes[16], bytes[17], bytes[18], bytes[19], bytes[20], bytes[21], bytes[22], bytes[23],
|
||||
]);
|
||||
|
||||
// Metadata section: [count: u32][key_len: u32][key bytes][val_len: u32][val bytes]...
|
||||
// Records written before metadata support stop at byte 24 — return empty map.
|
||||
let mut pos = 24usize;
|
||||
let read_u32 = |p: &mut usize| -> Option<u32> {
|
||||
if *p + 4 > bytes.len() {
|
||||
return None;
|
||||
}
|
||||
let v = u32::from_le_bytes([bytes[*p], bytes[*p + 1], bytes[*p + 2], bytes[*p + 3]]);
|
||||
*p += 4;
|
||||
Some(v)
|
||||
};
|
||||
let metadata = 'parse: {
|
||||
let Some(count) = read_u32(&mut pos) else {
|
||||
break 'parse std::collections::HashMap::new();
|
||||
};
|
||||
let mut map = std::collections::HashMap::with_capacity(count as usize);
|
||||
for _ in 0..count {
|
||||
let Some(key_len) = read_u32(&mut pos) else {
|
||||
break 'parse map;
|
||||
};
|
||||
let key_len = key_len as usize;
|
||||
if pos + key_len > bytes.len() {
|
||||
break 'parse map;
|
||||
}
|
||||
let key = String::from_utf8_lossy(&bytes[pos..pos + key_len]).to_string();
|
||||
pos += key_len;
|
||||
|
||||
let Some(val_len) = read_u32(&mut pos) else {
|
||||
break 'parse map;
|
||||
};
|
||||
let val_len = val_len as usize;
|
||||
if pos + val_len > bytes.len() {
|
||||
break 'parse map;
|
||||
}
|
||||
let val = String::from_utf8_lossy(&bytes[pos..pos + val_len]).to_string();
|
||||
pos += val_len;
|
||||
|
||||
map.insert(key, val);
|
||||
}
|
||||
map
|
||||
};
|
||||
|
||||
Some((
|
||||
SessionId::from_raw(session_id),
|
||||
user_id,
|
||||
started_at_ns,
|
||||
metadata,
|
||||
))
|
||||
}
|
||||
|
||||
// ── Unit tests ────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(clippy::unwrap_used)]
|
||||
mod tests {
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64};
|
||||
|
||||
use super::{deserialize_start_record, serialize_start_record};
|
||||
use crate::session::audit::AuditLog;
|
||||
use crate::session::state::SessionState;
|
||||
use crate::session::types::{AgentId, SessionId};
|
||||
|
||||
fn make_state(metadata: HashMap<String, String>) -> SessionState {
|
||||
let closed = Arc::new(AtomicBool::new(false));
|
||||
SessionState {
|
||||
id: SessionId(42),
|
||||
user_id: 7,
|
||||
agent_id: AgentId::new("agent-a").unwrap(),
|
||||
policy_name: "default".to_string(),
|
||||
started_at: std::time::Instant::now(),
|
||||
started_at_ns: 123_456_789,
|
||||
metadata,
|
||||
signals: dashmap::DashMap::new(),
|
||||
signaled_entities: dashmap::DashMap::new(),
|
||||
annotations: std::sync::Mutex::new(Vec::new()),
|
||||
signals_written: AtomicU64::new(0),
|
||||
signals_rejected: AtomicU64::new(0),
|
||||
audit_log: std::sync::Mutex::new(AuditLog::new()),
|
||||
closed,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn start_record_roundtrip_with_metadata() {
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("tool".to_string(), "planner".to_string());
|
||||
metadata.insert("context".to_string(), "daily-feed".to_string());
|
||||
|
||||
let state = make_state(metadata);
|
||||
let bytes = serialize_start_record(&state);
|
||||
let (session_id, user_id, started_at_ns, metadata) =
|
||||
deserialize_start_record(&bytes).unwrap();
|
||||
|
||||
assert_eq!(session_id, SessionId(42));
|
||||
assert_eq!(user_id, 7);
|
||||
assert_eq!(started_at_ns, 123_456_789);
|
||||
assert_eq!(metadata.get("tool").map(String::as_str), Some("planner"));
|
||||
assert_eq!(
|
||||
metadata.get("context").map(String::as_str),
|
||||
Some("daily-feed")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn start_record_roundtrip_empty_metadata() {
|
||||
let state = make_state(HashMap::new());
|
||||
let bytes = serialize_start_record(&state);
|
||||
let (session_id, user_id, started_at_ns, metadata) =
|
||||
deserialize_start_record(&bytes).unwrap();
|
||||
|
||||
assert_eq!(session_id, SessionId(42));
|
||||
assert_eq!(user_id, 7);
|
||||
assert_eq!(started_at_ns, 123_456_789);
|
||||
assert!(metadata.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn start_record_backward_compat_no_metadata() {
|
||||
// A record with only the 24-byte header (pre-metadata format) must parse
|
||||
// without error and return an empty metadata map.
|
||||
let mut buf = Vec::new();
|
||||
buf.extend_from_slice(&42u64.to_le_bytes()); // session_id
|
||||
buf.extend_from_slice(&7u64.to_le_bytes()); // user_id
|
||||
buf.extend_from_slice(&999u64.to_le_bytes()); // started_at_ns
|
||||
// No metadata bytes — simulates a record written before this fix.
|
||||
|
||||
let (session_id, user_id, started_at_ns, metadata) =
|
||||
deserialize_start_record(&buf).unwrap();
|
||||
assert_eq!(session_id, SessionId(42));
|
||||
assert_eq!(user_id, 7);
|
||||
assert_eq!(started_at_ns, 999);
|
||||
assert!(
|
||||
metadata.is_empty(),
|
||||
"backward-compat: metadata should be empty"
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -441,7 +441,165 @@ fn active_session_state_restored_after_crash() {
|
||||
}
|
||||
}
|
||||
|
||||
// ── Test 7: WAL replay preserves signal counts exactly ───────────────────────
|
||||
// ── Test 6: Session metadata survives crash recovery ─────────────────────────
|
||||
|
||||
/// Verifies that `start_session` metadata (e.g. `{"tool": "planner"}`) is
|
||||
/// rehydrated from the storage start record during WAL replay, so that
|
||||
/// `session_snapshot().metadata` is populated after a crash.
|
||||
#[test]
|
||||
fn metadata_survives_crash() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let schema = test_schema();
|
||||
|
||||
let session_id;
|
||||
{
|
||||
let db = TidalDb::builder()
|
||||
.with_data_dir(dir.path())
|
||||
.with_schema(schema.clone())
|
||||
.open()
|
||||
.unwrap();
|
||||
|
||||
let mut meta = HashMap::new();
|
||||
meta.insert("tool".to_string(), "planner".to_string());
|
||||
meta.insert("context".to_string(), "daily-feed".to_string());
|
||||
|
||||
let handle = db
|
||||
.start_session(10, "agent-meta", "default_policy", meta)
|
||||
.unwrap();
|
||||
session_id = handle.id;
|
||||
|
||||
// Write a signal so the WAL has a SessionSignal record too.
|
||||
let mut item_meta = HashMap::new();
|
||||
item_meta.insert("title".to_string(), "item-1".to_string());
|
||||
db.write_item_with_metadata(EntityId::new(1), &item_meta)
|
||||
.unwrap();
|
||||
let ts = Timestamp::now();
|
||||
db.session_signal(&handle, "reward", EntityId::new(1), 1.0, ts, None)
|
||||
.unwrap();
|
||||
|
||||
// Simulate crash: drop without close_session().
|
||||
drop(db);
|
||||
}
|
||||
|
||||
// Reopen — metadata must be present in the restored active session.
|
||||
{
|
||||
let db = TidalDb::builder()
|
||||
.with_data_dir(dir.path())
|
||||
.with_schema(schema)
|
||||
.open()
|
||||
.unwrap();
|
||||
|
||||
let active = db.active_sessions();
|
||||
assert!(
|
||||
active.iter().any(|info| info.id == session_id),
|
||||
"session must be restored as active after crash"
|
||||
);
|
||||
|
||||
let snap = db.session_snapshot(session_id).unwrap();
|
||||
assert_eq!(
|
||||
snap.metadata.get("tool").map(String::as_str),
|
||||
Some("planner"),
|
||||
"tool metadata must survive crash recovery"
|
||||
);
|
||||
assert_eq!(
|
||||
snap.metadata.get("context").map(String::as_str),
|
||||
Some("daily-feed"),
|
||||
"context metadata must survive crash recovery"
|
||||
);
|
||||
|
||||
db.close().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
// ── Test 7: Preference annotations survive crash recovery ────────────────────
|
||||
|
||||
/// Verifies that session signal annotations ("more jazz today") are replayed
|
||||
/// from the WAL during crash recovery, so that `SessionContext::keywords` is
|
||||
/// populated for FOR SESSION ranking after a restart.
|
||||
#[test]
|
||||
fn annotations_survive_crash() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let schema = test_schema();
|
||||
|
||||
let session_id;
|
||||
{
|
||||
let db = TidalDb::builder()
|
||||
.with_data_dir(dir.path())
|
||||
.with_schema(schema.clone())
|
||||
.open()
|
||||
.unwrap();
|
||||
|
||||
let handle = db
|
||||
.start_session(11, "agent-ann", "default_policy", HashMap::new())
|
||||
.unwrap();
|
||||
session_id = handle.id;
|
||||
|
||||
let mut item_meta = HashMap::new();
|
||||
item_meta.insert("genre".to_string(), "jazz".to_string());
|
||||
db.write_item_with_metadata(EntityId::new(1), &item_meta)
|
||||
.unwrap();
|
||||
let ts = Timestamp::now();
|
||||
|
||||
// Write two annotated signals.
|
||||
db.session_signal(
|
||||
&handle,
|
||||
"reward",
|
||||
EntityId::new(1),
|
||||
1.0,
|
||||
ts,
|
||||
Some("more jazz today".to_string()),
|
||||
)
|
||||
.unwrap();
|
||||
db.session_signal(
|
||||
&handle,
|
||||
"view",
|
||||
EntityId::new(1),
|
||||
0.5,
|
||||
ts,
|
||||
Some("acoustic vibes".to_string()),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Simulate crash.
|
||||
drop(db);
|
||||
}
|
||||
|
||||
// Reopen — annotations must be present so FOR SESSION ranking can use them.
|
||||
{
|
||||
let db = TidalDb::builder()
|
||||
.with_data_dir(dir.path())
|
||||
.with_schema(schema)
|
||||
.open()
|
||||
.unwrap();
|
||||
|
||||
let active = db.active_sessions();
|
||||
assert!(
|
||||
active.iter().any(|info| info.id == session_id),
|
||||
"session must be restored as active after crash"
|
||||
);
|
||||
|
||||
let snap = db.session_snapshot(session_id).unwrap();
|
||||
assert_eq!(
|
||||
snap.annotations.len(),
|
||||
2,
|
||||
"both annotations must be replayed from WAL"
|
||||
);
|
||||
|
||||
let texts: Vec<&str> = snap.annotations.iter().map(|(_, s)| s.as_str()).collect();
|
||||
assert!(
|
||||
texts.contains(&"more jazz today"),
|
||||
"jazz annotation must survive crash recovery"
|
||||
);
|
||||
assert!(
|
||||
texts.contains(&"acoustic vibes"),
|
||||
"acoustic annotation must survive crash recovery"
|
||||
);
|
||||
|
||||
db.close().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
// ── Test 8: WAL replay preserves signal counts exactly ───────────────────────
|
||||
|
||||
/// Property-like correctness test: write exactly K signals of one type into
|
||||
/// an active session, "crash" (drop without close_session), reopen, and
|
||||
|
||||
Loading…
Reference in New Issue
Block a user