From c2a08e4134c5e72a19820c3c3bb05a15da73e670 Mon Sep 17 00:00:00 2001 From: zhaoawd Date: Mon, 3 Nov 2025 00:18:26 +0800 Subject: [PATCH] =?UTF-8?q?table=20profiling=E5=8A=9F=E8=83=BD=E5=BC=80?= =?UTF-8?q?=E5=8F=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/db.py | 26 + app/main.py | 113 ++++- app/models.py | 87 ++++ app/services/table_profiling.py | 832 ++++++++++++++++++++++++++++++++ app/services/table_snippet.py | 184 +++++++ table_snippet.sql | 54 +++ 6 files changed, 1280 insertions(+), 16 deletions(-) create mode 100644 app/db.py create mode 100644 app/services/table_profiling.py create mode 100644 app/services/table_snippet.py create mode 100644 table_snippet.sql diff --git a/app/db.py b/app/db.py new file mode 100644 index 0000000..af9739f --- /dev/null +++ b/app/db.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import os +from functools import lru_cache + +from sqlalchemy import create_engine +from sqlalchemy.engine import Engine + + +@lru_cache(maxsize=1) +def get_engine() -> Engine: + """Return a cached SQLAlchemy engine configured from DATABASE_URL.""" + database_url = os.getenv( + "DATABASE_URL", + "mysql+pymysql://root:12345678@localhost:3306/data-ge?charset=utf8mb4", + ) + connect_args = {} + if database_url.startswith("sqlite"): + connect_args["check_same_thread"] = False + + return create_engine( + database_url, + pool_pre_ping=True, + future=True, + connect_args=connect_args, + ) diff --git a/app/main.py b/app/main.py index b7c0c0a..7b7d55b 100644 --- a/app/main.py +++ b/app/main.py @@ -2,12 +2,17 @@ from __future__ import annotations import asyncio import logging +import logging.config import os from contextlib import asynccontextmanager from typing import Any +import yaml + import httpx from fastapi import Depends, FastAPI, HTTPException, Request +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse from app.exceptions import ProviderAPICallError, ProviderConfigurationError from app.models import ( @@ -15,30 +20,42 @@ from app.models import ( DataImportAnalysisJobRequest, LLMRequest, LLMResponse, + TableProfilingJobAck, + TableProfilingJobRequest, + TableSnippetUpsertRequest, + TableSnippetUpsertResponse, ) from app.services import LLMGateway from app.services.import_analysis import process_import_analysis_job +from app.services.table_profiling import process_table_profiling_job +from app.services.table_snippet import upsert_action_result + + +def _ensure_log_directories(config: dict[str, Any]) -> None: + handlers = config.get("handlers", {}) + for handler_config in handlers.values(): + filename = handler_config.get("filename") + if not filename: + continue + directory = os.path.dirname(filename) + if directory and not os.path.exists(directory): + os.makedirs(directory, exist_ok=True) def _configure_logging() -> None: - level_name = os.getenv("LOG_LEVEL", "INFO").upper() - level = getattr(logging, level_name, logging.INFO) - log_format = os.getenv( - "LOG_FORMAT", - "%(asctime)s %(levelname)s %(name)s:%(lineno)d %(message)s", + config_path = os.getenv("LOGGING_CONFIG", "logging.yaml") + if os.path.exists(config_path): + with open(config_path, "r", encoding="utf-8") as fh: + config = yaml.safe_load(fh) + if isinstance(config, dict): + _ensure_log_directories(config) + logging.config.dictConfig(config) + return + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s:%(lineno)d %(message)s", ) - root = logging.getLogger() - - if not root.handlers: - logging.basicConfig(level=level, format=log_format) - else: - root.setLevel(level) - formatter = logging.Formatter(log_format) - for handler in root.handlers: - handler.setLevel(level) - handler.setFormatter(formatter) - _configure_logging() logger = logging.getLogger(__name__) @@ -119,6 +136,24 @@ def create_app() -> FastAPI: lifespan=lifespan, ) + @application.exception_handler(RequestValidationError) + async def request_validation_exception_handler( + request: Request, exc: RequestValidationError + ) -> JSONResponse: + try: + raw_body = await request.body() + except Exception: # pragma: no cover - defensive + raw_body = b"" + truncated_body = raw_body[:4096] + logger.warning( + "Validation error on %s %s: %s | body preview=%s", + request.method, + request.url.path, + exc.errors(), + truncated_body.decode("utf-8", errors="ignore"), + ) + return JSONResponse(status_code=422, content={"detail": exc.errors()}) + @application.post( "/v1/chat/completions", response_model=LLMResponse, @@ -164,6 +199,52 @@ def create_app() -> FastAPI: return DataImportAnalysisJobAck(import_record_id=payload.import_record_id, status="accepted") + @application.post( + "/v1/table/profiling", + response_model=TableProfilingJobAck, + summary="Run end-to-end GE profiling pipeline and notify via callback per action", + status_code=202, + ) + async def run_table_profiling( + payload: TableProfilingJobRequest, + gateway: LLMGateway = Depends(get_gateway), + client: httpx.AsyncClient = Depends(get_http_client), + ) -> TableProfilingJobAck: + request_copy = payload.model_copy(deep=True) + + async def _runner() -> None: + await process_table_profiling_job(request_copy, gateway, client) + + asyncio.create_task(_runner()) + + return TableProfilingJobAck( + table_id=payload.table_id, + version_ts=payload.version_ts, + status="accepted", + ) + + @application.post( + "/v1/table/snippet", + response_model=TableSnippetUpsertResponse, + summary="Persist or update action results, such as table snippets.", + ) + async def upsert_table_snippet( + payload: TableSnippetUpsertRequest, + ) -> TableSnippetUpsertResponse: + request_copy = payload.model_copy(deep=True) + + try: + return await asyncio.to_thread(upsert_action_result, request_copy) + except Exception as exc: + logger.error( + "Failed to upsert table snippet: table_id=%s version_ts=%s action_type=%s", + payload.table_id, + payload.version_ts, + payload.action_type, + exc_info=True, + ) + raise HTTPException(status_code=500, detail=str(exc)) from exc + @application.post("/__mock__/import-callback") async def mock_import_callback(payload: dict[str, Any]) -> dict[str, str]: logger.info("Received import analysis callback: %s", payload) diff --git a/app/models.py b/app/models.py index 0aadb81..fbcf096 100644 --- a/app/models.py +++ b/app/models.py @@ -1,5 +1,6 @@ from __future__ import annotations +from datetime import datetime from enum import Enum from typing import Any, Dict, List, Optional, Union @@ -135,3 +136,89 @@ class DataImportAnalysisJobRequest(BaseModel): class DataImportAnalysisJobAck(BaseModel): import_record_id: str = Field(..., description="Echo of the import record identifier") status: str = Field("accepted", description="Processing status acknowledgement.") + + +class TableProfilingJobRequest(BaseModel): + table_id: str = Field(..., description="Unique identifier for the table to profile.") + version_ts: str = Field( + ..., + pattern=r"^\d{14}$", + description="Version timestamp expressed as fourteen digit string (yyyyMMddHHmmss).", + ) + callback_url: HttpUrl = Field( + ..., + description="Callback endpoint invoked after each pipeline action completes.", + ) + table_schema: Optional[Any] = Field( + None, + description="Schema structure snapshot for the current table version.", + ) + table_schema_version_id: Optional[str] = Field( + None, + description="Identifier for the schema snapshot provided in table_schema.", + ) + table_link_info: Optional[Dict[str, Any]] = Field( + None, + description=( + "Information describing how to locate the source table for profiling. " + "For example: {'type': 'sql', 'connection_string': 'mysql+pymysql://user:pass@host/db', " + "'table': 'schema.table_name'}." + ), + ) + table_access_info: Optional[Dict[str, Any]] = Field( + None, + description=( + "Credentials or supplemental parameters required to access the table described in table_link_info. " + "These values can be merged into the connection string using Python format placeholders." + ), + ) + ge_batch_request: Optional[Dict[str, Any]] = Field( + None, + description="Optional Great Expectations batch request payload used for profiling.", + ) + ge_expectation_suite_name: Optional[str] = Field( + None, + description="Expectation suite name used during profiling. Created automatically when absent.", + ) + ge_data_context_root: Optional[str] = Field( + None, + description="Custom root directory for the Great Expectations data context. Defaults to project ./gx.", + ) + ge_datasource_name: Optional[str] = Field( + None, + description="Datasource name registered inside the GE context when batch_request is not supplied.", + ) + ge_data_asset_name: Optional[str] = Field( + None, + description="Data asset reference used when inferring batch request from datasource configuration.", + ) + ge_profiler_type: str = Field( + "user_configurable", + description="Profiler implementation identifier. Currently supports 'user_configurable' or 'data_assistant'.", + ) + llm_model: Optional[str] = Field( + None, + description="Default LLM model spec applied to prompt-based actions when overrides are omitted.", + ) + result_desc_model: Optional[str] = Field( + None, + description="LLM model override used for GE result description (action 2).", + ) + snippet_model: Optional[str] = Field( + None, + description="LLM model override used for snippet generation (action 3).", + ) + snippet_alias_model: Optional[str] = Field( + None, + description="LLM model override used for snippet alias enrichment (action 4).", + ) + extra_options: Optional[Dict[str, Any]] = Field( + None, + description="Miscellaneous execution flags applied across pipeline steps.", + ) + + +class TableProfilingJobAck(BaseModel): + table_id: str = Field(..., description="Echo of the table identifier.") + version_ts: str = Field(..., description="Echo of the profiling version timestamp (yyyyMMddHHmmss).") + status: str = Field("accepted", description="Processing acknowledgement status.") diff --git a/app/services/table_profiling.py b/app/services/table_profiling.py new file mode 100644 index 0000000..ea6cc04 --- /dev/null +++ b/app/services/table_profiling.py @@ -0,0 +1,832 @@ +from __future__ import annotations + +import asyncio +import json +import logging +import os +import re +from datetime import date, datetime +from dataclasses import asdict, dataclass, is_dataclass +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import httpx +import great_expectations as gx +from great_expectations.core.batch import RuntimeBatchRequest +from great_expectations.core.expectation_suite import ExpectationSuite +from great_expectations.data_context import AbstractDataContext +from great_expectations.exceptions import DataContextError, MetricResolutionError + +from app.exceptions import ProviderAPICallError +from app.models import TableProfilingJobRequest +from app.services import LLMGateway +from app.settings import DEFAULT_IMPORT_MODEL +from app.services.import_analysis import ( + IMPORT_GATEWAY_BASE_URL, + resolve_provider_from_model, +) + + +logger = logging.getLogger(__name__) + + +GE_REPORT_RELATIVE_PATH = Path("uncommitted") / "data_docs" / "local_site" / "index.html" +PROMPT_FILENAMES = { + "ge_result_desc": "ge_result_desc_prompt.md", + "snippet_generator": "snippet_generator.md", + "snippet_alias": "snippet_alias_generator.md", +} +DEFAULT_CHAT_TIMEOUT_SECONDS = 90.0 + + +@dataclass +class GEProfilingArtifacts: + profiling_result: Dict[str, Any] + profiling_summary: Dict[str, Any] + docs_path: str + + +class PipelineActionType: + GE_PROFILING = "ge_profiling" + GE_RESULT_DESC = "ge_result_desc" + SNIPPET = "snippet" + SNIPPET_ALIAS = "snippet_alias" + + +def _project_root() -> Path: + return Path(__file__).resolve().parents[2] + + +def _prompt_dir() -> Path: + return _project_root() / "prompt" + + +@lru_cache(maxsize=None) +def _load_prompt_parts(filename: str) -> Tuple[str, str]: + prompt_path = _prompt_dir() / filename + if not prompt_path.exists(): + raise FileNotFoundError(f"Prompt template not found: {prompt_path}") + + raw = prompt_path.read_text(encoding="utf-8") + splitter = "用户消息(User)" + if splitter not in raw: + raise ValueError(f"Prompt template '{filename}' missing separator '{splitter}'.") + + system_raw, user_raw = raw.split(splitter, maxsplit=1) + system_text = system_raw.replace("系统角色(System)", "").strip() + user_text = user_raw.strip() + return system_text, user_text + + +def _render_prompt(template_key: str, replacements: Dict[str, str]) -> Tuple[str, str]: + filename = PROMPT_FILENAMES[template_key] + system_text, user_template = _load_prompt_parts(filename) + + rendered_user = user_template + for key, value in replacements.items(): + rendered_user = rendered_user.replace(key, value) + + return system_text, rendered_user + + +def _extract_timeout_seconds(options: Optional[Dict[str, Any]]) -> Optional[float]: + if not options: + return None + value = options.get("llm_timeout_seconds") + if value is None: + return None + try: + timeout = float(value) + if timeout <= 0: + raise ValueError + return timeout + except (TypeError, ValueError): + logger.warning( + "Invalid llm_timeout_seconds value in extra_options: %r. Falling back to default.", + value, + ) + return DEFAULT_CHAT_TIMEOUT_SECONDS + + +def _extract_json_payload(content: str) -> str: + fenced = re.search( + r"```(?:json)?\s*([\s\S]+?)```", + content, + flags=re.IGNORECASE, + ) + if fenced: + snippet = fenced.group(1).strip() + if snippet: + return snippet + + stripped = content.strip() + if not stripped: + raise ValueError("Empty LLM content.") + + for opener, closer in (("{", "}"), ("[", "]")): + start = stripped.find(opener) + end = stripped.rfind(closer) + if start != -1 and end != -1 and end > start: + candidate = stripped[start : end + 1].strip() + return candidate + + return stripped + + +def _parse_completion_payload(response_payload: Dict[str, Any]) -> Any: + choices = response_payload.get("choices") or [] + if not choices: + raise ProviderAPICallError("LLM response did not contain choices to parse.") + message = choices[0].get("message") or {} + content = message.get("content") or "" + if not content.strip(): + raise ProviderAPICallError("LLM response content is empty.") + json_payload = _extract_json_payload(content) + try: + return json.loads(json_payload) + except json.JSONDecodeError as exc: + preview = json_payload[:800] + logger.error("Failed to parse JSON from LLM response: %s", preview, exc_info=True) + raise ProviderAPICallError("LLM response JSON parsing failed.") from exc + + +async def _post_callback(callback_url: str, payload: Dict[str, Any], client: httpx.AsyncClient) -> None: + safe_payload = _normalize_for_json(payload) + try: + logger.info( + "Posting pipeline action callback to %s: %s", + callback_url, + json.dumps(safe_payload, ensure_ascii=False), + ) + response = await client.post(callback_url, json=safe_payload) + response.raise_for_status() + except httpx.HTTPError as exc: + logger.error("Callback delivery to %s failed: %s", callback_url, exc, exc_info=True) + + +def _sanitize_value_set(value: Any, max_values: int) -> Tuple[Any, Optional[Dict[str, int]]]: + if not isinstance(value, list): + return value, None + original_len = len(value) + if original_len <= max_values: + return value, None + trimmed = value[:max_values] + return trimmed, {"original_length": original_len, "retained": max_values} + + +def _sanitize_expectation_suite(suite: ExpectationSuite, max_value_set_values: int = 100) -> Dict[str, Any]: + suite_dict = suite.to_json_dict() + remarks: List[Dict[str, Any]] = [] + + for expectation in suite_dict.get("expectations", []): + kwargs = expectation.get("kwargs", {}) + if "value_set" in kwargs: + sanitized_value, note = _sanitize_value_set(kwargs["value_set"], max_value_set_values) + kwargs["value_set"] = sanitized_value + if note: + expectation.setdefault("meta", {}) + expectation["meta"]["value_set_truncated"] = note + remarks.append( + { + "column": kwargs.get("column"), + "expectation": expectation.get("expectation_type"), + "note": note, + } + ) + + if remarks: + suite_dict.setdefault("meta", {}) + suite_dict["meta"]["value_set_truncations"] = remarks + + return suite_dict + + +def _summarize_expectation_suite(suite_dict: Dict[str, Any]) -> Dict[str, Any]: + column_map: Dict[str, Dict[str, Any]] = {} + table_expectations: List[Dict[str, Any]] = [] + + for expectation in suite_dict.get("expectations", []): + expectation_type = expectation.get("expectation_type") + kwargs = expectation.get("kwargs", {}) + column = kwargs.get("column") + summary_entry: Dict[str, Any] = {"expectation": expectation_type} + + if "value_set" in kwargs and isinstance(kwargs["value_set"], list): + summary_entry["value_set_size"] = len(kwargs["value_set"]) + summary_entry["value_set_preview"] = kwargs["value_set"][:5] + + if column: + column_entry = column_map.setdefault( + column, + {"name": column, "expectations": []}, + ) + column_entry["expectations"].append(summary_entry) + else: + table_expectations.append(summary_entry) + + summary = { + "column_profiles": list(column_map.values()), + "table_level_expectations": table_expectations, + "total_expectations": len(suite_dict.get("expectations", [])), + } + return summary + + +def _sanitize_identifier(raw: Optional[str], fallback: str) -> str: + if not raw: + return fallback + candidate = re.sub(r"[^0-9A-Za-z_]+", "_", raw).strip("_") + return candidate or fallback + + +def _format_connection_string(template: str, access_info: Dict[str, Any]) -> str: + if not access_info: + return template + try: + return template.format_map({k: v for k, v in access_info.items()}) + except KeyError as exc: + missing = exc.args[0] + raise ValueError(f"table_access_info missing key '{missing}' required by connection_string.") from exc + + +def _ensure_sql_runtime_datasource( + context: AbstractDataContext, + datasource_name: str, + connection_string: str, +) -> None: + try: + datasource = context.get_datasource(datasource_name) + except (DataContextError, ValueError) as exc: + message = str(exc) + if "Could not find a datasource" in message or "Unable to load datasource" in message: + datasource = None + else: # pragma: no cover - defensive + raise RuntimeError(f"Failed to inspect datasource '{datasource_name}'.") from exc + except Exception as exc: # pragma: no cover - defensive + raise RuntimeError(f"Failed to inspect datasource '{datasource_name}'.") from exc + + if datasource is not None: + execution_engine = getattr(datasource, "execution_engine", None) + current_conn = getattr(execution_engine, "connection_string", None) + if current_conn and current_conn != connection_string: + logger.info( + "Existing datasource %s uses different connection string; creating dedicated runtime datasource.", + datasource_name, + ) + try: + context.delete_datasource(datasource_name) + except Exception as exc: # pragma: no cover - defensive + logger.warning( + "Failed to delete datasource %s before recreation: %s", + datasource_name, + exc, + ) + else: + datasource = None + + if datasource is not None: + return + + runtime_datasource_config = { + "name": datasource_name, + "class_name": "Datasource", + "execution_engine": { + "class_name": "SqlAlchemyExecutionEngine", + "connection_string": connection_string, + }, + "data_connectors": { + "runtime_connector": { + "class_name": "RuntimeDataConnector", + "batch_identifiers": ["default_identifier_name"], + } + }, + } + try: + context.add_datasource(**runtime_datasource_config) + except Exception as exc: # pragma: no cover - defensive + raise RuntimeError(f"Failed to create runtime datasource '{datasource_name}'.") from exc + + +def _build_sql_runtime_batch_request( + context: AbstractDataContext, + request: TableProfilingJobRequest, +) -> RuntimeBatchRequest: + link_info = request.table_link_info or {} + access_info = request.table_access_info or {} + + connection_template = link_info.get("connection_string") + if not connection_template: + raise ValueError("table_link_info.connection_string is required when using table_link_info.") + + connection_string = _format_connection_string(connection_template, access_info) + + source_type = (link_info.get("type") or "sql").lower() + if source_type != "sql": + raise ValueError(f"Unsupported table_link_info.type='{source_type}'. Only 'sql' is supported.") + + query = link_info.get("query") + table_name = link_info.get("table") or link_info.get("table_name") + schema_name = link_info.get("schema") + + if not query and not table_name: + raise ValueError("Either table_link_info.query or table_link_info.table must be provided.") + + if not query: + if not table_name: + raise ValueError("table_link_info.table must be provided when query is omitted.") + + identifier = re.compile(r"^[A-Za-z_][A-Za-z0-9_$]*$") + + def _quote(name: str) -> str: + if identifier.match(name): + return name + return f"`{name.replace('`', '``')}`" + + if schema_name: + schema_part = schema_name if "." not in schema_name else schema_name.split(".")[-1] + table_part = table_name if "." not in table_name else table_name.split(".")[-1] + qualified_table = f"{_quote(schema_part)}.{_quote(table_part)}" + else: + qualified_table = _quote(table_name) + + query = f"SELECT * FROM {qualified_table}" + limit = link_info.get("limit") + if isinstance(limit, int) and limit > 0: + query = f"{query} LIMIT {limit}" + + datasource_name = request.ge_datasource_name or _sanitize_identifier( + f"{request.table_id}_runtime_ds", "runtime_ds" + ) + data_asset_name = request.ge_data_asset_name or _sanitize_identifier( + table_name or "runtime_query", "runtime_query" + ) + + _ensure_sql_runtime_datasource(context, datasource_name, connection_string) + + batch_identifiers = { + "default_identifier_name": f"{request.table_id}:{request.version_ts}", + } + + return RuntimeBatchRequest( + datasource_name=datasource_name, + data_connector_name="runtime_connector", + data_asset_name=data_asset_name, + runtime_parameters={"query": query}, + batch_identifiers=batch_identifiers, + ) + + +def _run_onboarding_assistant( + context: AbstractDataContext, + batch_request: Any, + suite_name: str, +) -> Tuple[ExpectationSuite, Any]: + assistant = context.assistants.onboarding + assistant_result = assistant.run(batch_request=batch_request) + suite = assistant_result.get_expectation_suite(expectation_suite_name=suite_name) + context.save_expectation_suite(suite, expectation_suite_name=suite_name) + validation_getter = getattr(assistant_result, "get_validation_result", None) + if callable(validation_getter): + validation_result = validation_getter() + else: + validation_result = getattr(assistant_result, "validation_result", None) + if validation_result is None: + # Fallback: rerun validation using the freshly generated expectation suite. + validator = context.get_validator( + batch_request=batch_request, + expectation_suite_name=suite_name, + ) + validation_result = validator.validate() + return suite, validation_result + + +def _resolve_context(request: TableProfilingJobRequest) -> AbstractDataContext: + context_kwargs: Dict[str, Any] = {} + if request.ge_data_context_root: + context_kwargs["project_root_dir"] = request.ge_data_context_root + elif os.environ.get("GE_DATA_CONTEXT_ROOT"): + context_kwargs["project_root_dir"] = os.environ["GE_DATA_CONTEXT_ROOT"] + else: + context_kwargs["project_root_dir"] = str(_project_root()) + + return gx.get_context(**context_kwargs) + + +def _build_batch_request( + context: AbstractDataContext, + request: TableProfilingJobRequest, +) -> Any: + if request.ge_batch_request: + from great_expectations.core.batch import BatchRequest + + return BatchRequest(**request.ge_batch_request) + + if request.table_link_info: + return _build_sql_runtime_batch_request(context, request) + + if not request.ge_datasource_name or not request.ge_data_asset_name: + raise ValueError( + "ge_batch_request or (ge_datasource_name and ge_data_asset_name) must be provided." + ) + + datasource = context.get_datasource(request.ge_datasource_name) + data_asset = datasource.get_asset(request.ge_data_asset_name) + return data_asset.build_batch_request() + + +async def _run_ge_profiling(request: TableProfilingJobRequest) -> GEProfilingArtifacts: + def _execute() -> GEProfilingArtifacts: + context = _resolve_context(request) + suite_name = ( + request.ge_expectation_suite_name + or f"{request.table_id}_profiling" + ) + + batch_request = _build_batch_request(context, request) + try: + context.get_expectation_suite(suite_name) + except DataContextError: + context.add_expectation_suite(suite_name) + + validator = context.get_validator( + batch_request=batch_request, + expectation_suite_name=suite_name, + ) + + profiler_type = (request.ge_profiler_type or "user_configurable").lower() + + if profiler_type == "data_assistant": + suite, validation_result = _run_onboarding_assistant( + context, + batch_request, + suite_name, + ) + else: + try: + from great_expectations.profile.user_configurable_profiler import ( + UserConfigurableProfiler, + ) + except ImportError as err: # pragma: no cover - dependency guard + raise RuntimeError( + "UserConfigurableProfiler is unavailable; install great_expectations profiling extra or switch profiler." + ) from err + + profiler = UserConfigurableProfiler(profile_dataset=validator) + try: + suite = profiler.build_suite() + context.save_expectation_suite(suite, expectation_suite_name=suite_name) + validator.expectation_suite = suite + validation_result = validator.validate() + except MetricResolutionError as exc: + logger.warning( + "UserConfigurableProfiler failed (%s); falling back to data assistant profiling.", + exc, + ) + suite, validation_result = _run_onboarding_assistant( + context, + batch_request, + suite_name, + ) + + sanitized_suite = _sanitize_expectation_suite(suite) + summary = _summarize_expectation_suite(sanitized_suite) + validation_dict = validation_result.to_json_dict() + + context.build_data_docs() + docs_path = Path(context.root_directory) / GE_REPORT_RELATIVE_PATH + + profiling_result = { + "expectation_suite": sanitized_suite, + "validation_result": validation_dict, + "batch_request": getattr(batch_request, "to_json_dict", lambda: None)() or getattr(batch_request, "dict", lambda: None)(), + } + + return GEProfilingArtifacts( + profiling_result=profiling_result, + profiling_summary=summary, + docs_path=str(docs_path), + ) + + return await asyncio.to_thread(_execute) + + +async def _call_chat_completions( + *, + model_spec: str, + system_prompt: str, + user_prompt: str, + client: httpx.AsyncClient, + temperature: float = 0.2, + timeout_seconds: Optional[float] = None, +) -> Any: + provider, model_name = resolve_provider_from_model(model_spec) + payload = { + "provider": provider.value, + "model": model_name, + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + "temperature": temperature, + } + payload_size_bytes = len(json.dumps(payload, ensure_ascii=False).encode("utf-8")) + + url = f"{IMPORT_GATEWAY_BASE_URL.rstrip('/')}/v1/chat/completions" + try: + # log the request whole info + logger.info( + "Calling chat completions API %s with model %s and size %s and payload %s", + url, + model_name, + payload_size_bytes, + payload, + ) + response = await client.post(url, json=payload, timeout=timeout_seconds) + + response.raise_for_status() + except httpx.HTTPError as exc: + error_name = exc.__class__.__name__ + detail = str(exc).strip() + if detail: + message = f"Chat completions request failed ({error_name}): {detail}" + else: + message = f"Chat completions request failed ({error_name})." + raise ProviderAPICallError(message) from exc + + try: + response_payload = response.json() + except ValueError as exc: + raise ProviderAPICallError("Chat completions response was not valid JSON.") from exc + + return _parse_completion_payload(response_payload) + + +def _normalize_for_json(value: Any) -> Any: + if value is None or isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, (datetime, date)): + return str(value) + if hasattr(value, "model_dump"): + try: + return value.model_dump() + except Exception: # pragma: no cover - defensive + pass + if is_dataclass(value): + return asdict(value) + if isinstance(value, dict): + return {k: _normalize_for_json(v) for k, v in value.items()} + if isinstance(value, (list, tuple, set)): + return [_normalize_for_json(v) for v in value] + if hasattr(value, "to_json_dict"): + try: + return value.to_json_dict() + except Exception: # pragma: no cover - defensive + pass + if hasattr(value, "__dict__"): + return _normalize_for_json(value.__dict__) + return repr(value) + + +def _json_dumps(data: Any) -> str: + normalised = _normalize_for_json(data) + return json.dumps(normalised, ensure_ascii=False, indent=2) + + +def _preview_for_log(data: Any) -> str: + try: + serialised = _json_dumps(data) + except Exception: + serialised = repr(data) + + return serialised + + +def _profiling_request_for_log(request: TableProfilingJobRequest) -> Dict[str, Any]: + payload = request.model_dump() + access_info = payload.get("table_access_info") + if isinstance(access_info, dict): + payload["table_access_info"] = {key: "***" for key in access_info.keys()} + return payload + + +async def _execute_result_desc( + profiling_json: Dict[str, Any], + _request: TableProfilingJobRequest, + llm_model: str, + client: httpx.AsyncClient, + timeout_seconds: Optional[float], +) -> Dict[str, Any]: + system_prompt, user_prompt = _render_prompt( + "ge_result_desc", + {"{{GE_RESULT_JSON}}": _json_dumps(profiling_json)}, + ) + llm_output = await _call_chat_completions( + model_spec=llm_model, + system_prompt=system_prompt, + user_prompt=user_prompt, + client=client, + timeout_seconds=timeout_seconds, + ) + if not isinstance(llm_output, dict): + raise ProviderAPICallError("GE result description payload must be a JSON object.") + return llm_output + + +async def _execute_snippet_generation( + table_desc_json: Dict[str, Any], + _request: TableProfilingJobRequest, + llm_model: str, + client: httpx.AsyncClient, + timeout_seconds: Optional[float], +) -> List[Dict[str, Any]]: + system_prompt, user_prompt = _render_prompt( + "snippet_generator", + {"{{TABLE_PROFILE_JSON}}": _json_dumps(table_desc_json)}, + ) + llm_output = await _call_chat_completions( + model_spec=llm_model, + system_prompt=system_prompt, + user_prompt=user_prompt, + client=client, + timeout_seconds=timeout_seconds, + ) + if not isinstance(llm_output, list): + raise ProviderAPICallError("Snippet generator must return a JSON array.") + return llm_output + + +async def _execute_snippet_alias( + snippets_json: List[Dict[str, Any]], + _request: TableProfilingJobRequest, + llm_model: str, + client: httpx.AsyncClient, + timeout_seconds: Optional[float], +) -> List[Dict[str, Any]]: + system_prompt, user_prompt = _render_prompt( + "snippet_alias", + {"{{SNIPPET_ARRAY}}": _json_dumps(snippets_json)}, + ) + llm_output = await _call_chat_completions( + model_spec=llm_model, + system_prompt=system_prompt, + user_prompt=user_prompt, + client=client, + timeout_seconds=timeout_seconds, + ) + if not isinstance(llm_output, list): + raise ProviderAPICallError("Snippet alias generator must return a JSON array.") + return llm_output + + +async def _run_action_with_callback( + *, + action_type: str, + runner, + callback_base: Dict[str, Any], + client: httpx.AsyncClient, + callback_url: str, + input_payload: Any = None, + model_spec: Optional[str] = None, +) -> Any: + if input_payload is not None: + logger.info( + "Pipeline action %s input: %s", + action_type, + _preview_for_log(input_payload), + ) + try: + result = await runner() + except Exception as exc: + failure_payload = dict(callback_base) + failure_payload.update( + { + "status": "failed", + "action_type": action_type, + "error": str(exc), + } + ) + if model_spec is not None: + failure_payload["model"] = model_spec + await _post_callback(callback_url, failure_payload, client) + raise + + success_payload = dict(callback_base) + success_payload.update( + { + "status": "success", + "action_type": action_type, + } + ) + if model_spec is not None: + success_payload["model"] = model_spec + + logger.info( + "Pipeline action %s output: %s", + action_type, + _preview_for_log(result), + ) + + if action_type == PipelineActionType.GE_PROFILING: + artifacts: GEProfilingArtifacts = result + success_payload["profiling_json"] = artifacts.profiling_result + success_payload["profiling_summary"] = artifacts.profiling_summary + success_payload["ge_report_path"] = artifacts.docs_path + elif action_type == PipelineActionType.GE_RESULT_DESC: + success_payload["table_desc_json"] = result + elif action_type == PipelineActionType.SNIPPET: + success_payload["snippet_json"] = result + elif action_type == PipelineActionType.SNIPPET_ALIAS: + success_payload["snippet_alias_json"] = result + + await _post_callback(callback_url, success_payload, client) + return result + + +async def process_table_profiling_job( + request: TableProfilingJobRequest, + _gateway: LLMGateway, + client: httpx.AsyncClient, +) -> None: + """Sequentially execute the four-step profiling pipeline and emit callbacks per action.""" + + timeout_seconds = _extract_timeout_seconds(request.extra_options) + if timeout_seconds is None: + timeout_seconds = DEFAULT_CHAT_TIMEOUT_SECONDS + + base_payload = { + "table_id": request.table_id, + "version_ts": request.version_ts, + "callback_url": str(request.callback_url), + "table_schema": request.table_schema, + "table_schema_version_id": request.table_schema_version_id, + "llm_model": request.llm_model, + "llm_timeout_seconds": timeout_seconds, + } + + logging_request_payload = _profiling_request_for_log(request) + + try: + artifacts: GEProfilingArtifacts = await _run_action_with_callback( + action_type=PipelineActionType.GE_PROFILING, + runner=lambda: _run_ge_profiling(request), + callback_base=base_payload, + client=client, + callback_url=str(request.callback_url), + input_payload=logging_request_payload, + model_spec=request.llm_model, + ) + + table_desc_json: Dict[str, Any] = await _run_action_with_callback( + action_type=PipelineActionType.GE_RESULT_DESC, + runner=lambda: _execute_result_desc( + artifacts.profiling_result, + request, + request.llm_model, + client, + timeout_seconds, + ), + callback_base=base_payload, + client=client, + callback_url=str(request.callback_url), + input_payload=artifacts.profiling_result, + model_spec=request.llm_model, + ) + + snippet_json: List[Dict[str, Any]] = await _run_action_with_callback( + action_type=PipelineActionType.SNIPPET, + runner=lambda: _execute_snippet_generation( + table_desc_json, + request, + request.llm_model, + client, + timeout_seconds, + ), + callback_base=base_payload, + client=client, + callback_url=str(request.callback_url), + input_payload=table_desc_json, + model_spec=request.llm_model, + ) + + await _run_action_with_callback( + action_type=PipelineActionType.SNIPPET_ALIAS, + runner=lambda: _execute_snippet_alias( + snippet_json, + request, + request.llm_model, + client, + timeout_seconds, + ), + callback_base=base_payload, + client=client, + callback_url=str(request.callback_url), + input_payload=snippet_json, + model_spec=request.llm_model, + ) + except Exception: # pragma: no cover - defensive catch + logger.exception( + "Table profiling pipeline failed for table_id=%s version_ts=%s", + request.table_id, + request.version_ts, + ) diff --git a/app/services/table_snippet.py b/app/services/table_snippet.py new file mode 100644 index 0000000..e7e2c95 --- /dev/null +++ b/app/services/table_snippet.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +import json +import logging +from typing import Any, Dict, Tuple + +from sqlalchemy import text +from sqlalchemy.engine import Engine +from sqlalchemy.exc import SQLAlchemyError + +from app.db import get_engine +from app.models import ( + ActionType, + TableSnippetUpsertRequest, + TableSnippetUpsertResponse, +) + + +logger = logging.getLogger(__name__) + + +def _serialize_json(value: Any) -> Tuple[str | None, int | None]: + logger.debug("Serializing JSON payload: %s", value) + if value is None: + return None, None + if isinstance(value, str): + encoded = value.encode("utf-8") + return value, len(encoded) + serialized = json.dumps(value, ensure_ascii=False) + encoded = serialized.encode("utf-8") + return serialized, len(encoded) + + +def _prepare_table_schema(value: Any) -> str: + logger.debug("Preparing table_schema payload.") + if isinstance(value, str): + return value + return json.dumps(value, ensure_ascii=False) + + +def _collect_common_columns(request: TableSnippetUpsertRequest) -> Dict[str, Any]: + logger.debug( + "Collecting common columns for table_id=%s version_ts=%s action_type=%s", + request.table_id, + request.version_ts, + request.action_type, + ) + payload: Dict[str, Any] = { + "table_id": request.table_id, + "version_ts": request.version_ts, + "action_type": request.action_type.value, + "status": request.status.value, + "callback_url": str(request.callback_url), + "table_schema_version_id": request.table_schema_version_id, + "table_schema": _prepare_table_schema(request.table_schema), + } + + if request.error_code is not None: + logger.debug("Adding error_code: %s", request.error_code) + payload["error_code"] = request.error_code + if request.error_message is not None: + logger.debug("Adding error_message: %s", request.error_message) + payload["error_message"] = request.error_message + if request.started_at is not None: + payload["started_at"] = request.started_at + if request.finished_at is not None: + payload["finished_at"] = request.finished_at + if request.duration_ms is not None: + payload["duration_ms"] = request.duration_ms + if request.result_checksum is not None: + payload["result_checksum"] = request.result_checksum + + logger.debug("Collected common payload: %s", payload) + return payload + + +def _apply_action_payload( + request: TableSnippetUpsertRequest, + payload: Dict[str, Any], +) -> None: + logger.debug("Applying action-specific payload for action_type=%s", request.action_type) + if request.action_type == ActionType.GE_PROFILING: + full_json, full_size = _serialize_json(request.result_json) + summary_json, summary_size = _serialize_json(request.result_summary_json) + if full_json is not None: + payload["ge_profiling_full"] = full_json + payload["ge_profiling_full_size_bytes"] = full_size + if summary_json is not None: + payload["ge_profiling_summary"] = summary_json + payload["ge_profiling_summary_size_bytes"] = summary_size + if full_size is not None or summary_size is not None: + payload["ge_profiling_total_size_bytes"] = (full_size or 0) + ( + summary_size or 0 + ) + if request.html_report_url: + payload["ge_profiling_html_report_url"] = request.html_report_url + elif request.action_type == ActionType.GE_RESULT_DESC: + full_json, full_size = _serialize_json(request.result_json) + if full_json is not None: + payload["ge_result_desc_full"] = full_json + payload["ge_result_desc_full_size_bytes"] = full_size + elif request.action_type == ActionType.SNIPPET: + full_json, full_size = _serialize_json(request.result_json) + if full_json is not None: + payload["snippet_full"] = full_json + payload["snippet_full_size_bytes"] = full_size + elif request.action_type == ActionType.SNIPPET_ALIAS: + full_json, full_size = _serialize_json(request.result_json) + if full_json is not None: + payload["snippet_alias_full"] = full_json + payload["snippet_alias_full_size_bytes"] = full_size + else: + logger.error("Unsupported action type encountered: %s", request.action_type) + raise ValueError(f"Unsupported action type '{request.action_type}'.") + + logger.debug("Payload after applying action-specific data: %s", payload) + + +def _build_insert_statement(columns: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]: + logger.debug("Building insert statement for columns: %s", list(columns.keys())) + column_names = list(columns.keys()) + placeholders = [f":{name}" for name in column_names] + update_assignments = [ + f"{name}=VALUES({name})" + for name in column_names + if name not in {"table_id", "version_ts", "action_type"} + ] + update_assignments.append("updated_at=CURRENT_TIMESTAMP") + + sql = ( + "INSERT INTO action_results ({cols}) VALUES ({vals}) " + "ON DUPLICATE KEY UPDATE {updates}" + ).format( + cols=", ".join(column_names), + vals=", ".join(placeholders), + updates=", ".join(update_assignments), + ) + logger.debug("Generated SQL: %s", sql) + return sql, columns + + +def _execute_upsert(engine: Engine, sql: str, params: Dict[str, Any]) -> int: + logger.info("Executing upsert for table_id=%s version_ts=%s action_type=%s", params.get("table_id"), params.get("version_ts"), params.get("action_type")) + with engine.begin() as conn: + result = conn.execute(text(sql), params) + logger.info("Rows affected: %s", result.rowcount) + return result.rowcount + + +def upsert_action_result(request: TableSnippetUpsertRequest) -> TableSnippetUpsertResponse: + logger.info( + "Received upsert request: table_id=%s version_ts=%s action_type=%s status=%s", + request.table_id, + request.version_ts, + request.action_type, + request.status, + ) + logger.debug("Request payload: %s", request.model_dump()) + columns = _collect_common_columns(request) + _apply_action_payload(request, columns) + + sql, params = _build_insert_statement(columns) + logger.debug("Final SQL params: %s", params) + + engine = get_engine() + try: + rowcount = _execute_upsert(engine, sql, params) + except SQLAlchemyError as exc: + logger.exception( + "Failed to upsert action result: table_id=%s version_ts=%s action_type=%s", + request.table_id, + request.version_ts, + request.action_type, + ) + raise RuntimeError(f"Database operation failed: {exc}") from exc + + updated = rowcount > 1 + return TableSnippetUpsertResponse( + table_id=request.table_id, + version_ts=request.version_ts, + action_type=request.action_type, + status=request.status, + updated=updated, + ) diff --git a/table_snippet.sql b/table_snippet.sql new file mode 100644 index 0000000..b9fb19b --- /dev/null +++ b/table_snippet.sql @@ -0,0 +1,54 @@ +CREATE TABLE IF NOT EXISTS action_results ( + id BIGINT NOT NULL AUTO_INCREMENT COMMENT '主键', + table_id BIGINT NOT NULL COMMENT '表ID', + version_ts BIGINT NOT NULL COMMENT '版本时间戳(版本号)', + action_type ENUM('ge_profiling','ge_result_desc','snippet','snippet_alias') NOT NULL COMMENT '动作类型', + + status ENUM('pending','running','success','failed','partial') NOT NULL DEFAULT 'pending' COMMENT '执行状态', + error_code VARCHAR(128) NULL, + error_message TEXT NULL, + + -- 回调 & 观测 + callback_url VARCHAR(1024) NOT NULL, + started_at DATETIME NULL, + finished_at DATETIME NULL, + duration_ms INT NULL, + + -- 本次schema信息 + table_schema_version_id BIGINT NOT NULL, + table_schema JSON NOT NULL, + + -- ===== 动作1:GE Profiling ===== + ge_profiling_full JSON NULL COMMENT 'Profiling完整结果JSON', + ge_profiling_full_size_bytes BIGINT NULL, + ge_profiling_summary JSON NULL COMMENT 'Profiling摘要(剔除大value_set等)', + ge_profiling_summary_size_bytes BIGINT NULL, + ge_profiling_total_size_bytes BIGINT NULL COMMENT '上两者合计', + ge_profiling_html_report_url VARCHAR(1024) NULL COMMENT 'GE报告HTML路径/URL', + + -- ===== 动作2:GE Result Desc ===== + ge_result_desc_full JSON NULL COMMENT '表描述结果JSON', + ge_result_desc_full_size_bytes BIGINT NULL, + + -- ===== 动作3:Snippet 生成 ===== + snippet_full JSON NULL COMMENT 'SQL知识片段结果JSON', + snippet_full_size_bytes BIGINT NULL, + + -- ===== 动作4:Snippet Alias 改写 ===== + snippet_alias_full JSON NULL COMMENT 'SQL片段改写/丰富结果JSON', + snippet_alias_full_size_bytes BIGINT NULL, + + -- 通用可选指标 + result_checksum VARBINARY(32) NULL COMMENT '对当前action有效载荷计算的MD5/xxhash', + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + + PRIMARY KEY (id), + UNIQUE KEY uq_table_ver_action (table_id, version_ts, action_type), + KEY idx_status (status), + KEY idx_table (table_id, updated_at), + KEY idx_action_time (action_type, version_ts), + KEY idx_schema_version (table_schema_version_id) +) ENGINE=InnoDB + ROW_FORMAT=DYNAMIC + COMMENT='数据分析知识片段表';