stemedb/applications/aphoria/src/promotion/regex_gen.rs
jordan 28fc3b5391 feat(aphoria): add C language support and streamline documentation
Add Language::C variant with file detection (.c, Makefile, CMakeLists.txt)
and integration across prompts, regex_gen, and path_mapper. Simplify
README and guides to be more concise and scannable.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-12 03:02:33 -07:00

350 lines
11 KiB
Rust

//! LLM-based regex generation for pattern promotion.
//!
//! Uses the Gemini API to generate regex patterns from learned pattern examples.
use tracing::{debug, info, warn};
use crate::extractors::{DeclarativeClaimDef, DeclarativeExtractorDef, DeclarativeValue};
use crate::learning::{LearnedPattern, ValueType};
use crate::llm::GeminiClient;
use crate::types::Language;
use crate::AphoriaError;
/// System prompt for regex generation.
const REGEX_GEN_SYSTEM_PROMPT: &str = r#"You are an expert regex engineer. Your task is to generate a regex pattern that matches code examples.
REQUIREMENTS:
1. The regex MUST match the example code shown
2. Use named capture groups for dynamic values when value_from_match is needed (e.g., (?P<value>...))
3. Avoid catastrophic backtracking (no nested quantifiers like (a+)+ or (.*)+)
4. Keep the regex readable and maintainable
5. Use case-insensitive matching (?i) when appropriate
6. Escape special regex characters in literal strings
IMPORTANT:
- Return ONLY the regex pattern as a single line
- No explanation, no markdown, no code blocks
- Just the raw regex pattern"#;
/// Generates regex patterns from learned pattern examples.
pub struct RegexGenerator<'a> {
/// The Gemini client for LLM calls.
client: &'a GeminiClient,
}
impl<'a> RegexGenerator<'a> {
/// Create a new regex generator with the given client.
pub fn new(client: &'a GeminiClient) -> Self {
Self { client }
}
/// Generate a declarative extractor definition from a learned pattern.
///
/// Uses the LLM to generate an appropriate regex pattern based on
/// the example code and claim template.
pub fn generate(
&self,
pattern: &LearnedPattern,
) -> Result<DeclarativeExtractorDef, AphoriaError> {
let prompt = self.build_prompt(pattern);
debug!(
pattern_id = %pattern.id,
example = %truncate(&pattern.example_code, 100),
"Generating regex for pattern"
);
// Call LLM to generate regex
let result = self.client.complete(REGEX_GEN_SYSTEM_PROMPT, &prompt)?;
// Clean and validate the response
let regex_pattern = clean_regex_response(&result.response_text);
if regex_pattern.is_empty() {
return Err(AphoriaError::RegexGeneration(
"LLM returned empty regex pattern".to_string(),
));
}
// Validate that the regex compiles
if let Err(e) = regex::Regex::new(&regex_pattern) {
warn!(
pattern = %regex_pattern,
error = %e,
"LLM generated invalid regex"
);
return Err(AphoriaError::RegexGeneration(format!(
"LLM generated invalid regex: {}",
e
)));
}
info!(
pattern_id = %pattern.id,
regex = %regex_pattern,
"Generated regex pattern"
);
// Build the extractor definition
let extractor = self.build_extractor_def(pattern, regex_pattern);
Ok(extractor)
}
/// Build the prompt for regex generation.
fn build_prompt(&self, pattern: &LearnedPattern) -> String {
let value_type_desc = match pattern.claim_template.value_type {
ValueType::Text => "text/string",
ValueType::Number => "numeric",
ValueType::Boolean => "boolean",
};
format!(
r#"Generate a regex pattern that matches the following code example.
EXAMPLE CODE:
```
{}
```
NORMALIZED PATTERN:
{}
CLAIM TO EXTRACT:
- Subject: {}
- Predicate: {}
- Value type: {}
Return ONLY the regex pattern as a single line, no explanation."#,
pattern.example_code,
pattern.normalized_pattern,
pattern.claim_template.subject_template,
pattern.claim_template.predicate,
value_type_desc
)
}
/// Build the extractor definition from the pattern and generated regex.
fn build_extractor_def(
&self,
pattern: &LearnedPattern,
regex_pattern: String,
) -> DeclarativeExtractorDef {
let name = generate_extractor_name(pattern);
let description = pattern.claim_template.description_template.clone();
// Determine value specification based on value type
let value = match pattern.claim_template.value_type {
// For text values, use value_from_match to capture the dynamic value
ValueType::Text => DeclarativeValue::MatchedText { value_from_match: true },
// For numbers, also capture from match
ValueType::Number => DeclarativeValue::MatchedText { value_from_match: true },
// For booleans, we typically want the matched value
ValueType::Boolean => DeclarativeValue::MatchedText { value_from_match: true },
};
DeclarativeExtractorDef {
name,
description,
languages: vec![language_to_string(pattern.language)],
pattern: regex_pattern,
claim: DeclarativeClaimDef {
subject: pattern.claim_template.subject_template.clone(),
predicate: pattern.claim_template.predicate.clone(),
value,
},
confidence: pattern.avg_confidence,
source: None,
}
}
}
/// Generate a unique extractor name from a pattern.
pub fn generate_extractor_name(pattern: &LearnedPattern) -> String {
// Build name from subject and predicate
let base = format!(
"learned_{}_{}",
sanitize_name_part(&pattern.claim_template.subject_template),
sanitize_name_part(&pattern.claim_template.predicate)
);
// Truncate if too long
if base.len() > 50 {
format!("{}_{}", &base[..45], &pattern.id.to_string()[..8])
} else {
base
}
}
/// Sanitize a string for use in an extractor name.
fn sanitize_name_part(s: &str) -> String {
s.chars()
.map(|c| if c.is_alphanumeric() { c.to_ascii_lowercase() } else { '_' })
.collect::<String>()
.trim_matches('_')
.to_string()
}
/// Clean the LLM response to extract just the regex pattern.
fn clean_regex_response(response: &str) -> String {
let cleaned = response.trim();
// Remove markdown code blocks if present
let cleaned = if cleaned.starts_with("```") {
cleaned
.lines()
.skip(1) // Skip opening ```
.take_while(|line| !line.starts_with("```"))
.collect::<Vec<_>>()
.join("")
.trim()
.to_string()
} else {
cleaned.to_string()
};
// Remove surrounding quotes (only if string starts AND ends with same quote)
let cleaned = if (cleaned.starts_with('"') && cleaned.ends_with('"'))
|| (cleaned.starts_with('\'') && cleaned.ends_with('\''))
{
&cleaned[1..cleaned.len() - 1]
} else {
&cleaned
};
// Take only the first line if multiple lines
cleaned.lines().next().unwrap_or("").trim().to_string()
}
/// Truncate a string for logging.
fn truncate(s: &str, max: usize) -> String {
if s.len() <= max {
s.to_string()
} else {
format!("{}...", &s[..max.saturating_sub(3)])
}
}
/// Convert a Language to its string representation.
fn language_to_string(lang: Language) -> String {
match lang {
Language::Rust => "rust",
Language::Go => "go",
Language::Python => "python",
Language::TypeScript => "typescript",
Language::JavaScript => "javascript",
Language::C => "c",
Language::Cpp => "cpp",
Language::Java => "java",
Language::Php => "php",
Language::Ruby => "ruby",
Language::CSharp => "csharp",
Language::Yaml => "yaml",
Language::Toml => "toml",
Language::Json => "json",
Language::Ini => "ini",
Language::Properties => "properties",
Language::Dotenv => "dotenv",
Language::Docker => "docker",
Language::CargoManifest => "cargo",
Language::GoMod => "gomod",
Language::NpmManifest => "npm",
Language::PythonManifest => "pip",
Language::Terraform => "terraform",
Language::Unknown => "unknown",
}
.to_string()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::learning::ClaimTemplate;
use crate::types::Language;
fn create_test_pattern() -> LearnedPattern {
LearnedPattern::new(
"const TLS_MIN_VERSION = \"1.0\"",
"const TLS_MIN_VERSION = <string:version>",
ClaimTemplate::new(
"tls/min_version",
"version",
ValueType::Text,
"TLS minimum version set to deprecated value",
),
Language::Rust,
"project1",
0.9,
)
}
#[test]
fn test_generate_extractor_name() {
let pattern = create_test_pattern();
let name = generate_extractor_name(&pattern);
assert!(name.starts_with("learned_"));
assert!(name.contains("tls"));
assert!(name.contains("version"));
}
#[test]
fn test_generate_extractor_name_long_subject() {
let mut pattern = create_test_pattern();
pattern.claim_template.subject_template =
"very/long/subject/path/that/exceeds/the/maximum/length/allowed".to_string();
let name = generate_extractor_name(&pattern);
assert!(name.len() <= 60); // Should be truncated with UUID suffix
}
#[test]
fn test_sanitize_name_part() {
assert_eq!(sanitize_name_part("simple"), "simple");
assert_eq!(sanitize_name_part("with/slashes"), "with_slashes");
assert_eq!(sanitize_name_part("MixedCase"), "mixedcase");
assert_eq!(sanitize_name_part("has spaces"), "has_spaces");
assert_eq!(sanitize_name_part("_leading_trailing_"), "leading_trailing");
}
#[test]
fn test_clean_regex_response_simple() {
let response = r#"(?i)tls_min_version\s*=\s*"([^"]+)""#;
let cleaned = clean_regex_response(response);
assert_eq!(cleaned, response);
}
#[test]
fn test_clean_regex_response_with_markdown() {
let response = "```regex\n(?i)tls_min_version\n```";
let cleaned = clean_regex_response(response);
assert_eq!(cleaned, "(?i)tls_min_version");
}
#[test]
fn test_clean_regex_response_with_quotes() {
let response = r#""(?i)pattern""#;
let cleaned = clean_regex_response(response);
assert_eq!(cleaned, "(?i)pattern");
}
#[test]
fn test_clean_regex_response_multiline() {
let response = "first line\nsecond line\nthird line";
let cleaned = clean_regex_response(response);
assert_eq!(cleaned, "first line");
}
#[test]
fn test_clean_regex_response_whitespace() {
let response = " \n pattern \n ";
let cleaned = clean_regex_response(response);
assert_eq!(cleaned, "pattern");
}
#[test]
fn test_truncate() {
assert_eq!(truncate("short", 10), "short");
assert_eq!(truncate("longer string here", 10), "longer ...");
}
}