from __future__ import annotations import asyncio import logging import os from contextlib import asynccontextmanager from typing import Any import httpx from fastapi import Depends, FastAPI, HTTPException, Request from app.exceptions import ProviderAPICallError, ProviderConfigurationError from app.models import ( DataImportAnalysisJobAck, DataImportAnalysisJobRequest, LLMRequest, LLMResponse, ) from app.services import LLMGateway from app.services.import_analysis import process_import_analysis_job def _configure_logging() -> None: level_name = os.getenv("LOG_LEVEL", "INFO").upper() level = getattr(logging, level_name, logging.INFO) log_format = os.getenv( "LOG_FORMAT", "%(asctime)s %(levelname)s %(name)s:%(lineno)d %(message)s", ) root = logging.getLogger() if not root.handlers: logging.basicConfig(level=level, format=log_format) else: root.setLevel(level) formatter = logging.Formatter(log_format) for handler in root.handlers: handler.setLevel(level) handler.setFormatter(formatter) _configure_logging() logger = logging.getLogger(__name__) def _env_bool(name: str, default: bool) -> bool: raw = os.getenv(name) if raw is None: return default return raw.strip().lower() in {"1", "true", "yes", "on"} def _env_float(name: str, default: float) -> float: raw = os.getenv(name) if raw is None: return default try: return float(raw) except ValueError: logger.warning("Invalid value for %s=%r, using default %.2f", name, raw, default) return default def _parse_proxy_config(raw: str | None) -> dict[str, str] | str | None: if raw is None: return None cleaned = raw.strip() if not cleaned: return None # Support comma-separated key=value pairs for scheme-specific proxies. if "=" in cleaned: proxies: dict[str, str] = {} for part in cleaned.split(","): key, sep, value = part.partition("=") if not sep: continue key = key.strip() value = value.strip() if key and value: proxies[key] = value if proxies: return proxies return cleaned def _create_http_client() -> httpx.AsyncClient: timeout_seconds = _env_float("HTTP_CLIENT_TIMEOUT", 30.0) trust_env = _env_bool("HTTP_CLIENT_TRUST_ENV", True) proxies = _parse_proxy_config(os.getenv("HTTP_CLIENT_PROXY")) client_kwargs: dict[str, object] = { "timeout": httpx.Timeout(timeout_seconds), "trust_env": trust_env, } if proxies: client_kwargs["proxies"] = proxies return httpx.AsyncClient(**client_kwargs) @asynccontextmanager async def lifespan(app: FastAPI): client = _create_http_client() gateway = LLMGateway() try: app.state.http_client = client # type: ignore[attr-defined] app.state.gateway = gateway # type: ignore[attr-defined] yield finally: await client.aclose() def create_app() -> FastAPI: application = FastAPI( title="Unified LLM Gateway", version="0.1.0", lifespan=lifespan, ) @application.post( "/v1/chat/completions", response_model=LLMResponse, summary="Dispatch chat completion to upstream provider", ) async def create_chat_completion( payload: LLMRequest, gateway: LLMGateway = Depends(get_gateway), client: httpx.AsyncClient = Depends(get_http_client), ) -> LLMResponse: try: return await gateway.chat(payload, client) except ProviderConfigurationError as exc: logger.error("Provider configuration error: %s", exc, exc_info=True) raise HTTPException(status_code=422, detail=str(exc)) from exc except ProviderAPICallError as exc: status_code = exc.status_code or 502 log_detail = exc.response_text or str(exc) logger.error( "Provider API call error (status %s): %s", status_code, log_detail, exc_info=True, ) raise HTTPException(status_code=status_code, detail=str(exc)) from exc @application.post( "/v1/import/analyze", response_model=DataImportAnalysisJobAck, summary="Schedule async import analysis and notify via callback", status_code=202, ) async def analyze_import_data( payload: DataImportAnalysisJobRequest, client: httpx.AsyncClient = Depends(get_http_client), ) -> DataImportAnalysisJobAck: request_copy = payload.model_copy(deep=True) async def _runner() -> None: await process_import_analysis_job(request_copy, client) asyncio.create_task(_runner()) return DataImportAnalysisJobAck(import_record_id=payload.import_record_id, status="accepted") @application.post("/__mock__/import-callback") async def mock_import_callback(payload: dict[str, Any]) -> dict[str, str]: logger.info("Received import analysis callback: %s", payload) return {"status": "received"} return application async def get_gateway(request: Request) -> LLMGateway: return request.app.state.gateway # type: ignore[return-value, attr-defined] async def get_http_client(request: Request) -> httpx.AsyncClient: return request.app.state.http_client # type: ignore[return-value, attr-defined] app = create_app()