295 lines
9.7 KiB
Python
295 lines
9.7 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import logging.config
|
|
import os
|
|
from contextlib import asynccontextmanager
|
|
from typing import Any
|
|
|
|
import yaml
|
|
|
|
import httpx
|
|
from fastapi import Depends, FastAPI, HTTPException, Request
|
|
from fastapi.exceptions import RequestValidationError
|
|
from fastapi.responses import JSONResponse
|
|
|
|
from app.exceptions import ProviderAPICallError, ProviderConfigurationError
|
|
from app.models import (
|
|
ActionStatus,
|
|
ActionType,
|
|
DataImportAnalysisJobAck,
|
|
DataImportAnalysisJobRequest,
|
|
LLMRequest,
|
|
LLMResponse,
|
|
TableProfilingJobAck,
|
|
TableProfilingJobRequest,
|
|
TableSnippetUpsertRequest,
|
|
TableSnippetUpsertResponse,
|
|
)
|
|
from app.routers import chat_router, metrics_router
|
|
from app.services import LLMGateway
|
|
from app.services.import_analysis import process_import_analysis_job
|
|
from app.services.table_profiling import process_table_profiling_job
|
|
from app.services.table_snippet import ingest_snippet_rag_from_db, upsert_action_result
|
|
|
|
|
|
def _ensure_log_directories(config: dict[str, Any]) -> None:
|
|
handlers = config.get("handlers", {})
|
|
for handler_config in handlers.values():
|
|
filename = handler_config.get("filename")
|
|
if not filename:
|
|
continue
|
|
directory = os.path.dirname(filename)
|
|
if directory and not os.path.exists(directory):
|
|
os.makedirs(directory, exist_ok=True)
|
|
|
|
|
|
def _configure_logging() -> None:
|
|
config_path = os.getenv("LOGGING_CONFIG", "logging.yaml")
|
|
if os.path.exists(config_path):
|
|
with open(config_path, "r", encoding="utf-8") as fh:
|
|
config = yaml.safe_load(fh)
|
|
if isinstance(config, dict):
|
|
_ensure_log_directories(config)
|
|
logging.config.dictConfig(config)
|
|
return
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s %(levelname)s %(name)s:%(lineno)d %(message)s",
|
|
)
|
|
|
|
|
|
_configure_logging()
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _env_bool(name: str, default: bool) -> bool:
|
|
raw = os.getenv(name)
|
|
if raw is None:
|
|
return default
|
|
return raw.strip().lower() in {"1", "true", "yes", "on"}
|
|
|
|
|
|
def _env_float(name: str, default: float) -> float:
|
|
raw = os.getenv(name)
|
|
if raw is None:
|
|
return default
|
|
try:
|
|
return float(raw)
|
|
except ValueError:
|
|
logger.warning("Invalid value for %s=%r, using default %.2f", name, raw, default)
|
|
return default
|
|
|
|
|
|
def _parse_proxy_config(raw: str | None) -> dict[str, str] | str | None:
|
|
if raw is None:
|
|
return None
|
|
|
|
cleaned = raw.strip()
|
|
if not cleaned:
|
|
return None
|
|
|
|
# Support comma-separated key=value pairs for scheme-specific proxies.
|
|
if "=" in cleaned:
|
|
proxies: dict[str, str] = {}
|
|
for part in cleaned.split(","):
|
|
key, sep, value = part.partition("=")
|
|
if not sep:
|
|
continue
|
|
key = key.strip()
|
|
value = value.strip()
|
|
if key and value:
|
|
proxies[key] = value
|
|
if proxies:
|
|
return proxies
|
|
|
|
return cleaned
|
|
|
|
|
|
def _create_http_client() -> httpx.AsyncClient:
|
|
timeout_seconds = _env_float("HTTP_CLIENT_TIMEOUT", 30.0)
|
|
trust_env = _env_bool("HTTP_CLIENT_TRUST_ENV", True)
|
|
proxies = _parse_proxy_config(os.getenv("HTTP_CLIENT_PROXY"))
|
|
client_kwargs: dict[str, object] = {
|
|
"timeout": httpx.Timeout(timeout_seconds),
|
|
"trust_env": trust_env,
|
|
}
|
|
if proxies:
|
|
client_kwargs["proxies"] = proxies
|
|
return httpx.AsyncClient(**client_kwargs)
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
client = _create_http_client()
|
|
gateway = LLMGateway()
|
|
try:
|
|
app.state.http_client = client # type: ignore[attr-defined]
|
|
app.state.gateway = gateway # type: ignore[attr-defined]
|
|
yield
|
|
finally:
|
|
await client.aclose()
|
|
|
|
|
|
def create_app() -> FastAPI:
|
|
application = FastAPI(
|
|
title="Unified LLM Gateway",
|
|
version="0.1.0",
|
|
lifespan=lifespan,
|
|
)
|
|
# Chat/metric management APIs
|
|
application.include_router(chat_router)
|
|
application.include_router(metrics_router)
|
|
|
|
@application.exception_handler(RequestValidationError)
|
|
async def request_validation_exception_handler(
|
|
request: Request, exc: RequestValidationError
|
|
) -> JSONResponse:
|
|
try:
|
|
raw_body = await request.body()
|
|
except Exception: # pragma: no cover - defensive
|
|
raw_body = b"<unavailable>"
|
|
truncated_body = raw_body[:4096]
|
|
logger.warning(
|
|
"Validation error on %s %s: %s | body preview=%s",
|
|
request.method,
|
|
request.url.path,
|
|
exc.errors(),
|
|
truncated_body.decode("utf-8", errors="ignore"),
|
|
)
|
|
return JSONResponse(status_code=422, content={"detail": exc.errors()})
|
|
|
|
@application.post(
|
|
"/v1/chat/completions",
|
|
response_model=LLMResponse,
|
|
summary="Dispatch chat completion to upstream provider",
|
|
)
|
|
async def create_chat_completion(
|
|
payload: LLMRequest,
|
|
gateway: LLMGateway = Depends(get_gateway),
|
|
client: httpx.AsyncClient = Depends(get_http_client),
|
|
) -> LLMResponse:
|
|
try:
|
|
return await gateway.chat(payload, client)
|
|
except ProviderConfigurationError as exc:
|
|
logger.error("Provider configuration error: %s", exc, exc_info=True)
|
|
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
|
except ProviderAPICallError as exc:
|
|
status_code = exc.status_code or 502
|
|
log_detail = exc.response_text or str(exc)
|
|
logger.error(
|
|
"Provider API call error (status %s): %s",
|
|
status_code,
|
|
log_detail,
|
|
exc_info=True,
|
|
)
|
|
raise HTTPException(status_code=status_code, detail=str(exc)) from exc
|
|
|
|
@application.post(
|
|
"/v1/import/analyze",
|
|
response_model=DataImportAnalysisJobAck,
|
|
summary="Schedule async import analysis and notify via callback",
|
|
status_code=202,
|
|
)
|
|
async def analyze_import_data(
|
|
payload: DataImportAnalysisJobRequest,
|
|
client: httpx.AsyncClient = Depends(get_http_client),
|
|
) -> DataImportAnalysisJobAck:
|
|
request_copy = payload.model_copy(deep=True)
|
|
|
|
async def _runner() -> None:
|
|
await process_import_analysis_job(request_copy, client)
|
|
|
|
asyncio.create_task(_runner())
|
|
|
|
return DataImportAnalysisJobAck(import_record_id=payload.import_record_id, status="accepted")
|
|
|
|
@application.post(
|
|
"/v1/table/profiling",
|
|
response_model=TableProfilingJobAck,
|
|
summary="Run end-to-end GE profiling pipeline and notify via callback per action",
|
|
status_code=202,
|
|
)
|
|
async def run_table_profiling(
|
|
payload: TableProfilingJobRequest,
|
|
gateway: LLMGateway = Depends(get_gateway),
|
|
client: httpx.AsyncClient = Depends(get_http_client),
|
|
) -> TableProfilingJobAck:
|
|
request_copy = payload.model_copy(deep=True)
|
|
|
|
async def _runner() -> None:
|
|
await process_table_profiling_job(request_copy, gateway, client)
|
|
|
|
asyncio.create_task(_runner())
|
|
|
|
return TableProfilingJobAck(
|
|
table_id=payload.table_id,
|
|
version_ts=payload.version_ts,
|
|
status="accepted",
|
|
)
|
|
|
|
@application.post(
|
|
"/v1/table/snippet",
|
|
response_model=TableSnippetUpsertResponse,
|
|
summary="Persist or update action results, such as table snippets.",
|
|
)
|
|
async def upsert_table_snippet(
|
|
payload: TableSnippetUpsertRequest,
|
|
client: httpx.AsyncClient = Depends(get_http_client),
|
|
) -> TableSnippetUpsertResponse:
|
|
request_copy = payload.model_copy(deep=True)
|
|
|
|
try:
|
|
response = await asyncio.to_thread(upsert_action_result, request_copy)
|
|
except Exception as exc:
|
|
logger.error(
|
|
"Failed to upsert table snippet: table_id=%s version_ts=%s action_type=%s",
|
|
payload.table_id,
|
|
payload.version_ts,
|
|
payload.action_type,
|
|
exc_info=True,
|
|
)
|
|
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
|
else:
|
|
if (
|
|
payload.action_type == ActionType.SNIPPET_ALIAS
|
|
and payload.status == ActionStatus.SUCCESS
|
|
and payload.rag_workspace_id is not None
|
|
):
|
|
try:
|
|
await ingest_snippet_rag_from_db(
|
|
table_id=payload.table_id,
|
|
version_ts=payload.version_ts,
|
|
workspace_id=payload.rag_workspace_id,
|
|
rag_item_type=payload.rag_item_type or "SNIPPET",
|
|
client=client,
|
|
)
|
|
except Exception:
|
|
logger.exception(
|
|
"Failed to ingest snippet RAG artifacts",
|
|
extra={
|
|
"table_id": payload.table_id,
|
|
"version_ts": payload.version_ts,
|
|
},
|
|
)
|
|
return response
|
|
|
|
@application.post("/__mock__/import-callback")
|
|
async def mock_import_callback(payload: dict[str, Any]) -> dict[str, str]:
|
|
logger.info("Received import analysis callback: %s", payload)
|
|
return {"status": "received"}
|
|
|
|
return application
|
|
|
|
|
|
async def get_gateway(request: Request) -> LLMGateway:
|
|
return request.app.state.gateway # type: ignore[return-value, attr-defined]
|
|
|
|
|
|
async def get_http_client(request: Request) -> httpx.AsyncClient:
|
|
return request.app.state.http_client # type: ignore[return-value, attr-defined]
|
|
|
|
|
|
app = create_app()
|