table profiling功能开发

This commit is contained in:
zhaoawd
2025-11-03 00:18:26 +08:00
parent 557efc4bf1
commit c2a08e4134
6 changed files with 1280 additions and 16 deletions

26
app/db.py Normal file
View File

@ -0,0 +1,26 @@
from __future__ import annotations
import os
from functools import lru_cache
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
@lru_cache(maxsize=1)
def get_engine() -> Engine:
"""Return a cached SQLAlchemy engine configured from DATABASE_URL."""
database_url = os.getenv(
"DATABASE_URL",
"mysql+pymysql://root:12345678@localhost:3306/data-ge?charset=utf8mb4",
)
connect_args = {}
if database_url.startswith("sqlite"):
connect_args["check_same_thread"] = False
return create_engine(
database_url,
pool_pre_ping=True,
future=True,
connect_args=connect_args,
)

View File

@ -2,12 +2,17 @@ from __future__ import annotations
import asyncio
import logging
import logging.config
import os
from contextlib import asynccontextmanager
from typing import Any
import yaml
import httpx
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from app.exceptions import ProviderAPICallError, ProviderConfigurationError
from app.models import (
@ -15,30 +20,42 @@ from app.models import (
DataImportAnalysisJobRequest,
LLMRequest,
LLMResponse,
TableProfilingJobAck,
TableProfilingJobRequest,
TableSnippetUpsertRequest,
TableSnippetUpsertResponse,
)
from app.services import LLMGateway
from app.services.import_analysis import process_import_analysis_job
from app.services.table_profiling import process_table_profiling_job
from app.services.table_snippet import upsert_action_result
def _ensure_log_directories(config: dict[str, Any]) -> None:
handlers = config.get("handlers", {})
for handler_config in handlers.values():
filename = handler_config.get("filename")
if not filename:
continue
directory = os.path.dirname(filename)
if directory and not os.path.exists(directory):
os.makedirs(directory, exist_ok=True)
def _configure_logging() -> None:
level_name = os.getenv("LOG_LEVEL", "INFO").upper()
level = getattr(logging, level_name, logging.INFO)
log_format = os.getenv(
"LOG_FORMAT",
"%(asctime)s %(levelname)s %(name)s:%(lineno)d %(message)s",
config_path = os.getenv("LOGGING_CONFIG", "logging.yaml")
if os.path.exists(config_path):
with open(config_path, "r", encoding="utf-8") as fh:
config = yaml.safe_load(fh)
if isinstance(config, dict):
_ensure_log_directories(config)
logging.config.dictConfig(config)
return
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s:%(lineno)d %(message)s",
)
root = logging.getLogger()
if not root.handlers:
logging.basicConfig(level=level, format=log_format)
else:
root.setLevel(level)
formatter = logging.Formatter(log_format)
for handler in root.handlers:
handler.setLevel(level)
handler.setFormatter(formatter)
_configure_logging()
logger = logging.getLogger(__name__)
@ -119,6 +136,24 @@ def create_app() -> FastAPI:
lifespan=lifespan,
)
@application.exception_handler(RequestValidationError)
async def request_validation_exception_handler(
request: Request, exc: RequestValidationError
) -> JSONResponse:
try:
raw_body = await request.body()
except Exception: # pragma: no cover - defensive
raw_body = b"<unavailable>"
truncated_body = raw_body[:4096]
logger.warning(
"Validation error on %s %s: %s | body preview=%s",
request.method,
request.url.path,
exc.errors(),
truncated_body.decode("utf-8", errors="ignore"),
)
return JSONResponse(status_code=422, content={"detail": exc.errors()})
@application.post(
"/v1/chat/completions",
response_model=LLMResponse,
@ -164,6 +199,52 @@ def create_app() -> FastAPI:
return DataImportAnalysisJobAck(import_record_id=payload.import_record_id, status="accepted")
@application.post(
"/v1/table/profiling",
response_model=TableProfilingJobAck,
summary="Run end-to-end GE profiling pipeline and notify via callback per action",
status_code=202,
)
async def run_table_profiling(
payload: TableProfilingJobRequest,
gateway: LLMGateway = Depends(get_gateway),
client: httpx.AsyncClient = Depends(get_http_client),
) -> TableProfilingJobAck:
request_copy = payload.model_copy(deep=True)
async def _runner() -> None:
await process_table_profiling_job(request_copy, gateway, client)
asyncio.create_task(_runner())
return TableProfilingJobAck(
table_id=payload.table_id,
version_ts=payload.version_ts,
status="accepted",
)
@application.post(
"/v1/table/snippet",
response_model=TableSnippetUpsertResponse,
summary="Persist or update action results, such as table snippets.",
)
async def upsert_table_snippet(
payload: TableSnippetUpsertRequest,
) -> TableSnippetUpsertResponse:
request_copy = payload.model_copy(deep=True)
try:
return await asyncio.to_thread(upsert_action_result, request_copy)
except Exception as exc:
logger.error(
"Failed to upsert table snippet: table_id=%s version_ts=%s action_type=%s",
payload.table_id,
payload.version_ts,
payload.action_type,
exc_info=True,
)
raise HTTPException(status_code=500, detail=str(exc)) from exc
@application.post("/__mock__/import-callback")
async def mock_import_callback(payload: dict[str, Any]) -> dict[str, str]:
logger.info("Received import analysis callback: %s", payload)

View File

@ -1,5 +1,6 @@
from __future__ import annotations
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Union
@ -135,3 +136,89 @@ class DataImportAnalysisJobRequest(BaseModel):
class DataImportAnalysisJobAck(BaseModel):
import_record_id: str = Field(..., description="Echo of the import record identifier")
status: str = Field("accepted", description="Processing status acknowledgement.")
class TableProfilingJobRequest(BaseModel):
table_id: str = Field(..., description="Unique identifier for the table to profile.")
version_ts: str = Field(
...,
pattern=r"^\d{14}$",
description="Version timestamp expressed as fourteen digit string (yyyyMMddHHmmss).",
)
callback_url: HttpUrl = Field(
...,
description="Callback endpoint invoked after each pipeline action completes.",
)
table_schema: Optional[Any] = Field(
None,
description="Schema structure snapshot for the current table version.",
)
table_schema_version_id: Optional[str] = Field(
None,
description="Identifier for the schema snapshot provided in table_schema.",
)
table_link_info: Optional[Dict[str, Any]] = Field(
None,
description=(
"Information describing how to locate the source table for profiling. "
"For example: {'type': 'sql', 'connection_string': 'mysql+pymysql://user:pass@host/db', "
"'table': 'schema.table_name'}."
),
)
table_access_info: Optional[Dict[str, Any]] = Field(
None,
description=(
"Credentials or supplemental parameters required to access the table described in table_link_info. "
"These values can be merged into the connection string using Python format placeholders."
),
)
ge_batch_request: Optional[Dict[str, Any]] = Field(
None,
description="Optional Great Expectations batch request payload used for profiling.",
)
ge_expectation_suite_name: Optional[str] = Field(
None,
description="Expectation suite name used during profiling. Created automatically when absent.",
)
ge_data_context_root: Optional[str] = Field(
None,
description="Custom root directory for the Great Expectations data context. Defaults to project ./gx.",
)
ge_datasource_name: Optional[str] = Field(
None,
description="Datasource name registered inside the GE context when batch_request is not supplied.",
)
ge_data_asset_name: Optional[str] = Field(
None,
description="Data asset reference used when inferring batch request from datasource configuration.",
)
ge_profiler_type: str = Field(
"user_configurable",
description="Profiler implementation identifier. Currently supports 'user_configurable' or 'data_assistant'.",
)
llm_model: Optional[str] = Field(
None,
description="Default LLM model spec applied to prompt-based actions when overrides are omitted.",
)
result_desc_model: Optional[str] = Field(
None,
description="LLM model override used for GE result description (action 2).",
)
snippet_model: Optional[str] = Field(
None,
description="LLM model override used for snippet generation (action 3).",
)
snippet_alias_model: Optional[str] = Field(
None,
description="LLM model override used for snippet alias enrichment (action 4).",
)
extra_options: Optional[Dict[str, Any]] = Field(
None,
description="Miscellaneous execution flags applied across pipeline steps.",
)
class TableProfilingJobAck(BaseModel):
table_id: str = Field(..., description="Echo of the table identifier.")
version_ts: str = Field(..., description="Echo of the profiling version timestamp (yyyyMMddHHmmss).")
status: str = Field("accepted", description="Processing acknowledgement status.")

View File

@ -0,0 +1,832 @@
from __future__ import annotations
import asyncio
import json
import logging
import os
import re
from datetime import date, datetime
from dataclasses import asdict, dataclass, is_dataclass
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import httpx
import great_expectations as gx
from great_expectations.core.batch import RuntimeBatchRequest
from great_expectations.core.expectation_suite import ExpectationSuite
from great_expectations.data_context import AbstractDataContext
from great_expectations.exceptions import DataContextError, MetricResolutionError
from app.exceptions import ProviderAPICallError
from app.models import TableProfilingJobRequest
from app.services import LLMGateway
from app.settings import DEFAULT_IMPORT_MODEL
from app.services.import_analysis import (
IMPORT_GATEWAY_BASE_URL,
resolve_provider_from_model,
)
logger = logging.getLogger(__name__)
GE_REPORT_RELATIVE_PATH = Path("uncommitted") / "data_docs" / "local_site" / "index.html"
PROMPT_FILENAMES = {
"ge_result_desc": "ge_result_desc_prompt.md",
"snippet_generator": "snippet_generator.md",
"snippet_alias": "snippet_alias_generator.md",
}
DEFAULT_CHAT_TIMEOUT_SECONDS = 90.0
@dataclass
class GEProfilingArtifacts:
profiling_result: Dict[str, Any]
profiling_summary: Dict[str, Any]
docs_path: str
class PipelineActionType:
GE_PROFILING = "ge_profiling"
GE_RESULT_DESC = "ge_result_desc"
SNIPPET = "snippet"
SNIPPET_ALIAS = "snippet_alias"
def _project_root() -> Path:
return Path(__file__).resolve().parents[2]
def _prompt_dir() -> Path:
return _project_root() / "prompt"
@lru_cache(maxsize=None)
def _load_prompt_parts(filename: str) -> Tuple[str, str]:
prompt_path = _prompt_dir() / filename
if not prompt_path.exists():
raise FileNotFoundError(f"Prompt template not found: {prompt_path}")
raw = prompt_path.read_text(encoding="utf-8")
splitter = "用户消息User"
if splitter not in raw:
raise ValueError(f"Prompt template '{filename}' missing separator '{splitter}'.")
system_raw, user_raw = raw.split(splitter, maxsplit=1)
system_text = system_raw.replace("系统角色System", "").strip()
user_text = user_raw.strip()
return system_text, user_text
def _render_prompt(template_key: str, replacements: Dict[str, str]) -> Tuple[str, str]:
filename = PROMPT_FILENAMES[template_key]
system_text, user_template = _load_prompt_parts(filename)
rendered_user = user_template
for key, value in replacements.items():
rendered_user = rendered_user.replace(key, value)
return system_text, rendered_user
def _extract_timeout_seconds(options: Optional[Dict[str, Any]]) -> Optional[float]:
if not options:
return None
value = options.get("llm_timeout_seconds")
if value is None:
return None
try:
timeout = float(value)
if timeout <= 0:
raise ValueError
return timeout
except (TypeError, ValueError):
logger.warning(
"Invalid llm_timeout_seconds value in extra_options: %r. Falling back to default.",
value,
)
return DEFAULT_CHAT_TIMEOUT_SECONDS
def _extract_json_payload(content: str) -> str:
fenced = re.search(
r"```(?:json)?\s*([\s\S]+?)```",
content,
flags=re.IGNORECASE,
)
if fenced:
snippet = fenced.group(1).strip()
if snippet:
return snippet
stripped = content.strip()
if not stripped:
raise ValueError("Empty LLM content.")
for opener, closer in (("{", "}"), ("[", "]")):
start = stripped.find(opener)
end = stripped.rfind(closer)
if start != -1 and end != -1 and end > start:
candidate = stripped[start : end + 1].strip()
return candidate
return stripped
def _parse_completion_payload(response_payload: Dict[str, Any]) -> Any:
choices = response_payload.get("choices") or []
if not choices:
raise ProviderAPICallError("LLM response did not contain choices to parse.")
message = choices[0].get("message") or {}
content = message.get("content") or ""
if not content.strip():
raise ProviderAPICallError("LLM response content is empty.")
json_payload = _extract_json_payload(content)
try:
return json.loads(json_payload)
except json.JSONDecodeError as exc:
preview = json_payload[:800]
logger.error("Failed to parse JSON from LLM response: %s", preview, exc_info=True)
raise ProviderAPICallError("LLM response JSON parsing failed.") from exc
async def _post_callback(callback_url: str, payload: Dict[str, Any], client: httpx.AsyncClient) -> None:
safe_payload = _normalize_for_json(payload)
try:
logger.info(
"Posting pipeline action callback to %s: %s",
callback_url,
json.dumps(safe_payload, ensure_ascii=False),
)
response = await client.post(callback_url, json=safe_payload)
response.raise_for_status()
except httpx.HTTPError as exc:
logger.error("Callback delivery to %s failed: %s", callback_url, exc, exc_info=True)
def _sanitize_value_set(value: Any, max_values: int) -> Tuple[Any, Optional[Dict[str, int]]]:
if not isinstance(value, list):
return value, None
original_len = len(value)
if original_len <= max_values:
return value, None
trimmed = value[:max_values]
return trimmed, {"original_length": original_len, "retained": max_values}
def _sanitize_expectation_suite(suite: ExpectationSuite, max_value_set_values: int = 100) -> Dict[str, Any]:
suite_dict = suite.to_json_dict()
remarks: List[Dict[str, Any]] = []
for expectation in suite_dict.get("expectations", []):
kwargs = expectation.get("kwargs", {})
if "value_set" in kwargs:
sanitized_value, note = _sanitize_value_set(kwargs["value_set"], max_value_set_values)
kwargs["value_set"] = sanitized_value
if note:
expectation.setdefault("meta", {})
expectation["meta"]["value_set_truncated"] = note
remarks.append(
{
"column": kwargs.get("column"),
"expectation": expectation.get("expectation_type"),
"note": note,
}
)
if remarks:
suite_dict.setdefault("meta", {})
suite_dict["meta"]["value_set_truncations"] = remarks
return suite_dict
def _summarize_expectation_suite(suite_dict: Dict[str, Any]) -> Dict[str, Any]:
column_map: Dict[str, Dict[str, Any]] = {}
table_expectations: List[Dict[str, Any]] = []
for expectation in suite_dict.get("expectations", []):
expectation_type = expectation.get("expectation_type")
kwargs = expectation.get("kwargs", {})
column = kwargs.get("column")
summary_entry: Dict[str, Any] = {"expectation": expectation_type}
if "value_set" in kwargs and isinstance(kwargs["value_set"], list):
summary_entry["value_set_size"] = len(kwargs["value_set"])
summary_entry["value_set_preview"] = kwargs["value_set"][:5]
if column:
column_entry = column_map.setdefault(
column,
{"name": column, "expectations": []},
)
column_entry["expectations"].append(summary_entry)
else:
table_expectations.append(summary_entry)
summary = {
"column_profiles": list(column_map.values()),
"table_level_expectations": table_expectations,
"total_expectations": len(suite_dict.get("expectations", [])),
}
return summary
def _sanitize_identifier(raw: Optional[str], fallback: str) -> str:
if not raw:
return fallback
candidate = re.sub(r"[^0-9A-Za-z_]+", "_", raw).strip("_")
return candidate or fallback
def _format_connection_string(template: str, access_info: Dict[str, Any]) -> str:
if not access_info:
return template
try:
return template.format_map({k: v for k, v in access_info.items()})
except KeyError as exc:
missing = exc.args[0]
raise ValueError(f"table_access_info missing key '{missing}' required by connection_string.") from exc
def _ensure_sql_runtime_datasource(
context: AbstractDataContext,
datasource_name: str,
connection_string: str,
) -> None:
try:
datasource = context.get_datasource(datasource_name)
except (DataContextError, ValueError) as exc:
message = str(exc)
if "Could not find a datasource" in message or "Unable to load datasource" in message:
datasource = None
else: # pragma: no cover - defensive
raise RuntimeError(f"Failed to inspect datasource '{datasource_name}'.") from exc
except Exception as exc: # pragma: no cover - defensive
raise RuntimeError(f"Failed to inspect datasource '{datasource_name}'.") from exc
if datasource is not None:
execution_engine = getattr(datasource, "execution_engine", None)
current_conn = getattr(execution_engine, "connection_string", None)
if current_conn and current_conn != connection_string:
logger.info(
"Existing datasource %s uses different connection string; creating dedicated runtime datasource.",
datasource_name,
)
try:
context.delete_datasource(datasource_name)
except Exception as exc: # pragma: no cover - defensive
logger.warning(
"Failed to delete datasource %s before recreation: %s",
datasource_name,
exc,
)
else:
datasource = None
if datasource is not None:
return
runtime_datasource_config = {
"name": datasource_name,
"class_name": "Datasource",
"execution_engine": {
"class_name": "SqlAlchemyExecutionEngine",
"connection_string": connection_string,
},
"data_connectors": {
"runtime_connector": {
"class_name": "RuntimeDataConnector",
"batch_identifiers": ["default_identifier_name"],
}
},
}
try:
context.add_datasource(**runtime_datasource_config)
except Exception as exc: # pragma: no cover - defensive
raise RuntimeError(f"Failed to create runtime datasource '{datasource_name}'.") from exc
def _build_sql_runtime_batch_request(
context: AbstractDataContext,
request: TableProfilingJobRequest,
) -> RuntimeBatchRequest:
link_info = request.table_link_info or {}
access_info = request.table_access_info or {}
connection_template = link_info.get("connection_string")
if not connection_template:
raise ValueError("table_link_info.connection_string is required when using table_link_info.")
connection_string = _format_connection_string(connection_template, access_info)
source_type = (link_info.get("type") or "sql").lower()
if source_type != "sql":
raise ValueError(f"Unsupported table_link_info.type='{source_type}'. Only 'sql' is supported.")
query = link_info.get("query")
table_name = link_info.get("table") or link_info.get("table_name")
schema_name = link_info.get("schema")
if not query and not table_name:
raise ValueError("Either table_link_info.query or table_link_info.table must be provided.")
if not query:
if not table_name:
raise ValueError("table_link_info.table must be provided when query is omitted.")
identifier = re.compile(r"^[A-Za-z_][A-Za-z0-9_$]*$")
def _quote(name: str) -> str:
if identifier.match(name):
return name
return f"`{name.replace('`', '``')}`"
if schema_name:
schema_part = schema_name if "." not in schema_name else schema_name.split(".")[-1]
table_part = table_name if "." not in table_name else table_name.split(".")[-1]
qualified_table = f"{_quote(schema_part)}.{_quote(table_part)}"
else:
qualified_table = _quote(table_name)
query = f"SELECT * FROM {qualified_table}"
limit = link_info.get("limit")
if isinstance(limit, int) and limit > 0:
query = f"{query} LIMIT {limit}"
datasource_name = request.ge_datasource_name or _sanitize_identifier(
f"{request.table_id}_runtime_ds", "runtime_ds"
)
data_asset_name = request.ge_data_asset_name or _sanitize_identifier(
table_name or "runtime_query", "runtime_query"
)
_ensure_sql_runtime_datasource(context, datasource_name, connection_string)
batch_identifiers = {
"default_identifier_name": f"{request.table_id}:{request.version_ts}",
}
return RuntimeBatchRequest(
datasource_name=datasource_name,
data_connector_name="runtime_connector",
data_asset_name=data_asset_name,
runtime_parameters={"query": query},
batch_identifiers=batch_identifiers,
)
def _run_onboarding_assistant(
context: AbstractDataContext,
batch_request: Any,
suite_name: str,
) -> Tuple[ExpectationSuite, Any]:
assistant = context.assistants.onboarding
assistant_result = assistant.run(batch_request=batch_request)
suite = assistant_result.get_expectation_suite(expectation_suite_name=suite_name)
context.save_expectation_suite(suite, expectation_suite_name=suite_name)
validation_getter = getattr(assistant_result, "get_validation_result", None)
if callable(validation_getter):
validation_result = validation_getter()
else:
validation_result = getattr(assistant_result, "validation_result", None)
if validation_result is None:
# Fallback: rerun validation using the freshly generated expectation suite.
validator = context.get_validator(
batch_request=batch_request,
expectation_suite_name=suite_name,
)
validation_result = validator.validate()
return suite, validation_result
def _resolve_context(request: TableProfilingJobRequest) -> AbstractDataContext:
context_kwargs: Dict[str, Any] = {}
if request.ge_data_context_root:
context_kwargs["project_root_dir"] = request.ge_data_context_root
elif os.environ.get("GE_DATA_CONTEXT_ROOT"):
context_kwargs["project_root_dir"] = os.environ["GE_DATA_CONTEXT_ROOT"]
else:
context_kwargs["project_root_dir"] = str(_project_root())
return gx.get_context(**context_kwargs)
def _build_batch_request(
context: AbstractDataContext,
request: TableProfilingJobRequest,
) -> Any:
if request.ge_batch_request:
from great_expectations.core.batch import BatchRequest
return BatchRequest(**request.ge_batch_request)
if request.table_link_info:
return _build_sql_runtime_batch_request(context, request)
if not request.ge_datasource_name or not request.ge_data_asset_name:
raise ValueError(
"ge_batch_request or (ge_datasource_name and ge_data_asset_name) must be provided."
)
datasource = context.get_datasource(request.ge_datasource_name)
data_asset = datasource.get_asset(request.ge_data_asset_name)
return data_asset.build_batch_request()
async def _run_ge_profiling(request: TableProfilingJobRequest) -> GEProfilingArtifacts:
def _execute() -> GEProfilingArtifacts:
context = _resolve_context(request)
suite_name = (
request.ge_expectation_suite_name
or f"{request.table_id}_profiling"
)
batch_request = _build_batch_request(context, request)
try:
context.get_expectation_suite(suite_name)
except DataContextError:
context.add_expectation_suite(suite_name)
validator = context.get_validator(
batch_request=batch_request,
expectation_suite_name=suite_name,
)
profiler_type = (request.ge_profiler_type or "user_configurable").lower()
if profiler_type == "data_assistant":
suite, validation_result = _run_onboarding_assistant(
context,
batch_request,
suite_name,
)
else:
try:
from great_expectations.profile.user_configurable_profiler import (
UserConfigurableProfiler,
)
except ImportError as err: # pragma: no cover - dependency guard
raise RuntimeError(
"UserConfigurableProfiler is unavailable; install great_expectations profiling extra or switch profiler."
) from err
profiler = UserConfigurableProfiler(profile_dataset=validator)
try:
suite = profiler.build_suite()
context.save_expectation_suite(suite, expectation_suite_name=suite_name)
validator.expectation_suite = suite
validation_result = validator.validate()
except MetricResolutionError as exc:
logger.warning(
"UserConfigurableProfiler failed (%s); falling back to data assistant profiling.",
exc,
)
suite, validation_result = _run_onboarding_assistant(
context,
batch_request,
suite_name,
)
sanitized_suite = _sanitize_expectation_suite(suite)
summary = _summarize_expectation_suite(sanitized_suite)
validation_dict = validation_result.to_json_dict()
context.build_data_docs()
docs_path = Path(context.root_directory) / GE_REPORT_RELATIVE_PATH
profiling_result = {
"expectation_suite": sanitized_suite,
"validation_result": validation_dict,
"batch_request": getattr(batch_request, "to_json_dict", lambda: None)() or getattr(batch_request, "dict", lambda: None)(),
}
return GEProfilingArtifacts(
profiling_result=profiling_result,
profiling_summary=summary,
docs_path=str(docs_path),
)
return await asyncio.to_thread(_execute)
async def _call_chat_completions(
*,
model_spec: str,
system_prompt: str,
user_prompt: str,
client: httpx.AsyncClient,
temperature: float = 0.2,
timeout_seconds: Optional[float] = None,
) -> Any:
provider, model_name = resolve_provider_from_model(model_spec)
payload = {
"provider": provider.value,
"model": model_name,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
"temperature": temperature,
}
payload_size_bytes = len(json.dumps(payload, ensure_ascii=False).encode("utf-8"))
url = f"{IMPORT_GATEWAY_BASE_URL.rstrip('/')}/v1/chat/completions"
try:
# log the request whole info
logger.info(
"Calling chat completions API %s with model %s and size %s and payload %s",
url,
model_name,
payload_size_bytes,
payload,
)
response = await client.post(url, json=payload, timeout=timeout_seconds)
response.raise_for_status()
except httpx.HTTPError as exc:
error_name = exc.__class__.__name__
detail = str(exc).strip()
if detail:
message = f"Chat completions request failed ({error_name}): {detail}"
else:
message = f"Chat completions request failed ({error_name})."
raise ProviderAPICallError(message) from exc
try:
response_payload = response.json()
except ValueError as exc:
raise ProviderAPICallError("Chat completions response was not valid JSON.") from exc
return _parse_completion_payload(response_payload)
def _normalize_for_json(value: Any) -> Any:
if value is None or isinstance(value, (str, int, float, bool)):
return value
if isinstance(value, (datetime, date)):
return str(value)
if hasattr(value, "model_dump"):
try:
return value.model_dump()
except Exception: # pragma: no cover - defensive
pass
if is_dataclass(value):
return asdict(value)
if isinstance(value, dict):
return {k: _normalize_for_json(v) for k, v in value.items()}
if isinstance(value, (list, tuple, set)):
return [_normalize_for_json(v) for v in value]
if hasattr(value, "to_json_dict"):
try:
return value.to_json_dict()
except Exception: # pragma: no cover - defensive
pass
if hasattr(value, "__dict__"):
return _normalize_for_json(value.__dict__)
return repr(value)
def _json_dumps(data: Any) -> str:
normalised = _normalize_for_json(data)
return json.dumps(normalised, ensure_ascii=False, indent=2)
def _preview_for_log(data: Any) -> str:
try:
serialised = _json_dumps(data)
except Exception:
serialised = repr(data)
return serialised
def _profiling_request_for_log(request: TableProfilingJobRequest) -> Dict[str, Any]:
payload = request.model_dump()
access_info = payload.get("table_access_info")
if isinstance(access_info, dict):
payload["table_access_info"] = {key: "***" for key in access_info.keys()}
return payload
async def _execute_result_desc(
profiling_json: Dict[str, Any],
_request: TableProfilingJobRequest,
llm_model: str,
client: httpx.AsyncClient,
timeout_seconds: Optional[float],
) -> Dict[str, Any]:
system_prompt, user_prompt = _render_prompt(
"ge_result_desc",
{"{{GE_RESULT_JSON}}": _json_dumps(profiling_json)},
)
llm_output = await _call_chat_completions(
model_spec=llm_model,
system_prompt=system_prompt,
user_prompt=user_prompt,
client=client,
timeout_seconds=timeout_seconds,
)
if not isinstance(llm_output, dict):
raise ProviderAPICallError("GE result description payload must be a JSON object.")
return llm_output
async def _execute_snippet_generation(
table_desc_json: Dict[str, Any],
_request: TableProfilingJobRequest,
llm_model: str,
client: httpx.AsyncClient,
timeout_seconds: Optional[float],
) -> List[Dict[str, Any]]:
system_prompt, user_prompt = _render_prompt(
"snippet_generator",
{"{{TABLE_PROFILE_JSON}}": _json_dumps(table_desc_json)},
)
llm_output = await _call_chat_completions(
model_spec=llm_model,
system_prompt=system_prompt,
user_prompt=user_prompt,
client=client,
timeout_seconds=timeout_seconds,
)
if not isinstance(llm_output, list):
raise ProviderAPICallError("Snippet generator must return a JSON array.")
return llm_output
async def _execute_snippet_alias(
snippets_json: List[Dict[str, Any]],
_request: TableProfilingJobRequest,
llm_model: str,
client: httpx.AsyncClient,
timeout_seconds: Optional[float],
) -> List[Dict[str, Any]]:
system_prompt, user_prompt = _render_prompt(
"snippet_alias",
{"{{SNIPPET_ARRAY}}": _json_dumps(snippets_json)},
)
llm_output = await _call_chat_completions(
model_spec=llm_model,
system_prompt=system_prompt,
user_prompt=user_prompt,
client=client,
timeout_seconds=timeout_seconds,
)
if not isinstance(llm_output, list):
raise ProviderAPICallError("Snippet alias generator must return a JSON array.")
return llm_output
async def _run_action_with_callback(
*,
action_type: str,
runner,
callback_base: Dict[str, Any],
client: httpx.AsyncClient,
callback_url: str,
input_payload: Any = None,
model_spec: Optional[str] = None,
) -> Any:
if input_payload is not None:
logger.info(
"Pipeline action %s input: %s",
action_type,
_preview_for_log(input_payload),
)
try:
result = await runner()
except Exception as exc:
failure_payload = dict(callback_base)
failure_payload.update(
{
"status": "failed",
"action_type": action_type,
"error": str(exc),
}
)
if model_spec is not None:
failure_payload["model"] = model_spec
await _post_callback(callback_url, failure_payload, client)
raise
success_payload = dict(callback_base)
success_payload.update(
{
"status": "success",
"action_type": action_type,
}
)
if model_spec is not None:
success_payload["model"] = model_spec
logger.info(
"Pipeline action %s output: %s",
action_type,
_preview_for_log(result),
)
if action_type == PipelineActionType.GE_PROFILING:
artifacts: GEProfilingArtifacts = result
success_payload["profiling_json"] = artifacts.profiling_result
success_payload["profiling_summary"] = artifacts.profiling_summary
success_payload["ge_report_path"] = artifacts.docs_path
elif action_type == PipelineActionType.GE_RESULT_DESC:
success_payload["table_desc_json"] = result
elif action_type == PipelineActionType.SNIPPET:
success_payload["snippet_json"] = result
elif action_type == PipelineActionType.SNIPPET_ALIAS:
success_payload["snippet_alias_json"] = result
await _post_callback(callback_url, success_payload, client)
return result
async def process_table_profiling_job(
request: TableProfilingJobRequest,
_gateway: LLMGateway,
client: httpx.AsyncClient,
) -> None:
"""Sequentially execute the four-step profiling pipeline and emit callbacks per action."""
timeout_seconds = _extract_timeout_seconds(request.extra_options)
if timeout_seconds is None:
timeout_seconds = DEFAULT_CHAT_TIMEOUT_SECONDS
base_payload = {
"table_id": request.table_id,
"version_ts": request.version_ts,
"callback_url": str(request.callback_url),
"table_schema": request.table_schema,
"table_schema_version_id": request.table_schema_version_id,
"llm_model": request.llm_model,
"llm_timeout_seconds": timeout_seconds,
}
logging_request_payload = _profiling_request_for_log(request)
try:
artifacts: GEProfilingArtifacts = await _run_action_with_callback(
action_type=PipelineActionType.GE_PROFILING,
runner=lambda: _run_ge_profiling(request),
callback_base=base_payload,
client=client,
callback_url=str(request.callback_url),
input_payload=logging_request_payload,
model_spec=request.llm_model,
)
table_desc_json: Dict[str, Any] = await _run_action_with_callback(
action_type=PipelineActionType.GE_RESULT_DESC,
runner=lambda: _execute_result_desc(
artifacts.profiling_result,
request,
request.llm_model,
client,
timeout_seconds,
),
callback_base=base_payload,
client=client,
callback_url=str(request.callback_url),
input_payload=artifacts.profiling_result,
model_spec=request.llm_model,
)
snippet_json: List[Dict[str, Any]] = await _run_action_with_callback(
action_type=PipelineActionType.SNIPPET,
runner=lambda: _execute_snippet_generation(
table_desc_json,
request,
request.llm_model,
client,
timeout_seconds,
),
callback_base=base_payload,
client=client,
callback_url=str(request.callback_url),
input_payload=table_desc_json,
model_spec=request.llm_model,
)
await _run_action_with_callback(
action_type=PipelineActionType.SNIPPET_ALIAS,
runner=lambda: _execute_snippet_alias(
snippet_json,
request,
request.llm_model,
client,
timeout_seconds,
),
callback_base=base_payload,
client=client,
callback_url=str(request.callback_url),
input_payload=snippet_json,
model_spec=request.llm_model,
)
except Exception: # pragma: no cover - defensive catch
logger.exception(
"Table profiling pipeline failed for table_id=%s version_ts=%s",
request.table_id,
request.version_ts,
)

View File

@ -0,0 +1,184 @@
from __future__ import annotations
import json
import logging
from typing import Any, Dict, Tuple
from sqlalchemy import text
from sqlalchemy.engine import Engine
from sqlalchemy.exc import SQLAlchemyError
from app.db import get_engine
from app.models import (
ActionType,
TableSnippetUpsertRequest,
TableSnippetUpsertResponse,
)
logger = logging.getLogger(__name__)
def _serialize_json(value: Any) -> Tuple[str | None, int | None]:
logger.debug("Serializing JSON payload: %s", value)
if value is None:
return None, None
if isinstance(value, str):
encoded = value.encode("utf-8")
return value, len(encoded)
serialized = json.dumps(value, ensure_ascii=False)
encoded = serialized.encode("utf-8")
return serialized, len(encoded)
def _prepare_table_schema(value: Any) -> str:
logger.debug("Preparing table_schema payload.")
if isinstance(value, str):
return value
return json.dumps(value, ensure_ascii=False)
def _collect_common_columns(request: TableSnippetUpsertRequest) -> Dict[str, Any]:
logger.debug(
"Collecting common columns for table_id=%s version_ts=%s action_type=%s",
request.table_id,
request.version_ts,
request.action_type,
)
payload: Dict[str, Any] = {
"table_id": request.table_id,
"version_ts": request.version_ts,
"action_type": request.action_type.value,
"status": request.status.value,
"callback_url": str(request.callback_url),
"table_schema_version_id": request.table_schema_version_id,
"table_schema": _prepare_table_schema(request.table_schema),
}
if request.error_code is not None:
logger.debug("Adding error_code: %s", request.error_code)
payload["error_code"] = request.error_code
if request.error_message is not None:
logger.debug("Adding error_message: %s", request.error_message)
payload["error_message"] = request.error_message
if request.started_at is not None:
payload["started_at"] = request.started_at
if request.finished_at is not None:
payload["finished_at"] = request.finished_at
if request.duration_ms is not None:
payload["duration_ms"] = request.duration_ms
if request.result_checksum is not None:
payload["result_checksum"] = request.result_checksum
logger.debug("Collected common payload: %s", payload)
return payload
def _apply_action_payload(
request: TableSnippetUpsertRequest,
payload: Dict[str, Any],
) -> None:
logger.debug("Applying action-specific payload for action_type=%s", request.action_type)
if request.action_type == ActionType.GE_PROFILING:
full_json, full_size = _serialize_json(request.result_json)
summary_json, summary_size = _serialize_json(request.result_summary_json)
if full_json is not None:
payload["ge_profiling_full"] = full_json
payload["ge_profiling_full_size_bytes"] = full_size
if summary_json is not None:
payload["ge_profiling_summary"] = summary_json
payload["ge_profiling_summary_size_bytes"] = summary_size
if full_size is not None or summary_size is not None:
payload["ge_profiling_total_size_bytes"] = (full_size or 0) + (
summary_size or 0
)
if request.html_report_url:
payload["ge_profiling_html_report_url"] = request.html_report_url
elif request.action_type == ActionType.GE_RESULT_DESC:
full_json, full_size = _serialize_json(request.result_json)
if full_json is not None:
payload["ge_result_desc_full"] = full_json
payload["ge_result_desc_full_size_bytes"] = full_size
elif request.action_type == ActionType.SNIPPET:
full_json, full_size = _serialize_json(request.result_json)
if full_json is not None:
payload["snippet_full"] = full_json
payload["snippet_full_size_bytes"] = full_size
elif request.action_type == ActionType.SNIPPET_ALIAS:
full_json, full_size = _serialize_json(request.result_json)
if full_json is not None:
payload["snippet_alias_full"] = full_json
payload["snippet_alias_full_size_bytes"] = full_size
else:
logger.error("Unsupported action type encountered: %s", request.action_type)
raise ValueError(f"Unsupported action type '{request.action_type}'.")
logger.debug("Payload after applying action-specific data: %s", payload)
def _build_insert_statement(columns: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
logger.debug("Building insert statement for columns: %s", list(columns.keys()))
column_names = list(columns.keys())
placeholders = [f":{name}" for name in column_names]
update_assignments = [
f"{name}=VALUES({name})"
for name in column_names
if name not in {"table_id", "version_ts", "action_type"}
]
update_assignments.append("updated_at=CURRENT_TIMESTAMP")
sql = (
"INSERT INTO action_results ({cols}) VALUES ({vals}) "
"ON DUPLICATE KEY UPDATE {updates}"
).format(
cols=", ".join(column_names),
vals=", ".join(placeholders),
updates=", ".join(update_assignments),
)
logger.debug("Generated SQL: %s", sql)
return sql, columns
def _execute_upsert(engine: Engine, sql: str, params: Dict[str, Any]) -> int:
logger.info("Executing upsert for table_id=%s version_ts=%s action_type=%s", params.get("table_id"), params.get("version_ts"), params.get("action_type"))
with engine.begin() as conn:
result = conn.execute(text(sql), params)
logger.info("Rows affected: %s", result.rowcount)
return result.rowcount
def upsert_action_result(request: TableSnippetUpsertRequest) -> TableSnippetUpsertResponse:
logger.info(
"Received upsert request: table_id=%s version_ts=%s action_type=%s status=%s",
request.table_id,
request.version_ts,
request.action_type,
request.status,
)
logger.debug("Request payload: %s", request.model_dump())
columns = _collect_common_columns(request)
_apply_action_payload(request, columns)
sql, params = _build_insert_statement(columns)
logger.debug("Final SQL params: %s", params)
engine = get_engine()
try:
rowcount = _execute_upsert(engine, sql, params)
except SQLAlchemyError as exc:
logger.exception(
"Failed to upsert action result: table_id=%s version_ts=%s action_type=%s",
request.table_id,
request.version_ts,
request.action_type,
)
raise RuntimeError(f"Database operation failed: {exc}") from exc
updated = rowcount > 1
return TableSnippetUpsertResponse(
table_id=request.table_id,
version_ts=request.version_ts,
action_type=request.action_type,
status=request.status,
updated=updated,
)

54
table_snippet.sql Normal file
View File

@ -0,0 +1,54 @@
CREATE TABLE IF NOT EXISTS action_results (
id BIGINT NOT NULL AUTO_INCREMENT COMMENT '主键',
table_id BIGINT NOT NULL COMMENT '表ID',
version_ts BIGINT NOT NULL COMMENT '版本时间戳(版本号)',
action_type ENUM('ge_profiling','ge_result_desc','snippet','snippet_alias') NOT NULL COMMENT '动作类型',
status ENUM('pending','running','success','failed','partial') NOT NULL DEFAULT 'pending' COMMENT '执行状态',
error_code VARCHAR(128) NULL,
error_message TEXT NULL,
-- 回调 & 观测
callback_url VARCHAR(1024) NOT NULL,
started_at DATETIME NULL,
finished_at DATETIME NULL,
duration_ms INT NULL,
-- 本次schema信息
table_schema_version_id BIGINT NOT NULL,
table_schema JSON NOT NULL,
-- ===== 动作1GE Profiling =====
ge_profiling_full JSON NULL COMMENT 'Profiling完整结果JSON',
ge_profiling_full_size_bytes BIGINT NULL,
ge_profiling_summary JSON NULL COMMENT 'Profiling摘要剔除大value_set等',
ge_profiling_summary_size_bytes BIGINT NULL,
ge_profiling_total_size_bytes BIGINT NULL COMMENT '上两者合计',
ge_profiling_html_report_url VARCHAR(1024) NULL COMMENT 'GE报告HTML路径/URL',
-- ===== 动作2GE Result Desc =====
ge_result_desc_full JSON NULL COMMENT '表描述结果JSON',
ge_result_desc_full_size_bytes BIGINT NULL,
-- ===== 动作3Snippet 生成 =====
snippet_full JSON NULL COMMENT 'SQL知识片段结果JSON',
snippet_full_size_bytes BIGINT NULL,
-- ===== 动作4Snippet Alias 改写 =====
snippet_alias_full JSON NULL COMMENT 'SQL片段改写/丰富结果JSON',
snippet_alias_full_size_bytes BIGINT NULL,
-- 通用可选指标
result_checksum VARBINARY(32) NULL COMMENT '对当前action有效载荷计算的MD5/xxhash',
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
PRIMARY KEY (id),
UNIQUE KEY uq_table_ver_action (table_id, version_ts, action_type),
KEY idx_status (status),
KEY idx_table (table_id, updated_at),
KEY idx_action_time (action_type, version_ts),
KEY idx_schema_version (table_schema_version_id)
) ENGINE=InnoDB
ROW_FORMAT=DYNAMIC
COMMENT='数据分析知识片段表';