package adk import ( "context" "encoding/json" "fmt" "slices" ) // CallbackContext provides context for agent callbacks. // // This interface matches the pattern expected by ADK-Go but is generic // enough to work with any agent framework. type CallbackContext interface { Context() context.Context AgentName() string SessionID() string } // ToolCall represents a tool invocation. type ToolCall struct { Name string `json:"name"` Input json.RawMessage `json:"input"` } // ToolResult represents the result of a tool invocation. type ToolResult struct { Output json.RawMessage `json:"output"` Error error `json:"error,omitempty"` } // BeforeToolCallback is a function that runs before a tool is called. // // This can be used to: // - Validate tool inputs // - Enforce constraints // - Block dangerous operations // // Return an error to prevent the tool from executing. type BeforeToolCallback func(ctx CallbackContext, call *ToolCall) error // AfterToolCallback is a function that runs after a tool is called. // // This can be used to: // - Log tool results // - Check confidence thresholds // - Trigger escalations // // Return an error to indicate the tool result should be treated as a failure. type AfterToolCallback func(ctx CallbackContext, call *ToolCall, result *ToolResult) error // ConstraintEnforcementCallback enforces constraints before code generation. // // This callback is designed for Implementation Agent. It checks for // forbidden patterns before any tool that might generate code. // // Example usage: // // agent.BeforeToolCallback = ConstraintEnforcementCallback(client, []string{ // "write_code", // "generate_config", // }) func ConstraintEnforcementCallback(client EpistemeClient, codeGenTools []string) BeforeToolCallback { return func(ctx CallbackContext, call *ToolCall) error { // Check if this is a code generation tool if !slices.Contains(codeGenTools, call.Name) { return nil // Not a code generation tool, allow it } // Extract domain context from tool input // This is a simplified implementation - in production you'd // parse the actual tool input to determine context context := extractDomainContext(call) // Check constraints constraintTool := NewConstraintCheckTool(client) inputBytes, err := json.Marshal(ConstraintCheckInput{ Context: context, }) if err != nil { return fmt.Errorf("failed to marshal constraint check input: %w", err) } resultBytes, err := constraintTool.Execute(ctx.Context(), inputBytes) if err != nil { return fmt.Errorf("constraint check failed: %w", err) } var output ConstraintCheckOutput if err := json.Unmarshal(resultBytes, &output); err != nil { return fmt.Errorf("failed to parse constraints: %w", err) } // Validate against constraints for _, constraint := range output.Constraints { if constraint.Forbidden != "" { // Check if the tool call would use a forbidden pattern if wouldUseForbiddenPattern(call, constraint.Forbidden) { return fmt.Errorf("blocked: %s is forbidden - %s", constraint.Forbidden, constraint.Reason) } } } return nil } } // ConfidenceEscalationCallback escalates to human when confidence is too low. // // This callback is designed for Lead Orchestrator. It checks query results // and marks the session for human review if confidence is below threshold. // // Example usage: // // agent.AfterToolCallback = ConfidenceEscalationCallback(0.8, sessionState) func ConfidenceEscalationCallback(threshold float32, setState func(key string, value any)) AfterToolCallback { return func(ctx CallbackContext, call *ToolCall, result *ToolResult) error { // Only check query results if call.Name != "episteme_query" { return nil } if result.Error != nil { return nil // Already failed, no need to check confidence } // Parse query output var output QueryOutput if err := json.Unmarshal(result.Output, &output); err != nil { return nil // Can't parse, skip } // Check confidence threshold if output.Confidence < threshold { // Mark for human review setState("needs_human_review", true) setState("low_confidence_query", output.QueryID) setState("escalation_reason", fmt.Sprintf( "Confidence %.2f below threshold %.2f", output.Confidence, threshold, )) } return nil } } // AuditLoggingCallback logs all tool calls for audit trail. // // This callback logs every tool invocation to help with debugging and // incident investigation. // // Example usage: // // agent.AfterToolCallback = AuditLoggingCallback(logger) func AuditLoggingCallback(log func(format string, args ...any)) AfterToolCallback { return func(ctx CallbackContext, call *ToolCall, result *ToolResult) error { if result.Error != nil { log("[%s] Tool %s failed: %v", ctx.AgentName(), call.Name, result.Error) } else { log("[%s] Tool %s succeeded", ctx.AgentName(), call.Name) } return nil } } // ChainCallbacks chains multiple callbacks into one. // // This allows combining multiple callback behaviors. // // Example usage: // // agent.BeforeToolCallback = ChainBeforeCallbacks( // ConstraintEnforcementCallback(client, codeGenTools), // CustomValidationCallback(), // ) func ChainBeforeCallbacks(callbacks ...BeforeToolCallback) BeforeToolCallback { return func(ctx CallbackContext, call *ToolCall) error { for _, cb := range callbacks { if err := cb(ctx, call); err != nil { return err } } return nil } } // ChainAfterCallbacks chains multiple after-tool callbacks into one. func ChainAfterCallbacks(callbacks ...AfterToolCallback) AfterToolCallback { return func(ctx CallbackContext, call *ToolCall, result *ToolResult) error { for _, cb := range callbacks { if err := cb(ctx, call, result); err != nil { return err } } return nil } } // Helper functions func extractDomainContext(_ *ToolCall) string { // This is a simplified implementation // In production, you'd parse the tool input to determine context // For example, from write_code input you might extract: // - Language (python, go, etc.) // - Domain (http, auth, database, etc.) return "general" } func wouldUseForbiddenPattern(call *ToolCall, forbidden string) bool { // This is a simplified implementation // In production, you'd analyze the tool input to check if it // would use the forbidden pattern // // For example, check if the input contains references to: // - Forbidden libraries // - Forbidden function names // - Forbidden architectural patterns // // This might involve: // - Parsing code ASTs // - Checking import statements // - Pattern matching on configuration // For now, just check if the input contains the forbidden string return containsPattern(call.Input, forbidden) } func containsPattern(_ json.RawMessage, pattern string) bool { // Simple string containment check // In production, use more sophisticated pattern matching return len(pattern) > 0 // Placeholder }