228 lines
6.6 KiB
Go
228 lines
6.6 KiB
Go
package llm
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
defaultTimeout = 120 * time.Second
|
|
maxAttempts = 3 // первая попытка + 2 ретрая на транзиентных сбоях
|
|
baseRetryWait = 500 * time.Millisecond
|
|
maxResponseBody = 8 << 20 // 8 MiB — потолок на тело ответа
|
|
)
|
|
|
|
// openAICompat — клиент OpenAI-совместимого Chat Completions API.
|
|
type openAICompat struct {
|
|
endpoint string // полный URL .../chat/completions
|
|
hc *http.Client
|
|
apiKey string
|
|
model string
|
|
retryWait time.Duration // базовая пауза между ретраями (0 в тестах)
|
|
}
|
|
|
|
// newOpenAICompat собирает клиент из конфига.
|
|
func newOpenAICompat(cfg Config) (*openAICompat, error) {
|
|
if cfg.BaseURL == "" {
|
|
return nil, fmt.Errorf("llm: empty base_url")
|
|
}
|
|
if cfg.Model == "" {
|
|
return nil, fmt.Errorf("llm: empty model")
|
|
}
|
|
|
|
timeout := cfg.Timeout
|
|
if timeout == 0 {
|
|
timeout = defaultTimeout
|
|
}
|
|
|
|
transport := http.DefaultTransport
|
|
if cfg.Proxy != "" {
|
|
proxyURL, err := url.Parse(cfg.Proxy)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("llm: parse proxy %q: %w", cfg.Proxy, err)
|
|
}
|
|
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
|
|
}
|
|
|
|
return &openAICompat{
|
|
endpoint: strings.TrimRight(cfg.BaseURL, "/") + "/chat/completions",
|
|
hc: &http.Client{Timeout: timeout, Transport: transport},
|
|
apiKey: cfg.APIKey,
|
|
model: cfg.Model,
|
|
retryWait: baseRetryWait,
|
|
}, nil
|
|
}
|
|
|
|
// chatRequest — тело запроса Chat Completions.
|
|
type chatRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []chatMessage `json:"messages"`
|
|
Temperature *float64 `json:"temperature,omitempty"`
|
|
MaxTokens int `json:"max_tokens,omitempty"`
|
|
ResponseFormat *respFormat `json:"response_format,omitempty"`
|
|
}
|
|
|
|
type chatMessage struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
type respFormat struct {
|
|
Type string `json:"type"`
|
|
}
|
|
|
|
// chatResponse — нужное подмножество ответа Chat Completions.
|
|
type chatResponse struct {
|
|
Model string `json:"model"`
|
|
Choices []struct {
|
|
Message struct {
|
|
Content string `json:"content"`
|
|
} `json:"message"`
|
|
FinishReason string `json:"finish_reason"`
|
|
} `json:"choices"`
|
|
Usage struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
CompletionTokens int `json:"completion_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
Cost float64 `json:"cost"`
|
|
} `json:"usage"`
|
|
// Некоторые шлюзы кладут ошибку в тело при HTTP 200.
|
|
Error *struct {
|
|
Message string `json:"message"`
|
|
} `json:"error"`
|
|
}
|
|
|
|
// Complete выполняет один вызов модели с транзиентными ретраями на сетевых
|
|
// сбоях, 429 и 5xx. 4xx (кроме 429) — постоянная ошибка, без ретраев.
|
|
func (c *openAICompat) Complete(ctx context.Context, req Request) (Response, error) {
|
|
if len(req.Messages) == 0 {
|
|
return Response{}, fmt.Errorf("llm: empty messages")
|
|
}
|
|
|
|
body, err := json.Marshal(c.buildRequest(req))
|
|
if err != nil {
|
|
return Response{}, fmt.Errorf("llm: marshal request: %w", err)
|
|
}
|
|
|
|
var lastErr error
|
|
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
|
if attempt > 1 {
|
|
if err := c.wait(ctx, attempt); err != nil {
|
|
return Response{}, err
|
|
}
|
|
}
|
|
|
|
resp, retryable, err := c.do(ctx, body)
|
|
if err == nil {
|
|
return resp, nil
|
|
}
|
|
lastErr = err
|
|
if !retryable {
|
|
return Response{}, err
|
|
}
|
|
}
|
|
return Response{}, fmt.Errorf("llm: exhausted %d attempts: %w", maxAttempts, lastErr)
|
|
}
|
|
|
|
func (c *openAICompat) buildRequest(req Request) chatRequest {
|
|
msgs := make([]chatMessage, len(req.Messages))
|
|
for i, m := range req.Messages {
|
|
msgs[i] = chatMessage{Role: string(m.Role), Content: m.Content}
|
|
}
|
|
cr := chatRequest{
|
|
Model: c.model,
|
|
Messages: msgs,
|
|
Temperature: req.Temperature,
|
|
MaxTokens: req.MaxTokens,
|
|
}
|
|
if req.JSONMode {
|
|
cr.ResponseFormat = &respFormat{Type: "json_object"}
|
|
}
|
|
return cr
|
|
}
|
|
|
|
// do делает один HTTP-запрос. Второй результат — можно ли ретраить ошибку.
|
|
func (c *openAICompat) do(ctx context.Context, body []byte) (Response, bool, error) {
|
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.endpoint, bytes.NewReader(body))
|
|
if err != nil {
|
|
return Response{}, false, fmt.Errorf("llm: build request: %w", err)
|
|
}
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
if c.apiKey != "" {
|
|
httpReq.Header.Set("Authorization", "Bearer "+c.apiKey)
|
|
}
|
|
|
|
resp, err := c.hc.Do(httpReq)
|
|
if err != nil {
|
|
// Сетевой сбой — ретраибелен (ctx.Err проверит wait на следующем заходе).
|
|
return Response{}, true, fmt.Errorf("llm: request: %w", err)
|
|
}
|
|
defer func() { _ = resp.Body.Close() }()
|
|
|
|
raw, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBody))
|
|
if err != nil {
|
|
return Response{}, true, fmt.Errorf("llm: read body: %w", err)
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
retryable := resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode >= 500
|
|
return Response{}, retryable, fmt.Errorf("llm: status %d: %s",
|
|
resp.StatusCode, snippet(raw))
|
|
}
|
|
|
|
var cr chatResponse
|
|
if err := json.Unmarshal(raw, &cr); err != nil {
|
|
return Response{}, false, fmt.Errorf("llm: decode response: %w (body: %s)", err, snippet(raw))
|
|
}
|
|
if cr.Error != nil && cr.Error.Message != "" {
|
|
return Response{}, false, fmt.Errorf("llm: provider error: %s", cr.Error.Message)
|
|
}
|
|
if len(cr.Choices) == 0 {
|
|
return Response{}, false, fmt.Errorf("llm: empty choices (body: %s)", snippet(raw))
|
|
}
|
|
|
|
return Response{
|
|
Content: cr.Choices[0].Message.Content,
|
|
Model: cr.Model,
|
|
Usage: Usage{
|
|
PromptTokens: cr.Usage.PromptTokens,
|
|
CompletionTokens: cr.Usage.CompletionTokens,
|
|
TotalTokens: cr.Usage.TotalTokens,
|
|
Cost: cr.Usage.Cost,
|
|
},
|
|
}, false, nil
|
|
}
|
|
|
|
// wait выдерживает паузу перед ретраем (линейный backoff), уважая ctx.
|
|
func (c *openAICompat) wait(ctx context.Context, attempt int) error {
|
|
d := c.retryWait * time.Duration(attempt-1)
|
|
if d <= 0 {
|
|
return ctx.Err()
|
|
}
|
|
t := time.NewTimer(d)
|
|
defer t.Stop()
|
|
select {
|
|
case <-ctx.Done():
|
|
return fmt.Errorf("llm: %w", ctx.Err())
|
|
case <-t.C:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// snippet обрезает тело для сообщения об ошибке.
|
|
func snippet(b []byte) string {
|
|
const max = 300
|
|
s := strings.TrimSpace(string(b))
|
|
if len(s) > max {
|
|
return s[:max] + "…"
|
|
}
|
|
return s
|
|
}
|