tidaldb/tidal/tests/vector_usearch.rs
2026-02-23 22:41:16 -07:00

413 lines
13 KiB
Rust

//! Integration tests for the `USearch` HNSW vector index backend.
//!
//! These tests exercise the `UsearchIndex` implementation of `VectorIndex`
//! at a scale that validates recall, persistence, and correctness properties
//! that unit tests cannot cover.
#![allow(clippy::unwrap_used)]
use rand::Rng;
use tidaldb::storage::vector::{
BruteForceIndex, DistanceMetric, QuantizationLevel, UsearchIndex, VectorIndex,
VectorIndexConfig,
};
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
/// Generate a random unit vector of the given dimensionality.
fn random_unit_vector(dim: usize, rng: &mut impl Rng) -> Vec<f32> {
let v: Vec<f32> = (0..dim).map(|_| rng.random::<f32>() - 0.5).collect();
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm < f32::EPSILON {
// Extremely unlikely for high-dim, but handle gracefully.
let mut fallback = vec![0.0f32; dim];
fallback[0] = 1.0;
return fallback;
}
v.iter().map(|x| x / norm).collect()
}
/// Compute recall@k: fraction of the brute-force top-k that appear in the
/// HNSW result set.
#[allow(clippy::cast_precision_loss)]
fn recall_at_k(brute_results: &[u64], hnsw_results: &[u64]) -> f64 {
let relevant: std::collections::HashSet<u64> = brute_results.iter().copied().collect();
let found = hnsw_results
.iter()
.filter(|id| relevant.contains(id))
.count();
found as f64 / brute_results.len() as f64
}
const fn default_config(dimensions: usize) -> VectorIndexConfig {
VectorIndexConfig {
dimensions,
metric: DistanceMetric::L2,
quantization: QuantizationLevel::F16,
connectivity: 16,
ef_construction: 200,
ef_search: 200,
}
}
const fn f32_config(dimensions: usize) -> VectorIndexConfig {
VectorIndexConfig {
dimensions,
metric: DistanceMetric::L2,
quantization: QuantizationLevel::F32,
connectivity: 16,
ef_construction: 200,
ef_search: 200,
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
/// Insert 1000 vectors (dim=128, F16), search with 10 queries, verify
/// recall@100 > 0.90 against `BruteForceIndex` ground truth.
#[test]
fn usearch_insert_and_search_1000_vectors() {
let dim = 128;
let n = 1000;
let k = 100;
let num_queries: i32 = 10;
let config = default_config(dim);
let brute_config = f32_config(dim);
let usearch = UsearchIndex::new(config).unwrap();
let brute = BruteForceIndex::new(brute_config);
usearch.reserve(n).unwrap();
let mut rng = rand::rng();
// Insert the same vectors into both indexes.
for id in 0..n as u64 {
let v = random_unit_vector(dim, &mut rng);
usearch.insert(id, &v).unwrap();
brute.insert(id, &v).unwrap();
}
assert_eq!(usearch.len(), n);
assert_eq!(usearch.len_live(), n);
// Query and measure recall.
let mut total_recall = 0.0;
for _ in 0..num_queries {
let query = random_unit_vector(dim, &mut rng);
let brute_results = brute.search(&query, k, 0).unwrap();
let usearch_results = usearch.search(&query, k, 200).unwrap();
let brute_ids: Vec<u64> = brute_results.iter().map(|r| r.id).collect();
let usearch_ids: Vec<u64> = usearch_results.iter().map(|r| r.id).collect();
let r = recall_at_k(&brute_ids, &usearch_ids);
total_recall += r;
}
let avg_recall = total_recall / f64::from(num_queries);
assert!(
avg_recall > 0.90,
"recall@{k} = {avg_recall:.3}, expected > 0.90"
);
eprintln!("1K vectors recall@{k}: {avg_recall:.3}");
}
/// 100 vectors, filter even IDs, verify all results are even.
#[test]
fn usearch_filtered_search_excludes_non_matching() {
let dim = 32;
let n = 100;
let k = 20;
let config = default_config(dim);
let index = UsearchIndex::new(config).unwrap();
index.reserve(n).unwrap();
let mut rng = rand::rng();
for id in 0..n as u64 {
let v = random_unit_vector(dim, &mut rng);
index.insert(id, &v).unwrap();
}
let query = random_unit_vector(dim, &mut rng);
let results = index
.filtered_search(&query, k, 200, &|id| id % 2 == 0)
.unwrap();
for r in &results {
assert!(
r.id % 2 == 0,
"odd ID {} found in even-only filtered search",
r.id
);
}
// Should have found some results (50 even IDs available).
assert!(
!results.is_empty(),
"filtered search returned no results despite 50 eligible vectors"
);
}
/// 50 vectors, delete ID 0, search for its vector, verify 0 absent.
#[test]
fn usearch_delete_excludes_from_results() {
let dim = 32;
let n = 50;
let k = n; // ask for all
let config = default_config(dim);
let index = UsearchIndex::new(config).unwrap();
index.reserve(n).unwrap();
let mut rng = rand::rng();
let mut vectors = Vec::with_capacity(n);
for id in 0..n as u64 {
let v = random_unit_vector(dim, &mut rng);
index.insert(id, &v).unwrap();
vectors.push(v);
}
// Delete vector 0.
index.delete(0).unwrap();
assert_eq!(index.len_live(), n - 1);
// Search for the deleted vector -- it must not appear in results.
let results = index.search(&vectors[0], k, 200).unwrap();
assert!(
results.iter().all(|r| r.id != 0),
"deleted vector ID 0 found in search results"
);
}
/// 100 vectors, save, load, verify top-1 match.
#[test]
fn usearch_save_load_roundtrip() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("usearch_roundtrip.idx");
let dim = 64;
let n = 100;
let config = default_config(dim);
let index = UsearchIndex::new(config.clone()).unwrap();
index.reserve(n).unwrap();
let mut rng = rand::rng();
let mut vectors = Vec::with_capacity(n);
for id in 0..n as u64 {
let v = random_unit_vector(dim, &mut rng);
index.insert(id, &v).unwrap();
vectors.push(v);
}
index.save(&path).unwrap();
let loaded = UsearchIndex::load(&path, &config).unwrap();
assert_eq!(loaded.len(), n);
assert_eq!(loaded.len_live(), n);
// Verify that each vector's nearest neighbor in the loaded index is itself.
for (id, vec) in vectors.iter().enumerate() {
let results = loaded.search(vec, 1, 200).unwrap();
assert_eq!(
results[0].id, id as u64,
"top-1 mismatch after load: expected {id}, got {}",
results[0].id
);
}
}
/// 50 vectors, save, view (mmap), search works.
#[test]
fn usearch_view_readonly() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("usearch_view.idx");
let dim = 32;
let n = 50;
let config = default_config(dim);
let index = UsearchIndex::new(config.clone()).unwrap();
index.reserve(n).unwrap();
let mut rng = rand::rng();
let mut vectors = Vec::with_capacity(n);
for id in 0..n as u64 {
let v = random_unit_vector(dim, &mut rng);
index.insert(id, &v).unwrap();
vectors.push(v);
}
index.save(&path).unwrap();
let viewed = UsearchIndex::view(&path, &config).unwrap();
assert_eq!(viewed.len(), n);
// Search works on the viewed index.
let results = viewed.search(&vectors[0], 1, 200).unwrap();
assert_eq!(results[0].id, 0);
}
/// Wrong dimensions on insert and search return `DimensionMismatch`.
#[test]
fn usearch_dimension_mismatch() {
let config = default_config(64);
let index = UsearchIndex::new(config).unwrap();
// Insert with wrong dimensions.
let too_short = vec![1.0f32; 32];
let result = index.insert(1, &too_short);
assert!(result.is_err());
match result.unwrap_err() {
tidaldb::storage::vector::VectorError::DimensionMismatch { expected, got } => {
assert_eq!(expected, 64);
assert_eq!(got, 32);
}
other => panic!("expected DimensionMismatch, got {other:?}"),
}
// Search with wrong dimensions.
let too_long = vec![1.0f32; 128];
let result = index.search(&too_long, 1, 200);
assert!(result.is_err());
match result.unwrap_err() {
tidaldb::storage::vector::VectorError::DimensionMismatch { expected, got } => {
assert_eq!(expected, 64);
assert_eq!(got, 128);
}
other => panic!("expected DimensionMismatch, got {other:?}"),
}
}
/// Compile-time assertion that `UsearchIndex` is `Send + Sync`.
#[test]
fn usearch_is_send_and_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<UsearchIndex>();
}
/// Recall@10 correctness guard for the default VectorIndexConfig (M=16, ef=400).
///
/// Uses 1K vectors / 128D to verify that the production default achieves
/// recall@10 > 0.95. This test is designed to catch regressions if the
/// default parameters are changed to values that hurt recall.
///
/// Based on the m7p3 USearch grid search (see docs/profiling/usearch-tuning.md):
/// M=16, ef_construction=400 achieves recall@10 ≈ 0.993 at 100K vectors / 128D.
#[test]
fn recall_at_10_above_threshold() {
let dim = 128;
let n = 1_000u64;
let k = 10;
let num_queries = 50;
// Use the production default config.
let config = VectorIndexConfig::default_for_dim(dim);
let brute_config = VectorIndexConfig {
dimensions: dim,
metric: DistanceMetric::L2,
quantization: QuantizationLevel::F32,
connectivity: 16,
ef_construction: 400,
ef_search: 400,
};
let usearch = UsearchIndex::new(config).unwrap();
let brute = BruteForceIndex::new(brute_config);
usearch.reserve(n as usize).unwrap();
let mut rng = rand::rng();
let mut vectors: Vec<Vec<f32>> = Vec::with_capacity(n as usize);
for id in 0..n {
let v = random_unit_vector(dim, &mut rng);
usearch.insert(id, &v).unwrap();
brute.insert(id, &v).unwrap();
vectors.push(v);
}
let mut total_recall = 0.0_f64;
for _ in 0..num_queries {
let query = random_unit_vector(dim, &mut rng);
let brute_results = brute.search(&query, k, 400).unwrap();
let usearch_results = usearch.search(&query, k, 400).unwrap();
let brute_ids: Vec<u64> = brute_results.iter().map(|r| r.id).collect();
let usearch_ids: Vec<u64> = usearch_results.iter().map(|r| r.id).collect();
total_recall += recall_at_k(&brute_ids, &usearch_ids);
}
let avg_recall = total_recall / num_queries as f64;
assert!(
avg_recall > 0.95,
"recall@{k} = {avg_recall:.3} with default config, expected > 0.95"
);
eprintln!("Default config (M=16, ef=400) recall@{k} at 1K vectors: {avg_recall:.3}");
}
/// 10K vectors (dim=128), recall@100 > 0.95.
///
/// Uses F32 quantization and `ef_search`=400 for the HNSW index to ensure
/// high recall at this scale. F16 quantization introduces enough precision
/// loss at 10K that recall drops below the 0.95 threshold with default
/// parameters. Production indexes would tune these parameters per-dataset.
#[test]
fn usearch_recall_at_10k() {
let dim = 128;
let n = 10_000;
let k = 100;
let num_queries: i32 = 10;
let config = VectorIndexConfig {
dimensions: dim,
metric: DistanceMetric::L2,
quantization: QuantizationLevel::F32,
connectivity: 16,
ef_construction: 200,
ef_search: 400,
};
let brute_config = f32_config(dim);
let usearch = UsearchIndex::new(config).unwrap();
let brute = BruteForceIndex::new(brute_config);
usearch.reserve(n).unwrap();
let mut rng = rand::rng();
for id in 0..n as u64 {
let v = random_unit_vector(dim, &mut rng);
usearch.insert(id, &v).unwrap();
brute.insert(id, &v).unwrap();
}
assert_eq!(usearch.len(), n);
assert_eq!(usearch.len_live(), n);
let mut total_recall = 0.0;
for _ in 0..num_queries {
let query = random_unit_vector(dim, &mut rng);
let brute_results = brute.search(&query, k, 0).unwrap();
let usearch_results = usearch.search(&query, k, 200).unwrap();
let brute_ids: Vec<u64> = brute_results.iter().map(|r| r.id).collect();
let usearch_ids: Vec<u64> = usearch_results.iter().map(|r| r.id).collect();
let r = recall_at_k(&brute_ids, &usearch_ids);
total_recall += r;
}
let avg_recall = total_recall / f64::from(num_queries);
assert!(
avg_recall > 0.95,
"recall@{k} = {avg_recall:.3}, expected > 0.95 for 10K vectors"
);
eprintln!("10K vectors recall@{k}: {avg_recall:.3}");
}