tidaldb/tidal/tests/m6_cohort.rs
2026-02-23 22:41:16 -07:00

719 lines
21 KiB
Rust

//! Milestone 6 Phase 1 Integration Tests: Cohort Engine + Cohort-Scoped Trending.
//!
//! Exercises the complete M6P1 scenario end-to-end:
//!
//! 1. Define cohorts (tech, sports) via predicates on user metadata.
//! 2. Write users with metadata placing them in different cohorts.
//! 3. Write items and signals with user context (triggering cohort attribution).
//! 4. Verify cohort signal state is independent per cohort.
//! 5. Execute RETRIEVE with `.cohort("tech")` and verify cohort-scoped ranking.
//! 6. Execute RETRIEVE with `.profile("cohort_trending")` + cohort clause.
//! 7. Verify that signals from non-cohort users do not appear in cohort state.
#![allow(clippy::unwrap_used, clippy::cast_precision_loss)]
use std::collections::HashMap;
use std::time::Duration;
#[cfg(feature = "test-utils")]
use tidaldb::TempTidalHome;
use tidaldb::TidalDb;
use tidaldb::cohort::{CohortDef, Predicate};
use tidaldb::query::retrieve::Retrieve;
use tidaldb::schema::{DecaySpec, EntityId, EntityKind, SchemaBuilder, Timestamp, Window};
// ── Schema ──────────────────────────────────────────────────────────────────
fn m6_schema() -> tidaldb::schema::Schema {
let mut builder = SchemaBuilder::new();
for &(name, half_life_days) in &[
("view", 7),
("like", 14),
("share", 7),
("skip", 1),
("completion", 14),
] {
let _ = builder
.signal(
name,
EntityKind::Item,
DecaySpec::Exponential {
half_life: Duration::from_secs(half_life_days * 24 * 3600),
},
)
.windows(&[
Window::OneHour,
Window::TwentyFourHours,
Window::SevenDays,
Window::AllTime,
])
.velocity(true)
.add();
}
builder.build().expect("m6 schema must be valid")
}
// ── Helpers ─────────────────────────────────────────────────────────────────
fn user_metadata(locale: &str, primary_category: &str) -> HashMap<String, String> {
let mut m = HashMap::new();
m.insert("locale".to_string(), locale.to_string());
m.insert("primary_category".to_string(), primary_category.to_string());
m
}
fn item_metadata(category: &str, creator_id: u64) -> HashMap<String, String> {
let mut meta = HashMap::new();
meta.insert("category".to_string(), category.to_string());
meta.insert("format".to_string(), "video".to_string());
meta.insert("creator_id".to_string(), creator_id.to_string());
meta.insert(
"created_at".to_string(),
Timestamp::now().as_nanos().to_string(),
);
meta
}
fn open_ephemeral_db() -> TidalDb {
TidalDb::builder()
.ephemeral()
.with_schema(m6_schema())
.open()
.expect("db open")
}
// ── Tests ───────────────────────────────────────────────────────────────────
#[test]
fn define_cohorts_and_verify_registry() {
let db = open_ephemeral_db();
// Define two cohorts.
db.define_cohort(CohortDef {
name: "tech_en".to_string(),
predicate: Predicate::And(vec![
Predicate::Eq {
field: "locale".into(),
value: "en".into(),
},
Predicate::Eq {
field: "primary_category".into(),
value: "tech".into(),
},
]),
})
.unwrap();
db.define_cohort(CohortDef {
name: "sports".to_string(),
predicate: Predicate::Eq {
field: "primary_category".into(),
value: "sports".into(),
},
})
.unwrap();
// Duplicate should fail.
let result = db.define_cohort(CohortDef {
name: "tech_en".to_string(),
predicate: Predicate::Eq {
field: "x".into(),
value: "y".into(),
},
});
assert!(result.is_err());
db.close().unwrap();
}
#[test]
fn cohort_signal_attribution_on_signal_with_context() {
let db = open_ephemeral_db();
// Define cohorts.
db.define_cohort(CohortDef {
name: "tech_en".to_string(),
predicate: Predicate::And(vec![
Predicate::Eq {
field: "locale".into(),
value: "en".into(),
},
Predicate::Eq {
field: "primary_category".into(),
value: "tech".into(),
},
]),
})
.unwrap();
db.define_cohort(CohortDef {
name: "sports".to_string(),
predicate: Predicate::Eq {
field: "primary_category".into(),
value: "sports".into(),
},
})
.unwrap();
// Write users with metadata.
let user_1 = EntityId::new(1001);
let user_2 = EntityId::new(1002);
let user_3 = EntityId::new(1003);
db.write_user(user_1, &user_metadata("en", "tech")).unwrap();
db.write_user(user_2, &user_metadata("en", "sports"))
.unwrap();
db.write_user(user_3, &user_metadata("fr", "tech")).unwrap(); // French tech user -- no cohort match for tech_en
// Write items.
for i in 1..=5u64 {
db.write_item_with_metadata(EntityId::new(i), &item_metadata("tech", 100))
.unwrap();
}
let now = Timestamp::now();
// User 1 (tech_en) views items 1, 2, 3.
for i in 1..=3u64 {
db.signal_with_context("view", EntityId::new(i), 1.0, now, Some(1001), Some(100))
.unwrap();
}
// User 2 (sports) views items 2, 4.
for &i in &[2u64, 4] {
db.signal_with_context("view", EntityId::new(i), 1.0, now, Some(1002), Some(100))
.unwrap();
}
// User 3 (fr, tech -- not in tech_en cohort) views item 5.
db.signal_with_context("view", EntityId::new(5), 1.0, now, Some(1003), Some(100))
.unwrap();
// Verify cohort signal state.
let cohort_ledger = db.cohort_ledger();
// tech_en cohort: items 1, 2, 3 should have signals (from user 1).
let tech_count_1 = cohort_ledger
.read_windowed_count("tech_en", EntityId::new(1), "view", Window::AllTime)
.unwrap();
assert_eq!(tech_count_1, 1, "tech_en item 1 should have 1 view");
let tech_count_2 = cohort_ledger
.read_windowed_count("tech_en", EntityId::new(2), "view", Window::AllTime)
.unwrap();
assert_eq!(tech_count_2, 1, "tech_en item 2 should have 1 view");
// tech_en cohort: item 4 should have 0 signals (only user 2 viewed it).
let tech_count_4 = cohort_ledger
.read_windowed_count("tech_en", EntityId::new(4), "view", Window::AllTime)
.unwrap();
assert_eq!(
tech_count_4, 0,
"tech_en item 4 should have 0 views (sports user)"
);
// sports cohort: items 2, 4 should have signals (from user 2).
let sports_count_2 = cohort_ledger
.read_windowed_count("sports", EntityId::new(2), "view", Window::AllTime)
.unwrap();
assert_eq!(sports_count_2, 1, "sports item 2 should have 1 view");
let sports_count_4 = cohort_ledger
.read_windowed_count("sports", EntityId::new(4), "view", Window::AllTime)
.unwrap();
assert_eq!(sports_count_4, 1, "sports item 4 should have 1 view");
// sports cohort: item 1 should have 0 signals (only tech user viewed it).
let sports_count_1 = cohort_ledger
.read_windowed_count("sports", EntityId::new(1), "view", Window::AllTime)
.unwrap();
assert_eq!(
sports_count_1, 0,
"sports item 1 should have 0 views (tech user)"
);
// tech_en: item 5 should have 0 (user 3 is fr, not en).
let tech_count_5 = cohort_ledger
.read_windowed_count("tech_en", EntityId::new(5), "view", Window::AllTime)
.unwrap();
assert_eq!(
tech_count_5, 0,
"tech_en item 5 should have 0 views (user 3 is fr locale)"
);
db.close().unwrap();
}
#[test]
fn retrieve_with_named_cohort() {
let db = open_ephemeral_db();
// Define cohort.
db.define_cohort(CohortDef {
name: "tech_en".to_string(),
predicate: Predicate::And(vec![
Predicate::Eq {
field: "locale".into(),
value: "en".into(),
},
Predicate::Eq {
field: "primary_category".into(),
value: "tech".into(),
},
]),
})
.unwrap();
// Write user.
db.write_user(EntityId::new(1001), &user_metadata("en", "tech"))
.unwrap();
// Write 5 items.
for i in 1..=5u64 {
db.write_item_with_metadata(EntityId::new(i), &item_metadata("tech", 100))
.unwrap();
}
let now = Timestamp::now();
// User 1001 (tech_en) views items with different weights.
// Item 1: 10 views, Item 2: 5 views, Item 3: 1 view.
for _ in 0..10 {
db.signal_with_context("view", EntityId::new(1), 1.0, now, Some(1001), Some(100))
.unwrap();
}
for _ in 0..5 {
db.signal_with_context("view", EntityId::new(2), 1.0, now, Some(1001), Some(100))
.unwrap();
}
db.signal_with_context("view", EntityId::new(3), 1.0, now, Some(1001), Some(100))
.unwrap();
// Execute RETRIEVE with cohort_trending profile and cohort clause.
let query = Retrieve::builder()
.profile("cohort_trending")
.cohort("tech_en")
.limit(5)
.build()
.unwrap();
let results = db.retrieve(&query).unwrap();
// All 5 items should be returned.
assert!(
!results.items.is_empty(),
"cohort query should return items"
);
// Item 1 should be ranked highest (most cohort views).
assert_eq!(
results.items[0].entity_id,
EntityId::new(1),
"item 1 should be ranked first (10 cohort views)"
);
db.close().unwrap();
}
#[test]
fn retrieve_with_nonexistent_cohort_fails() {
let db = open_ephemeral_db();
// Write at least one item so the query pipeline runs.
db.write_item_with_metadata(EntityId::new(1), &item_metadata("tech", 100))
.unwrap();
let query = Retrieve::builder()
.profile("cohort_trending")
.cohort("nonexistent_cohort")
.limit(5)
.build()
.unwrap();
let result = db.retrieve(&query);
assert!(result.is_err(), "query with nonexistent cohort should fail");
db.close().unwrap();
}
#[test]
fn cohort_trending_profile_exists() {
let db = open_ephemeral_db();
// Verify the profile is registered and usable.
let query = Retrieve::builder()
.profile("cohort_trending")
.limit(5)
.build()
.unwrap();
// Should succeed (no items, but the profile is found).
let results = db.retrieve(&query).unwrap();
assert!(results.items.is_empty());
db.close().unwrap();
}
#[test]
fn cohort_ledger_decay_score() {
let db = open_ephemeral_db();
db.define_cohort(CohortDef {
name: "test".to_string(),
predicate: Predicate::Eq {
field: "primary_category".into(),
value: "test".into(),
},
})
.unwrap();
db.write_user(EntityId::new(2001), &user_metadata("en", "test"))
.unwrap();
db.write_item_with_metadata(EntityId::new(1), &item_metadata("test", 200))
.unwrap();
let now = Timestamp::now();
db.signal_with_context("view", EntityId::new(1), 5.0, now, Some(2001), Some(200))
.unwrap();
let cohort_ledger = db.cohort_ledger();
let score = cohort_ledger
.read_decay_score("test", EntityId::new(1), "view", 0)
.unwrap();
assert!(score.is_some(), "cohort decay score should be present");
assert!(
score.unwrap() > 0.0,
"cohort decay score should be positive"
);
db.close().unwrap();
}
#[test]
fn cohort_ledger_velocity() {
let db = open_ephemeral_db();
db.define_cohort(CohortDef {
name: "test".to_string(),
predicate: Predicate::Eq {
field: "primary_category".into(),
value: "test".into(),
},
})
.unwrap();
db.write_user(EntityId::new(2001), &user_metadata("en", "test"))
.unwrap();
db.write_item_with_metadata(EntityId::new(1), &item_metadata("test", 200))
.unwrap();
let now = Timestamp::now();
// Record 10 signals from the test cohort user.
for _ in 0..10 {
db.signal_with_context("view", EntityId::new(1), 1.0, now, Some(2001), Some(200))
.unwrap();
}
let cohort_ledger = db.cohort_ledger();
let vel = cohort_ledger
.read_velocity("test", EntityId::new(1), "view", Window::OneHour)
.unwrap();
assert!(
vel > 0.0,
"cohort velocity should be positive after 10 signals"
);
db.close().unwrap();
}
#[test]
fn signal_without_user_context_no_cohort_attribution() {
let db = open_ephemeral_db();
db.define_cohort(CohortDef {
name: "test".to_string(),
predicate: Predicate::Eq {
field: "primary_category".into(),
value: "test".into(),
},
})
.unwrap();
db.write_item_with_metadata(EntityId::new(1), &item_metadata("test", 200))
.unwrap();
let now = Timestamp::now();
// Signal without user context -- should NOT attribute to any cohort.
db.signal("view", EntityId::new(1), 1.0, now).unwrap();
let cohort_ledger = db.cohort_ledger();
let count = cohort_ledger
.read_windowed_count("test", EntityId::new(1), "view", Window::AllTime)
.unwrap();
assert_eq!(
count, 0,
"signals without user context should not be attributed to cohorts"
);
db.close().unwrap();
}
#[test]
#[cfg(feature = "test-utils")]
fn cohort_definition_survives_restart() {
let home = TempTidalHome::new().unwrap();
// First open: define two cohorts with different predicate shapes.
{
let db = TidalDb::builder()
.with_data_dir(home.path())
.with_schema(m6_schema())
.open()
.unwrap();
db.define_cohort(CohortDef {
name: "tech_fans".to_string(),
predicate: Predicate::And(vec![
Predicate::Eq {
field: "locale".into(),
value: "en".into(),
},
Predicate::Eq {
field: "primary_category".into(),
value: "tech".into(),
},
]),
})
.unwrap();
db.define_cohort(CohortDef {
name: "gamers".to_string(),
predicate: Predicate::Any {
field: "interest".into(),
values: vec!["gaming".into(), "esports".into()],
},
})
.unwrap();
db.close().unwrap();
}
// Second open: cohort definitions must be restored from durable storage.
{
let db = TidalDb::builder()
.with_data_dir(home.path())
.with_schema(m6_schema())
.open()
.unwrap();
// Verify both cohorts are present.
// We can exercise define_cohort with the same name to confirm
// the registry already has them -- it should return Err.
let dup_result = db.define_cohort(CohortDef {
name: "tech_fans".to_string(),
predicate: Predicate::Eq {
field: "x".into(),
value: "y".into(),
},
});
assert!(
dup_result.is_err(),
"tech_fans should already be registered after restart"
);
let dup_result = db.define_cohort(CohortDef {
name: "gamers".to_string(),
predicate: Predicate::Eq {
field: "x".into(),
value: "y".into(),
},
});
assert!(
dup_result.is_err(),
"gamers should already be registered after restart"
);
// Verify that a new cohort can still be defined after restart.
db.define_cohort(CohortDef {
name: "new_cohort".to_string(),
predicate: Predicate::Eq {
field: "locale".into(),
value: "fr".into(),
},
})
.unwrap();
db.close().unwrap();
}
// Third open: verify the third cohort was also persisted.
{
let db = TidalDb::builder()
.with_data_dir(home.path())
.with_schema(m6_schema())
.open()
.unwrap();
let dup_result = db.define_cohort(CohortDef {
name: "new_cohort".to_string(),
predicate: Predicate::Eq {
field: "x".into(),
value: "y".into(),
},
});
assert!(
dup_result.is_err(),
"new_cohort should be registered after second restart"
);
db.close().unwrap();
}
}
#[test]
#[cfg(feature = "test-utils")]
fn cohort_definition_survives_restart_and_queries_work() {
let home = TempTidalHome::new().unwrap();
// First open: define cohort, write user + items + signals.
{
let db = TidalDb::builder()
.with_data_dir(home.path())
.with_schema(m6_schema())
.open()
.unwrap();
db.define_cohort(CohortDef {
name: "tech_en".to_string(),
predicate: Predicate::And(vec![
Predicate::Eq {
field: "locale".into(),
value: "en".into(),
},
Predicate::Eq {
field: "primary_category".into(),
value: "tech".into(),
},
]),
})
.unwrap();
db.write_user(EntityId::new(1001), &user_metadata("en", "tech"))
.unwrap();
for i in 1..=3u64 {
db.write_item_with_metadata(EntityId::new(i), &item_metadata("tech", 100))
.unwrap();
}
let now = Timestamp::now();
for _ in 0..5 {
db.signal_with_context("view", EntityId::new(1), 1.0, now, Some(1001), Some(100))
.unwrap();
}
db.close().unwrap();
}
// Second open: verify cohort query still works.
{
let db = TidalDb::builder()
.with_data_dir(home.path())
.with_schema(m6_schema())
.open()
.unwrap();
// The cohort definition should be restored.
let dup = db.define_cohort(CohortDef {
name: "tech_en".to_string(),
predicate: Predicate::Eq {
field: "x".into(),
value: "y".into(),
},
});
assert!(
dup.is_err(),
"tech_en cohort definition should survive restart"
);
// Retrieve with cohort_trending should work (profile resolves the cohort name).
let query = Retrieve::builder()
.profile("cohort_trending")
.cohort("tech_en")
.limit(5)
.build()
.unwrap();
// This should not error -- the cohort is known.
let results = db.retrieve(&query).unwrap();
// Items may or may not have cohort signal state after restart (signals were
// checkpointed), but the query must not fail due to a missing cohort definition.
assert!(
results.items.len() <= 3,
"at most 3 items should be returned"
);
db.close().unwrap();
}
}
#[test]
fn multiple_cohort_memberships() {
let db = open_ephemeral_db();
// Define overlapping cohorts.
db.define_cohort(CohortDef {
name: "english".to_string(),
predicate: Predicate::Eq {
field: "locale".into(),
value: "en".into(),
},
})
.unwrap();
db.define_cohort(CohortDef {
name: "tech".to_string(),
predicate: Predicate::Eq {
field: "primary_category".into(),
value: "tech".into(),
},
})
.unwrap();
// User belongs to BOTH cohorts.
db.write_user(EntityId::new(3001), &user_metadata("en", "tech"))
.unwrap();
db.write_item_with_metadata(EntityId::new(1), &item_metadata("tech", 300))
.unwrap();
let now = Timestamp::now();
db.signal_with_context("view", EntityId::new(1), 1.0, now, Some(3001), Some(300))
.unwrap();
let cohort_ledger = db.cohort_ledger();
// Signal should appear in BOTH cohorts.
let english_count = cohort_ledger
.read_windowed_count("english", EntityId::new(1), "view", Window::AllTime)
.unwrap();
let tech_count = cohort_ledger
.read_windowed_count("tech", EntityId::new(1), "view", Window::AllTime)
.unwrap();
assert_eq!(
english_count, 1,
"english cohort should have the signal from en-locale user"
);
assert_eq!(
tech_count, 1,
"tech cohort should have the signal from tech user"
);
db.close().unwrap();
}