441 lines
12 KiB
Go
441 lines
12 KiB
Go
// Package laozhang provides a Go client for the LaoZhang API gateway.
|
|
//
|
|
// LaoZhang is an OpenAI-compatible API gateway that provides access to various
|
|
// AI models including chat completion, image generation (Nano Banana Pro), and
|
|
// video generation (Veo 3.1).
|
|
//
|
|
// Basic usage:
|
|
//
|
|
// client, err := laozhang.NewClient(laozhang.Config{
|
|
// APIKey: os.Getenv("LAOZHANG_API_KEY"),
|
|
// })
|
|
// if err != nil {
|
|
// log.Fatal(err)
|
|
// }
|
|
//
|
|
// // Generate an image
|
|
// resp, err := client.GenerateImage(ctx, laozhang.ImageRequest{
|
|
// Prompt: "A serene Japanese garden",
|
|
// })
|
|
//
|
|
// The client automatically handles retries for server errors (5xx) and rate
|
|
// limits (429) with exponential backoff. Use the sentinel errors (ErrRateLimit,
|
|
// ErrServerError, etc.) with errors.Is() for programmatic error handling.
|
|
package laozhang
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"git.threesix.ai/jordan/persona-community-2/pkg/httpclient"
|
|
)
|
|
|
|
const (
|
|
defaultBaseURL = "https://api.laozhang.ai/v1"
|
|
defaultTimeout = 60 * time.Second
|
|
defaultVideoTimeout = 5 * time.Minute
|
|
defaultMaxRetries = 3
|
|
)
|
|
|
|
// Config holds configuration options for the LaoZhang client
|
|
type Config struct {
|
|
APIKey string // Required: API key for authentication
|
|
BaseURL string // Optional: defaults to https://api.laozhang.ai/v1
|
|
Timeout time.Duration // Optional: defaults to 60s
|
|
VideoTimeout time.Duration // Optional: defaults to 5m (video generation takes 2-5 minutes)
|
|
MaxRetries int // Optional: defaults to 3
|
|
Logger *slog.Logger // Optional: defaults to slog.Default()
|
|
}
|
|
|
|
// Client is the LaoZhang API client
|
|
type Client struct {
|
|
httpClient *httpclient.Client // Standard timeout (60s) for text/image
|
|
videoHTTPClient *httpclient.Client // Long timeout (5m) for video generation
|
|
config *Config
|
|
logger *slog.Logger
|
|
}
|
|
|
|
// NewClient creates a new LaoZhang API client
|
|
func NewClient(config Config) (*Client, error) {
|
|
if config.APIKey == "" {
|
|
return nil, fmt.Errorf("%w: API key is required", ErrInvalidConfig)
|
|
}
|
|
|
|
if config.BaseURL == "" {
|
|
config.BaseURL = defaultBaseURL
|
|
}
|
|
|
|
if config.Timeout == 0 {
|
|
config.Timeout = defaultTimeout
|
|
}
|
|
|
|
if config.VideoTimeout == 0 {
|
|
config.VideoTimeout = defaultVideoTimeout
|
|
}
|
|
|
|
if config.MaxRetries == 0 {
|
|
config.MaxRetries = defaultMaxRetries
|
|
}
|
|
|
|
if config.Logger == nil {
|
|
config.Logger = slog.Default()
|
|
}
|
|
|
|
// Validate base URL
|
|
if _, err := url.Parse(config.BaseURL); err != nil {
|
|
return nil, fmt.Errorf("%w: invalid base URL: %v", ErrInvalidConfig, err)
|
|
}
|
|
|
|
return &Client{
|
|
httpClient: httpclient.New(httpclient.Config{
|
|
Timeout: config.Timeout,
|
|
MaxRetries: config.MaxRetries,
|
|
Logger: config.Logger,
|
|
}),
|
|
videoHTTPClient: httpclient.New(httpclient.Config{
|
|
Timeout: config.VideoTimeout,
|
|
MaxRetries: config.MaxRetries,
|
|
Logger: config.Logger,
|
|
}),
|
|
config: &config,
|
|
logger: config.Logger,
|
|
}, nil
|
|
}
|
|
|
|
// Health checks the health of the LaoZhang API
|
|
func (c *Client) Health(ctx context.Context) error {
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.config.BaseURL+"/models", nil)
|
|
if err != nil {
|
|
return fmt.Errorf("create health check request: %w", err)
|
|
}
|
|
|
|
req.Header.Set("Authorization", "Bearer "+c.config.APIKey)
|
|
|
|
resp, err := c.httpClient.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("health check request failed: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
baseErr := classifyHTTPError(resp.StatusCode)
|
|
return NewAPIError(resp.StatusCode, string(body), "", "", baseErr)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ChatCompletion sends a chat completion request to the LaoZhang API
|
|
func (c *Client) ChatCompletion(ctx context.Context, req ChatCompletionRequest) (*ChatCompletionResponse, error) {
|
|
if req.Model == "" {
|
|
return nil, fmt.Errorf("%w: model is required", ErrInvalidConfig)
|
|
}
|
|
if len(req.Messages) == 0 {
|
|
return nil, fmt.Errorf("%w: messages are required", ErrInvalidConfig)
|
|
}
|
|
|
|
respBody, err := c.doRequest(ctx, http.MethodPost, "/chat/completions", req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var chatResp ChatCompletionResponse
|
|
if err := json.Unmarshal(respBody, &chatResp); err != nil {
|
|
return nil, fmt.Errorf("unmarshal response: %w", err)
|
|
}
|
|
|
|
return &chatResp, nil
|
|
}
|
|
|
|
// ChatCompletionStream sends a streaming chat completion request.
|
|
// Returns chunks via the onChunk callback as tokens arrive from the API.
|
|
// The request is sent with stream=true and the response body is read as SSE.
|
|
func (c *Client) ChatCompletionStream(ctx context.Context, req ChatCompletionRequest, onChunk func(StreamChunk)) error {
|
|
if req.Model == "" {
|
|
return fmt.Errorf("%w: model is required", ErrInvalidConfig)
|
|
}
|
|
if len(req.Messages) == 0 {
|
|
return fmt.Errorf("%w: messages are required", ErrInvalidConfig)
|
|
}
|
|
|
|
req.Stream = true
|
|
|
|
jsonBody, err := json.Marshal(req)
|
|
if err != nil {
|
|
return fmt.Errorf("marshal request: %w", err)
|
|
}
|
|
|
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.config.BaseURL+"/chat/completions", bytes.NewReader(jsonBody))
|
|
if err != nil {
|
|
return fmt.Errorf("create request: %w", err)
|
|
}
|
|
httpReq.Header.Set("Authorization", "Bearer "+c.config.APIKey)
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
|
|
// Use a raw http.Client with timeout for streaming (httpClient.Do may buffer)
|
|
rawClient := &http.Client{Timeout: c.config.Timeout}
|
|
resp, err := rawClient.Do(httpReq)
|
|
if err != nil {
|
|
return fmt.Errorf("stream request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
baseErr := classifyHTTPError(resp.StatusCode)
|
|
return NewAPIError(resp.StatusCode, string(body), "", "", baseErr)
|
|
}
|
|
|
|
// Read SSE stream line by line
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
|
|
// Skip empty lines and comments
|
|
if line == "" || line[0] == ':' {
|
|
continue
|
|
}
|
|
|
|
// Parse SSE data lines
|
|
if !strings.HasPrefix(line, "data: ") {
|
|
continue
|
|
}
|
|
data := strings.TrimPrefix(line, "data: ")
|
|
|
|
// [DONE] marks the end of the stream
|
|
if data == "[DONE]" {
|
|
onChunk(StreamChunk{Done: true})
|
|
return nil
|
|
}
|
|
|
|
var chunk streamResponse
|
|
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
|
c.logger.Debug("skipping unparseable stream chunk", "error", err)
|
|
continue
|
|
}
|
|
|
|
if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" {
|
|
onChunk(StreamChunk{Text: chunk.Choices[0].Delta.Content})
|
|
}
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
return fmt.Errorf("read stream: %w", err)
|
|
}
|
|
|
|
// If we reach here without [DONE], send a final chunk
|
|
onChunk(StreamChunk{Done: true})
|
|
return nil
|
|
}
|
|
|
|
// StreamChunk represents a chunk from a streaming response.
|
|
type StreamChunk struct {
|
|
Text string
|
|
Done bool
|
|
}
|
|
|
|
// streamResponse is the JSON structure for streaming chat completion chunks.
|
|
type streamResponse struct {
|
|
Choices []streamChoice `json:"choices"`
|
|
}
|
|
|
|
type streamChoice struct {
|
|
Delta struct {
|
|
Content string `json:"content"`
|
|
} `json:"delta"`
|
|
FinishReason *string `json:"finish_reason"`
|
|
}
|
|
|
|
// doRequest is a helper method for making HTTP requests
|
|
func (c *Client) doRequest(ctx context.Context, method, path string, bodyData interface{}) ([]byte, error) {
|
|
var reqBody io.Reader
|
|
if bodyData != nil {
|
|
jsonBody, err := json.Marshal(bodyData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("marshal request: %w", err)
|
|
}
|
|
reqBody = bytes.NewReader(jsonBody)
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, method, c.config.BaseURL+path, reqBody)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create request: %w", err)
|
|
}
|
|
|
|
req.Header.Set("Authorization", "Bearer "+c.config.APIKey)
|
|
if bodyData != nil {
|
|
req.Header.Set("Content-Type", "application/json")
|
|
}
|
|
|
|
resp, err := c.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
respBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read response body: %w", err)
|
|
}
|
|
|
|
// Success response
|
|
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
|
return respBody, nil
|
|
}
|
|
|
|
// Parse error response
|
|
var errResp ErrorResponse
|
|
baseErr := classifyHTTPError(resp.StatusCode)
|
|
|
|
if err := json.Unmarshal(respBody, &errResp); err != nil {
|
|
// Failed to parse error response
|
|
return nil, NewAPIError(
|
|
resp.StatusCode,
|
|
string(respBody),
|
|
"",
|
|
"",
|
|
baseErr,
|
|
)
|
|
}
|
|
|
|
// Successfully parsed error response
|
|
return nil, NewAPIError(
|
|
resp.StatusCode,
|
|
errResp.Error.Message,
|
|
errResp.Error.Type,
|
|
errResp.Error.Code,
|
|
baseErr,
|
|
)
|
|
}
|
|
|
|
// doRequestVideo is like doRequest but uses the video HTTP client with a longer timeout.
|
|
// Video generation (Veo) takes 2-5 minutes, exceeding the standard 60s client timeout.
|
|
func (c *Client) doRequestVideo(ctx context.Context, method, path string, bodyData interface{}) ([]byte, error) {
|
|
var reqBody io.Reader
|
|
if bodyData != nil {
|
|
jsonBody, err := json.Marshal(bodyData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("marshal request: %w", err)
|
|
}
|
|
reqBody = bytes.NewReader(jsonBody)
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, method, c.config.BaseURL+path, reqBody)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create request: %w", err)
|
|
}
|
|
|
|
req.Header.Set("Authorization", "Bearer "+c.config.APIKey)
|
|
if bodyData != nil {
|
|
req.Header.Set("Content-Type", "application/json")
|
|
}
|
|
|
|
resp, err := c.videoHTTPClient.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
respBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read response body: %w", err)
|
|
}
|
|
|
|
// Success response
|
|
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
|
return respBody, nil
|
|
}
|
|
|
|
// Parse error response
|
|
var errResp ErrorResponse
|
|
baseErr := classifyHTTPError(resp.StatusCode)
|
|
|
|
if err := json.Unmarshal(respBody, &errResp); err != nil {
|
|
return nil, NewAPIError(
|
|
resp.StatusCode,
|
|
string(respBody),
|
|
"",
|
|
"",
|
|
baseErr,
|
|
)
|
|
}
|
|
|
|
return nil, NewAPIError(
|
|
resp.StatusCode,
|
|
errResp.Error.Message,
|
|
errResp.Error.Type,
|
|
errResp.Error.Code,
|
|
baseErr,
|
|
)
|
|
}
|
|
|
|
// geminiBaseURL returns the base URL for Gemini API endpoints (without /v1 suffix)
|
|
func (c *Client) geminiBaseURL() string {
|
|
// Strip /v1 suffix if present to get the root URL for Gemini endpoints
|
|
baseURL := c.config.BaseURL
|
|
if len(baseURL) > 3 && baseURL[len(baseURL)-3:] == "/v1" {
|
|
return baseURL[:len(baseURL)-3]
|
|
}
|
|
return baseURL
|
|
}
|
|
|
|
// doRequestGemini is similar to doRequest but uses the Gemini base URL format
|
|
func (c *Client) doRequestGemini(ctx context.Context, method, path string, bodyData interface{}) ([]byte, error) {
|
|
var reqBody io.Reader
|
|
if bodyData != nil {
|
|
jsonBody, err := json.Marshal(bodyData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("marshal request: %w", err)
|
|
}
|
|
reqBody = bytes.NewReader(jsonBody)
|
|
}
|
|
|
|
// Use Gemini base URL (without /v1)
|
|
fullURL := c.geminiBaseURL() + path
|
|
req, err := http.NewRequestWithContext(ctx, method, fullURL, reqBody)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create request: %w", err)
|
|
}
|
|
|
|
req.Header.Set("Authorization", "Bearer "+c.config.APIKey)
|
|
if bodyData != nil {
|
|
req.Header.Set("Content-Type", "application/json")
|
|
}
|
|
|
|
resp, err := c.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
respBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read response body: %w", err)
|
|
}
|
|
|
|
// Success response
|
|
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
|
return respBody, nil
|
|
}
|
|
|
|
// Parse error response (Gemini format may differ)
|
|
baseErr := classifyHTTPError(resp.StatusCode)
|
|
return nil, NewAPIError(
|
|
resp.StatusCode,
|
|
string(respBody),
|
|
"",
|
|
"",
|
|
baseErr,
|
|
)
|
|
}
|