切换成new-api方式进行llm调用

This commit is contained in:
zhaoawd
2025-12-08 23:11:43 +08:00
parent eefaf91ed1
commit f261121845
7 changed files with 145 additions and 57 deletions

View File

@ -1,53 +1,93 @@
from __future__ import annotations
import os
from typing import Dict, Type
import logging
import httpx
from pydantic import ValidationError
from app.exceptions import ProviderConfigurationError
from app.models import LLMProvider, LLMRequest, LLMResponse
from app.providers import (
AnthropicProvider,
DeepSeekProvider,
GeminiProvider,
LLMProviderClient,
OpenAIProvider,
OpenRouterProvider,
QwenProvider,
)
from app.exceptions import ProviderAPICallError
from app.models import LLMChoice, LLMMessage, LLMRequest, LLMResponse
from app.settings import NEW_API_AUTH_TOKEN, NEW_API_BASE_URL
logger = logging.getLogger(__name__)
class LLMGateway:
"""Simple registry that dispatches chat requests to provider clients."""
"""Forward chat requests to the configured new-api component."""
def __init__(self) -> None:
self._providers: Dict[LLMProvider, LLMProviderClient] = {}
self._factory: Dict[LLMProvider, Type[LLMProviderClient]] = {
LLMProvider.OPENAI: OpenAIProvider,
LLMProvider.ANTHROPIC: AnthropicProvider,
LLMProvider.OPENROUTER: OpenRouterProvider,
LLMProvider.GEMINI: GeminiProvider,
LLMProvider.QWEN: QwenProvider,
LLMProvider.DEEPSEEK: DeepSeekProvider,
}
def get_provider(self, provider: LLMProvider) -> LLMProviderClient:
if provider not in self._factory:
raise ProviderConfigurationError(f"Unsupported provider '{provider.value}'.")
if provider not in self._providers:
self._providers[provider] = self._build_provider(provider)
return self._providers[provider]
def _build_provider(self, provider: LLMProvider) -> LLMProviderClient:
provider_cls = self._factory[provider]
api_key_env = getattr(provider_cls, "api_key_env", None)
api_key = os.getenv(api_key_env) if api_key_env else None
return provider_cls(api_key)
def __init__(
self,
*,
base_url: str | None = None,
auth_token: str | None = None,
) -> None:
resolved_base = base_url or NEW_API_BASE_URL
self._base_url = resolved_base.rstrip("/")
self._auth_token = auth_token or NEW_API_AUTH_TOKEN
async def chat(
self, request: LLMRequest, client: httpx.AsyncClient
) -> LLMResponse:
provider_client = self.get_provider(request.provider)
return await provider_client.chat(request, client)
url = f"{self._base_url}/v1/chat/completions"
payload = request.model_dump(mode="json", exclude_none=True)
headers = {"Content-Type": "application/json"}
if self._auth_token:
headers["Authorization"] = f"Bearer {self._auth_token}"
logger.info("Forwarding chat request to new-api at %s", url)
try:
response = await client.post(url, json=payload, headers=headers)
response.raise_for_status()
except httpx.HTTPStatusError as exc:
status_code = exc.response.status_code if exc.response else None
response_text = exc.response.text if exc.response else ""
logger.error(
"new-api upstream returned %s: %s",
status_code,
response_text,
exc_info=True,
)
raise ProviderAPICallError(
"Chat completion request failed.",
status_code=status_code,
response_text=response_text,
) from exc
except httpx.HTTPError as exc:
logger.error("new-api transport error: %s", exc, exc_info=True)
raise ProviderAPICallError(f"Chat completion request failed: {exc}") from exc
try:
data = response.json()
except ValueError as exc:
logger.error("new-api responded with invalid JSON.", exc_info=True)
raise ProviderAPICallError(
"Chat completion response was not valid JSON."
) from exc
logger.info("new-api payload: %s", data)
normalized_choices: list[LLMChoice] = []
for idx, choice in enumerate(data.get("choices", []) or []):
message_payload = choice.get("message") or {}
message = LLMMessage(
role=message_payload.get("role", "assistant"),
content=message_payload.get("content", ""),
)
normalized_choices.append(
LLMChoice(index=choice.get("index", idx), message=message)
)
try:
normalized_response = LLMResponse(
provider=request.provider,
model=data.get("model", request.model),
choices=normalized_choices,
raw=data,
)
return normalized_response
except ValidationError as exc:
logger.error(
"new-api response did not match expected schema: %s", data, exc_info=True
)
raise ProviderAPICallError(
"Chat completion response was not in the expected format."
) from exc

View File

@ -22,14 +22,24 @@ from app.models import (
LLMResponse,
LLMRole,
)
from app.settings import DEFAULT_IMPORT_MODEL, get_supported_import_models
from app.settings import (
DEFAULT_IMPORT_MODEL,
NEW_API_AUTH_TOKEN,
NEW_API_BASE_URL,
get_supported_import_models,
)
from app.utils.llm_usage import extract_usage
logger = logging.getLogger(__name__)
IMPORT_GATEWAY_BASE_URL = os.getenv(
"IMPORT_GATEWAY_BASE_URL", "http://localhost:8000"
)
IMPORT_GATEWAY_BASE_URL = os.getenv("IMPORT_GATEWAY_BASE_URL", NEW_API_BASE_URL)
def build_import_gateway_headers() -> dict[str, str]:
headers = {"Content-Type": "application/json"}
if NEW_API_AUTH_TOKEN:
headers["Authorization"] = f"Bearer {NEW_API_AUTH_TOKEN}"
return headers
def _env_float(name: str, default: float) -> float:
@ -314,16 +324,18 @@ async def dispatch_import_analysis_job(
url = f"{IMPORT_GATEWAY_BASE_URL.rstrip('/')}/v1/chat/completions"
logger.info(
"Dispatching import %s to %s: %s",
"Dispatching import %s to %s using provider=%s model=%s",
request.import_record_id,
url,
json.dumps(payload, ensure_ascii=False),
payload.get("provider"),
payload.get("model"),
)
timeout = httpx.Timeout(IMPORT_CHAT_TIMEOUT_SECONDS)
headers = build_import_gateway_headers()
try:
response = await client.post(url, json=payload, timeout=timeout)
response = await client.post(url, json=payload, timeout=timeout, headers=headers)
response.raise_for_status()
except httpx.HTTPStatusError as exc:
body_preview = ""
@ -348,9 +360,10 @@ async def dispatch_import_analysis_job(
response.status_code,
)
logger.info(
"LLM response for %s: %s",
"LLM response received for %s (status %s, choices=%s)",
request.import_record_id,
json.dumps(response_data, ensure_ascii=False),
response.status_code,
len(response_data.get("choices") or []),
)
try:
@ -404,6 +417,7 @@ async def process_import_analysis_job(
request: DataImportAnalysisJobRequest,
client: httpx.AsyncClient,
) -> None:
# Run the import analysis and ensure the callback fires regardless of success/failure.
try:
payload = await dispatch_import_analysis_job(request, client)
except ProviderAPICallError as exc:

View File

@ -24,6 +24,7 @@ from app.services import LLMGateway
from app.settings import DEFAULT_IMPORT_MODEL
from app.services.import_analysis import (
IMPORT_GATEWAY_BASE_URL,
build_import_gateway_headers,
resolve_provider_from_model,
)
from app.utils.llm_usage import extract_usage as extract_llm_usage
@ -532,6 +533,7 @@ async def _call_chat_completions(
temperature: float = 0.2,
timeout_seconds: Optional[float] = None,
) -> Any:
# Normalize model spec to provider+model and issue the unified chat call.
provider, model_name = resolve_provider_from_model(model_spec)
payload = {
"provider": provider.value,
@ -545,16 +547,17 @@ async def _call_chat_completions(
payload_size_bytes = len(json.dumps(payload, ensure_ascii=False).encode("utf-8"))
url = f"{IMPORT_GATEWAY_BASE_URL.rstrip('/')}/v1/chat/completions"
headers = build_import_gateway_headers()
try:
# log the request whole info
logger.info(
"Calling chat completions API %s with model %s and size %s and payload %s",
"Calling chat completions API %s with model=%s payload_size=%sB",
url,
model_name,
payload_size_bytes,
payload,
)
response = await client.post(url, json=payload, timeout=timeout_seconds)
response = await client.post(
url, json=payload, timeout=timeout_seconds, headers=headers
)
response.raise_for_status()
except httpx.HTTPError as exc:
@ -703,6 +706,7 @@ async def _run_action_with_callback(
input_payload: Any = None,
model_spec: Optional[str] = None,
) -> Any:
# Execute a pipeline action and always emit a callback capturing success/failure.
if input_payload is not None:
logger.info(
"Pipeline action %s input: %s",

View File

@ -20,7 +20,11 @@ PROVIDER_KEY_ENV_MAP: Dict[str, str] = {
}
DEFAULT_IMPORT_MODEL = os.getenv("DEFAULT_IMPORT_MODEL", "openai:gpt-4.1-mini")
DEFAULT_IMPORT_MODEL = os.getenv("DEFAULT_IMPORT_MODEL", "deepseek:deepseek-chat")
NEW_API_BASE_URL = os.getenv("NEW_API_BASE_URL")
NEW_API_AUTH_TOKEN = os.getenv("NEW_API_AUTH_TOKEN")
RAG_API_BASE_URL = os.getenv("RAG_API_BASE_URL", "http://127.0.0.1:8000")
RAG_API_AUTH_TOKEN = os.getenv("RAG_API_AUTH_TOKEN")
@lru_cache(maxsize=1)