Добавил интеграцию с LLM
This commit is contained in:
@@ -1,4 +0,0 @@
|
||||
// Package llm — провайдер LLM за интерфейсом (дискриминатор type).
|
||||
//
|
||||
// Заглушка: реализация в фазе Ф2 (см. docs/specs/recognition.md).
|
||||
package llm
|
||||
@@ -0,0 +1,81 @@
|
||||
package llm_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.vakhrushev.me/av/jellybit/internal/llm"
|
||||
)
|
||||
|
||||
// TestIntegration_OpenAICompat бьётся в реальный эндпоинт. По умолчанию
|
||||
// пропускается; включается переменными окружения:
|
||||
//
|
||||
// JELLYBIT_LLM_BASE_URL=https://bothub.chat/api/v2/openai/v1 \
|
||||
// JELLYBIT_LLM_API_KEY=... \
|
||||
// JELLYBIT_LLM_MODEL=deepseek-v4-flash \
|
||||
// go test ./internal/llm/ -run Integration -v
|
||||
func TestIntegration_OpenAICompat(t *testing.T) {
|
||||
base := os.Getenv("JELLYBIT_LLM_BASE_URL")
|
||||
key := os.Getenv("JELLYBIT_LLM_API_KEY")
|
||||
model := os.Getenv("JELLYBIT_LLM_MODEL")
|
||||
if base == "" || model == "" {
|
||||
t.Skip("set JELLYBIT_LLM_BASE_URL and JELLYBIT_LLM_MODEL to run")
|
||||
}
|
||||
|
||||
p, err := llm.New(llm.Config{
|
||||
Type: "openai-compat",
|
||||
BaseURL: base,
|
||||
APIKey: key,
|
||||
Model: model,
|
||||
Timeout: 90 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second)
|
||||
defer cancel()
|
||||
|
||||
zero := 0.0
|
||||
resp, err := p.Complete(ctx, llm.Request{
|
||||
JSONMode: true,
|
||||
Temperature: &zero,
|
||||
MaxTokens: 2000,
|
||||
Messages: []llm.Message{
|
||||
{Role: llm.RoleSystem, Content: `Распознай медиа. Ответь только JSON по схеме: ` +
|
||||
`{"kind":"movie|series","title":string,"year":number,"season":number|null}`},
|
||||
{Role: llm.RoleUser, Content: "Аватар: Легенда об Аанге / Avatar: The Last Airbender / " +
|
||||
"Книга 2: Земля [2006, США, DVDRip-AVC]"},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Complete: %v", err)
|
||||
}
|
||||
t.Logf("model=%s usage=%+v content=%s", resp.Model, resp.Usage, resp.Content)
|
||||
|
||||
raw, err := llm.ExtractJSONObject(resp.Content)
|
||||
if err != nil {
|
||||
t.Fatalf("ExtractJSONObject: %v (content: %q)", err, resp.Content)
|
||||
}
|
||||
var plan struct {
|
||||
Kind string `json:"kind"`
|
||||
Title string `json:"title"`
|
||||
Year int `json:"year"`
|
||||
Season *int `json:"season"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(raw), &plan); err != nil {
|
||||
t.Fatalf("unmarshal plan: %v (raw: %s)", err, raw)
|
||||
}
|
||||
if plan.Kind != "series" {
|
||||
t.Errorf("kind = %q, want series", plan.Kind)
|
||||
}
|
||||
if plan.Year != 2006 {
|
||||
t.Errorf("year = %d, want 2006", plan.Year)
|
||||
}
|
||||
if plan.Season == nil || *plan.Season != 2 {
|
||||
t.Errorf("season = %v, want 2", plan.Season)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ErrNoJSON — в тексте не нашлось JSON-объекта.
|
||||
var ErrNoJSON = errors.New("llm: no JSON object in response")
|
||||
|
||||
// ExtractJSONObject вытаскивает первый верхнеуровневый JSON-объект из ответа
|
||||
// модели. Модели часто оборачивают JSON в ```-ограждения или добавляют
|
||||
// пояснения до/после — здесь это срезается, а скобки считаются с учётом
|
||||
// строковых литералов и экранирования, чтобы `{` внутри строки не сбивал
|
||||
// баланс. Возвращает срез исходного текста (валидацию делает вызывающий).
|
||||
func ExtractJSONObject(s string) (string, error) {
|
||||
s = stripCodeFences(s)
|
||||
|
||||
start := strings.IndexByte(s, '{')
|
||||
if start < 0 {
|
||||
return "", ErrNoJSON
|
||||
}
|
||||
|
||||
depth := 0
|
||||
inStr := false
|
||||
esc := false
|
||||
for i := start; i < len(s); i++ {
|
||||
c := s[i]
|
||||
if inStr {
|
||||
switch {
|
||||
case esc:
|
||||
esc = false
|
||||
case c == '\\':
|
||||
esc = true
|
||||
case c == '"':
|
||||
inStr = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
switch c {
|
||||
case '"':
|
||||
inStr = true
|
||||
case '{':
|
||||
depth++
|
||||
case '}':
|
||||
depth--
|
||||
if depth == 0 {
|
||||
return s[start : i+1], nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", ErrNoJSON
|
||||
}
|
||||
|
||||
// stripCodeFences убирает markdown-ограждение ```...``` (с опциональным
|
||||
// языковым тегом вроде ```json), если ответ обёрнут в него целиком.
|
||||
func stripCodeFences(s string) string {
|
||||
t := strings.TrimSpace(s)
|
||||
if !strings.HasPrefix(t, "```") {
|
||||
return s
|
||||
}
|
||||
// Отрезаем первую строку с открывающим ``` и языковым тегом.
|
||||
if nl := strings.IndexByte(t, '\n'); nl >= 0 {
|
||||
t = t[nl+1:]
|
||||
} else {
|
||||
return s
|
||||
}
|
||||
if end := strings.LastIndex(t, "```"); end >= 0 {
|
||||
t = t[:end]
|
||||
}
|
||||
return t
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
package llm
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestExtractJSONObject(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{"plain", `{"a":1}`, `{"a":1}`},
|
||||
{"with prose", "Конечно, вот:\n{\"a\":1}\nготово", `{"a":1}`},
|
||||
{"fenced json", "```json\n{\"a\":1}\n```", `{"a":1}`},
|
||||
{"fenced bare", "```\n{\"a\":1}\n```", `{"a":1}`},
|
||||
{"nested", `{"a":{"b":2},"c":3}`, `{"a":{"b":2},"c":3}`},
|
||||
{"brace in string", `{"path":"a{b}c","n":1}`, `{"path":"a{b}c","n":1}`},
|
||||
{"escaped quote in string", `{"q":"he said \"hi\" {x}"}`, `{"q":"he said \"hi\" {x}"}`},
|
||||
{"trailing after object", `{"a":1} extra {ignored}`, `{"a":1}`},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ExtractJSONObject(tt.in)
|
||||
if err != nil {
|
||||
t.Fatalf("ExtractJSONObject(%q): %v", tt.in, err)
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("ExtractJSONObject(%q) = %q, want %q", tt.in, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractJSONObject_Errors(t *testing.T) {
|
||||
for _, in := range []string{"", "no json here", "just text", `{"unbalanced":1`} {
|
||||
if _, err := ExtractJSONObject(in); err == nil {
|
||||
t.Errorf("ExtractJSONObject(%q): want error", in)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
// Package llm — провайдер LLM за интерфейсом (дискриминатор type).
|
||||
//
|
||||
// Реализация выбирается полем [llm].type (см. docs/specs/recognition.md).
|
||||
// Первый и пока единственный тип — "openai-compat": OpenAI-совместимый Chat
|
||||
// Completions API (локальные серверы LM Studio/llama.cpp/Ollama и облачные
|
||||
// совместимые провайдеры — DeepSeek, Qwen и др.).
|
||||
//
|
||||
// Пакет отвечает за один вызов модели и транспортную устойчивость (ретраи
|
||||
// на сетевые сбои, 429 и 5xx). Бюджет переразбора ответа со схемой-в-промпте
|
||||
// (llm.max_retries) принадлежит вызывающему recognize, а не провайдеру.
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Role — роль сообщения в диалоге.
|
||||
type Role string
|
||||
|
||||
const (
|
||||
RoleSystem Role = "system"
|
||||
RoleUser Role = "user"
|
||||
RoleAssistant Role = "assistant"
|
||||
)
|
||||
|
||||
// Message — одно сообщение запроса.
|
||||
type Message struct {
|
||||
Role Role
|
||||
Content string
|
||||
}
|
||||
|
||||
// Request — запрос к модели.
|
||||
type Request struct {
|
||||
Messages []Message
|
||||
JSONMode bool // response_format: json_object (структурированный вывод)
|
||||
Temperature *float64 // nil — не передаём, у модели остаётся её дефолт
|
||||
MaxTokens int // 0 — не передаём
|
||||
}
|
||||
|
||||
// Usage — расход токенов и стоимость (если провайдер их сообщает).
|
||||
type Usage struct {
|
||||
PromptTokens int
|
||||
CompletionTokens int
|
||||
TotalTokens int
|
||||
Cost float64
|
||||
}
|
||||
|
||||
// Response — ответ модели.
|
||||
type Response struct {
|
||||
Content string
|
||||
Model string
|
||||
Usage Usage
|
||||
}
|
||||
|
||||
// Provider — абстракция доступа к LLM. recognize работает только с ним и не
|
||||
// знает про конкретный транспорт.
|
||||
type Provider interface {
|
||||
Complete(ctx context.Context, req Request) (Response, error)
|
||||
}
|
||||
|
||||
// Config — параметры провайдера (подмножество [llm] из конфига).
|
||||
type Config struct {
|
||||
Type string
|
||||
BaseURL string
|
||||
APIKey string
|
||||
Model string
|
||||
Proxy string
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// ErrUnknownType — запрошенный [llm].type не поддерживается.
|
||||
var ErrUnknownType = errors.New("llm: unknown provider type")
|
||||
|
||||
// New собирает провайдер по дискриминатору cfg.Type.
|
||||
func New(cfg Config) (Provider, error) {
|
||||
switch cfg.Type {
|
||||
case "openai-compat":
|
||||
return newOpenAICompat(cfg)
|
||||
case "":
|
||||
return nil, fmt.Errorf("%w: %q (укажите [llm].type)", ErrUnknownType, cfg.Type)
|
||||
default:
|
||||
return nil, fmt.Errorf("%w: %q", ErrUnknownType, cfg.Type)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,236 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// newTestProvider собирает openai-compat клиент на адрес стенда и убирает
|
||||
// паузы между ретраями, чтобы тесты не висели.
|
||||
func newTestProvider(t *testing.T, baseURL, apiKey string) *openAICompat {
|
||||
t.Helper()
|
||||
p, err := newOpenAICompat(Config{
|
||||
Type: "openai-compat",
|
||||
BaseURL: baseURL,
|
||||
APIKey: apiKey,
|
||||
Model: "test-model",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("newOpenAICompat: %v", err)
|
||||
}
|
||||
p.retryWait = 0
|
||||
return p
|
||||
}
|
||||
|
||||
func TestComplete_Success(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer secret" {
|
||||
t.Errorf("Authorization = %q, want Bearer secret", got)
|
||||
}
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
var req chatRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
t.Fatalf("decode request: %v", err)
|
||||
}
|
||||
if req.Model != "test-model" {
|
||||
t.Errorf("model = %q, want test-model", req.Model)
|
||||
}
|
||||
if req.ResponseFormat == nil || req.ResponseFormat.Type != "json_object" {
|
||||
t.Errorf("response_format = %+v, want json_object", req.ResponseFormat)
|
||||
}
|
||||
_, _ = io.WriteString(w, `{"model":"resolved-model",
|
||||
"choices":[{"message":{"content":"{\"ok\":true}"},"finish_reason":"stop"}],
|
||||
"usage":{"prompt_tokens":10,"completion_tokens":3,"total_tokens":13,"cost":0.0001}}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := newTestProvider(t, srv.URL, "secret")
|
||||
resp, err := p.Complete(context.Background(), Request{
|
||||
Messages: []Message{{Role: RoleUser, Content: "hi"}},
|
||||
JSONMode: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Complete: %v", err)
|
||||
}
|
||||
if resp.Content != `{"ok":true}` {
|
||||
t.Errorf("content = %q", resp.Content)
|
||||
}
|
||||
if resp.Model != "resolved-model" {
|
||||
t.Errorf("model = %q", resp.Model)
|
||||
}
|
||||
if resp.Usage.TotalTokens != 13 || resp.Usage.Cost != 0.0001 {
|
||||
t.Errorf("usage = %+v", resp.Usage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplete_NoJSONModeOmitsResponseFormat(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
var req chatRequest
|
||||
_ = json.Unmarshal(body, &req)
|
||||
if req.ResponseFormat != nil {
|
||||
t.Errorf("response_format should be omitted, got %+v", req.ResponseFormat)
|
||||
}
|
||||
if !strings.Contains(string(body), `"temperature":0`) {
|
||||
t.Errorf("temperature 0 must be sent explicitly, body: %s", body)
|
||||
}
|
||||
_, _ = io.WriteString(w, `{"choices":[{"message":{"content":"hello"}}]}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := newTestProvider(t, srv.URL, "")
|
||||
zero := 0.0
|
||||
resp, err := p.Complete(context.Background(), Request{
|
||||
Messages: []Message{{Role: RoleUser, Content: "hi"}},
|
||||
Temperature: &zero,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Complete: %v", err)
|
||||
}
|
||||
if resp.Content != "hello" {
|
||||
t.Errorf("content = %q", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplete_RetriesOn5xxThenSucceeds(t *testing.T) {
|
||||
var calls atomic.Int32
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if calls.Add(1) <= 2 {
|
||||
w.WriteHeader(http.StatusBadGateway)
|
||||
_, _ = io.WriteString(w, "upstream down")
|
||||
return
|
||||
}
|
||||
_, _ = io.WriteString(w, `{"choices":[{"message":{"content":"recovered"}}]}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := newTestProvider(t, srv.URL, "")
|
||||
resp, err := p.Complete(context.Background(), Request{
|
||||
Messages: []Message{{Role: RoleUser, Content: "hi"}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Complete: %v", err)
|
||||
}
|
||||
if resp.Content != "recovered" {
|
||||
t.Errorf("content = %q", resp.Content)
|
||||
}
|
||||
if got := calls.Load(); got != 3 {
|
||||
t.Errorf("calls = %d, want 3", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplete_429IsRetryable(t *testing.T) {
|
||||
var calls atomic.Int32
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if calls.Add(1) == 1 {
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
_, _ = io.WriteString(w, `{"choices":[{"message":{"content":"ok"}}]}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := newTestProvider(t, srv.URL, "")
|
||||
if _, err := p.Complete(context.Background(), Request{
|
||||
Messages: []Message{{Role: RoleUser, Content: "hi"}},
|
||||
}); err != nil {
|
||||
t.Fatalf("Complete: %v", err)
|
||||
}
|
||||
if got := calls.Load(); got != 2 {
|
||||
t.Errorf("calls = %d, want 2", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplete_4xxNotRetried(t *testing.T) {
|
||||
var calls atomic.Int32
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
calls.Add(1)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = io.WriteString(w, `{"error":{"message":"bad model"}}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := newTestProvider(t, srv.URL, "")
|
||||
_, err := p.Complete(context.Background(), Request{
|
||||
Messages: []Message{{Role: RoleUser, Content: "hi"}},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("Complete: want error on 400")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "status 400") {
|
||||
t.Errorf("err = %v, want status 400", err)
|
||||
}
|
||||
if got := calls.Load(); got != 1 {
|
||||
t.Errorf("calls = %d, want 1 (no retry on 4xx)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplete_ExhaustsRetries(t *testing.T) {
|
||||
var calls atomic.Int32
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
calls.Add(1)
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := newTestProvider(t, srv.URL, "")
|
||||
_, err := p.Complete(context.Background(), Request{
|
||||
Messages: []Message{{Role: RoleUser, Content: "hi"}},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("want error after exhausting retries")
|
||||
}
|
||||
if got := calls.Load(); got != maxAttempts {
|
||||
t.Errorf("calls = %d, want %d", got, maxAttempts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplete_ProviderErrorInBody(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// HTTP 200, но ошибка в теле — частый паттерн шлюзов.
|
||||
_, _ = io.WriteString(w, `{"error":{"message":"context length exceeded"}}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := newTestProvider(t, srv.URL, "")
|
||||
_, err := p.Complete(context.Background(), Request{
|
||||
Messages: []Message{{Role: RoleUser, Content: "hi"}},
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "context length exceeded") {
|
||||
t.Fatalf("err = %v, want provider error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplete_EmptyMessages(t *testing.T) {
|
||||
p := newTestProvider(t, "http://example.invalid", "")
|
||||
if _, err := p.Complete(context.Background(), Request{}); err == nil {
|
||||
t.Fatal("want error on empty messages")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew_UnknownType(t *testing.T) {
|
||||
if _, err := New(Config{Type: "anthropic", Model: "x", BaseURL: "http://x"}); err == nil {
|
||||
t.Fatal("want error for unknown type")
|
||||
}
|
||||
if _, err := New(Config{Type: ""}); err == nil {
|
||||
t.Fatal("want error for empty type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew_OpenAICompatValidation(t *testing.T) {
|
||||
if _, err := New(Config{Type: "openai-compat", Model: "x"}); err == nil {
|
||||
t.Fatal("want error for empty base_url")
|
||||
}
|
||||
if _, err := New(Config{Type: "openai-compat", BaseURL: "http://x"}); err == nil {
|
||||
t.Fatal("want error for empty model")
|
||||
}
|
||||
if _, err := New(Config{Type: "openai-compat", BaseURL: "http://x", Model: "m"}); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,227 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTimeout = 120 * time.Second
|
||||
maxAttempts = 3 // первая попытка + 2 ретрая на транзиентных сбоях
|
||||
baseRetryWait = 500 * time.Millisecond
|
||||
maxResponseBody = 8 << 20 // 8 MiB — потолок на тело ответа
|
||||
)
|
||||
|
||||
// openAICompat — клиент OpenAI-совместимого Chat Completions API.
|
||||
type openAICompat struct {
|
||||
endpoint string // полный URL .../chat/completions
|
||||
hc *http.Client
|
||||
apiKey string
|
||||
model string
|
||||
retryWait time.Duration // базовая пауза между ретраями (0 в тестах)
|
||||
}
|
||||
|
||||
// newOpenAICompat собирает клиент из конфига.
|
||||
func newOpenAICompat(cfg Config) (*openAICompat, error) {
|
||||
if cfg.BaseURL == "" {
|
||||
return nil, fmt.Errorf("llm: empty base_url")
|
||||
}
|
||||
if cfg.Model == "" {
|
||||
return nil, fmt.Errorf("llm: empty model")
|
||||
}
|
||||
|
||||
timeout := cfg.Timeout
|
||||
if timeout == 0 {
|
||||
timeout = defaultTimeout
|
||||
}
|
||||
|
||||
transport := http.DefaultTransport
|
||||
if cfg.Proxy != "" {
|
||||
proxyURL, err := url.Parse(cfg.Proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("llm: parse proxy %q: %w", cfg.Proxy, err)
|
||||
}
|
||||
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
|
||||
}
|
||||
|
||||
return &openAICompat{
|
||||
endpoint: strings.TrimRight(cfg.BaseURL, "/") + "/chat/completions",
|
||||
hc: &http.Client{Timeout: timeout, Transport: transport},
|
||||
apiKey: cfg.APIKey,
|
||||
model: cfg.Model,
|
||||
retryWait: baseRetryWait,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// chatRequest — тело запроса Chat Completions.
|
||||
type chatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []chatMessage `json:"messages"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
ResponseFormat *respFormat `json:"response_format,omitempty"`
|
||||
}
|
||||
|
||||
type chatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type respFormat struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// chatResponse — нужное подмножество ответа Chat Completions.
|
||||
type chatResponse struct {
|
||||
Model string `json:"model"`
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
Cost float64 `json:"cost"`
|
||||
} `json:"usage"`
|
||||
// Некоторые шлюзы кладут ошибку в тело при HTTP 200.
|
||||
Error *struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
// Complete выполняет один вызов модели с транзиентными ретраями на сетевых
|
||||
// сбоях, 429 и 5xx. 4xx (кроме 429) — постоянная ошибка, без ретраев.
|
||||
func (c *openAICompat) Complete(ctx context.Context, req Request) (Response, error) {
|
||||
if len(req.Messages) == 0 {
|
||||
return Response{}, fmt.Errorf("llm: empty messages")
|
||||
}
|
||||
|
||||
body, err := json.Marshal(c.buildRequest(req))
|
||||
if err != nil {
|
||||
return Response{}, fmt.Errorf("llm: marshal request: %w", err)
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||
if attempt > 1 {
|
||||
if err := c.wait(ctx, attempt); err != nil {
|
||||
return Response{}, err
|
||||
}
|
||||
}
|
||||
|
||||
resp, retryable, err := c.do(ctx, body)
|
||||
if err == nil {
|
||||
return resp, nil
|
||||
}
|
||||
lastErr = err
|
||||
if !retryable {
|
||||
return Response{}, err
|
||||
}
|
||||
}
|
||||
return Response{}, fmt.Errorf("llm: exhausted %d attempts: %w", maxAttempts, lastErr)
|
||||
}
|
||||
|
||||
func (c *openAICompat) buildRequest(req Request) chatRequest {
|
||||
msgs := make([]chatMessage, len(req.Messages))
|
||||
for i, m := range req.Messages {
|
||||
msgs[i] = chatMessage{Role: string(m.Role), Content: m.Content}
|
||||
}
|
||||
cr := chatRequest{
|
||||
Model: c.model,
|
||||
Messages: msgs,
|
||||
Temperature: req.Temperature,
|
||||
MaxTokens: req.MaxTokens,
|
||||
}
|
||||
if req.JSONMode {
|
||||
cr.ResponseFormat = &respFormat{Type: "json_object"}
|
||||
}
|
||||
return cr
|
||||
}
|
||||
|
||||
// do делает один HTTP-запрос. Второй результат — можно ли ретраить ошибку.
|
||||
func (c *openAICompat) do(ctx context.Context, body []byte) (Response, bool, error) {
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.endpoint, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return Response{}, false, fmt.Errorf("llm: build request: %w", err)
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
if c.apiKey != "" {
|
||||
httpReq.Header.Set("Authorization", "Bearer "+c.apiKey)
|
||||
}
|
||||
|
||||
resp, err := c.hc.Do(httpReq)
|
||||
if err != nil {
|
||||
// Сетевой сбой — ретраибелен (ctx.Err проверит wait на следующем заходе).
|
||||
return Response{}, true, fmt.Errorf("llm: request: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
raw, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBody))
|
||||
if err != nil {
|
||||
return Response{}, true, fmt.Errorf("llm: read body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
retryable := resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode >= 500
|
||||
return Response{}, retryable, fmt.Errorf("llm: status %d: %s",
|
||||
resp.StatusCode, snippet(raw))
|
||||
}
|
||||
|
||||
var cr chatResponse
|
||||
if err := json.Unmarshal(raw, &cr); err != nil {
|
||||
return Response{}, false, fmt.Errorf("llm: decode response: %w (body: %s)", err, snippet(raw))
|
||||
}
|
||||
if cr.Error != nil && cr.Error.Message != "" {
|
||||
return Response{}, false, fmt.Errorf("llm: provider error: %s", cr.Error.Message)
|
||||
}
|
||||
if len(cr.Choices) == 0 {
|
||||
return Response{}, false, fmt.Errorf("llm: empty choices (body: %s)", snippet(raw))
|
||||
}
|
||||
|
||||
return Response{
|
||||
Content: cr.Choices[0].Message.Content,
|
||||
Model: cr.Model,
|
||||
Usage: Usage{
|
||||
PromptTokens: cr.Usage.PromptTokens,
|
||||
CompletionTokens: cr.Usage.CompletionTokens,
|
||||
TotalTokens: cr.Usage.TotalTokens,
|
||||
Cost: cr.Usage.Cost,
|
||||
},
|
||||
}, false, nil
|
||||
}
|
||||
|
||||
// wait выдерживает паузу перед ретраем (линейный backoff), уважая ctx.
|
||||
func (c *openAICompat) wait(ctx context.Context, attempt int) error {
|
||||
d := c.retryWait * time.Duration(attempt-1)
|
||||
if d <= 0 {
|
||||
return ctx.Err()
|
||||
}
|
||||
t := time.NewTimer(d)
|
||||
defer t.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("llm: %w", ctx.Err())
|
||||
case <-t.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// snippet обрезает тело для сообщения об ошибке.
|
||||
func snippet(b []byte) string {
|
||||
const max = 300
|
||||
s := strings.TrimSpace(string(b))
|
||||
if len(s) > max {
|
||||
return s[:max] + "…"
|
||||
}
|
||||
return s
|
||||
}
|
||||
Reference in New Issue
Block a user