Files
jellybit/internal/llm/llm_test.go
T

237 lines
7.0 KiB
Go

package llm
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
)
// newTestProvider собирает openai-compat клиент на адрес стенда и убирает
// паузы между ретраями, чтобы тесты не висели.
func newTestProvider(t *testing.T, baseURL, apiKey string) *openAICompat {
t.Helper()
p, err := newOpenAICompat(Config{
Type: "openai-compat",
BaseURL: baseURL,
APIKey: apiKey,
Model: "test-model",
})
if err != nil {
t.Fatalf("newOpenAICompat: %v", err)
}
p.retryWait = 0
return p
}
func TestComplete_Success(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("Authorization"); got != "Bearer secret" {
t.Errorf("Authorization = %q, want Bearer secret", got)
}
body, _ := io.ReadAll(r.Body)
var req chatRequest
if err := json.Unmarshal(body, &req); err != nil {
t.Fatalf("decode request: %v", err)
}
if req.Model != "test-model" {
t.Errorf("model = %q, want test-model", req.Model)
}
if req.ResponseFormat == nil || req.ResponseFormat.Type != "json_object" {
t.Errorf("response_format = %+v, want json_object", req.ResponseFormat)
}
_, _ = io.WriteString(w, `{"model":"resolved-model",
"choices":[{"message":{"content":"{\"ok\":true}"},"finish_reason":"stop"}],
"usage":{"prompt_tokens":10,"completion_tokens":3,"total_tokens":13,"cost":0.0001}}`)
}))
defer srv.Close()
p := newTestProvider(t, srv.URL, "secret")
resp, err := p.Complete(context.Background(), Request{
Messages: []Message{{Role: RoleUser, Content: "hi"}},
JSONMode: true,
})
if err != nil {
t.Fatalf("Complete: %v", err)
}
if resp.Content != `{"ok":true}` {
t.Errorf("content = %q", resp.Content)
}
if resp.Model != "resolved-model" {
t.Errorf("model = %q", resp.Model)
}
if resp.Usage.TotalTokens != 13 || resp.Usage.Cost != 0.0001 {
t.Errorf("usage = %+v", resp.Usage)
}
}
func TestComplete_NoJSONModeOmitsResponseFormat(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
var req chatRequest
_ = json.Unmarshal(body, &req)
if req.ResponseFormat != nil {
t.Errorf("response_format should be omitted, got %+v", req.ResponseFormat)
}
if !strings.Contains(string(body), `"temperature":0`) {
t.Errorf("temperature 0 must be sent explicitly, body: %s", body)
}
_, _ = io.WriteString(w, `{"choices":[{"message":{"content":"hello"}}]}`)
}))
defer srv.Close()
p := newTestProvider(t, srv.URL, "")
zero := 0.0
resp, err := p.Complete(context.Background(), Request{
Messages: []Message{{Role: RoleUser, Content: "hi"}},
Temperature: &zero,
})
if err != nil {
t.Fatalf("Complete: %v", err)
}
if resp.Content != "hello" {
t.Errorf("content = %q", resp.Content)
}
}
func TestComplete_RetriesOn5xxThenSucceeds(t *testing.T) {
var calls atomic.Int32
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if calls.Add(1) <= 2 {
w.WriteHeader(http.StatusBadGateway)
_, _ = io.WriteString(w, "upstream down")
return
}
_, _ = io.WriteString(w, `{"choices":[{"message":{"content":"recovered"}}]}`)
}))
defer srv.Close()
p := newTestProvider(t, srv.URL, "")
resp, err := p.Complete(context.Background(), Request{
Messages: []Message{{Role: RoleUser, Content: "hi"}},
})
if err != nil {
t.Fatalf("Complete: %v", err)
}
if resp.Content != "recovered" {
t.Errorf("content = %q", resp.Content)
}
if got := calls.Load(); got != 3 {
t.Errorf("calls = %d, want 3", got)
}
}
func TestComplete_429IsRetryable(t *testing.T) {
var calls atomic.Int32
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if calls.Add(1) == 1 {
w.WriteHeader(http.StatusTooManyRequests)
return
}
_, _ = io.WriteString(w, `{"choices":[{"message":{"content":"ok"}}]}`)
}))
defer srv.Close()
p := newTestProvider(t, srv.URL, "")
if _, err := p.Complete(context.Background(), Request{
Messages: []Message{{Role: RoleUser, Content: "hi"}},
}); err != nil {
t.Fatalf("Complete: %v", err)
}
if got := calls.Load(); got != 2 {
t.Errorf("calls = %d, want 2", got)
}
}
func TestComplete_4xxNotRetried(t *testing.T) {
var calls atomic.Int32
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
calls.Add(1)
w.WriteHeader(http.StatusBadRequest)
_, _ = io.WriteString(w, `{"error":{"message":"bad model"}}`)
}))
defer srv.Close()
p := newTestProvider(t, srv.URL, "")
_, err := p.Complete(context.Background(), Request{
Messages: []Message{{Role: RoleUser, Content: "hi"}},
})
if err == nil {
t.Fatal("Complete: want error on 400")
}
if !strings.Contains(err.Error(), "status 400") {
t.Errorf("err = %v, want status 400", err)
}
if got := calls.Load(); got != 1 {
t.Errorf("calls = %d, want 1 (no retry on 4xx)", got)
}
}
func TestComplete_ExhaustsRetries(t *testing.T) {
var calls atomic.Int32
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
calls.Add(1)
w.WriteHeader(http.StatusServiceUnavailable)
}))
defer srv.Close()
p := newTestProvider(t, srv.URL, "")
_, err := p.Complete(context.Background(), Request{
Messages: []Message{{Role: RoleUser, Content: "hi"}},
})
if err == nil {
t.Fatal("want error after exhausting retries")
}
if got := calls.Load(); got != maxAttempts {
t.Errorf("calls = %d, want %d", got, maxAttempts)
}
}
func TestComplete_ProviderErrorInBody(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// HTTP 200, но ошибка в теле — частый паттерн шлюзов.
_, _ = io.WriteString(w, `{"error":{"message":"context length exceeded"}}`)
}))
defer srv.Close()
p := newTestProvider(t, srv.URL, "")
_, err := p.Complete(context.Background(), Request{
Messages: []Message{{Role: RoleUser, Content: "hi"}},
})
if err == nil || !strings.Contains(err.Error(), "context length exceeded") {
t.Fatalf("err = %v, want provider error", err)
}
}
func TestComplete_EmptyMessages(t *testing.T) {
p := newTestProvider(t, "http://example.invalid", "")
if _, err := p.Complete(context.Background(), Request{}); err == nil {
t.Fatal("want error on empty messages")
}
}
func TestNew_UnknownType(t *testing.T) {
if _, err := New(Config{Type: "anthropic", Model: "x", BaseURL: "http://x"}); err == nil {
t.Fatal("want error for unknown type")
}
if _, err := New(Config{Type: ""}); err == nil {
t.Fatal("want error for empty type")
}
}
func TestNew_OpenAICompatValidation(t *testing.T) {
if _, err := New(Config{Type: "openai-compat", Model: "x"}); err == nil {
t.Fatal("want error for empty base_url")
}
if _, err := New(Config{Type: "openai-compat", BaseURL: "http://x"}); err == nil {
t.Fatal("want error for empty model")
}
if _, err := New(Config{Type: "openai-compat", BaseURL: "http://x", Model: "m"}); err != nil {
t.Fatalf("unexpected error: %v", err)
}
}