Files
data-ge/app/main.py
2025-10-30 18:25:29 +08:00

100 lines
3.0 KiB
Python

from __future__ import annotations
import asyncio
import logging
from contextlib import asynccontextmanager
import httpx
from fastapi import Depends, FastAPI, HTTPException, Request
from app.exceptions import ProviderAPICallError, ProviderConfigurationError
from app.models import (
DataImportAnalysisJobAck,
DataImportAnalysisJobRequest,
LLMRequest,
LLMResponse,
)
from app.services import LLMGateway
from app.services.import_analysis import process_import_analysis_job
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
client = httpx.AsyncClient(timeout=httpx.Timeout(30.0))
gateway = LLMGateway()
try:
app.state.http_client = client # type: ignore[attr-defined]
app.state.gateway = gateway # type: ignore[attr-defined]
yield
finally:
await client.aclose()
def create_app() -> FastAPI:
application = FastAPI(
title="Unified LLM Gateway",
version="0.1.0",
lifespan=lifespan,
)
@application.post(
"/v1/chat/completions",
response_model=LLMResponse,
summary="Dispatch chat completion to upstream provider",
)
async def create_chat_completion(
payload: LLMRequest,
gateway: LLMGateway = Depends(get_gateway),
client: httpx.AsyncClient = Depends(get_http_client),
) -> LLMResponse:
try:
return await gateway.chat(payload, client)
except ProviderConfigurationError as exc:
logger.error("Provider configuration error: %s", exc, exc_info=True)
raise HTTPException(status_code=422, detail=str(exc)) from exc
except ProviderAPICallError as exc:
status_code = exc.status_code or 502
log_detail = exc.response_text or str(exc)
logger.error(
"Provider API call error (status %s): %s",
status_code,
log_detail,
exc_info=True,
)
raise HTTPException(status_code=status_code, detail=str(exc)) from exc
@application.post(
"/v1/import/analyze",
response_model=DataImportAnalysisJobAck,
summary="Schedule async import analysis and notify via callback",
status_code=202,
)
async def analyze_import_data(
payload: DataImportAnalysisJobRequest,
client: httpx.AsyncClient = Depends(get_http_client),
) -> DataImportAnalysisJobAck:
request_copy = payload.model_copy(deep=True)
async def _runner() -> None:
await process_import_analysis_job(request_copy, client)
asyncio.create_task(_runner())
return DataImportAnalysisJobAck(import_record_id=payload.import_record_id, status="accepted")
return application
async def get_gateway(request: Request) -> LLMGateway:
return request.app.state.gateway # type: ignore[return-value, attr-defined]
async def get_http_client(request: Request) -> httpx.AsyncClient:
return request.app.state.http_client # type: ignore[return-value, attr-defined]
app = create_app()