132 lines
3.7 KiB
Go
132 lines
3.7 KiB
Go
package gemini
|
|
|
|
import (
|
|
"errors"
|
|
"strings"
|
|
)
|
|
|
|
// Sentinel errors for common error types
|
|
var (
|
|
// ErrInvalidConfig indicates configuration validation failed
|
|
ErrInvalidConfig = errors.New("invalid configuration")
|
|
|
|
// ErrRateLimit indicates rate limit exceeded
|
|
ErrRateLimit = errors.New("rate limit exceeded")
|
|
|
|
// ErrServerError indicates server-side error
|
|
ErrServerError = errors.New("server error")
|
|
|
|
// ErrInvalidRequest indicates client error
|
|
ErrInvalidRequest = errors.New("invalid request")
|
|
|
|
// ErrTimeout indicates request timeout
|
|
ErrTimeout = errors.New("request timeout")
|
|
|
|
// ErrUnauthorized indicates authentication failed
|
|
ErrUnauthorized = errors.New("unauthorized")
|
|
|
|
// ErrQuotaExceeded indicates quota has been exceeded
|
|
ErrQuotaExceeded = errors.New("quota exceeded")
|
|
|
|
// ErrContentBlocked indicates content was blocked by safety filters
|
|
ErrContentBlocked = errors.New("content blocked by safety filters")
|
|
)
|
|
|
|
// classifyError attempts to classify an error from the Gemini API
|
|
// into one of our sentinel errors
|
|
func classifyError(err error) error {
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
|
|
errStr := strings.ToLower(err.Error())
|
|
|
|
// Check for rate limiting
|
|
if strings.Contains(errStr, "rate limit") ||
|
|
strings.Contains(errStr, "resource exhausted") ||
|
|
strings.Contains(errStr, "too many requests") {
|
|
return errors.Join(ErrRateLimit, err)
|
|
}
|
|
|
|
// Check for quota exceeded
|
|
if strings.Contains(errStr, "quota") ||
|
|
strings.Contains(errStr, "billing") {
|
|
return errors.Join(ErrQuotaExceeded, err)
|
|
}
|
|
|
|
// Check for authentication errors
|
|
if strings.Contains(errStr, "unauthorized") ||
|
|
strings.Contains(errStr, "unauthenticated") ||
|
|
strings.Contains(errStr, "invalid api key") ||
|
|
strings.Contains(errStr, "permission denied") {
|
|
return errors.Join(ErrUnauthorized, err)
|
|
}
|
|
|
|
// Check for content blocked
|
|
if strings.Contains(errStr, "blocked") ||
|
|
strings.Contains(errStr, "safety") ||
|
|
strings.Contains(errStr, "harmful") {
|
|
return errors.Join(ErrContentBlocked, err)
|
|
}
|
|
|
|
// Check for timeout
|
|
if strings.Contains(errStr, "timeout") ||
|
|
strings.Contains(errStr, "deadline exceeded") {
|
|
return errors.Join(ErrTimeout, err)
|
|
}
|
|
|
|
// Check for server errors
|
|
if strings.Contains(errStr, "internal") ||
|
|
strings.Contains(errStr, "unavailable") ||
|
|
strings.Contains(errStr, "server error") {
|
|
return errors.Join(ErrServerError, err)
|
|
}
|
|
|
|
// Check for invalid request
|
|
if strings.Contains(errStr, "invalid") ||
|
|
strings.Contains(errStr, "bad request") ||
|
|
strings.Contains(errStr, "malformed") {
|
|
return errors.Join(ErrInvalidRequest, err)
|
|
}
|
|
|
|
// Return original error if no classification matches
|
|
return err
|
|
}
|
|
|
|
// IsRateLimitError checks if the error is a rate limit error
|
|
func IsRateLimitError(err error) bool {
|
|
return errors.Is(err, ErrRateLimit)
|
|
}
|
|
|
|
// IsQuotaExceededError checks if the error is a quota exceeded error
|
|
func IsQuotaExceededError(err error) bool {
|
|
return errors.Is(err, ErrQuotaExceeded)
|
|
}
|
|
|
|
// IsUnauthorizedError checks if the error is an unauthorized error
|
|
func IsUnauthorizedError(err error) bool {
|
|
return errors.Is(err, ErrUnauthorized)
|
|
}
|
|
|
|
// IsContentBlockedError checks if the error is a content blocked error
|
|
func IsContentBlockedError(err error) bool {
|
|
return errors.Is(err, ErrContentBlocked)
|
|
}
|
|
|
|
// IsTimeoutError checks if the error is a timeout error
|
|
func IsTimeoutError(err error) bool {
|
|
return errors.Is(err, ErrTimeout)
|
|
}
|
|
|
|
// IsServerError checks if the error is a server error
|
|
func IsServerError(err error) bool {
|
|
return errors.Is(err, ErrServerError)
|
|
}
|
|
|
|
// IsRetryableError checks if the error should trigger a retry
|
|
func IsRetryableError(err error) bool {
|
|
return errors.Is(err, ErrRateLimit) ||
|
|
errors.Is(err, ErrServerError) ||
|
|
errors.Is(err, ErrTimeout)
|
|
}
|