305 lines
9.1 KiB
Rust
305 lines
9.1 KiB
Rust
//! Unsafe and atomic patterns extractor for Rust.
|
|
//!
|
|
//! Tracks `unsafe` blocks and `Ordering::*` patterns for correctness conventions.
|
|
//! Enables learning loop to establish patterns like:
|
|
//! - "All wallet operations use Ordering::SeqCst"
|
|
//! - "Unsafe code requires documented safety invariants"
|
|
|
|
use regex::Regex;
|
|
use stemedb_core::types::ObjectValue;
|
|
|
|
use super::Extractor;
|
|
use crate::types::{ExtractedClaim, Language};
|
|
|
|
/// Extractor for unsafe blocks and atomic ordering patterns.
|
|
///
|
|
/// Detects safety-critical patterns in Rust code to enable
|
|
/// correctness conventions.
|
|
pub struct UnsafeAtomicExtractor {
|
|
/// Matches: Ordering::SeqCst, Ordering::Relaxed, etc.
|
|
ordering_pattern: Regex,
|
|
/// Matches: unsafe { ... } or unsafe fn
|
|
unsafe_keyword: Regex,
|
|
}
|
|
|
|
impl Default for UnsafeAtomicExtractor {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
impl UnsafeAtomicExtractor {
|
|
/// Create a new unsafe/atomic extractor.
|
|
///
|
|
/// # Panics
|
|
/// Panics if any regex pattern is invalid (programmer error).
|
|
#[allow(clippy::expect_used)]
|
|
pub fn new() -> Self {
|
|
Self {
|
|
// Ordering::SeqCst, Ordering::Relaxed, etc.
|
|
ordering_pattern: Regex::new(r"Ordering::(SeqCst|Acquire|Release|AcqRel|Relaxed)")
|
|
.expect("valid regex"),
|
|
|
|
// unsafe keyword (blocks or functions)
|
|
unsafe_keyword: Regex::new(r"\b(unsafe)\s*(\{|fn)").expect("valid regex"),
|
|
}
|
|
}
|
|
|
|
/// Determine confidence based on context.
|
|
fn confidence_for_file(&self, file: &str) -> f32 {
|
|
if file.contains("test") || file.contains("example") || file.contains("bench") {
|
|
0.5
|
|
} else {
|
|
1.0
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Extractor for UnsafeAtomicExtractor {
|
|
fn name(&self) -> &str {
|
|
"unsafe_atomic"
|
|
}
|
|
|
|
fn languages(&self) -> &[Language] {
|
|
&[Language::Rust]
|
|
}
|
|
|
|
fn extract(
|
|
&self,
|
|
path_segments: &[String],
|
|
content: &str,
|
|
_language: Language,
|
|
file: &str,
|
|
) -> Vec<ExtractedClaim> {
|
|
let mut claims = Vec::new();
|
|
let confidence = self.confidence_for_file(file);
|
|
|
|
// Track unique patterns to avoid excessive claims
|
|
let mut seen_orderings = std::collections::HashSet::new();
|
|
let mut unsafe_count = 0;
|
|
|
|
for (line_idx, line) in content.lines().enumerate() {
|
|
let line_num = line_idx + 1;
|
|
|
|
// Check for atomic ordering patterns
|
|
if let Some(cap) = self.ordering_pattern.captures(line) {
|
|
let ordering = cap.get(1).map_or("", |m| m.as_str());
|
|
|
|
if !seen_orderings.contains(ordering) {
|
|
seen_orderings.insert(ordering.to_string());
|
|
|
|
let mut concept_path = path_segments.to_vec();
|
|
concept_path.push("atomics".to_string());
|
|
concept_path.push("ordering".to_string());
|
|
|
|
claims.push(ExtractedClaim {
|
|
concept_path: format!("code://{}", concept_path.join("/")),
|
|
predicate: "pattern".to_string(),
|
|
value: ObjectValue::Text(ordering.to_string()),
|
|
file: file.to_string(),
|
|
line: line_num,
|
|
matched_text: line.trim().to_string(),
|
|
confidence,
|
|
description: format!("Atomic operation uses Ordering::{}", ordering),
|
|
});
|
|
}
|
|
}
|
|
|
|
// Check for unsafe blocks/functions
|
|
if self.unsafe_keyword.is_match(line) {
|
|
unsafe_count += 1;
|
|
}
|
|
}
|
|
|
|
// Add a summary claim for unsafe usage if found
|
|
if unsafe_count > 0 {
|
|
let mut concept_path = path_segments.to_vec();
|
|
concept_path.push("unsafe".to_string());
|
|
concept_path.push("count".to_string());
|
|
|
|
claims.push(ExtractedClaim {
|
|
concept_path: format!("code://{}", concept_path.join("/")),
|
|
predicate: "occurrences".to_string(),
|
|
value: ObjectValue::Number(unsafe_count as f64),
|
|
file: file.to_string(),
|
|
line: 1,
|
|
matched_text: format!("{} unsafe blocks/functions", unsafe_count),
|
|
confidence: confidence * 0.9, // Slightly lower as this is a summary
|
|
description: format!(
|
|
"File contains {} unsafe block(s) or function(s)",
|
|
unsafe_count
|
|
),
|
|
});
|
|
}
|
|
|
|
claims
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_atomic_ordering() {
|
|
let extractor = UnsafeAtomicExtractor::new();
|
|
let content = r#"
|
|
let balance = self.balance.load(Ordering::SeqCst);
|
|
self.balance.store(new_balance, Ordering::SeqCst);
|
|
"#;
|
|
|
|
let claims = extractor.extract(
|
|
&["rust".to_string(), "maxwell".to_string(), "wallet".to_string()],
|
|
content,
|
|
Language::Rust,
|
|
"src/wallet.rs",
|
|
);
|
|
|
|
// Should have one claim for SeqCst (deduplicated)
|
|
assert!(claims.iter().any(|c| {
|
|
c.concept_path.contains("atomics/ordering")
|
|
&& c.value == ObjectValue::Text("SeqCst".to_string())
|
|
}));
|
|
}
|
|
|
|
#[test]
|
|
fn test_multiple_orderings() {
|
|
let extractor = UnsafeAtomicExtractor::new();
|
|
let content = r#"
|
|
let a = atomic.load(Ordering::Acquire);
|
|
let b = atomic.load(Ordering::Relaxed);
|
|
atomic.store(x, Ordering::Release);
|
|
"#;
|
|
|
|
let claims =
|
|
extractor.extract(&["rust".to_string()], content, Language::Rust, "src/sync.rs");
|
|
|
|
// Should have 3 distinct ordering claims (Acquire, Relaxed, Release)
|
|
let ordering_claims: Vec<_> =
|
|
claims.iter().filter(|c| c.concept_path.contains("ordering")).collect();
|
|
|
|
assert_eq!(ordering_claims.len(), 3);
|
|
}
|
|
|
|
#[test]
|
|
fn test_unsafe_block() {
|
|
let extractor = UnsafeAtomicExtractor::new();
|
|
let content = r#"
|
|
unsafe {
|
|
let ptr = mem::transmute(addr);
|
|
}
|
|
"#;
|
|
|
|
let claims =
|
|
extractor.extract(&["rust".to_string()], content, Language::Rust, "src/lib.rs");
|
|
|
|
// Should have one unsafe count claim
|
|
let unsafe_claim = claims.iter().find(|c| c.concept_path.contains("unsafe/count"));
|
|
assert!(unsafe_claim.is_some());
|
|
assert_eq!(unsafe_claim.unwrap().value, ObjectValue::Number(1.0));
|
|
}
|
|
|
|
#[test]
|
|
fn test_unsafe_fn() {
|
|
let extractor = UnsafeAtomicExtractor::new();
|
|
let content = r#"
|
|
unsafe fn read_msr(reg: u32) -> u64 {
|
|
// ...
|
|
}
|
|
"#;
|
|
|
|
let claims =
|
|
extractor.extract(&["rust".to_string()], content, Language::Rust, "src/msr.rs");
|
|
|
|
let unsafe_claim = claims.iter().find(|c| c.concept_path.contains("unsafe"));
|
|
assert!(unsafe_claim.is_some());
|
|
}
|
|
|
|
#[test]
|
|
fn test_multiple_unsafe_blocks() {
|
|
let extractor = UnsafeAtomicExtractor::new();
|
|
let content = r#"
|
|
unsafe fn foo() {}
|
|
|
|
fn bar() {
|
|
unsafe {
|
|
// block 1
|
|
}
|
|
|
|
unsafe {
|
|
// block 2
|
|
}
|
|
}
|
|
"#;
|
|
|
|
let claims =
|
|
extractor.extract(&["rust".to_string()], content, Language::Rust, "src/lib.rs");
|
|
|
|
let unsafe_claim = claims.iter().find(|c| c.concept_path.contains("unsafe/count")).unwrap();
|
|
assert_eq!(unsafe_claim.value, ObjectValue::Number(3.0)); // 1 fn + 2 blocks
|
|
}
|
|
|
|
#[test]
|
|
fn test_confidence_in_test_file() {
|
|
let extractor = UnsafeAtomicExtractor::new();
|
|
let content = r#"
|
|
unsafe { test_something(); }
|
|
"#;
|
|
|
|
let claims =
|
|
extractor.extract(&["rust".to_string()], content, Language::Rust, "src/test.rs");
|
|
|
|
assert!(!claims.is_empty());
|
|
// Confidence should be reduced for test files
|
|
assert!(claims.iter().all(|c| c.confidence <= 0.5));
|
|
}
|
|
|
|
#[test]
|
|
fn test_real_world_wallet() {
|
|
let extractor = UnsafeAtomicExtractor::new();
|
|
let content = r#"
|
|
//! Wallet with atomic balance tracking
|
|
|
|
use std::sync::atomic::{AtomicU64, Ordering};
|
|
|
|
pub struct Wallet {
|
|
balance: AtomicU64,
|
|
}
|
|
|
|
impl Wallet {
|
|
pub fn deposit(&self, amount: u64) {
|
|
self.balance.fetch_add(amount, Ordering::SeqCst);
|
|
}
|
|
|
|
pub fn withdraw(&self, amount: u64) -> bool {
|
|
let current = self.balance.load(Ordering::SeqCst);
|
|
if current >= amount {
|
|
self.balance.fetch_sub(amount, Ordering::SeqCst);
|
|
true
|
|
} else {
|
|
false
|
|
}
|
|
}
|
|
|
|
pub fn balance(&self) -> u64 {
|
|
self.balance.load(Ordering::SeqCst);
|
|
}
|
|
}
|
|
"#;
|
|
|
|
let claims = extractor.extract(
|
|
&["rust".to_string(), "maxwell".to_string(), "wallet".to_string()],
|
|
content,
|
|
Language::Rust,
|
|
"src/wallet.rs",
|
|
);
|
|
|
|
// Should detect SeqCst ordering (all wallet ops use it consistently)
|
|
assert!(claims.iter().any(|c| c.concept_path.contains("ordering")
|
|
&& c.value == ObjectValue::Text("SeqCst".to_string())));
|
|
|
|
// Should NOT have unsafe claims (no unsafe code)
|
|
assert!(!claims.iter().any(|c| c.concept_path.contains("unsafe")));
|
|
}
|
|
}
|