94 lines
3.4 KiB
Python
94 lines
3.4 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
|
|
import httpx
|
|
from pydantic import ValidationError
|
|
|
|
from app.exceptions import ProviderAPICallError
|
|
from app.models import LLMChoice, LLMMessage, LLMRequest, LLMResponse
|
|
from app.settings import NEW_API_AUTH_TOKEN, NEW_API_BASE_URL
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class LLMGateway:
|
|
"""Forward chat requests to the configured new-api component."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
base_url: str | None = None,
|
|
auth_token: str | None = None,
|
|
) -> None:
|
|
resolved_base = base_url or NEW_API_BASE_URL
|
|
self._base_url = resolved_base.rstrip("/")
|
|
self._auth_token = auth_token or NEW_API_AUTH_TOKEN
|
|
|
|
async def chat(
|
|
self, request: LLMRequest, client: httpx.AsyncClient
|
|
) -> LLMResponse:
|
|
url = f"{self._base_url}/v1/chat/completions"
|
|
payload = request.model_dump(mode="json", exclude_none=True)
|
|
headers = {"Content-Type": "application/json"}
|
|
if self._auth_token:
|
|
headers["Authorization"] = f"Bearer {self._auth_token}"
|
|
logger.info("Forwarding chat request to new-api at %s", url)
|
|
try:
|
|
response = await client.post(url, json=payload, headers=headers)
|
|
response.raise_for_status()
|
|
except httpx.HTTPStatusError as exc:
|
|
status_code = exc.response.status_code if exc.response else None
|
|
response_text = exc.response.text if exc.response else ""
|
|
logger.error(
|
|
"new-api upstream returned %s: %s",
|
|
status_code,
|
|
response_text,
|
|
exc_info=True,
|
|
)
|
|
raise ProviderAPICallError(
|
|
"Chat completion request failed.",
|
|
status_code=status_code,
|
|
response_text=response_text,
|
|
) from exc
|
|
except httpx.HTTPError as exc:
|
|
logger.error("new-api transport error: %s", exc, exc_info=True)
|
|
raise ProviderAPICallError(f"Chat completion request failed: {exc}") from exc
|
|
|
|
try:
|
|
data = response.json()
|
|
except ValueError as exc:
|
|
logger.error("new-api responded with invalid JSON.", exc_info=True)
|
|
raise ProviderAPICallError(
|
|
"Chat completion response was not valid JSON."
|
|
) from exc
|
|
|
|
logger.info("new-api payload: %s", data)
|
|
normalized_choices: list[LLMChoice] = []
|
|
for idx, choice in enumerate(data.get("choices", []) or []):
|
|
message_payload = choice.get("message") or {}
|
|
message = LLMMessage(
|
|
role=message_payload.get("role", "assistant"),
|
|
content=message_payload.get("content", ""),
|
|
)
|
|
normalized_choices.append(
|
|
LLMChoice(index=choice.get("index", idx), message=message)
|
|
)
|
|
|
|
try:
|
|
normalized_response = LLMResponse(
|
|
provider=request.provider,
|
|
model=data.get("model", request.model),
|
|
choices=normalized_choices,
|
|
raw=data,
|
|
)
|
|
return normalized_response
|
|
except ValidationError as exc:
|
|
logger.error(
|
|
"new-api response did not match expected schema: %s", data, exc_info=True
|
|
)
|
|
raise ProviderAPICallError(
|
|
"Chat completion response was not in the expected format."
|
|
) from exc
|