Files
2026-06-14 19:37:09 +03:00

249 lines
7.5 KiB
Go

package llm
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"strings"
"time"
)
const (
defaultTimeout = 120 * time.Second
maxAttempts = 3 // первая попытка + 2 ретрая на транзиентных сбоях
baseRetryWait = 500 * time.Millisecond
maxResponseBody = 8 << 20 // 8 MiB — потолок на тело ответа
)
// openAICompat — клиент OpenAI-совместимого Chat Completions API.
type openAICompat struct {
endpoint string // полный URL .../chat/completions
hc *http.Client
apiKey string
model string
retryWait time.Duration // базовая пауза между ретраями (0 в тестах)
log *slog.Logger
}
// newOpenAICompat собирает клиент из конфига.
func newOpenAICompat(cfg Config, logger *slog.Logger) (*openAICompat, error) {
if cfg.BaseURL == "" {
return nil, fmt.Errorf("llm: empty base_url")
}
if cfg.Model == "" {
return nil, fmt.Errorf("llm: empty model")
}
timeout := cfg.Timeout
if timeout == 0 {
timeout = defaultTimeout
}
transport := http.DefaultTransport
if cfg.Proxy != "" {
proxyURL, err := url.Parse(cfg.Proxy)
if err != nil {
return nil, fmt.Errorf("llm: parse proxy %q: %w", cfg.Proxy, err)
}
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
}
if logger == nil {
logger = slog.Default()
}
return &openAICompat{
endpoint: strings.TrimRight(cfg.BaseURL, "/") + "/chat/completions",
hc: &http.Client{Timeout: timeout, Transport: transport},
apiKey: cfg.APIKey,
model: cfg.Model,
retryWait: baseRetryWait,
log: logger,
}, nil
}
// chatRequest — тело запроса Chat Completions.
type chatRequest struct {
Model string `json:"model"`
Messages []chatMessage `json:"messages"`
Temperature *float64 `json:"temperature,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
ResponseFormat *respFormat `json:"response_format,omitempty"`
}
type chatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type respFormat struct {
Type string `json:"type"`
}
// chatResponse — нужное подмножество ответа Chat Completions.
type chatResponse struct {
Model string `json:"model"`
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
Cost float64 `json:"cost"`
} `json:"usage"`
// Некоторые шлюзы кладут ошибку в тело при HTTP 200.
Error *struct {
Message string `json:"message"`
} `json:"error"`
}
// Complete выполняет один вызов модели с транзиентными ретраями на сетевых
// сбоях, 429 и 5xx. 4xx (кроме 429) — постоянная ошибка, без ретраев.
func (c *openAICompat) Complete(ctx context.Context, req Request) (Response, error) {
if len(req.Messages) == 0 {
return Response{}, fmt.Errorf("llm: empty messages")
}
body, err := json.Marshal(c.buildRequest(req))
if err != nil {
return Response{}, fmt.Errorf("llm: marshal request: %w", err)
}
var lastErr error
for attempt := 1; attempt <= maxAttempts; attempt++ {
if attempt > 1 {
if err := c.wait(ctx, attempt); err != nil {
return Response{}, err
}
}
c.log.Debug("llm: request",
"endpoint", c.endpoint, "model", c.model,
"attempt", attempt, "max_attempts", maxAttempts)
start := time.Now()
resp, retryable, err := c.do(ctx, body)
if err == nil {
c.log.Debug("llm: response ok",
"model", resp.Model, "attempt", attempt,
"duration", time.Since(start),
"total_tokens", resp.Usage.TotalTokens, "cost", resp.Usage.Cost)
return resp, nil
}
lastErr = err
if !retryable {
c.log.Error("llm: request failed (non-retryable)",
"model", c.model, "attempt", attempt, "duration", time.Since(start), "err", err)
return Response{}, err
}
c.log.Warn("llm: request failed, will retry",
"model", c.model, "attempt", attempt, "max_attempts", maxAttempts,
"duration", time.Since(start), "err", err)
}
c.log.Error("llm: all attempts exhausted",
"model", c.model, "max_attempts", maxAttempts, "err", lastErr)
return Response{}, fmt.Errorf("llm: exhausted %d attempts: %w", maxAttempts, lastErr)
}
func (c *openAICompat) buildRequest(req Request) chatRequest {
msgs := make([]chatMessage, len(req.Messages))
for i, m := range req.Messages {
msgs[i] = chatMessage{Role: string(m.Role), Content: m.Content}
}
cr := chatRequest{
Model: c.model,
Messages: msgs,
Temperature: req.Temperature,
MaxTokens: req.MaxTokens,
}
if req.JSONMode {
cr.ResponseFormat = &respFormat{Type: "json_object"}
}
return cr
}
// do делает один HTTP-запрос. Второй результат — можно ли ретраить ошибку.
func (c *openAICompat) do(ctx context.Context, body []byte) (Response, bool, error) {
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.endpoint, bytes.NewReader(body))
if err != nil {
return Response{}, false, fmt.Errorf("llm: build request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
if c.apiKey != "" {
httpReq.Header.Set("Authorization", "Bearer "+c.apiKey)
}
resp, err := c.hc.Do(httpReq)
if err != nil {
// Сетевой сбой — ретраибелен (ctx.Err проверит wait на следующем заходе).
return Response{}, true, fmt.Errorf("llm: request: %w", err)
}
defer func() { _ = resp.Body.Close() }()
raw, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBody))
if err != nil {
return Response{}, true, fmt.Errorf("llm: read body: %w", err)
}
if resp.StatusCode != http.StatusOK {
retryable := resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode >= 500
return Response{}, retryable, fmt.Errorf("llm: status %d: %s",
resp.StatusCode, snippet(raw))
}
var cr chatResponse
if err := json.Unmarshal(raw, &cr); err != nil {
return Response{}, false, fmt.Errorf("llm: decode response: %w (body: %s)", err, snippet(raw))
}
if cr.Error != nil && cr.Error.Message != "" {
return Response{}, false, fmt.Errorf("llm: provider error: %s", cr.Error.Message)
}
if len(cr.Choices) == 0 {
return Response{}, false, fmt.Errorf("llm: empty choices (body: %s)", snippet(raw))
}
return Response{
Content: cr.Choices[0].Message.Content,
Model: cr.Model,
Usage: Usage{
PromptTokens: cr.Usage.PromptTokens,
CompletionTokens: cr.Usage.CompletionTokens,
TotalTokens: cr.Usage.TotalTokens,
Cost: cr.Usage.Cost,
},
}, false, nil
}
// wait выдерживает паузу перед ретраем (линейный backoff), уважая ctx.
func (c *openAICompat) wait(ctx context.Context, attempt int) error {
d := c.retryWait * time.Duration(attempt-1)
if d <= 0 {
return ctx.Err()
}
t := time.NewTimer(d)
defer t.Stop()
select {
case <-ctx.Done():
return fmt.Errorf("llm: %w", ctx.Err())
case <-t.C:
return nil
}
}
// snippet обрезает тело для сообщения об ошибке.
func snippet(b []byte) string {
const max = 300
s := strings.TrimSpace(string(b))
if len(s) > max {
return s[:max] + "…"
}
return s
}