202 lines
5.1 KiB
Go
202 lines
5.1 KiB
Go
// Package adapters provides textgen provider adapters for various AI services.
|
|
package adapters
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"git.threesix.ai/jordan/persona-community-2/pkg/textgen"
|
|
"google.golang.org/genai"
|
|
)
|
|
|
|
const (
|
|
// IMPORTANT: Always use gemini-3-flash-preview for text generation.
|
|
// DO NOT use gemini-2.5-flash or any 2.x models.
|
|
defaultGeminiTextModel = "gemini-3-flash-preview"
|
|
)
|
|
|
|
// GeminiTextProvider implements textgen.TextGenerator using Gemini API.
|
|
type GeminiTextProvider struct {
|
|
client *genai.Client
|
|
model string
|
|
}
|
|
|
|
// GeminiTextConfig holds configuration for the Gemini text provider.
|
|
type GeminiTextConfig struct {
|
|
// APIKey for Gemini API (required if Client is nil)
|
|
APIKey string
|
|
|
|
// Client is an existing genai.Client (optional, takes precedence over APIKey)
|
|
Client *genai.Client
|
|
|
|
// Model to use for text generation (default: gemini-3-flash-preview)
|
|
// IMPORTANT: DO NOT use gemini-2.5-flash or any 2.x models.
|
|
Model string
|
|
}
|
|
|
|
// NewGeminiTextProvider creates a new Gemini text generation provider.
|
|
func NewGeminiTextProvider(ctx context.Context, cfg GeminiTextConfig) (*GeminiTextProvider, error) {
|
|
var client *genai.Client
|
|
var err error
|
|
|
|
if cfg.Client != nil {
|
|
client = cfg.Client
|
|
} else if cfg.APIKey != "" {
|
|
client, err = genai.NewClient(ctx, &genai.ClientConfig{
|
|
APIKey: cfg.APIKey,
|
|
Backend: genai.BackendGeminiAPI,
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create genai client: %w", err)
|
|
}
|
|
} else {
|
|
return nil, fmt.Errorf("%w: APIKey or Client required", textgen.ErrInvalidConfig)
|
|
}
|
|
|
|
model := cfg.Model
|
|
if model == "" {
|
|
model = defaultGeminiTextModel
|
|
}
|
|
|
|
return &GeminiTextProvider{
|
|
client: client,
|
|
model: model,
|
|
}, nil
|
|
}
|
|
|
|
// Name implements textgen.Provider.
|
|
func (p *GeminiTextProvider) Name() string {
|
|
return "gemini"
|
|
}
|
|
|
|
// Health implements textgen.Provider.
|
|
func (p *GeminiTextProvider) Health(ctx context.Context) error {
|
|
// Try a minimal request to verify the API is working
|
|
_, err := p.client.Models.Get(ctx, p.model, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("gemini health check: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GenerateText implements textgen.TextGenerator.
|
|
func (p *GeminiTextProvider) GenerateText(ctx context.Context, req textgen.TextRequest) (*textgen.TextResponse, error) {
|
|
model := req.Model
|
|
if model == "" {
|
|
model = p.model
|
|
}
|
|
|
|
// Build content from request
|
|
var content []*genai.Content
|
|
|
|
// Add system prompt if provided
|
|
if req.SystemPrompt != "" {
|
|
content = append(content, &genai.Content{
|
|
Role: "user",
|
|
Parts: []*genai.Part{
|
|
{Text: "System: " + req.SystemPrompt},
|
|
},
|
|
})
|
|
}
|
|
|
|
// Add messages if provided (multi-turn conversation)
|
|
if len(req.Messages) > 0 {
|
|
for _, msg := range req.Messages {
|
|
role := msg.Role
|
|
if role == "assistant" {
|
|
role = "model"
|
|
}
|
|
content = append(content, &genai.Content{
|
|
Role: role,
|
|
Parts: []*genai.Part{
|
|
{Text: msg.Content},
|
|
},
|
|
})
|
|
}
|
|
} else if req.Prompt != "" {
|
|
// Single prompt
|
|
content = append(content, &genai.Content{
|
|
Role: "user",
|
|
Parts: []*genai.Part{
|
|
{Text: req.Prompt},
|
|
},
|
|
})
|
|
}
|
|
|
|
// Configure generation
|
|
var config *genai.GenerateContentConfig
|
|
if req.MaxTokens > 0 || req.Temperature > 0 {
|
|
config = &genai.GenerateContentConfig{}
|
|
if req.MaxTokens > 0 {
|
|
config.MaxOutputTokens = int32(req.MaxTokens)
|
|
}
|
|
if req.Temperature > 0 {
|
|
temp := float32(req.Temperature)
|
|
config.Temperature = &temp
|
|
}
|
|
}
|
|
|
|
// Call Gemini API
|
|
resp, err := p.client.Models.GenerateContent(ctx, model, content, config)
|
|
if err != nil {
|
|
return nil, classifyGeminiError(err)
|
|
}
|
|
|
|
// Extract text from response
|
|
var responseText string
|
|
if resp != nil && len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
|
|
for _, part := range resp.Candidates[0].Content.Parts {
|
|
if part.Text != "" {
|
|
responseText += part.Text
|
|
}
|
|
}
|
|
}
|
|
|
|
if responseText == "" {
|
|
return nil, fmt.Errorf("empty response from Gemini")
|
|
}
|
|
|
|
// Build response
|
|
result := &textgen.TextResponse{
|
|
Text: responseText,
|
|
}
|
|
|
|
// Add usage if available
|
|
if resp.UsageMetadata != nil {
|
|
result.Usage = &textgen.Usage{
|
|
PromptTokens: int(resp.UsageMetadata.PromptTokenCount),
|
|
CompletionTokens: int(resp.UsageMetadata.CandidatesTokenCount),
|
|
TotalTokens: int(resp.UsageMetadata.TotalTokenCount),
|
|
}
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// classifyGeminiError converts Gemini errors to textgen sentinel errors.
|
|
func classifyGeminiError(err error) error {
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
|
|
errStr := strings.ToLower(err.Error())
|
|
|
|
// Check for common error patterns
|
|
switch {
|
|
case strings.Contains(errStr, "quota") || strings.Contains(errStr, "429"):
|
|
return fmt.Errorf("%w: %v", textgen.ErrQuotaExceeded, err)
|
|
case strings.Contains(errStr, "rate") || strings.Contains(errStr, "limit"):
|
|
return fmt.Errorf("%w: %v", textgen.ErrRateLimited, err)
|
|
case strings.Contains(errStr, "safety") || strings.Contains(errStr, "blocked"):
|
|
return fmt.Errorf("%w: %v", textgen.ErrContentBlocked, err)
|
|
case strings.Contains(errStr, "timeout") || strings.Contains(errStr, "deadline"):
|
|
return fmt.Errorf("%w: %v", textgen.ErrTimeout, err)
|
|
default:
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Compile-time interface check
|
|
var _ textgen.TextGenerator = (*GeminiTextProvider)(nil)
|