stemedb/applications/aphoria/src/extractors/unsafe_atomic.rs
2026-02-07 19:51:05 -07:00

305 lines
9.1 KiB
Rust

//! Unsafe and atomic patterns extractor for Rust.
//!
//! Tracks `unsafe` blocks and `Ordering::*` patterns for correctness conventions.
//! Enables learning loop to establish patterns like:
//! - "All wallet operations use Ordering::SeqCst"
//! - "Unsafe code requires documented safety invariants"
use regex::Regex;
use stemedb_core::types::ObjectValue;
use super::Extractor;
use crate::types::{ExtractedClaim, Language};
/// Extractor for unsafe blocks and atomic ordering patterns.
///
/// Detects safety-critical patterns in Rust code to enable
/// correctness conventions.
pub struct UnsafeAtomicExtractor {
/// Matches: Ordering::SeqCst, Ordering::Relaxed, etc.
ordering_pattern: Regex,
/// Matches: unsafe { ... } or unsafe fn
unsafe_keyword: Regex,
}
impl Default for UnsafeAtomicExtractor {
fn default() -> Self {
Self::new()
}
}
impl UnsafeAtomicExtractor {
/// Create a new unsafe/atomic extractor.
///
/// # Panics
/// Panics if any regex pattern is invalid (programmer error).
#[allow(clippy::expect_used)]
pub fn new() -> Self {
Self {
// Ordering::SeqCst, Ordering::Relaxed, etc.
ordering_pattern: Regex::new(r"Ordering::(SeqCst|Acquire|Release|AcqRel|Relaxed)")
.expect("valid regex"),
// unsafe keyword (blocks or functions)
unsafe_keyword: Regex::new(r"\b(unsafe)\s*(\{|fn)").expect("valid regex"),
}
}
/// Determine confidence based on context.
fn confidence_for_file(&self, file: &str) -> f32 {
if file.contains("test") || file.contains("example") || file.contains("bench") {
0.5
} else {
1.0
}
}
}
impl Extractor for UnsafeAtomicExtractor {
fn name(&self) -> &str {
"unsafe_atomic"
}
fn languages(&self) -> &[Language] {
&[Language::Rust]
}
fn extract(
&self,
path_segments: &[String],
content: &str,
_language: Language,
file: &str,
) -> Vec<ExtractedClaim> {
let mut claims = Vec::new();
let confidence = self.confidence_for_file(file);
// Track unique patterns to avoid excessive claims
let mut seen_orderings = std::collections::HashSet::new();
let mut unsafe_count = 0;
for (line_idx, line) in content.lines().enumerate() {
let line_num = line_idx + 1;
// Check for atomic ordering patterns
if let Some(cap) = self.ordering_pattern.captures(line) {
let ordering = cap.get(1).map_or("", |m| m.as_str());
if !seen_orderings.contains(ordering) {
seen_orderings.insert(ordering.to_string());
let mut concept_path = path_segments.to_vec();
concept_path.push("atomics".to_string());
concept_path.push("ordering".to_string());
claims.push(ExtractedClaim {
concept_path: format!("code://{}", concept_path.join("/")),
predicate: "pattern".to_string(),
value: ObjectValue::Text(ordering.to_string()),
file: file.to_string(),
line: line_num,
matched_text: line.trim().to_string(),
confidence,
description: format!("Atomic operation uses Ordering::{}", ordering),
});
}
}
// Check for unsafe blocks/functions
if self.unsafe_keyword.is_match(line) {
unsafe_count += 1;
}
}
// Add a summary claim for unsafe usage if found
if unsafe_count > 0 {
let mut concept_path = path_segments.to_vec();
concept_path.push("unsafe".to_string());
concept_path.push("count".to_string());
claims.push(ExtractedClaim {
concept_path: format!("code://{}", concept_path.join("/")),
predicate: "occurrences".to_string(),
value: ObjectValue::Number(unsafe_count as f64),
file: file.to_string(),
line: 1,
matched_text: format!("{} unsafe blocks/functions", unsafe_count),
confidence: confidence * 0.9, // Slightly lower as this is a summary
description: format!(
"File contains {} unsafe block(s) or function(s)",
unsafe_count
),
});
}
claims
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_atomic_ordering() {
let extractor = UnsafeAtomicExtractor::new();
let content = r#"
let balance = self.balance.load(Ordering::SeqCst);
self.balance.store(new_balance, Ordering::SeqCst);
"#;
let claims = extractor.extract(
&["rust".to_string(), "maxwell".to_string(), "wallet".to_string()],
content,
Language::Rust,
"src/wallet.rs",
);
// Should have one claim for SeqCst (deduplicated)
assert!(claims.iter().any(|c| {
c.concept_path.contains("atomics/ordering")
&& c.value == ObjectValue::Text("SeqCst".to_string())
}));
}
#[test]
fn test_multiple_orderings() {
let extractor = UnsafeAtomicExtractor::new();
let content = r#"
let a = atomic.load(Ordering::Acquire);
let b = atomic.load(Ordering::Relaxed);
atomic.store(x, Ordering::Release);
"#;
let claims =
extractor.extract(&["rust".to_string()], content, Language::Rust, "src/sync.rs");
// Should have 3 distinct ordering claims (Acquire, Relaxed, Release)
let ordering_claims: Vec<_> =
claims.iter().filter(|c| c.concept_path.contains("ordering")).collect();
assert_eq!(ordering_claims.len(), 3);
}
#[test]
fn test_unsafe_block() {
let extractor = UnsafeAtomicExtractor::new();
let content = r#"
unsafe {
let ptr = mem::transmute(addr);
}
"#;
let claims =
extractor.extract(&["rust".to_string()], content, Language::Rust, "src/lib.rs");
// Should have one unsafe count claim
let unsafe_claim = claims.iter().find(|c| c.concept_path.contains("unsafe/count"));
assert!(unsafe_claim.is_some());
assert_eq!(unsafe_claim.unwrap().value, ObjectValue::Number(1.0));
}
#[test]
fn test_unsafe_fn() {
let extractor = UnsafeAtomicExtractor::new();
let content = r#"
unsafe fn read_msr(reg: u32) -> u64 {
// ...
}
"#;
let claims =
extractor.extract(&["rust".to_string()], content, Language::Rust, "src/msr.rs");
let unsafe_claim = claims.iter().find(|c| c.concept_path.contains("unsafe"));
assert!(unsafe_claim.is_some());
}
#[test]
fn test_multiple_unsafe_blocks() {
let extractor = UnsafeAtomicExtractor::new();
let content = r#"
unsafe fn foo() {}
fn bar() {
unsafe {
// block 1
}
unsafe {
// block 2
}
}
"#;
let claims =
extractor.extract(&["rust".to_string()], content, Language::Rust, "src/lib.rs");
let unsafe_claim = claims.iter().find(|c| c.concept_path.contains("unsafe/count")).unwrap();
assert_eq!(unsafe_claim.value, ObjectValue::Number(3.0)); // 1 fn + 2 blocks
}
#[test]
fn test_confidence_in_test_file() {
let extractor = UnsafeAtomicExtractor::new();
let content = r#"
unsafe { test_something(); }
"#;
let claims =
extractor.extract(&["rust".to_string()], content, Language::Rust, "src/test.rs");
assert!(!claims.is_empty());
// Confidence should be reduced for test files
assert!(claims.iter().all(|c| c.confidence <= 0.5));
}
#[test]
fn test_real_world_wallet() {
let extractor = UnsafeAtomicExtractor::new();
let content = r#"
//! Wallet with atomic balance tracking
use std::sync::atomic::{AtomicU64, Ordering};
pub struct Wallet {
balance: AtomicU64,
}
impl Wallet {
pub fn deposit(&self, amount: u64) {
self.balance.fetch_add(amount, Ordering::SeqCst);
}
pub fn withdraw(&self, amount: u64) -> bool {
let current = self.balance.load(Ordering::SeqCst);
if current >= amount {
self.balance.fetch_sub(amount, Ordering::SeqCst);
true
} else {
false
}
}
pub fn balance(&self) -> u64 {
self.balance.load(Ordering::SeqCst);
}
}
"#;
let claims = extractor.extract(
&["rust".to_string(), "maxwell".to_string(), "wallet".to_string()],
content,
Language::Rust,
"src/wallet.rs",
);
// Should detect SeqCst ordering (all wallet ops use it consistently)
assert!(claims.iter().any(|c| c.concept_path.contains("ordering")
&& c.value == ObjectValue::Text("SeqCst".to_string())));
// Should NOT have unsafe claims (no unsafe code)
assert!(!claims.iter().any(|c| c.concept_path.contains("unsafe")));
}
}