切换成new-api方式进行llm调用

This commit is contained in:
zhaoawd
2025-12-08 23:11:43 +08:00
parent eefaf91ed1
commit f261121845
7 changed files with 145 additions and 57 deletions

View File

@ -1,53 +1,93 @@
from __future__ import annotations
import os
from typing import Dict, Type
import logging
import httpx
from pydantic import ValidationError
from app.exceptions import ProviderConfigurationError
from app.models import LLMProvider, LLMRequest, LLMResponse
from app.providers import (
AnthropicProvider,
DeepSeekProvider,
GeminiProvider,
LLMProviderClient,
OpenAIProvider,
OpenRouterProvider,
QwenProvider,
)
from app.exceptions import ProviderAPICallError
from app.models import LLMChoice, LLMMessage, LLMRequest, LLMResponse
from app.settings import NEW_API_AUTH_TOKEN, NEW_API_BASE_URL
logger = logging.getLogger(__name__)
class LLMGateway:
"""Simple registry that dispatches chat requests to provider clients."""
"""Forward chat requests to the configured new-api component."""
def __init__(self) -> None:
self._providers: Dict[LLMProvider, LLMProviderClient] = {}
self._factory: Dict[LLMProvider, Type[LLMProviderClient]] = {
LLMProvider.OPENAI: OpenAIProvider,
LLMProvider.ANTHROPIC: AnthropicProvider,
LLMProvider.OPENROUTER: OpenRouterProvider,
LLMProvider.GEMINI: GeminiProvider,
LLMProvider.QWEN: QwenProvider,
LLMProvider.DEEPSEEK: DeepSeekProvider,
}
def get_provider(self, provider: LLMProvider) -> LLMProviderClient:
if provider not in self._factory:
raise ProviderConfigurationError(f"Unsupported provider '{provider.value}'.")
if provider not in self._providers:
self._providers[provider] = self._build_provider(provider)
return self._providers[provider]
def _build_provider(self, provider: LLMProvider) -> LLMProviderClient:
provider_cls = self._factory[provider]
api_key_env = getattr(provider_cls, "api_key_env", None)
api_key = os.getenv(api_key_env) if api_key_env else None
return provider_cls(api_key)
def __init__(
self,
*,
base_url: str | None = None,
auth_token: str | None = None,
) -> None:
resolved_base = base_url or NEW_API_BASE_URL
self._base_url = resolved_base.rstrip("/")
self._auth_token = auth_token or NEW_API_AUTH_TOKEN
async def chat(
self, request: LLMRequest, client: httpx.AsyncClient
) -> LLMResponse:
provider_client = self.get_provider(request.provider)
return await provider_client.chat(request, client)
url = f"{self._base_url}/v1/chat/completions"
payload = request.model_dump(mode="json", exclude_none=True)
headers = {"Content-Type": "application/json"}
if self._auth_token:
headers["Authorization"] = f"Bearer {self._auth_token}"
logger.info("Forwarding chat request to new-api at %s", url)
try:
response = await client.post(url, json=payload, headers=headers)
response.raise_for_status()
except httpx.HTTPStatusError as exc:
status_code = exc.response.status_code if exc.response else None
response_text = exc.response.text if exc.response else ""
logger.error(
"new-api upstream returned %s: %s",
status_code,
response_text,
exc_info=True,
)
raise ProviderAPICallError(
"Chat completion request failed.",
status_code=status_code,
response_text=response_text,
) from exc
except httpx.HTTPError as exc:
logger.error("new-api transport error: %s", exc, exc_info=True)
raise ProviderAPICallError(f"Chat completion request failed: {exc}") from exc
try:
data = response.json()
except ValueError as exc:
logger.error("new-api responded with invalid JSON.", exc_info=True)
raise ProviderAPICallError(
"Chat completion response was not valid JSON."
) from exc
logger.info("new-api payload: %s", data)
normalized_choices: list[LLMChoice] = []
for idx, choice in enumerate(data.get("choices", []) or []):
message_payload = choice.get("message") or {}
message = LLMMessage(
role=message_payload.get("role", "assistant"),
content=message_payload.get("content", ""),
)
normalized_choices.append(
LLMChoice(index=choice.get("index", idx), message=message)
)
try:
normalized_response = LLMResponse(
provider=request.provider,
model=data.get("model", request.model),
choices=normalized_choices,
raw=data,
)
return normalized_response
except ValidationError as exc:
logger.error(
"new-api response did not match expected schema: %s", data, exc_info=True
)
raise ProviderAPICallError(
"Chat completion response was not in the expected format."
) from exc