Files
data-ge/app/main.py
2025-10-30 22:38:23 +08:00

184 lines
5.4 KiB
Python

from __future__ import annotations
import asyncio
import logging
import os
from contextlib import asynccontextmanager
from typing import Any
import httpx
from fastapi import Depends, FastAPI, HTTPException, Request
from app.exceptions import ProviderAPICallError, ProviderConfigurationError
from app.models import (
DataImportAnalysisJobAck,
DataImportAnalysisJobRequest,
LLMRequest,
LLMResponse,
)
from app.services import LLMGateway
from app.services.import_analysis import process_import_analysis_job
def _configure_logging() -> None:
level_name = os.getenv("LOG_LEVEL", "INFO").upper()
level = getattr(logging, level_name, logging.INFO)
log_format = os.getenv(
"LOG_FORMAT",
"%(asctime)s %(levelname)s %(name)s:%(lineno)d %(message)s",
)
root = logging.getLogger()
if not root.handlers:
logging.basicConfig(level=level, format=log_format)
else:
root.setLevel(level)
formatter = logging.Formatter(log_format)
for handler in root.handlers:
handler.setLevel(level)
handler.setFormatter(formatter)
_configure_logging()
logger = logging.getLogger(__name__)
def _env_bool(name: str, default: bool) -> bool:
raw = os.getenv(name)
if raw is None:
return default
return raw.strip().lower() in {"1", "true", "yes", "on"}
def _env_float(name: str, default: float) -> float:
raw = os.getenv(name)
if raw is None:
return default
try:
return float(raw)
except ValueError:
logger.warning("Invalid value for %s=%r, using default %.2f", name, raw, default)
return default
def _parse_proxy_config(raw: str | None) -> dict[str, str] | str | None:
if raw is None:
return None
cleaned = raw.strip()
if not cleaned:
return None
# Support comma-separated key=value pairs for scheme-specific proxies.
if "=" in cleaned:
proxies: dict[str, str] = {}
for part in cleaned.split(","):
key, sep, value = part.partition("=")
if not sep:
continue
key = key.strip()
value = value.strip()
if key and value:
proxies[key] = value
if proxies:
return proxies
return cleaned
def _create_http_client() -> httpx.AsyncClient:
timeout_seconds = _env_float("HTTP_CLIENT_TIMEOUT", 30.0)
trust_env = _env_bool("HTTP_CLIENT_TRUST_ENV", True)
proxies = _parse_proxy_config(os.getenv("HTTP_CLIENT_PROXY"))
client_kwargs: dict[str, object] = {
"timeout": httpx.Timeout(timeout_seconds),
"trust_env": trust_env,
}
if proxies:
client_kwargs["proxies"] = proxies
return httpx.AsyncClient(**client_kwargs)
@asynccontextmanager
async def lifespan(app: FastAPI):
client = _create_http_client()
gateway = LLMGateway()
try:
app.state.http_client = client # type: ignore[attr-defined]
app.state.gateway = gateway # type: ignore[attr-defined]
yield
finally:
await client.aclose()
def create_app() -> FastAPI:
application = FastAPI(
title="Unified LLM Gateway",
version="0.1.0",
lifespan=lifespan,
)
@application.post(
"/v1/chat/completions",
response_model=LLMResponse,
summary="Dispatch chat completion to upstream provider",
)
async def create_chat_completion(
payload: LLMRequest,
gateway: LLMGateway = Depends(get_gateway),
client: httpx.AsyncClient = Depends(get_http_client),
) -> LLMResponse:
try:
return await gateway.chat(payload, client)
except ProviderConfigurationError as exc:
logger.error("Provider configuration error: %s", exc, exc_info=True)
raise HTTPException(status_code=422, detail=str(exc)) from exc
except ProviderAPICallError as exc:
status_code = exc.status_code or 502
log_detail = exc.response_text or str(exc)
logger.error(
"Provider API call error (status %s): %s",
status_code,
log_detail,
exc_info=True,
)
raise HTTPException(status_code=status_code, detail=str(exc)) from exc
@application.post(
"/v1/import/analyze",
response_model=DataImportAnalysisJobAck,
summary="Schedule async import analysis and notify via callback",
status_code=202,
)
async def analyze_import_data(
payload: DataImportAnalysisJobRequest,
client: httpx.AsyncClient = Depends(get_http_client),
) -> DataImportAnalysisJobAck:
request_copy = payload.model_copy(deep=True)
async def _runner() -> None:
await process_import_analysis_job(request_copy, client)
asyncio.create_task(_runner())
return DataImportAnalysisJobAck(import_record_id=payload.import_record_id, status="accepted")
@application.post("/__mock__/import-callback")
async def mock_import_callback(payload: dict[str, Any]) -> dict[str, str]:
logger.info("Received import analysis callback: %s", payload)
return {"status": "received"}
return application
async def get_gateway(request: Request) -> LLMGateway:
return request.app.state.gateway # type: ignore[return-value, attr-defined]
async def get_http_client(request: Request) -> httpx.AsyncClient:
return request.app.state.http_client # type: ignore[return-value, attr-defined]
app = create_app()