Files
data-ge/app/main.py
2025-12-09 00:15:22 +08:00

295 lines
9.7 KiB
Python

from __future__ import annotations
import asyncio
import logging
import logging.config
import os
from contextlib import asynccontextmanager
from typing import Any
import yaml
import httpx
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from app.exceptions import ProviderAPICallError, ProviderConfigurationError
from app.models import (
ActionStatus,
ActionType,
DataImportAnalysisJobAck,
DataImportAnalysisJobRequest,
LLMRequest,
LLMResponse,
TableProfilingJobAck,
TableProfilingJobRequest,
TableSnippetUpsertRequest,
TableSnippetUpsertResponse,
)
from app.routers import chat_router, metrics_router
from app.services import LLMGateway
from app.services.import_analysis import process_import_analysis_job
from app.services.table_profiling import process_table_profiling_job
from app.services.table_snippet import ingest_snippet_rag_from_db, upsert_action_result
def _ensure_log_directories(config: dict[str, Any]) -> None:
handlers = config.get("handlers", {})
for handler_config in handlers.values():
filename = handler_config.get("filename")
if not filename:
continue
directory = os.path.dirname(filename)
if directory and not os.path.exists(directory):
os.makedirs(directory, exist_ok=True)
def _configure_logging() -> None:
config_path = os.getenv("LOGGING_CONFIG", "logging.yaml")
if os.path.exists(config_path):
with open(config_path, "r", encoding="utf-8") as fh:
config = yaml.safe_load(fh)
if isinstance(config, dict):
_ensure_log_directories(config)
logging.config.dictConfig(config)
return
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s:%(lineno)d %(message)s",
)
_configure_logging()
logger = logging.getLogger(__name__)
def _env_bool(name: str, default: bool) -> bool:
raw = os.getenv(name)
if raw is None:
return default
return raw.strip().lower() in {"1", "true", "yes", "on"}
def _env_float(name: str, default: float) -> float:
raw = os.getenv(name)
if raw is None:
return default
try:
return float(raw)
except ValueError:
logger.warning("Invalid value for %s=%r, using default %.2f", name, raw, default)
return default
def _parse_proxy_config(raw: str | None) -> dict[str, str] | str | None:
if raw is None:
return None
cleaned = raw.strip()
if not cleaned:
return None
# Support comma-separated key=value pairs for scheme-specific proxies.
if "=" in cleaned:
proxies: dict[str, str] = {}
for part in cleaned.split(","):
key, sep, value = part.partition("=")
if not sep:
continue
key = key.strip()
value = value.strip()
if key and value:
proxies[key] = value
if proxies:
return proxies
return cleaned
def _create_http_client() -> httpx.AsyncClient:
timeout_seconds = _env_float("HTTP_CLIENT_TIMEOUT", 30.0)
trust_env = _env_bool("HTTP_CLIENT_TRUST_ENV", True)
proxies = _parse_proxy_config(os.getenv("HTTP_CLIENT_PROXY"))
client_kwargs: dict[str, object] = {
"timeout": httpx.Timeout(timeout_seconds),
"trust_env": trust_env,
}
if proxies:
client_kwargs["proxies"] = proxies
return httpx.AsyncClient(**client_kwargs)
@asynccontextmanager
async def lifespan(app: FastAPI):
client = _create_http_client()
gateway = LLMGateway()
try:
app.state.http_client = client # type: ignore[attr-defined]
app.state.gateway = gateway # type: ignore[attr-defined]
yield
finally:
await client.aclose()
def create_app() -> FastAPI:
application = FastAPI(
title="Unified LLM Gateway",
version="0.1.0",
lifespan=lifespan,
)
# Chat/metric management APIs
application.include_router(chat_router)
application.include_router(metrics_router)
@application.exception_handler(RequestValidationError)
async def request_validation_exception_handler(
request: Request, exc: RequestValidationError
) -> JSONResponse:
try:
raw_body = await request.body()
except Exception: # pragma: no cover - defensive
raw_body = b"<unavailable>"
truncated_body = raw_body[:4096]
logger.warning(
"Validation error on %s %s: %s | body preview=%s",
request.method,
request.url.path,
exc.errors(),
truncated_body.decode("utf-8", errors="ignore"),
)
return JSONResponse(status_code=422, content={"detail": exc.errors()})
@application.post(
"/v1/chat/completions",
response_model=LLMResponse,
summary="Dispatch chat completion to upstream provider",
)
async def create_chat_completion(
payload: LLMRequest,
gateway: LLMGateway = Depends(get_gateway),
client: httpx.AsyncClient = Depends(get_http_client),
) -> LLMResponse:
try:
return await gateway.chat(payload, client)
except ProviderConfigurationError as exc:
logger.error("Provider configuration error: %s", exc, exc_info=True)
raise HTTPException(status_code=422, detail=str(exc)) from exc
except ProviderAPICallError as exc:
status_code = exc.status_code or 502
log_detail = exc.response_text or str(exc)
logger.error(
"Provider API call error (status %s): %s",
status_code,
log_detail,
exc_info=True,
)
raise HTTPException(status_code=status_code, detail=str(exc)) from exc
@application.post(
"/v1/import/analyze",
response_model=DataImportAnalysisJobAck,
summary="Schedule async import analysis and notify via callback",
status_code=202,
)
async def analyze_import_data(
payload: DataImportAnalysisJobRequest,
client: httpx.AsyncClient = Depends(get_http_client),
) -> DataImportAnalysisJobAck:
request_copy = payload.model_copy(deep=True)
async def _runner() -> None:
await process_import_analysis_job(request_copy, client)
asyncio.create_task(_runner())
return DataImportAnalysisJobAck(import_record_id=payload.import_record_id, status="accepted")
@application.post(
"/v1/table/profiling",
response_model=TableProfilingJobAck,
summary="Run end-to-end GE profiling pipeline and notify via callback per action",
status_code=202,
)
async def run_table_profiling(
payload: TableProfilingJobRequest,
gateway: LLMGateway = Depends(get_gateway),
client: httpx.AsyncClient = Depends(get_http_client),
) -> TableProfilingJobAck:
request_copy = payload.model_copy(deep=True)
async def _runner() -> None:
await process_table_profiling_job(request_copy, gateway, client)
asyncio.create_task(_runner())
return TableProfilingJobAck(
table_id=payload.table_id,
version_ts=payload.version_ts,
status="accepted",
)
@application.post(
"/v1/table/snippet",
response_model=TableSnippetUpsertResponse,
summary="Persist or update action results, such as table snippets.",
)
async def upsert_table_snippet(
payload: TableSnippetUpsertRequest,
client: httpx.AsyncClient = Depends(get_http_client),
) -> TableSnippetUpsertResponse:
request_copy = payload.model_copy(deep=True)
try:
response = await asyncio.to_thread(upsert_action_result, request_copy)
except Exception as exc:
logger.error(
"Failed to upsert table snippet: table_id=%s version_ts=%s action_type=%s",
payload.table_id,
payload.version_ts,
payload.action_type,
exc_info=True,
)
raise HTTPException(status_code=500, detail=str(exc)) from exc
else:
if (
payload.action_type == ActionType.SNIPPET_ALIAS
and payload.status == ActionStatus.SUCCESS
and payload.rag_workspace_id is not None
):
try:
await ingest_snippet_rag_from_db(
table_id=payload.table_id,
version_ts=payload.version_ts,
workspace_id=payload.rag_workspace_id,
rag_item_type=payload.rag_item_type or "SNIPPET",
client=client,
)
except Exception:
logger.exception(
"Failed to ingest snippet RAG artifacts",
extra={
"table_id": payload.table_id,
"version_ts": payload.version_ts,
},
)
return response
@application.post("/__mock__/import-callback")
async def mock_import_callback(payload: dict[str, Any]) -> dict[str, str]:
logger.info("Received import analysis callback: %s", payload)
return {"status": "received"}
return application
async def get_gateway(request: Request) -> LLMGateway:
return request.app.state.gateway # type: ignore[return-value, attr-defined]
async def get_http_client(request: Request) -> httpx.AsyncClient:
return request.app.state.http_client # type: ignore[return-value, attr-defined]
app = create_app()