184 lines
5.4 KiB
Python
184 lines
5.4 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
from contextlib import asynccontextmanager
|
|
from typing import Any
|
|
|
|
import httpx
|
|
from fastapi import Depends, FastAPI, HTTPException, Request
|
|
|
|
from app.exceptions import ProviderAPICallError, ProviderConfigurationError
|
|
from app.models import (
|
|
DataImportAnalysisJobAck,
|
|
DataImportAnalysisJobRequest,
|
|
LLMRequest,
|
|
LLMResponse,
|
|
)
|
|
from app.services import LLMGateway
|
|
from app.services.import_analysis import process_import_analysis_job
|
|
|
|
|
|
def _configure_logging() -> None:
|
|
level_name = os.getenv("LOG_LEVEL", "INFO").upper()
|
|
level = getattr(logging, level_name, logging.INFO)
|
|
log_format = os.getenv(
|
|
"LOG_FORMAT",
|
|
"%(asctime)s %(levelname)s %(name)s:%(lineno)d %(message)s",
|
|
)
|
|
|
|
root = logging.getLogger()
|
|
|
|
if not root.handlers:
|
|
logging.basicConfig(level=level, format=log_format)
|
|
else:
|
|
root.setLevel(level)
|
|
formatter = logging.Formatter(log_format)
|
|
for handler in root.handlers:
|
|
handler.setLevel(level)
|
|
handler.setFormatter(formatter)
|
|
|
|
|
|
_configure_logging()
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _env_bool(name: str, default: bool) -> bool:
|
|
raw = os.getenv(name)
|
|
if raw is None:
|
|
return default
|
|
return raw.strip().lower() in {"1", "true", "yes", "on"}
|
|
|
|
|
|
def _env_float(name: str, default: float) -> float:
|
|
raw = os.getenv(name)
|
|
if raw is None:
|
|
return default
|
|
try:
|
|
return float(raw)
|
|
except ValueError:
|
|
logger.warning("Invalid value for %s=%r, using default %.2f", name, raw, default)
|
|
return default
|
|
|
|
|
|
def _parse_proxy_config(raw: str | None) -> dict[str, str] | str | None:
|
|
if raw is None:
|
|
return None
|
|
|
|
cleaned = raw.strip()
|
|
if not cleaned:
|
|
return None
|
|
|
|
# Support comma-separated key=value pairs for scheme-specific proxies.
|
|
if "=" in cleaned:
|
|
proxies: dict[str, str] = {}
|
|
for part in cleaned.split(","):
|
|
key, sep, value = part.partition("=")
|
|
if not sep:
|
|
continue
|
|
key = key.strip()
|
|
value = value.strip()
|
|
if key and value:
|
|
proxies[key] = value
|
|
if proxies:
|
|
return proxies
|
|
|
|
return cleaned
|
|
|
|
|
|
def _create_http_client() -> httpx.AsyncClient:
|
|
timeout_seconds = _env_float("HTTP_CLIENT_TIMEOUT", 30.0)
|
|
trust_env = _env_bool("HTTP_CLIENT_TRUST_ENV", True)
|
|
proxies = _parse_proxy_config(os.getenv("HTTP_CLIENT_PROXY"))
|
|
client_kwargs: dict[str, object] = {
|
|
"timeout": httpx.Timeout(timeout_seconds),
|
|
"trust_env": trust_env,
|
|
}
|
|
if proxies:
|
|
client_kwargs["proxies"] = proxies
|
|
return httpx.AsyncClient(**client_kwargs)
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
client = _create_http_client()
|
|
gateway = LLMGateway()
|
|
try:
|
|
app.state.http_client = client # type: ignore[attr-defined]
|
|
app.state.gateway = gateway # type: ignore[attr-defined]
|
|
yield
|
|
finally:
|
|
await client.aclose()
|
|
|
|
|
|
def create_app() -> FastAPI:
|
|
application = FastAPI(
|
|
title="Unified LLM Gateway",
|
|
version="0.1.0",
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
@application.post(
|
|
"/v1/chat/completions",
|
|
response_model=LLMResponse,
|
|
summary="Dispatch chat completion to upstream provider",
|
|
)
|
|
async def create_chat_completion(
|
|
payload: LLMRequest,
|
|
gateway: LLMGateway = Depends(get_gateway),
|
|
client: httpx.AsyncClient = Depends(get_http_client),
|
|
) -> LLMResponse:
|
|
try:
|
|
return await gateway.chat(payload, client)
|
|
except ProviderConfigurationError as exc:
|
|
logger.error("Provider configuration error: %s", exc, exc_info=True)
|
|
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
|
except ProviderAPICallError as exc:
|
|
status_code = exc.status_code or 502
|
|
log_detail = exc.response_text or str(exc)
|
|
logger.error(
|
|
"Provider API call error (status %s): %s",
|
|
status_code,
|
|
log_detail,
|
|
exc_info=True,
|
|
)
|
|
raise HTTPException(status_code=status_code, detail=str(exc)) from exc
|
|
|
|
@application.post(
|
|
"/v1/import/analyze",
|
|
response_model=DataImportAnalysisJobAck,
|
|
summary="Schedule async import analysis and notify via callback",
|
|
status_code=202,
|
|
)
|
|
async def analyze_import_data(
|
|
payload: DataImportAnalysisJobRequest,
|
|
client: httpx.AsyncClient = Depends(get_http_client),
|
|
) -> DataImportAnalysisJobAck:
|
|
request_copy = payload.model_copy(deep=True)
|
|
|
|
async def _runner() -> None:
|
|
await process_import_analysis_job(request_copy, client)
|
|
|
|
asyncio.create_task(_runner())
|
|
|
|
return DataImportAnalysisJobAck(import_record_id=payload.import_record_id, status="accepted")
|
|
|
|
@application.post("/__mock__/import-callback")
|
|
async def mock_import_callback(payload: dict[str, Any]) -> dict[str, str]:
|
|
logger.info("Received import analysis callback: %s", payload)
|
|
return {"status": "received"}
|
|
|
|
return application
|
|
|
|
|
|
async def get_gateway(request: Request) -> LLMGateway:
|
|
return request.app.state.gateway # type: ignore[return-value, attr-defined]
|
|
|
|
|
|
async def get_http_client(request: Request) -> httpx.AsyncClient:
|
|
return request.app.state.http_client # type: ignore[return-value, attr-defined]
|
|
|
|
|
|
app = create_app()
|