table profiling功能开发
This commit is contained in:
113
app/main.py
113
app/main.py
@ -2,12 +2,17 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import logging.config
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
import httpx
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.exceptions import ProviderAPICallError, ProviderConfigurationError
|
||||
from app.models import (
|
||||
@ -15,30 +20,42 @@ from app.models import (
|
||||
DataImportAnalysisJobRequest,
|
||||
LLMRequest,
|
||||
LLMResponse,
|
||||
TableProfilingJobAck,
|
||||
TableProfilingJobRequest,
|
||||
TableSnippetUpsertRequest,
|
||||
TableSnippetUpsertResponse,
|
||||
)
|
||||
from app.services import LLMGateway
|
||||
from app.services.import_analysis import process_import_analysis_job
|
||||
from app.services.table_profiling import process_table_profiling_job
|
||||
from app.services.table_snippet import upsert_action_result
|
||||
|
||||
|
||||
def _ensure_log_directories(config: dict[str, Any]) -> None:
|
||||
handlers = config.get("handlers", {})
|
||||
for handler_config in handlers.values():
|
||||
filename = handler_config.get("filename")
|
||||
if not filename:
|
||||
continue
|
||||
directory = os.path.dirname(filename)
|
||||
if directory and not os.path.exists(directory):
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
|
||||
def _configure_logging() -> None:
|
||||
level_name = os.getenv("LOG_LEVEL", "INFO").upper()
|
||||
level = getattr(logging, level_name, logging.INFO)
|
||||
log_format = os.getenv(
|
||||
"LOG_FORMAT",
|
||||
"%(asctime)s %(levelname)s %(name)s:%(lineno)d %(message)s",
|
||||
config_path = os.getenv("LOGGING_CONFIG", "logging.yaml")
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, "r", encoding="utf-8") as fh:
|
||||
config = yaml.safe_load(fh)
|
||||
if isinstance(config, dict):
|
||||
_ensure_log_directories(config)
|
||||
logging.config.dictConfig(config)
|
||||
return
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s:%(lineno)d %(message)s",
|
||||
)
|
||||
|
||||
root = logging.getLogger()
|
||||
|
||||
if not root.handlers:
|
||||
logging.basicConfig(level=level, format=log_format)
|
||||
else:
|
||||
root.setLevel(level)
|
||||
formatter = logging.Formatter(log_format)
|
||||
for handler in root.handlers:
|
||||
handler.setLevel(level)
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
|
||||
_configure_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -119,6 +136,24 @@ def create_app() -> FastAPI:
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
@application.exception_handler(RequestValidationError)
|
||||
async def request_validation_exception_handler(
|
||||
request: Request, exc: RequestValidationError
|
||||
) -> JSONResponse:
|
||||
try:
|
||||
raw_body = await request.body()
|
||||
except Exception: # pragma: no cover - defensive
|
||||
raw_body = b"<unavailable>"
|
||||
truncated_body = raw_body[:4096]
|
||||
logger.warning(
|
||||
"Validation error on %s %s: %s | body preview=%s",
|
||||
request.method,
|
||||
request.url.path,
|
||||
exc.errors(),
|
||||
truncated_body.decode("utf-8", errors="ignore"),
|
||||
)
|
||||
return JSONResponse(status_code=422, content={"detail": exc.errors()})
|
||||
|
||||
@application.post(
|
||||
"/v1/chat/completions",
|
||||
response_model=LLMResponse,
|
||||
@ -164,6 +199,52 @@ def create_app() -> FastAPI:
|
||||
|
||||
return DataImportAnalysisJobAck(import_record_id=payload.import_record_id, status="accepted")
|
||||
|
||||
@application.post(
|
||||
"/v1/table/profiling",
|
||||
response_model=TableProfilingJobAck,
|
||||
summary="Run end-to-end GE profiling pipeline and notify via callback per action",
|
||||
status_code=202,
|
||||
)
|
||||
async def run_table_profiling(
|
||||
payload: TableProfilingJobRequest,
|
||||
gateway: LLMGateway = Depends(get_gateway),
|
||||
client: httpx.AsyncClient = Depends(get_http_client),
|
||||
) -> TableProfilingJobAck:
|
||||
request_copy = payload.model_copy(deep=True)
|
||||
|
||||
async def _runner() -> None:
|
||||
await process_table_profiling_job(request_copy, gateway, client)
|
||||
|
||||
asyncio.create_task(_runner())
|
||||
|
||||
return TableProfilingJobAck(
|
||||
table_id=payload.table_id,
|
||||
version_ts=payload.version_ts,
|
||||
status="accepted",
|
||||
)
|
||||
|
||||
@application.post(
|
||||
"/v1/table/snippet",
|
||||
response_model=TableSnippetUpsertResponse,
|
||||
summary="Persist or update action results, such as table snippets.",
|
||||
)
|
||||
async def upsert_table_snippet(
|
||||
payload: TableSnippetUpsertRequest,
|
||||
) -> TableSnippetUpsertResponse:
|
||||
request_copy = payload.model_copy(deep=True)
|
||||
|
||||
try:
|
||||
return await asyncio.to_thread(upsert_action_result, request_copy)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to upsert table snippet: table_id=%s version_ts=%s action_type=%s",
|
||||
payload.table_id,
|
||||
payload.version_ts,
|
||||
payload.action_type,
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
@application.post("/__mock__/import-callback")
|
||||
async def mock_import_callback(payload: dict[str, Any]) -> dict[str, str]:
|
||||
logger.info("Received import analysis callback: %s", payload)
|
||||
|
||||
Reference in New Issue
Block a user