from __future__ import annotations import asyncio import logging import logging.config import os from contextlib import asynccontextmanager from typing import Any import yaml import httpx from fastapi import Depends, FastAPI, HTTPException, Request from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse from app.exceptions import ProviderAPICallError, ProviderConfigurationError from app.models import ( DataImportAnalysisJobAck, DataImportAnalysisJobRequest, LLMRequest, LLMResponse, TableProfilingJobAck, TableProfilingJobRequest, TableSnippetUpsertRequest, TableSnippetUpsertResponse, ) from app.services import LLMGateway from app.services.import_analysis import process_import_analysis_job from app.services.table_profiling import process_table_profiling_job from app.services.table_snippet import upsert_action_result def _ensure_log_directories(config: dict[str, Any]) -> None: handlers = config.get("handlers", {}) for handler_config in handlers.values(): filename = handler_config.get("filename") if not filename: continue directory = os.path.dirname(filename) if directory and not os.path.exists(directory): os.makedirs(directory, exist_ok=True) def _configure_logging() -> None: config_path = os.getenv("LOGGING_CONFIG", "logging.yaml") if os.path.exists(config_path): with open(config_path, "r", encoding="utf-8") as fh: config = yaml.safe_load(fh) if isinstance(config, dict): _ensure_log_directories(config) logging.config.dictConfig(config) return logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s:%(lineno)d %(message)s", ) _configure_logging() logger = logging.getLogger(__name__) def _env_bool(name: str, default: bool) -> bool: raw = os.getenv(name) if raw is None: return default return raw.strip().lower() in {"1", "true", "yes", "on"} def _env_float(name: str, default: float) -> float: raw = os.getenv(name) if raw is None: return default try: return float(raw) except ValueError: logger.warning("Invalid value for %s=%r, using default %.2f", name, raw, default) return default def _parse_proxy_config(raw: str | None) -> dict[str, str] | str | None: if raw is None: return None cleaned = raw.strip() if not cleaned: return None # Support comma-separated key=value pairs for scheme-specific proxies. if "=" in cleaned: proxies: dict[str, str] = {} for part in cleaned.split(","): key, sep, value = part.partition("=") if not sep: continue key = key.strip() value = value.strip() if key and value: proxies[key] = value if proxies: return proxies return cleaned def _create_http_client() -> httpx.AsyncClient: timeout_seconds = _env_float("HTTP_CLIENT_TIMEOUT", 30.0) trust_env = _env_bool("HTTP_CLIENT_TRUST_ENV", True) proxies = _parse_proxy_config(os.getenv("HTTP_CLIENT_PROXY")) client_kwargs: dict[str, object] = { "timeout": httpx.Timeout(timeout_seconds), "trust_env": trust_env, } if proxies: client_kwargs["proxies"] = proxies return httpx.AsyncClient(**client_kwargs) @asynccontextmanager async def lifespan(app: FastAPI): client = _create_http_client() gateway = LLMGateway() try: app.state.http_client = client # type: ignore[attr-defined] app.state.gateway = gateway # type: ignore[attr-defined] yield finally: await client.aclose() def create_app() -> FastAPI: application = FastAPI( title="Unified LLM Gateway", version="0.1.0", lifespan=lifespan, ) @application.exception_handler(RequestValidationError) async def request_validation_exception_handler( request: Request, exc: RequestValidationError ) -> JSONResponse: try: raw_body = await request.body() except Exception: # pragma: no cover - defensive raw_body = b"" truncated_body = raw_body[:4096] logger.warning( "Validation error on %s %s: %s | body preview=%s", request.method, request.url.path, exc.errors(), truncated_body.decode("utf-8", errors="ignore"), ) return JSONResponse(status_code=422, content={"detail": exc.errors()}) @application.post( "/v1/chat/completions", response_model=LLMResponse, summary="Dispatch chat completion to upstream provider", ) async def create_chat_completion( payload: LLMRequest, gateway: LLMGateway = Depends(get_gateway), client: httpx.AsyncClient = Depends(get_http_client), ) -> LLMResponse: try: return await gateway.chat(payload, client) except ProviderConfigurationError as exc: logger.error("Provider configuration error: %s", exc, exc_info=True) raise HTTPException(status_code=422, detail=str(exc)) from exc except ProviderAPICallError as exc: status_code = exc.status_code or 502 log_detail = exc.response_text or str(exc) logger.error( "Provider API call error (status %s): %s", status_code, log_detail, exc_info=True, ) raise HTTPException(status_code=status_code, detail=str(exc)) from exc @application.post( "/v1/import/analyze", response_model=DataImportAnalysisJobAck, summary="Schedule async import analysis and notify via callback", status_code=202, ) async def analyze_import_data( payload: DataImportAnalysisJobRequest, client: httpx.AsyncClient = Depends(get_http_client), ) -> DataImportAnalysisJobAck: request_copy = payload.model_copy(deep=True) async def _runner() -> None: await process_import_analysis_job(request_copy, client) asyncio.create_task(_runner()) return DataImportAnalysisJobAck(import_record_id=payload.import_record_id, status="accepted") @application.post( "/v1/table/profiling", response_model=TableProfilingJobAck, summary="Run end-to-end GE profiling pipeline and notify via callback per action", status_code=202, ) async def run_table_profiling( payload: TableProfilingJobRequest, gateway: LLMGateway = Depends(get_gateway), client: httpx.AsyncClient = Depends(get_http_client), ) -> TableProfilingJobAck: request_copy = payload.model_copy(deep=True) async def _runner() -> None: await process_table_profiling_job(request_copy, gateway, client) asyncio.create_task(_runner()) return TableProfilingJobAck( table_id=payload.table_id, version_ts=payload.version_ts, status="accepted", ) @application.post( "/v1/table/snippet", response_model=TableSnippetUpsertResponse, summary="Persist or update action results, such as table snippets.", ) async def upsert_table_snippet( payload: TableSnippetUpsertRequest, ) -> TableSnippetUpsertResponse: request_copy = payload.model_copy(deep=True) try: return await asyncio.to_thread(upsert_action_result, request_copy) except Exception as exc: logger.error( "Failed to upsert table snippet: table_id=%s version_ts=%s action_type=%s", payload.table_id, payload.version_ts, payload.action_type, exc_info=True, ) raise HTTPException(status_code=500, detail=str(exc)) from exc @application.post("/__mock__/import-callback") async def mock_import_callback(payload: dict[str, Any]) -> dict[str, str]: logger.info("Received import analysis callback: %s", payload) return {"status": "received"} return application async def get_gateway(request: Request) -> LLMGateway: return request.app.state.gateway # type: ignore[return-value, attr-defined] async def get_http_client(request: Request) -> httpx.AsyncClient: return request.app.state.http_client # type: ignore[return-value, attr-defined] app = create_app()