//! LLM-based regex generation for pattern promotion. //! //! Uses the Gemini API to generate regex patterns from learned pattern examples. use tracing::{debug, info, warn}; use crate::extractors::{DeclarativeClaimDef, DeclarativeExtractorDef, DeclarativeValue}; use crate::learning::{LearnedPattern, ValueType}; use crate::llm::GeminiClient; use crate::types::Language; use crate::AphoriaError; /// System prompt for regex generation. const REGEX_GEN_SYSTEM_PROMPT: &str = r#"You are an expert regex engineer. Your task is to generate a regex pattern that matches code examples. REQUIREMENTS: 1. The regex MUST match the example code shown 2. Use named capture groups for dynamic values when value_from_match is needed (e.g., (?P...)) 3. Avoid catastrophic backtracking (no nested quantifiers like (a+)+ or (.*)+) 4. Keep the regex readable and maintainable 5. Use case-insensitive matching (?i) when appropriate 6. Escape special regex characters in literal strings IMPORTANT: - Return ONLY the regex pattern as a single line - No explanation, no markdown, no code blocks - Just the raw regex pattern"#; /// Generates regex patterns from learned pattern examples. pub struct RegexGenerator<'a> { /// The Gemini client for LLM calls. client: &'a GeminiClient, } impl<'a> RegexGenerator<'a> { /// Create a new regex generator with the given client. pub fn new(client: &'a GeminiClient) -> Self { Self { client } } /// Generate a declarative extractor definition from a learned pattern. /// /// Uses the LLM to generate an appropriate regex pattern based on /// the example code and claim template. pub fn generate( &self, pattern: &LearnedPattern, ) -> Result { let prompt = self.build_prompt(pattern); debug!( pattern_id = %pattern.id, example = %truncate(&pattern.example_code, 100), "Generating regex for pattern" ); // Call LLM to generate regex let result = self.client.complete(REGEX_GEN_SYSTEM_PROMPT, &prompt)?; // Clean and validate the response let regex_pattern = clean_regex_response(&result.response_text); if regex_pattern.is_empty() { return Err(AphoriaError::RegexGeneration( "LLM returned empty regex pattern".to_string(), )); } // Validate that the regex compiles if let Err(e) = regex::Regex::new(®ex_pattern) { warn!( pattern = %regex_pattern, error = %e, "LLM generated invalid regex" ); return Err(AphoriaError::RegexGeneration(format!( "LLM generated invalid regex: {}", e ))); } info!( pattern_id = %pattern.id, regex = %regex_pattern, "Generated regex pattern" ); // Build the extractor definition let extractor = self.build_extractor_def(pattern, regex_pattern); Ok(extractor) } /// Build the prompt for regex generation. fn build_prompt(&self, pattern: &LearnedPattern) -> String { let value_type_desc = match pattern.claim_template.value_type { ValueType::Text => "text/string", ValueType::Number => "numeric", ValueType::Boolean => "boolean", }; format!( r#"Generate a regex pattern that matches the following code example. EXAMPLE CODE: ``` {} ``` NORMALIZED PATTERN: {} CLAIM TO EXTRACT: - Subject: {} - Predicate: {} - Value type: {} Return ONLY the regex pattern as a single line, no explanation."#, pattern.example_code, pattern.normalized_pattern, pattern.claim_template.subject_template, pattern.claim_template.predicate, value_type_desc ) } /// Build the extractor definition from the pattern and generated regex. fn build_extractor_def( &self, pattern: &LearnedPattern, regex_pattern: String, ) -> DeclarativeExtractorDef { let name = generate_extractor_name(pattern); let description = pattern.claim_template.description_template.clone(); // Determine value specification based on value type let value = match pattern.claim_template.value_type { // For text values, use value_from_match to capture the dynamic value ValueType::Text => DeclarativeValue::MatchedText { value_from_match: true }, // For numbers, also capture from match ValueType::Number => DeclarativeValue::MatchedText { value_from_match: true }, // For booleans, we typically want the matched value ValueType::Boolean => DeclarativeValue::MatchedText { value_from_match: true }, }; DeclarativeExtractorDef { name, description, languages: vec![language_to_string(pattern.language)], pattern: regex_pattern, claim: DeclarativeClaimDef { subject: pattern.claim_template.subject_template.clone(), predicate: pattern.claim_template.predicate.clone(), value, }, confidence: pattern.avg_confidence, source: None, } } } /// Generate a unique extractor name from a pattern. pub fn generate_extractor_name(pattern: &LearnedPattern) -> String { // Build name from subject and predicate let base = format!( "learned_{}_{}", sanitize_name_part(&pattern.claim_template.subject_template), sanitize_name_part(&pattern.claim_template.predicate) ); // Truncate if too long if base.len() > 50 { format!("{}_{}", &base[..45], &pattern.id.to_string()[..8]) } else { base } } /// Sanitize a string for use in an extractor name. fn sanitize_name_part(s: &str) -> String { s.chars() .map(|c| if c.is_alphanumeric() { c.to_ascii_lowercase() } else { '_' }) .collect::() .trim_matches('_') .to_string() } /// Clean the LLM response to extract just the regex pattern. fn clean_regex_response(response: &str) -> String { let cleaned = response.trim(); // Remove markdown code blocks if present let cleaned = if cleaned.starts_with("```") { cleaned .lines() .skip(1) // Skip opening ``` .take_while(|line| !line.starts_with("```")) .collect::>() .join("") .trim() .to_string() } else { cleaned.to_string() }; // Remove surrounding quotes (only if string starts AND ends with same quote) let cleaned = if (cleaned.starts_with('"') && cleaned.ends_with('"')) || (cleaned.starts_with('\'') && cleaned.ends_with('\'')) { &cleaned[1..cleaned.len() - 1] } else { &cleaned }; // Take only the first line if multiple lines cleaned.lines().next().unwrap_or("").trim().to_string() } /// Truncate a string for logging. fn truncate(s: &str, max: usize) -> String { if s.len() <= max { s.to_string() } else { format!("{}...", &s[..max.saturating_sub(3)]) } } /// Convert a Language to its string representation. fn language_to_string(lang: Language) -> String { match lang { Language::Rust => "rust", Language::Go => "go", Language::Python => "python", Language::TypeScript => "typescript", Language::JavaScript => "javascript", Language::C => "c", Language::Cpp => "cpp", Language::Java => "java", Language::Php => "php", Language::Ruby => "ruby", Language::CSharp => "csharp", Language::Yaml => "yaml", Language::Toml => "toml", Language::Json => "json", Language::Ini => "ini", Language::Properties => "properties", Language::Dotenv => "dotenv", Language::Docker => "docker", Language::CargoManifest => "cargo", Language::GoMod => "gomod", Language::NpmManifest => "npm", Language::PythonManifest => "pip", Language::Terraform => "terraform", Language::Unknown => "unknown", } .to_string() } #[cfg(test)] mod tests { use super::*; use crate::learning::ClaimTemplate; use crate::types::Language; fn create_test_pattern() -> LearnedPattern { LearnedPattern::new( "const TLS_MIN_VERSION = \"1.0\"", "const TLS_MIN_VERSION = ", ClaimTemplate::new( "tls/min_version", "version", ValueType::Text, "TLS minimum version set to deprecated value", ), Language::Rust, "project1", 0.9, ) } #[test] fn test_generate_extractor_name() { let pattern = create_test_pattern(); let name = generate_extractor_name(&pattern); assert!(name.starts_with("learned_")); assert!(name.contains("tls")); assert!(name.contains("version")); } #[test] fn test_generate_extractor_name_long_subject() { let mut pattern = create_test_pattern(); pattern.claim_template.subject_template = "very/long/subject/path/that/exceeds/the/maximum/length/allowed".to_string(); let name = generate_extractor_name(&pattern); assert!(name.len() <= 60); // Should be truncated with UUID suffix } #[test] fn test_sanitize_name_part() { assert_eq!(sanitize_name_part("simple"), "simple"); assert_eq!(sanitize_name_part("with/slashes"), "with_slashes"); assert_eq!(sanitize_name_part("MixedCase"), "mixedcase"); assert_eq!(sanitize_name_part("has spaces"), "has_spaces"); assert_eq!(sanitize_name_part("_leading_trailing_"), "leading_trailing"); } #[test] fn test_clean_regex_response_simple() { let response = r#"(?i)tls_min_version\s*=\s*"([^"]+)""#; let cleaned = clean_regex_response(response); assert_eq!(cleaned, response); } #[test] fn test_clean_regex_response_with_markdown() { let response = "```regex\n(?i)tls_min_version\n```"; let cleaned = clean_regex_response(response); assert_eq!(cleaned, "(?i)tls_min_version"); } #[test] fn test_clean_regex_response_with_quotes() { let response = r#""(?i)pattern""#; let cleaned = clean_regex_response(response); assert_eq!(cleaned, "(?i)pattern"); } #[test] fn test_clean_regex_response_multiline() { let response = "first line\nsecond line\nthird line"; let cleaned = clean_regex_response(response); assert_eq!(cleaned, "first line"); } #[test] fn test_clean_regex_response_whitespace() { let response = " \n pattern \n "; let cleaned = clean_regex_response(response); assert_eq!(cleaned, "pattern"); } #[test] fn test_truncate() { assert_eq!(truncate("short", 10), "short"); assert_eq!(truncate("longer string here", 10), "longer ..."); } }