237 lines
7.1 KiB
Go
237 lines
7.1 KiB
Go
package llm
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"sync/atomic"
|
|
"testing"
|
|
)
|
|
|
|
// newTestProvider собирает openai-compat клиент на адрес стенда и убирает
|
|
// паузы между ретраями, чтобы тесты не висели.
|
|
func newTestProvider(t *testing.T, baseURL, apiKey string) *openAICompat {
|
|
t.Helper()
|
|
p, err := newOpenAICompat(Config{
|
|
Type: "openai-compat",
|
|
BaseURL: baseURL,
|
|
APIKey: apiKey,
|
|
Model: "test-model",
|
|
}, nil)
|
|
if err != nil {
|
|
t.Fatalf("newOpenAICompat: %v", err)
|
|
}
|
|
p.retryWait = 0
|
|
return p
|
|
}
|
|
|
|
func TestComplete_Success(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if got := r.Header.Get("Authorization"); got != "Bearer secret" {
|
|
t.Errorf("Authorization = %q, want Bearer secret", got)
|
|
}
|
|
body, _ := io.ReadAll(r.Body)
|
|
var req chatRequest
|
|
if err := json.Unmarshal(body, &req); err != nil {
|
|
t.Fatalf("decode request: %v", err)
|
|
}
|
|
if req.Model != "test-model" {
|
|
t.Errorf("model = %q, want test-model", req.Model)
|
|
}
|
|
if req.ResponseFormat == nil || req.ResponseFormat.Type != "json_object" {
|
|
t.Errorf("response_format = %+v, want json_object", req.ResponseFormat)
|
|
}
|
|
_, _ = io.WriteString(w, `{"model":"resolved-model",
|
|
"choices":[{"message":{"content":"{\"ok\":true}"},"finish_reason":"stop"}],
|
|
"usage":{"prompt_tokens":10,"completion_tokens":3,"total_tokens":13,"cost":0.0001}}`)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
p := newTestProvider(t, srv.URL, "secret")
|
|
resp, err := p.Complete(context.Background(), Request{
|
|
Messages: []Message{{Role: RoleUser, Content: "hi"}},
|
|
JSONMode: true,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Complete: %v", err)
|
|
}
|
|
if resp.Content != `{"ok":true}` {
|
|
t.Errorf("content = %q", resp.Content)
|
|
}
|
|
if resp.Model != "resolved-model" {
|
|
t.Errorf("model = %q", resp.Model)
|
|
}
|
|
if resp.Usage.TotalTokens != 13 || resp.Usage.Cost != 0.0001 {
|
|
t.Errorf("usage = %+v", resp.Usage)
|
|
}
|
|
}
|
|
|
|
func TestComplete_NoJSONModeOmitsResponseFormat(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
body, _ := io.ReadAll(r.Body)
|
|
var req chatRequest
|
|
_ = json.Unmarshal(body, &req)
|
|
if req.ResponseFormat != nil {
|
|
t.Errorf("response_format should be omitted, got %+v", req.ResponseFormat)
|
|
}
|
|
if !strings.Contains(string(body), `"temperature":0`) {
|
|
t.Errorf("temperature 0 must be sent explicitly, body: %s", body)
|
|
}
|
|
_, _ = io.WriteString(w, `{"choices":[{"message":{"content":"hello"}}]}`)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
p := newTestProvider(t, srv.URL, "")
|
|
zero := 0.0
|
|
resp, err := p.Complete(context.Background(), Request{
|
|
Messages: []Message{{Role: RoleUser, Content: "hi"}},
|
|
Temperature: &zero,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Complete: %v", err)
|
|
}
|
|
if resp.Content != "hello" {
|
|
t.Errorf("content = %q", resp.Content)
|
|
}
|
|
}
|
|
|
|
func TestComplete_RetriesOn5xxThenSucceeds(t *testing.T) {
|
|
var calls atomic.Int32
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if calls.Add(1) <= 2 {
|
|
w.WriteHeader(http.StatusBadGateway)
|
|
_, _ = io.WriteString(w, "upstream down")
|
|
return
|
|
}
|
|
_, _ = io.WriteString(w, `{"choices":[{"message":{"content":"recovered"}}]}`)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
p := newTestProvider(t, srv.URL, "")
|
|
resp, err := p.Complete(context.Background(), Request{
|
|
Messages: []Message{{Role: RoleUser, Content: "hi"}},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Complete: %v", err)
|
|
}
|
|
if resp.Content != "recovered" {
|
|
t.Errorf("content = %q", resp.Content)
|
|
}
|
|
if got := calls.Load(); got != 3 {
|
|
t.Errorf("calls = %d, want 3", got)
|
|
}
|
|
}
|
|
|
|
func TestComplete_429IsRetryable(t *testing.T) {
|
|
var calls atomic.Int32
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if calls.Add(1) == 1 {
|
|
w.WriteHeader(http.StatusTooManyRequests)
|
|
return
|
|
}
|
|
_, _ = io.WriteString(w, `{"choices":[{"message":{"content":"ok"}}]}`)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
p := newTestProvider(t, srv.URL, "")
|
|
if _, err := p.Complete(context.Background(), Request{
|
|
Messages: []Message{{Role: RoleUser, Content: "hi"}},
|
|
}); err != nil {
|
|
t.Fatalf("Complete: %v", err)
|
|
}
|
|
if got := calls.Load(); got != 2 {
|
|
t.Errorf("calls = %d, want 2", got)
|
|
}
|
|
}
|
|
|
|
func TestComplete_4xxNotRetried(t *testing.T) {
|
|
var calls atomic.Int32
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
calls.Add(1)
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
_, _ = io.WriteString(w, `{"error":{"message":"bad model"}}`)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
p := newTestProvider(t, srv.URL, "")
|
|
_, err := p.Complete(context.Background(), Request{
|
|
Messages: []Message{{Role: RoleUser, Content: "hi"}},
|
|
})
|
|
if err == nil {
|
|
t.Fatal("Complete: want error on 400")
|
|
}
|
|
if !strings.Contains(err.Error(), "status 400") {
|
|
t.Errorf("err = %v, want status 400", err)
|
|
}
|
|
if got := calls.Load(); got != 1 {
|
|
t.Errorf("calls = %d, want 1 (no retry on 4xx)", got)
|
|
}
|
|
}
|
|
|
|
func TestComplete_ExhaustsRetries(t *testing.T) {
|
|
var calls atomic.Int32
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
calls.Add(1)
|
|
w.WriteHeader(http.StatusServiceUnavailable)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
p := newTestProvider(t, srv.URL, "")
|
|
_, err := p.Complete(context.Background(), Request{
|
|
Messages: []Message{{Role: RoleUser, Content: "hi"}},
|
|
})
|
|
if err == nil {
|
|
t.Fatal("want error after exhausting retries")
|
|
}
|
|
if got := calls.Load(); got != maxAttempts {
|
|
t.Errorf("calls = %d, want %d", got, maxAttempts)
|
|
}
|
|
}
|
|
|
|
func TestComplete_ProviderErrorInBody(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// HTTP 200, но ошибка в теле — частый паттерн шлюзов.
|
|
_, _ = io.WriteString(w, `{"error":{"message":"context length exceeded"}}`)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
p := newTestProvider(t, srv.URL, "")
|
|
_, err := p.Complete(context.Background(), Request{
|
|
Messages: []Message{{Role: RoleUser, Content: "hi"}},
|
|
})
|
|
if err == nil || !strings.Contains(err.Error(), "context length exceeded") {
|
|
t.Fatalf("err = %v, want provider error", err)
|
|
}
|
|
}
|
|
|
|
func TestComplete_EmptyMessages(t *testing.T) {
|
|
p := newTestProvider(t, "http://example.invalid", "")
|
|
if _, err := p.Complete(context.Background(), Request{}); err == nil {
|
|
t.Fatal("want error on empty messages")
|
|
}
|
|
}
|
|
|
|
func TestNew_UnknownType(t *testing.T) {
|
|
if _, err := New(Config{Type: "anthropic", Model: "x", BaseURL: "http://x"}, nil); err == nil {
|
|
t.Fatal("want error for unknown type")
|
|
}
|
|
if _, err := New(Config{Type: ""}, nil); err == nil {
|
|
t.Fatal("want error for empty type")
|
|
}
|
|
}
|
|
|
|
func TestNew_OpenAICompatValidation(t *testing.T) {
|
|
if _, err := New(Config{Type: "openai-compat", Model: "x"}, nil); err == nil {
|
|
t.Fatal("want error for empty base_url")
|
|
}
|
|
if _, err := New(Config{Type: "openai-compat", BaseURL: "http://x"}, nil); err == nil {
|
|
t.Fatal("want error for empty model")
|
|
}
|
|
if _, err := New(Config{Type: "openai-compat", BaseURL: "http://x", Model: "m"}, nil); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
}
|