table profiling功能开发

This commit is contained in:
zhaoawd
2025-11-03 00:18:26 +08:00
parent 557efc4bf1
commit c2a08e4134
6 changed files with 1280 additions and 16 deletions

View File

@ -2,12 +2,17 @@ from __future__ import annotations
import asyncio
import logging
import logging.config
import os
from contextlib import asynccontextmanager
from typing import Any
import yaml
import httpx
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from app.exceptions import ProviderAPICallError, ProviderConfigurationError
from app.models import (
@ -15,30 +20,42 @@ from app.models import (
DataImportAnalysisJobRequest,
LLMRequest,
LLMResponse,
TableProfilingJobAck,
TableProfilingJobRequest,
TableSnippetUpsertRequest,
TableSnippetUpsertResponse,
)
from app.services import LLMGateway
from app.services.import_analysis import process_import_analysis_job
from app.services.table_profiling import process_table_profiling_job
from app.services.table_snippet import upsert_action_result
def _ensure_log_directories(config: dict[str, Any]) -> None:
handlers = config.get("handlers", {})
for handler_config in handlers.values():
filename = handler_config.get("filename")
if not filename:
continue
directory = os.path.dirname(filename)
if directory and not os.path.exists(directory):
os.makedirs(directory, exist_ok=True)
def _configure_logging() -> None:
level_name = os.getenv("LOG_LEVEL", "INFO").upper()
level = getattr(logging, level_name, logging.INFO)
log_format = os.getenv(
"LOG_FORMAT",
"%(asctime)s %(levelname)s %(name)s:%(lineno)d %(message)s",
config_path = os.getenv("LOGGING_CONFIG", "logging.yaml")
if os.path.exists(config_path):
with open(config_path, "r", encoding="utf-8") as fh:
config = yaml.safe_load(fh)
if isinstance(config, dict):
_ensure_log_directories(config)
logging.config.dictConfig(config)
return
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s:%(lineno)d %(message)s",
)
root = logging.getLogger()
if not root.handlers:
logging.basicConfig(level=level, format=log_format)
else:
root.setLevel(level)
formatter = logging.Formatter(log_format)
for handler in root.handlers:
handler.setLevel(level)
handler.setFormatter(formatter)
_configure_logging()
logger = logging.getLogger(__name__)
@ -119,6 +136,24 @@ def create_app() -> FastAPI:
lifespan=lifespan,
)
@application.exception_handler(RequestValidationError)
async def request_validation_exception_handler(
request: Request, exc: RequestValidationError
) -> JSONResponse:
try:
raw_body = await request.body()
except Exception: # pragma: no cover - defensive
raw_body = b"<unavailable>"
truncated_body = raw_body[:4096]
logger.warning(
"Validation error on %s %s: %s | body preview=%s",
request.method,
request.url.path,
exc.errors(),
truncated_body.decode("utf-8", errors="ignore"),
)
return JSONResponse(status_code=422, content={"detail": exc.errors()})
@application.post(
"/v1/chat/completions",
response_model=LLMResponse,
@ -164,6 +199,52 @@ def create_app() -> FastAPI:
return DataImportAnalysisJobAck(import_record_id=payload.import_record_id, status="accepted")
@application.post(
"/v1/table/profiling",
response_model=TableProfilingJobAck,
summary="Run end-to-end GE profiling pipeline and notify via callback per action",
status_code=202,
)
async def run_table_profiling(
payload: TableProfilingJobRequest,
gateway: LLMGateway = Depends(get_gateway),
client: httpx.AsyncClient = Depends(get_http_client),
) -> TableProfilingJobAck:
request_copy = payload.model_copy(deep=True)
async def _runner() -> None:
await process_table_profiling_job(request_copy, gateway, client)
asyncio.create_task(_runner())
return TableProfilingJobAck(
table_id=payload.table_id,
version_ts=payload.version_ts,
status="accepted",
)
@application.post(
"/v1/table/snippet",
response_model=TableSnippetUpsertResponse,
summary="Persist or update action results, such as table snippets.",
)
async def upsert_table_snippet(
payload: TableSnippetUpsertRequest,
) -> TableSnippetUpsertResponse:
request_copy = payload.model_copy(deep=True)
try:
return await asyncio.to_thread(upsert_action_result, request_copy)
except Exception as exc:
logger.error(
"Failed to upsert table snippet: table_id=%s version_ts=%s action_type=%s",
payload.table_id,
payload.version_ts,
payload.action_type,
exc_info=True,
)
raise HTTPException(status_code=500, detail=str(exc)) from exc
@application.post("/__mock__/import-callback")
async def mock_import_callback(payload: dict[str, Any]) -> dict[str, str]:
logger.info("Received import analysis callback: %s", payload)