// Package sanitize provides command input sanitization to prevent injection attacks. package sanitize import ( "fmt" "regexp" "strings" ) // Package-level compiled regex patterns for performance. var ( // Dangerous shell command patterns dangerousCommandPatterns = []*dangerousPattern{ {regexp.MustCompile(`(?i)\brm\s+(-[rf]+\s+)*(/|~|\.\.|/etc|/var|/usr|/home|/root)`), "destructive rm command"}, {regexp.MustCompile(`(?i)\bdd\s+`), "dd command"}, {regexp.MustCompile(`(?i)\bmkfs\b`), "mkfs command"}, {regexp.MustCompile(`(?i)\bfdisk\b`), "fdisk command"}, {regexp.MustCompile(`(?i)\bshutdown\b`), "shutdown command"}, {regexp.MustCompile(`(?i)\breboot\b`), "reboot command"}, {regexp.MustCompile(`(?i)\bsystemctl\s+(stop|disable|mask|halt|poweroff)`), "dangerous systemctl command"}, {regexp.MustCompile(`(?i)\bkill\s+-9\s+(-1|1)\b`), "kill all processes"}, {regexp.MustCompile(`(?i)\bchmod\s+(-R\s+)?(777|666)\s+/`), "dangerous chmod on root"}, {regexp.MustCompile(`(?i)\bchown\s+-R\s+\S+\s+/`), "dangerous chown on root"}, {regexp.MustCompile(`(?i):\(\)\s*\{\s*:\s*\|\s*:\s*&\s*\}\s*;`), "fork bomb"}, {regexp.MustCompile(`(?i)\bcurl\s+.*\|\s*(ba)?sh`), "remote code execution via curl"}, {regexp.MustCompile(`(?i)\bwget\s+.*\|\s*(ba)?sh`), "remote code execution via wget"}, } // Stream ID validation pattern streamIDPattern = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]*$`) ) // dangerousPattern pairs a compiled regex with its error reason. type dangerousPattern struct { pattern *regexp.Regexp reason string } // Error types for sanitization failures. type Error struct { Reason string Input string Pattern string } func (e *Error) Error() string { if e.Pattern != "" { return fmt.Sprintf("sanitization failed: %s (matched pattern: %q, input: %q)", e.Reason, e.Pattern, e.Input) } return fmt.Sprintf("sanitization failed: %s (input: %q)", e.Reason, e.Input) } // ShellCommand validates and sanitizes a shell command. // Returns an error if the command contains dangerous patterns. func ShellCommand(cmd string) error { if strings.TrimSpace(cmd) == "" { return &Error{Reason: "empty command", Input: cmd} } // Check for null bytes if strings.ContainsRune(cmd, '\x00') { return &Error{Reason: "contains null byte", Input: cmd} } // Dangerous command chaining patterns chainPatterns := []string{ `;`, // Command separator `&&`, // AND operator `||`, // OR operator `|`, // Pipe "`", // Backtick command substitution `$(`, // Command substitution `${`, // Variable expansion that could be exploited `>(`, // Process substitution `<(`, // Process substitution `\n`, // Newline (command separator in shell) `\r`, // Carriage return } for _, pattern := range chainPatterns { if strings.Contains(cmd, pattern) { return &Error{ Reason: "contains command chaining operator", Input: cmd, Pattern: pattern, } } } // Dangerous redirect patterns redirectPatterns := []string{ `>`, // Output redirect `>>`, // Append redirect `<`, // Input redirect } for _, pattern := range redirectPatterns { if strings.Contains(cmd, pattern) { return &Error{ Reason: "contains redirect operator", Input: cmd, Pattern: pattern, } } } // Check against pre-compiled dangerous command patterns for _, dp := range dangerousCommandPatterns { if dp.pattern.MatchString(cmd) { return &Error{ Reason: dp.reason, Input: cmd, Pattern: dp.pattern.String(), } } } return nil } // GitArgs validates git command arguments. // Returns an error if any argument contains dangerous patterns. func GitArgs(args []string) error { if len(args) == 0 { return &Error{Reason: "empty git args"} } // Dangerous git subcommands dangerousSubcommands := map[string]string{ "config": "git config can modify system settings", "remote": "git remote can add malicious remotes", "push": "git push can modify remote repositories", } // First arg is the subcommand subcommand := strings.ToLower(args[0]) if reason, dangerous := dangerousSubcommands[subcommand]; dangerous { // push is allowed for specific use cases, block only with --force if subcommand == "push" { for _, arg := range args[1:] { if arg == "-f" || arg == "--force" || arg == "--force-with-lease" { return &Error{Reason: "force push not allowed", Input: strings.Join(args, " ")} } } } else { return &Error{Reason: reason, Input: strings.Join(args, " ")} } } // Check all args for shell injection for _, arg := range args { if err := validateArg(arg); err != nil { return err } } return nil } // ClaudePrompt validates a Claude prompt. // This is relatively permissive since prompts are passed as a single argument to claude CLI. func ClaudePrompt(prompt string) error { if strings.TrimSpace(prompt) == "" { return &Error{Reason: "empty prompt"} } // Check for null bytes if strings.ContainsRune(prompt, '\x00') { return &Error{Reason: "contains null byte", Input: prompt} } // Limit length to prevent resource exhaustion const maxPromptLength = 100000 // 100KB if len(prompt) > maxPromptLength { return &Error{ Reason: fmt.Sprintf("prompt too long (max %d bytes)", maxPromptLength), Input: prompt[:100] + "...", } } return nil } // validateArg validates a single command argument for shell injection. func validateArg(arg string) error { // Check for null bytes if strings.ContainsRune(arg, '\x00') { return &Error{Reason: "argument contains null byte", Input: arg} } // Check for shell metacharacters that could break out of quoting shellMetachars := []string{ "`", // Backtick `$(`, // Command substitution `${`, // Variable expansion } for _, meta := range shellMetachars { if strings.Contains(arg, meta) { return &Error{ Reason: "argument contains shell metacharacter", Input: arg, Pattern: meta, } } } return nil } // StreamID validates a stream ID. // Stream IDs should be alphanumeric with hyphens and underscores only. func StreamID(id string) error { if id == "" { return nil // Empty is allowed (will be auto-generated) } // Must be reasonable length if len(id) > 64 { return &Error{Reason: "stream ID too long (max 64 chars)", Input: id} } // Must match safe pattern (uses pre-compiled regex) if !streamIDPattern.MatchString(id) { return &Error{ Reason: "stream ID must be alphanumeric with hyphens/underscores", Input: id, Pattern: streamIDPattern.String(), } } return nil }