table profiling功能开发
This commit is contained in:
26
app/db.py
Normal file
26
app/db.py
Normal file
@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from functools import lru_cache
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_engine() -> Engine:
|
||||
"""Return a cached SQLAlchemy engine configured from DATABASE_URL."""
|
||||
database_url = os.getenv(
|
||||
"DATABASE_URL",
|
||||
"mysql+pymysql://root:12345678@localhost:3306/data-ge?charset=utf8mb4",
|
||||
)
|
||||
connect_args = {}
|
||||
if database_url.startswith("sqlite"):
|
||||
connect_args["check_same_thread"] = False
|
||||
|
||||
return create_engine(
|
||||
database_url,
|
||||
pool_pre_ping=True,
|
||||
future=True,
|
||||
connect_args=connect_args,
|
||||
)
|
||||
113
app/main.py
113
app/main.py
@ -2,12 +2,17 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import logging.config
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
import httpx
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.exceptions import ProviderAPICallError, ProviderConfigurationError
|
||||
from app.models import (
|
||||
@ -15,30 +20,42 @@ from app.models import (
|
||||
DataImportAnalysisJobRequest,
|
||||
LLMRequest,
|
||||
LLMResponse,
|
||||
TableProfilingJobAck,
|
||||
TableProfilingJobRequest,
|
||||
TableSnippetUpsertRequest,
|
||||
TableSnippetUpsertResponse,
|
||||
)
|
||||
from app.services import LLMGateway
|
||||
from app.services.import_analysis import process_import_analysis_job
|
||||
from app.services.table_profiling import process_table_profiling_job
|
||||
from app.services.table_snippet import upsert_action_result
|
||||
|
||||
|
||||
def _ensure_log_directories(config: dict[str, Any]) -> None:
|
||||
handlers = config.get("handlers", {})
|
||||
for handler_config in handlers.values():
|
||||
filename = handler_config.get("filename")
|
||||
if not filename:
|
||||
continue
|
||||
directory = os.path.dirname(filename)
|
||||
if directory and not os.path.exists(directory):
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
|
||||
def _configure_logging() -> None:
|
||||
level_name = os.getenv("LOG_LEVEL", "INFO").upper()
|
||||
level = getattr(logging, level_name, logging.INFO)
|
||||
log_format = os.getenv(
|
||||
"LOG_FORMAT",
|
||||
"%(asctime)s %(levelname)s %(name)s:%(lineno)d %(message)s",
|
||||
config_path = os.getenv("LOGGING_CONFIG", "logging.yaml")
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, "r", encoding="utf-8") as fh:
|
||||
config = yaml.safe_load(fh)
|
||||
if isinstance(config, dict):
|
||||
_ensure_log_directories(config)
|
||||
logging.config.dictConfig(config)
|
||||
return
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s:%(lineno)d %(message)s",
|
||||
)
|
||||
|
||||
root = logging.getLogger()
|
||||
|
||||
if not root.handlers:
|
||||
logging.basicConfig(level=level, format=log_format)
|
||||
else:
|
||||
root.setLevel(level)
|
||||
formatter = logging.Formatter(log_format)
|
||||
for handler in root.handlers:
|
||||
handler.setLevel(level)
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
|
||||
_configure_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -119,6 +136,24 @@ def create_app() -> FastAPI:
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
@application.exception_handler(RequestValidationError)
|
||||
async def request_validation_exception_handler(
|
||||
request: Request, exc: RequestValidationError
|
||||
) -> JSONResponse:
|
||||
try:
|
||||
raw_body = await request.body()
|
||||
except Exception: # pragma: no cover - defensive
|
||||
raw_body = b"<unavailable>"
|
||||
truncated_body = raw_body[:4096]
|
||||
logger.warning(
|
||||
"Validation error on %s %s: %s | body preview=%s",
|
||||
request.method,
|
||||
request.url.path,
|
||||
exc.errors(),
|
||||
truncated_body.decode("utf-8", errors="ignore"),
|
||||
)
|
||||
return JSONResponse(status_code=422, content={"detail": exc.errors()})
|
||||
|
||||
@application.post(
|
||||
"/v1/chat/completions",
|
||||
response_model=LLMResponse,
|
||||
@ -164,6 +199,52 @@ def create_app() -> FastAPI:
|
||||
|
||||
return DataImportAnalysisJobAck(import_record_id=payload.import_record_id, status="accepted")
|
||||
|
||||
@application.post(
|
||||
"/v1/table/profiling",
|
||||
response_model=TableProfilingJobAck,
|
||||
summary="Run end-to-end GE profiling pipeline and notify via callback per action",
|
||||
status_code=202,
|
||||
)
|
||||
async def run_table_profiling(
|
||||
payload: TableProfilingJobRequest,
|
||||
gateway: LLMGateway = Depends(get_gateway),
|
||||
client: httpx.AsyncClient = Depends(get_http_client),
|
||||
) -> TableProfilingJobAck:
|
||||
request_copy = payload.model_copy(deep=True)
|
||||
|
||||
async def _runner() -> None:
|
||||
await process_table_profiling_job(request_copy, gateway, client)
|
||||
|
||||
asyncio.create_task(_runner())
|
||||
|
||||
return TableProfilingJobAck(
|
||||
table_id=payload.table_id,
|
||||
version_ts=payload.version_ts,
|
||||
status="accepted",
|
||||
)
|
||||
|
||||
@application.post(
|
||||
"/v1/table/snippet",
|
||||
response_model=TableSnippetUpsertResponse,
|
||||
summary="Persist or update action results, such as table snippets.",
|
||||
)
|
||||
async def upsert_table_snippet(
|
||||
payload: TableSnippetUpsertRequest,
|
||||
) -> TableSnippetUpsertResponse:
|
||||
request_copy = payload.model_copy(deep=True)
|
||||
|
||||
try:
|
||||
return await asyncio.to_thread(upsert_action_result, request_copy)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to upsert table snippet: table_id=%s version_ts=%s action_type=%s",
|
||||
payload.table_id,
|
||||
payload.version_ts,
|
||||
payload.action_type,
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
@application.post("/__mock__/import-callback")
|
||||
async def mock_import_callback(payload: dict[str, Any]) -> dict[str, str]:
|
||||
logger.info("Received import analysis callback: %s", payload)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
@ -135,3 +136,89 @@ class DataImportAnalysisJobRequest(BaseModel):
|
||||
class DataImportAnalysisJobAck(BaseModel):
|
||||
import_record_id: str = Field(..., description="Echo of the import record identifier")
|
||||
status: str = Field("accepted", description="Processing status acknowledgement.")
|
||||
|
||||
|
||||
class TableProfilingJobRequest(BaseModel):
|
||||
table_id: str = Field(..., description="Unique identifier for the table to profile.")
|
||||
version_ts: str = Field(
|
||||
...,
|
||||
pattern=r"^\d{14}$",
|
||||
description="Version timestamp expressed as fourteen digit string (yyyyMMddHHmmss).",
|
||||
)
|
||||
callback_url: HttpUrl = Field(
|
||||
...,
|
||||
description="Callback endpoint invoked after each pipeline action completes.",
|
||||
)
|
||||
table_schema: Optional[Any] = Field(
|
||||
None,
|
||||
description="Schema structure snapshot for the current table version.",
|
||||
)
|
||||
table_schema_version_id: Optional[str] = Field(
|
||||
None,
|
||||
description="Identifier for the schema snapshot provided in table_schema.",
|
||||
)
|
||||
table_link_info: Optional[Dict[str, Any]] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Information describing how to locate the source table for profiling. "
|
||||
"For example: {'type': 'sql', 'connection_string': 'mysql+pymysql://user:pass@host/db', "
|
||||
"'table': 'schema.table_name'}."
|
||||
),
|
||||
)
|
||||
table_access_info: Optional[Dict[str, Any]] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Credentials or supplemental parameters required to access the table described in table_link_info. "
|
||||
"These values can be merged into the connection string using Python format placeholders."
|
||||
),
|
||||
)
|
||||
ge_batch_request: Optional[Dict[str, Any]] = Field(
|
||||
None,
|
||||
description="Optional Great Expectations batch request payload used for profiling.",
|
||||
)
|
||||
ge_expectation_suite_name: Optional[str] = Field(
|
||||
None,
|
||||
description="Expectation suite name used during profiling. Created automatically when absent.",
|
||||
)
|
||||
ge_data_context_root: Optional[str] = Field(
|
||||
None,
|
||||
description="Custom root directory for the Great Expectations data context. Defaults to project ./gx.",
|
||||
)
|
||||
ge_datasource_name: Optional[str] = Field(
|
||||
None,
|
||||
description="Datasource name registered inside the GE context when batch_request is not supplied.",
|
||||
)
|
||||
ge_data_asset_name: Optional[str] = Field(
|
||||
None,
|
||||
description="Data asset reference used when inferring batch request from datasource configuration.",
|
||||
)
|
||||
ge_profiler_type: str = Field(
|
||||
"user_configurable",
|
||||
description="Profiler implementation identifier. Currently supports 'user_configurable' or 'data_assistant'.",
|
||||
)
|
||||
llm_model: Optional[str] = Field(
|
||||
None,
|
||||
description="Default LLM model spec applied to prompt-based actions when overrides are omitted.",
|
||||
)
|
||||
result_desc_model: Optional[str] = Field(
|
||||
None,
|
||||
description="LLM model override used for GE result description (action 2).",
|
||||
)
|
||||
snippet_model: Optional[str] = Field(
|
||||
None,
|
||||
description="LLM model override used for snippet generation (action 3).",
|
||||
)
|
||||
snippet_alias_model: Optional[str] = Field(
|
||||
None,
|
||||
description="LLM model override used for snippet alias enrichment (action 4).",
|
||||
)
|
||||
extra_options: Optional[Dict[str, Any]] = Field(
|
||||
None,
|
||||
description="Miscellaneous execution flags applied across pipeline steps.",
|
||||
)
|
||||
|
||||
|
||||
class TableProfilingJobAck(BaseModel):
|
||||
table_id: str = Field(..., description="Echo of the table identifier.")
|
||||
version_ts: str = Field(..., description="Echo of the profiling version timestamp (yyyyMMddHHmmss).")
|
||||
status: str = Field("accepted", description="Processing acknowledgement status.")
|
||||
|
||||
832
app/services/table_profiling.py
Normal file
832
app/services/table_profiling.py
Normal file
@ -0,0 +1,832 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from datetime import date, datetime
|
||||
from dataclasses import asdict, dataclass, is_dataclass
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
import great_expectations as gx
|
||||
from great_expectations.core.batch import RuntimeBatchRequest
|
||||
from great_expectations.core.expectation_suite import ExpectationSuite
|
||||
from great_expectations.data_context import AbstractDataContext
|
||||
from great_expectations.exceptions import DataContextError, MetricResolutionError
|
||||
|
||||
from app.exceptions import ProviderAPICallError
|
||||
from app.models import TableProfilingJobRequest
|
||||
from app.services import LLMGateway
|
||||
from app.settings import DEFAULT_IMPORT_MODEL
|
||||
from app.services.import_analysis import (
|
||||
IMPORT_GATEWAY_BASE_URL,
|
||||
resolve_provider_from_model,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
GE_REPORT_RELATIVE_PATH = Path("uncommitted") / "data_docs" / "local_site" / "index.html"
|
||||
PROMPT_FILENAMES = {
|
||||
"ge_result_desc": "ge_result_desc_prompt.md",
|
||||
"snippet_generator": "snippet_generator.md",
|
||||
"snippet_alias": "snippet_alias_generator.md",
|
||||
}
|
||||
DEFAULT_CHAT_TIMEOUT_SECONDS = 90.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class GEProfilingArtifacts:
|
||||
profiling_result: Dict[str, Any]
|
||||
profiling_summary: Dict[str, Any]
|
||||
docs_path: str
|
||||
|
||||
|
||||
class PipelineActionType:
|
||||
GE_PROFILING = "ge_profiling"
|
||||
GE_RESULT_DESC = "ge_result_desc"
|
||||
SNIPPET = "snippet"
|
||||
SNIPPET_ALIAS = "snippet_alias"
|
||||
|
||||
|
||||
def _project_root() -> Path:
|
||||
return Path(__file__).resolve().parents[2]
|
||||
|
||||
|
||||
def _prompt_dir() -> Path:
|
||||
return _project_root() / "prompt"
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def _load_prompt_parts(filename: str) -> Tuple[str, str]:
|
||||
prompt_path = _prompt_dir() / filename
|
||||
if not prompt_path.exists():
|
||||
raise FileNotFoundError(f"Prompt template not found: {prompt_path}")
|
||||
|
||||
raw = prompt_path.read_text(encoding="utf-8")
|
||||
splitter = "用户消息(User)"
|
||||
if splitter not in raw:
|
||||
raise ValueError(f"Prompt template '{filename}' missing separator '{splitter}'.")
|
||||
|
||||
system_raw, user_raw = raw.split(splitter, maxsplit=1)
|
||||
system_text = system_raw.replace("系统角色(System)", "").strip()
|
||||
user_text = user_raw.strip()
|
||||
return system_text, user_text
|
||||
|
||||
|
||||
def _render_prompt(template_key: str, replacements: Dict[str, str]) -> Tuple[str, str]:
|
||||
filename = PROMPT_FILENAMES[template_key]
|
||||
system_text, user_template = _load_prompt_parts(filename)
|
||||
|
||||
rendered_user = user_template
|
||||
for key, value in replacements.items():
|
||||
rendered_user = rendered_user.replace(key, value)
|
||||
|
||||
return system_text, rendered_user
|
||||
|
||||
|
||||
def _extract_timeout_seconds(options: Optional[Dict[str, Any]]) -> Optional[float]:
|
||||
if not options:
|
||||
return None
|
||||
value = options.get("llm_timeout_seconds")
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
timeout = float(value)
|
||||
if timeout <= 0:
|
||||
raise ValueError
|
||||
return timeout
|
||||
except (TypeError, ValueError):
|
||||
logger.warning(
|
||||
"Invalid llm_timeout_seconds value in extra_options: %r. Falling back to default.",
|
||||
value,
|
||||
)
|
||||
return DEFAULT_CHAT_TIMEOUT_SECONDS
|
||||
|
||||
|
||||
def _extract_json_payload(content: str) -> str:
|
||||
fenced = re.search(
|
||||
r"```(?:json)?\s*([\s\S]+?)```",
|
||||
content,
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
if fenced:
|
||||
snippet = fenced.group(1).strip()
|
||||
if snippet:
|
||||
return snippet
|
||||
|
||||
stripped = content.strip()
|
||||
if not stripped:
|
||||
raise ValueError("Empty LLM content.")
|
||||
|
||||
for opener, closer in (("{", "}"), ("[", "]")):
|
||||
start = stripped.find(opener)
|
||||
end = stripped.rfind(closer)
|
||||
if start != -1 and end != -1 and end > start:
|
||||
candidate = stripped[start : end + 1].strip()
|
||||
return candidate
|
||||
|
||||
return stripped
|
||||
|
||||
|
||||
def _parse_completion_payload(response_payload: Dict[str, Any]) -> Any:
|
||||
choices = response_payload.get("choices") or []
|
||||
if not choices:
|
||||
raise ProviderAPICallError("LLM response did not contain choices to parse.")
|
||||
message = choices[0].get("message") or {}
|
||||
content = message.get("content") or ""
|
||||
if not content.strip():
|
||||
raise ProviderAPICallError("LLM response content is empty.")
|
||||
json_payload = _extract_json_payload(content)
|
||||
try:
|
||||
return json.loads(json_payload)
|
||||
except json.JSONDecodeError as exc:
|
||||
preview = json_payload[:800]
|
||||
logger.error("Failed to parse JSON from LLM response: %s", preview, exc_info=True)
|
||||
raise ProviderAPICallError("LLM response JSON parsing failed.") from exc
|
||||
|
||||
|
||||
async def _post_callback(callback_url: str, payload: Dict[str, Any], client: httpx.AsyncClient) -> None:
|
||||
safe_payload = _normalize_for_json(payload)
|
||||
try:
|
||||
logger.info(
|
||||
"Posting pipeline action callback to %s: %s",
|
||||
callback_url,
|
||||
json.dumps(safe_payload, ensure_ascii=False),
|
||||
)
|
||||
response = await client.post(callback_url, json=safe_payload)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPError as exc:
|
||||
logger.error("Callback delivery to %s failed: %s", callback_url, exc, exc_info=True)
|
||||
|
||||
|
||||
def _sanitize_value_set(value: Any, max_values: int) -> Tuple[Any, Optional[Dict[str, int]]]:
|
||||
if not isinstance(value, list):
|
||||
return value, None
|
||||
original_len = len(value)
|
||||
if original_len <= max_values:
|
||||
return value, None
|
||||
trimmed = value[:max_values]
|
||||
return trimmed, {"original_length": original_len, "retained": max_values}
|
||||
|
||||
|
||||
def _sanitize_expectation_suite(suite: ExpectationSuite, max_value_set_values: int = 100) -> Dict[str, Any]:
|
||||
suite_dict = suite.to_json_dict()
|
||||
remarks: List[Dict[str, Any]] = []
|
||||
|
||||
for expectation in suite_dict.get("expectations", []):
|
||||
kwargs = expectation.get("kwargs", {})
|
||||
if "value_set" in kwargs:
|
||||
sanitized_value, note = _sanitize_value_set(kwargs["value_set"], max_value_set_values)
|
||||
kwargs["value_set"] = sanitized_value
|
||||
if note:
|
||||
expectation.setdefault("meta", {})
|
||||
expectation["meta"]["value_set_truncated"] = note
|
||||
remarks.append(
|
||||
{
|
||||
"column": kwargs.get("column"),
|
||||
"expectation": expectation.get("expectation_type"),
|
||||
"note": note,
|
||||
}
|
||||
)
|
||||
|
||||
if remarks:
|
||||
suite_dict.setdefault("meta", {})
|
||||
suite_dict["meta"]["value_set_truncations"] = remarks
|
||||
|
||||
return suite_dict
|
||||
|
||||
|
||||
def _summarize_expectation_suite(suite_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
column_map: Dict[str, Dict[str, Any]] = {}
|
||||
table_expectations: List[Dict[str, Any]] = []
|
||||
|
||||
for expectation in suite_dict.get("expectations", []):
|
||||
expectation_type = expectation.get("expectation_type")
|
||||
kwargs = expectation.get("kwargs", {})
|
||||
column = kwargs.get("column")
|
||||
summary_entry: Dict[str, Any] = {"expectation": expectation_type}
|
||||
|
||||
if "value_set" in kwargs and isinstance(kwargs["value_set"], list):
|
||||
summary_entry["value_set_size"] = len(kwargs["value_set"])
|
||||
summary_entry["value_set_preview"] = kwargs["value_set"][:5]
|
||||
|
||||
if column:
|
||||
column_entry = column_map.setdefault(
|
||||
column,
|
||||
{"name": column, "expectations": []},
|
||||
)
|
||||
column_entry["expectations"].append(summary_entry)
|
||||
else:
|
||||
table_expectations.append(summary_entry)
|
||||
|
||||
summary = {
|
||||
"column_profiles": list(column_map.values()),
|
||||
"table_level_expectations": table_expectations,
|
||||
"total_expectations": len(suite_dict.get("expectations", [])),
|
||||
}
|
||||
return summary
|
||||
|
||||
|
||||
def _sanitize_identifier(raw: Optional[str], fallback: str) -> str:
|
||||
if not raw:
|
||||
return fallback
|
||||
candidate = re.sub(r"[^0-9A-Za-z_]+", "_", raw).strip("_")
|
||||
return candidate or fallback
|
||||
|
||||
|
||||
def _format_connection_string(template: str, access_info: Dict[str, Any]) -> str:
|
||||
if not access_info:
|
||||
return template
|
||||
try:
|
||||
return template.format_map({k: v for k, v in access_info.items()})
|
||||
except KeyError as exc:
|
||||
missing = exc.args[0]
|
||||
raise ValueError(f"table_access_info missing key '{missing}' required by connection_string.") from exc
|
||||
|
||||
|
||||
def _ensure_sql_runtime_datasource(
|
||||
context: AbstractDataContext,
|
||||
datasource_name: str,
|
||||
connection_string: str,
|
||||
) -> None:
|
||||
try:
|
||||
datasource = context.get_datasource(datasource_name)
|
||||
except (DataContextError, ValueError) as exc:
|
||||
message = str(exc)
|
||||
if "Could not find a datasource" in message or "Unable to load datasource" in message:
|
||||
datasource = None
|
||||
else: # pragma: no cover - defensive
|
||||
raise RuntimeError(f"Failed to inspect datasource '{datasource_name}'.") from exc
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
raise RuntimeError(f"Failed to inspect datasource '{datasource_name}'.") from exc
|
||||
|
||||
if datasource is not None:
|
||||
execution_engine = getattr(datasource, "execution_engine", None)
|
||||
current_conn = getattr(execution_engine, "connection_string", None)
|
||||
if current_conn and current_conn != connection_string:
|
||||
logger.info(
|
||||
"Existing datasource %s uses different connection string; creating dedicated runtime datasource.",
|
||||
datasource_name,
|
||||
)
|
||||
try:
|
||||
context.delete_datasource(datasource_name)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.warning(
|
||||
"Failed to delete datasource %s before recreation: %s",
|
||||
datasource_name,
|
||||
exc,
|
||||
)
|
||||
else:
|
||||
datasource = None
|
||||
|
||||
if datasource is not None:
|
||||
return
|
||||
|
||||
runtime_datasource_config = {
|
||||
"name": datasource_name,
|
||||
"class_name": "Datasource",
|
||||
"execution_engine": {
|
||||
"class_name": "SqlAlchemyExecutionEngine",
|
||||
"connection_string": connection_string,
|
||||
},
|
||||
"data_connectors": {
|
||||
"runtime_connector": {
|
||||
"class_name": "RuntimeDataConnector",
|
||||
"batch_identifiers": ["default_identifier_name"],
|
||||
}
|
||||
},
|
||||
}
|
||||
try:
|
||||
context.add_datasource(**runtime_datasource_config)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
raise RuntimeError(f"Failed to create runtime datasource '{datasource_name}'.") from exc
|
||||
|
||||
|
||||
def _build_sql_runtime_batch_request(
|
||||
context: AbstractDataContext,
|
||||
request: TableProfilingJobRequest,
|
||||
) -> RuntimeBatchRequest:
|
||||
link_info = request.table_link_info or {}
|
||||
access_info = request.table_access_info or {}
|
||||
|
||||
connection_template = link_info.get("connection_string")
|
||||
if not connection_template:
|
||||
raise ValueError("table_link_info.connection_string is required when using table_link_info.")
|
||||
|
||||
connection_string = _format_connection_string(connection_template, access_info)
|
||||
|
||||
source_type = (link_info.get("type") or "sql").lower()
|
||||
if source_type != "sql":
|
||||
raise ValueError(f"Unsupported table_link_info.type='{source_type}'. Only 'sql' is supported.")
|
||||
|
||||
query = link_info.get("query")
|
||||
table_name = link_info.get("table") or link_info.get("table_name")
|
||||
schema_name = link_info.get("schema")
|
||||
|
||||
if not query and not table_name:
|
||||
raise ValueError("Either table_link_info.query or table_link_info.table must be provided.")
|
||||
|
||||
if not query:
|
||||
if not table_name:
|
||||
raise ValueError("table_link_info.table must be provided when query is omitted.")
|
||||
|
||||
identifier = re.compile(r"^[A-Za-z_][A-Za-z0-9_$]*$")
|
||||
|
||||
def _quote(name: str) -> str:
|
||||
if identifier.match(name):
|
||||
return name
|
||||
return f"`{name.replace('`', '``')}`"
|
||||
|
||||
if schema_name:
|
||||
schema_part = schema_name if "." not in schema_name else schema_name.split(".")[-1]
|
||||
table_part = table_name if "." not in table_name else table_name.split(".")[-1]
|
||||
qualified_table = f"{_quote(schema_part)}.{_quote(table_part)}"
|
||||
else:
|
||||
qualified_table = _quote(table_name)
|
||||
|
||||
query = f"SELECT * FROM {qualified_table}"
|
||||
limit = link_info.get("limit")
|
||||
if isinstance(limit, int) and limit > 0:
|
||||
query = f"{query} LIMIT {limit}"
|
||||
|
||||
datasource_name = request.ge_datasource_name or _sanitize_identifier(
|
||||
f"{request.table_id}_runtime_ds", "runtime_ds"
|
||||
)
|
||||
data_asset_name = request.ge_data_asset_name or _sanitize_identifier(
|
||||
table_name or "runtime_query", "runtime_query"
|
||||
)
|
||||
|
||||
_ensure_sql_runtime_datasource(context, datasource_name, connection_string)
|
||||
|
||||
batch_identifiers = {
|
||||
"default_identifier_name": f"{request.table_id}:{request.version_ts}",
|
||||
}
|
||||
|
||||
return RuntimeBatchRequest(
|
||||
datasource_name=datasource_name,
|
||||
data_connector_name="runtime_connector",
|
||||
data_asset_name=data_asset_name,
|
||||
runtime_parameters={"query": query},
|
||||
batch_identifiers=batch_identifiers,
|
||||
)
|
||||
|
||||
|
||||
def _run_onboarding_assistant(
|
||||
context: AbstractDataContext,
|
||||
batch_request: Any,
|
||||
suite_name: str,
|
||||
) -> Tuple[ExpectationSuite, Any]:
|
||||
assistant = context.assistants.onboarding
|
||||
assistant_result = assistant.run(batch_request=batch_request)
|
||||
suite = assistant_result.get_expectation_suite(expectation_suite_name=suite_name)
|
||||
context.save_expectation_suite(suite, expectation_suite_name=suite_name)
|
||||
validation_getter = getattr(assistant_result, "get_validation_result", None)
|
||||
if callable(validation_getter):
|
||||
validation_result = validation_getter()
|
||||
else:
|
||||
validation_result = getattr(assistant_result, "validation_result", None)
|
||||
if validation_result is None:
|
||||
# Fallback: rerun validation using the freshly generated expectation suite.
|
||||
validator = context.get_validator(
|
||||
batch_request=batch_request,
|
||||
expectation_suite_name=suite_name,
|
||||
)
|
||||
validation_result = validator.validate()
|
||||
return suite, validation_result
|
||||
|
||||
|
||||
def _resolve_context(request: TableProfilingJobRequest) -> AbstractDataContext:
|
||||
context_kwargs: Dict[str, Any] = {}
|
||||
if request.ge_data_context_root:
|
||||
context_kwargs["project_root_dir"] = request.ge_data_context_root
|
||||
elif os.environ.get("GE_DATA_CONTEXT_ROOT"):
|
||||
context_kwargs["project_root_dir"] = os.environ["GE_DATA_CONTEXT_ROOT"]
|
||||
else:
|
||||
context_kwargs["project_root_dir"] = str(_project_root())
|
||||
|
||||
return gx.get_context(**context_kwargs)
|
||||
|
||||
|
||||
def _build_batch_request(
|
||||
context: AbstractDataContext,
|
||||
request: TableProfilingJobRequest,
|
||||
) -> Any:
|
||||
if request.ge_batch_request:
|
||||
from great_expectations.core.batch import BatchRequest
|
||||
|
||||
return BatchRequest(**request.ge_batch_request)
|
||||
|
||||
if request.table_link_info:
|
||||
return _build_sql_runtime_batch_request(context, request)
|
||||
|
||||
if not request.ge_datasource_name or not request.ge_data_asset_name:
|
||||
raise ValueError(
|
||||
"ge_batch_request or (ge_datasource_name and ge_data_asset_name) must be provided."
|
||||
)
|
||||
|
||||
datasource = context.get_datasource(request.ge_datasource_name)
|
||||
data_asset = datasource.get_asset(request.ge_data_asset_name)
|
||||
return data_asset.build_batch_request()
|
||||
|
||||
|
||||
async def _run_ge_profiling(request: TableProfilingJobRequest) -> GEProfilingArtifacts:
|
||||
def _execute() -> GEProfilingArtifacts:
|
||||
context = _resolve_context(request)
|
||||
suite_name = (
|
||||
request.ge_expectation_suite_name
|
||||
or f"{request.table_id}_profiling"
|
||||
)
|
||||
|
||||
batch_request = _build_batch_request(context, request)
|
||||
try:
|
||||
context.get_expectation_suite(suite_name)
|
||||
except DataContextError:
|
||||
context.add_expectation_suite(suite_name)
|
||||
|
||||
validator = context.get_validator(
|
||||
batch_request=batch_request,
|
||||
expectation_suite_name=suite_name,
|
||||
)
|
||||
|
||||
profiler_type = (request.ge_profiler_type or "user_configurable").lower()
|
||||
|
||||
if profiler_type == "data_assistant":
|
||||
suite, validation_result = _run_onboarding_assistant(
|
||||
context,
|
||||
batch_request,
|
||||
suite_name,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
from great_expectations.profile.user_configurable_profiler import (
|
||||
UserConfigurableProfiler,
|
||||
)
|
||||
except ImportError as err: # pragma: no cover - dependency guard
|
||||
raise RuntimeError(
|
||||
"UserConfigurableProfiler is unavailable; install great_expectations profiling extra or switch profiler."
|
||||
) from err
|
||||
|
||||
profiler = UserConfigurableProfiler(profile_dataset=validator)
|
||||
try:
|
||||
suite = profiler.build_suite()
|
||||
context.save_expectation_suite(suite, expectation_suite_name=suite_name)
|
||||
validator.expectation_suite = suite
|
||||
validation_result = validator.validate()
|
||||
except MetricResolutionError as exc:
|
||||
logger.warning(
|
||||
"UserConfigurableProfiler failed (%s); falling back to data assistant profiling.",
|
||||
exc,
|
||||
)
|
||||
suite, validation_result = _run_onboarding_assistant(
|
||||
context,
|
||||
batch_request,
|
||||
suite_name,
|
||||
)
|
||||
|
||||
sanitized_suite = _sanitize_expectation_suite(suite)
|
||||
summary = _summarize_expectation_suite(sanitized_suite)
|
||||
validation_dict = validation_result.to_json_dict()
|
||||
|
||||
context.build_data_docs()
|
||||
docs_path = Path(context.root_directory) / GE_REPORT_RELATIVE_PATH
|
||||
|
||||
profiling_result = {
|
||||
"expectation_suite": sanitized_suite,
|
||||
"validation_result": validation_dict,
|
||||
"batch_request": getattr(batch_request, "to_json_dict", lambda: None)() or getattr(batch_request, "dict", lambda: None)(),
|
||||
}
|
||||
|
||||
return GEProfilingArtifacts(
|
||||
profiling_result=profiling_result,
|
||||
profiling_summary=summary,
|
||||
docs_path=str(docs_path),
|
||||
)
|
||||
|
||||
return await asyncio.to_thread(_execute)
|
||||
|
||||
|
||||
async def _call_chat_completions(
|
||||
*,
|
||||
model_spec: str,
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
client: httpx.AsyncClient,
|
||||
temperature: float = 0.2,
|
||||
timeout_seconds: Optional[float] = None,
|
||||
) -> Any:
|
||||
provider, model_name = resolve_provider_from_model(model_spec)
|
||||
payload = {
|
||||
"provider": provider.value,
|
||||
"model": model_name,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
"temperature": temperature,
|
||||
}
|
||||
payload_size_bytes = len(json.dumps(payload, ensure_ascii=False).encode("utf-8"))
|
||||
|
||||
url = f"{IMPORT_GATEWAY_BASE_URL.rstrip('/')}/v1/chat/completions"
|
||||
try:
|
||||
# log the request whole info
|
||||
logger.info(
|
||||
"Calling chat completions API %s with model %s and size %s and payload %s",
|
||||
url,
|
||||
model_name,
|
||||
payload_size_bytes,
|
||||
payload,
|
||||
)
|
||||
response = await client.post(url, json=payload, timeout=timeout_seconds)
|
||||
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPError as exc:
|
||||
error_name = exc.__class__.__name__
|
||||
detail = str(exc).strip()
|
||||
if detail:
|
||||
message = f"Chat completions request failed ({error_name}): {detail}"
|
||||
else:
|
||||
message = f"Chat completions request failed ({error_name})."
|
||||
raise ProviderAPICallError(message) from exc
|
||||
|
||||
try:
|
||||
response_payload = response.json()
|
||||
except ValueError as exc:
|
||||
raise ProviderAPICallError("Chat completions response was not valid JSON.") from exc
|
||||
|
||||
return _parse_completion_payload(response_payload)
|
||||
|
||||
|
||||
def _normalize_for_json(value: Any) -> Any:
|
||||
if value is None or isinstance(value, (str, int, float, bool)):
|
||||
return value
|
||||
if isinstance(value, (datetime, date)):
|
||||
return str(value)
|
||||
if hasattr(value, "model_dump"):
|
||||
try:
|
||||
return value.model_dump()
|
||||
except Exception: # pragma: no cover - defensive
|
||||
pass
|
||||
if is_dataclass(value):
|
||||
return asdict(value)
|
||||
if isinstance(value, dict):
|
||||
return {k: _normalize_for_json(v) for k, v in value.items()}
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
return [_normalize_for_json(v) for v in value]
|
||||
if hasattr(value, "to_json_dict"):
|
||||
try:
|
||||
return value.to_json_dict()
|
||||
except Exception: # pragma: no cover - defensive
|
||||
pass
|
||||
if hasattr(value, "__dict__"):
|
||||
return _normalize_for_json(value.__dict__)
|
||||
return repr(value)
|
||||
|
||||
|
||||
def _json_dumps(data: Any) -> str:
|
||||
normalised = _normalize_for_json(data)
|
||||
return json.dumps(normalised, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
def _preview_for_log(data: Any) -> str:
|
||||
try:
|
||||
serialised = _json_dumps(data)
|
||||
except Exception:
|
||||
serialised = repr(data)
|
||||
|
||||
return serialised
|
||||
|
||||
|
||||
def _profiling_request_for_log(request: TableProfilingJobRequest) -> Dict[str, Any]:
|
||||
payload = request.model_dump()
|
||||
access_info = payload.get("table_access_info")
|
||||
if isinstance(access_info, dict):
|
||||
payload["table_access_info"] = {key: "***" for key in access_info.keys()}
|
||||
return payload
|
||||
|
||||
|
||||
async def _execute_result_desc(
|
||||
profiling_json: Dict[str, Any],
|
||||
_request: TableProfilingJobRequest,
|
||||
llm_model: str,
|
||||
client: httpx.AsyncClient,
|
||||
timeout_seconds: Optional[float],
|
||||
) -> Dict[str, Any]:
|
||||
system_prompt, user_prompt = _render_prompt(
|
||||
"ge_result_desc",
|
||||
{"{{GE_RESULT_JSON}}": _json_dumps(profiling_json)},
|
||||
)
|
||||
llm_output = await _call_chat_completions(
|
||||
model_spec=llm_model,
|
||||
system_prompt=system_prompt,
|
||||
user_prompt=user_prompt,
|
||||
client=client,
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
if not isinstance(llm_output, dict):
|
||||
raise ProviderAPICallError("GE result description payload must be a JSON object.")
|
||||
return llm_output
|
||||
|
||||
|
||||
async def _execute_snippet_generation(
|
||||
table_desc_json: Dict[str, Any],
|
||||
_request: TableProfilingJobRequest,
|
||||
llm_model: str,
|
||||
client: httpx.AsyncClient,
|
||||
timeout_seconds: Optional[float],
|
||||
) -> List[Dict[str, Any]]:
|
||||
system_prompt, user_prompt = _render_prompt(
|
||||
"snippet_generator",
|
||||
{"{{TABLE_PROFILE_JSON}}": _json_dumps(table_desc_json)},
|
||||
)
|
||||
llm_output = await _call_chat_completions(
|
||||
model_spec=llm_model,
|
||||
system_prompt=system_prompt,
|
||||
user_prompt=user_prompt,
|
||||
client=client,
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
if not isinstance(llm_output, list):
|
||||
raise ProviderAPICallError("Snippet generator must return a JSON array.")
|
||||
return llm_output
|
||||
|
||||
|
||||
async def _execute_snippet_alias(
|
||||
snippets_json: List[Dict[str, Any]],
|
||||
_request: TableProfilingJobRequest,
|
||||
llm_model: str,
|
||||
client: httpx.AsyncClient,
|
||||
timeout_seconds: Optional[float],
|
||||
) -> List[Dict[str, Any]]:
|
||||
system_prompt, user_prompt = _render_prompt(
|
||||
"snippet_alias",
|
||||
{"{{SNIPPET_ARRAY}}": _json_dumps(snippets_json)},
|
||||
)
|
||||
llm_output = await _call_chat_completions(
|
||||
model_spec=llm_model,
|
||||
system_prompt=system_prompt,
|
||||
user_prompt=user_prompt,
|
||||
client=client,
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
if not isinstance(llm_output, list):
|
||||
raise ProviderAPICallError("Snippet alias generator must return a JSON array.")
|
||||
return llm_output
|
||||
|
||||
|
||||
async def _run_action_with_callback(
|
||||
*,
|
||||
action_type: str,
|
||||
runner,
|
||||
callback_base: Dict[str, Any],
|
||||
client: httpx.AsyncClient,
|
||||
callback_url: str,
|
||||
input_payload: Any = None,
|
||||
model_spec: Optional[str] = None,
|
||||
) -> Any:
|
||||
if input_payload is not None:
|
||||
logger.info(
|
||||
"Pipeline action %s input: %s",
|
||||
action_type,
|
||||
_preview_for_log(input_payload),
|
||||
)
|
||||
try:
|
||||
result = await runner()
|
||||
except Exception as exc:
|
||||
failure_payload = dict(callback_base)
|
||||
failure_payload.update(
|
||||
{
|
||||
"status": "failed",
|
||||
"action_type": action_type,
|
||||
"error": str(exc),
|
||||
}
|
||||
)
|
||||
if model_spec is not None:
|
||||
failure_payload["model"] = model_spec
|
||||
await _post_callback(callback_url, failure_payload, client)
|
||||
raise
|
||||
|
||||
success_payload = dict(callback_base)
|
||||
success_payload.update(
|
||||
{
|
||||
"status": "success",
|
||||
"action_type": action_type,
|
||||
}
|
||||
)
|
||||
if model_spec is not None:
|
||||
success_payload["model"] = model_spec
|
||||
|
||||
logger.info(
|
||||
"Pipeline action %s output: %s",
|
||||
action_type,
|
||||
_preview_for_log(result),
|
||||
)
|
||||
|
||||
if action_type == PipelineActionType.GE_PROFILING:
|
||||
artifacts: GEProfilingArtifacts = result
|
||||
success_payload["profiling_json"] = artifacts.profiling_result
|
||||
success_payload["profiling_summary"] = artifacts.profiling_summary
|
||||
success_payload["ge_report_path"] = artifacts.docs_path
|
||||
elif action_type == PipelineActionType.GE_RESULT_DESC:
|
||||
success_payload["table_desc_json"] = result
|
||||
elif action_type == PipelineActionType.SNIPPET:
|
||||
success_payload["snippet_json"] = result
|
||||
elif action_type == PipelineActionType.SNIPPET_ALIAS:
|
||||
success_payload["snippet_alias_json"] = result
|
||||
|
||||
await _post_callback(callback_url, success_payload, client)
|
||||
return result
|
||||
|
||||
|
||||
async def process_table_profiling_job(
|
||||
request: TableProfilingJobRequest,
|
||||
_gateway: LLMGateway,
|
||||
client: httpx.AsyncClient,
|
||||
) -> None:
|
||||
"""Sequentially execute the four-step profiling pipeline and emit callbacks per action."""
|
||||
|
||||
timeout_seconds = _extract_timeout_seconds(request.extra_options)
|
||||
if timeout_seconds is None:
|
||||
timeout_seconds = DEFAULT_CHAT_TIMEOUT_SECONDS
|
||||
|
||||
base_payload = {
|
||||
"table_id": request.table_id,
|
||||
"version_ts": request.version_ts,
|
||||
"callback_url": str(request.callback_url),
|
||||
"table_schema": request.table_schema,
|
||||
"table_schema_version_id": request.table_schema_version_id,
|
||||
"llm_model": request.llm_model,
|
||||
"llm_timeout_seconds": timeout_seconds,
|
||||
}
|
||||
|
||||
logging_request_payload = _profiling_request_for_log(request)
|
||||
|
||||
try:
|
||||
artifacts: GEProfilingArtifacts = await _run_action_with_callback(
|
||||
action_type=PipelineActionType.GE_PROFILING,
|
||||
runner=lambda: _run_ge_profiling(request),
|
||||
callback_base=base_payload,
|
||||
client=client,
|
||||
callback_url=str(request.callback_url),
|
||||
input_payload=logging_request_payload,
|
||||
model_spec=request.llm_model,
|
||||
)
|
||||
|
||||
table_desc_json: Dict[str, Any] = await _run_action_with_callback(
|
||||
action_type=PipelineActionType.GE_RESULT_DESC,
|
||||
runner=lambda: _execute_result_desc(
|
||||
artifacts.profiling_result,
|
||||
request,
|
||||
request.llm_model,
|
||||
client,
|
||||
timeout_seconds,
|
||||
),
|
||||
callback_base=base_payload,
|
||||
client=client,
|
||||
callback_url=str(request.callback_url),
|
||||
input_payload=artifacts.profiling_result,
|
||||
model_spec=request.llm_model,
|
||||
)
|
||||
|
||||
snippet_json: List[Dict[str, Any]] = await _run_action_with_callback(
|
||||
action_type=PipelineActionType.SNIPPET,
|
||||
runner=lambda: _execute_snippet_generation(
|
||||
table_desc_json,
|
||||
request,
|
||||
request.llm_model,
|
||||
client,
|
||||
timeout_seconds,
|
||||
),
|
||||
callback_base=base_payload,
|
||||
client=client,
|
||||
callback_url=str(request.callback_url),
|
||||
input_payload=table_desc_json,
|
||||
model_spec=request.llm_model,
|
||||
)
|
||||
|
||||
await _run_action_with_callback(
|
||||
action_type=PipelineActionType.SNIPPET_ALIAS,
|
||||
runner=lambda: _execute_snippet_alias(
|
||||
snippet_json,
|
||||
request,
|
||||
request.llm_model,
|
||||
client,
|
||||
timeout_seconds,
|
||||
),
|
||||
callback_base=base_payload,
|
||||
client=client,
|
||||
callback_url=str(request.callback_url),
|
||||
input_payload=snippet_json,
|
||||
model_spec=request.llm_model,
|
||||
)
|
||||
except Exception: # pragma: no cover - defensive catch
|
||||
logger.exception(
|
||||
"Table profiling pipeline failed for table_id=%s version_ts=%s",
|
||||
request.table_id,
|
||||
request.version_ts,
|
||||
)
|
||||
184
app/services/table_snippet.py
Normal file
184
app/services/table_snippet.py
Normal file
@ -0,0 +1,184 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from app.db import get_engine
|
||||
from app.models import (
|
||||
ActionType,
|
||||
TableSnippetUpsertRequest,
|
||||
TableSnippetUpsertResponse,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _serialize_json(value: Any) -> Tuple[str | None, int | None]:
|
||||
logger.debug("Serializing JSON payload: %s", value)
|
||||
if value is None:
|
||||
return None, None
|
||||
if isinstance(value, str):
|
||||
encoded = value.encode("utf-8")
|
||||
return value, len(encoded)
|
||||
serialized = json.dumps(value, ensure_ascii=False)
|
||||
encoded = serialized.encode("utf-8")
|
||||
return serialized, len(encoded)
|
||||
|
||||
|
||||
def _prepare_table_schema(value: Any) -> str:
|
||||
logger.debug("Preparing table_schema payload.")
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
|
||||
|
||||
def _collect_common_columns(request: TableSnippetUpsertRequest) -> Dict[str, Any]:
|
||||
logger.debug(
|
||||
"Collecting common columns for table_id=%s version_ts=%s action_type=%s",
|
||||
request.table_id,
|
||||
request.version_ts,
|
||||
request.action_type,
|
||||
)
|
||||
payload: Dict[str, Any] = {
|
||||
"table_id": request.table_id,
|
||||
"version_ts": request.version_ts,
|
||||
"action_type": request.action_type.value,
|
||||
"status": request.status.value,
|
||||
"callback_url": str(request.callback_url),
|
||||
"table_schema_version_id": request.table_schema_version_id,
|
||||
"table_schema": _prepare_table_schema(request.table_schema),
|
||||
}
|
||||
|
||||
if request.error_code is not None:
|
||||
logger.debug("Adding error_code: %s", request.error_code)
|
||||
payload["error_code"] = request.error_code
|
||||
if request.error_message is not None:
|
||||
logger.debug("Adding error_message: %s", request.error_message)
|
||||
payload["error_message"] = request.error_message
|
||||
if request.started_at is not None:
|
||||
payload["started_at"] = request.started_at
|
||||
if request.finished_at is not None:
|
||||
payload["finished_at"] = request.finished_at
|
||||
if request.duration_ms is not None:
|
||||
payload["duration_ms"] = request.duration_ms
|
||||
if request.result_checksum is not None:
|
||||
payload["result_checksum"] = request.result_checksum
|
||||
|
||||
logger.debug("Collected common payload: %s", payload)
|
||||
return payload
|
||||
|
||||
|
||||
def _apply_action_payload(
|
||||
request: TableSnippetUpsertRequest,
|
||||
payload: Dict[str, Any],
|
||||
) -> None:
|
||||
logger.debug("Applying action-specific payload for action_type=%s", request.action_type)
|
||||
if request.action_type == ActionType.GE_PROFILING:
|
||||
full_json, full_size = _serialize_json(request.result_json)
|
||||
summary_json, summary_size = _serialize_json(request.result_summary_json)
|
||||
if full_json is not None:
|
||||
payload["ge_profiling_full"] = full_json
|
||||
payload["ge_profiling_full_size_bytes"] = full_size
|
||||
if summary_json is not None:
|
||||
payload["ge_profiling_summary"] = summary_json
|
||||
payload["ge_profiling_summary_size_bytes"] = summary_size
|
||||
if full_size is not None or summary_size is not None:
|
||||
payload["ge_profiling_total_size_bytes"] = (full_size or 0) + (
|
||||
summary_size or 0
|
||||
)
|
||||
if request.html_report_url:
|
||||
payload["ge_profiling_html_report_url"] = request.html_report_url
|
||||
elif request.action_type == ActionType.GE_RESULT_DESC:
|
||||
full_json, full_size = _serialize_json(request.result_json)
|
||||
if full_json is not None:
|
||||
payload["ge_result_desc_full"] = full_json
|
||||
payload["ge_result_desc_full_size_bytes"] = full_size
|
||||
elif request.action_type == ActionType.SNIPPET:
|
||||
full_json, full_size = _serialize_json(request.result_json)
|
||||
if full_json is not None:
|
||||
payload["snippet_full"] = full_json
|
||||
payload["snippet_full_size_bytes"] = full_size
|
||||
elif request.action_type == ActionType.SNIPPET_ALIAS:
|
||||
full_json, full_size = _serialize_json(request.result_json)
|
||||
if full_json is not None:
|
||||
payload["snippet_alias_full"] = full_json
|
||||
payload["snippet_alias_full_size_bytes"] = full_size
|
||||
else:
|
||||
logger.error("Unsupported action type encountered: %s", request.action_type)
|
||||
raise ValueError(f"Unsupported action type '{request.action_type}'.")
|
||||
|
||||
logger.debug("Payload after applying action-specific data: %s", payload)
|
||||
|
||||
|
||||
def _build_insert_statement(columns: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
|
||||
logger.debug("Building insert statement for columns: %s", list(columns.keys()))
|
||||
column_names = list(columns.keys())
|
||||
placeholders = [f":{name}" for name in column_names]
|
||||
update_assignments = [
|
||||
f"{name}=VALUES({name})"
|
||||
for name in column_names
|
||||
if name not in {"table_id", "version_ts", "action_type"}
|
||||
]
|
||||
update_assignments.append("updated_at=CURRENT_TIMESTAMP")
|
||||
|
||||
sql = (
|
||||
"INSERT INTO action_results ({cols}) VALUES ({vals}) "
|
||||
"ON DUPLICATE KEY UPDATE {updates}"
|
||||
).format(
|
||||
cols=", ".join(column_names),
|
||||
vals=", ".join(placeholders),
|
||||
updates=", ".join(update_assignments),
|
||||
)
|
||||
logger.debug("Generated SQL: %s", sql)
|
||||
return sql, columns
|
||||
|
||||
|
||||
def _execute_upsert(engine: Engine, sql: str, params: Dict[str, Any]) -> int:
|
||||
logger.info("Executing upsert for table_id=%s version_ts=%s action_type=%s", params.get("table_id"), params.get("version_ts"), params.get("action_type"))
|
||||
with engine.begin() as conn:
|
||||
result = conn.execute(text(sql), params)
|
||||
logger.info("Rows affected: %s", result.rowcount)
|
||||
return result.rowcount
|
||||
|
||||
|
||||
def upsert_action_result(request: TableSnippetUpsertRequest) -> TableSnippetUpsertResponse:
|
||||
logger.info(
|
||||
"Received upsert request: table_id=%s version_ts=%s action_type=%s status=%s",
|
||||
request.table_id,
|
||||
request.version_ts,
|
||||
request.action_type,
|
||||
request.status,
|
||||
)
|
||||
logger.debug("Request payload: %s", request.model_dump())
|
||||
columns = _collect_common_columns(request)
|
||||
_apply_action_payload(request, columns)
|
||||
|
||||
sql, params = _build_insert_statement(columns)
|
||||
logger.debug("Final SQL params: %s", params)
|
||||
|
||||
engine = get_engine()
|
||||
try:
|
||||
rowcount = _execute_upsert(engine, sql, params)
|
||||
except SQLAlchemyError as exc:
|
||||
logger.exception(
|
||||
"Failed to upsert action result: table_id=%s version_ts=%s action_type=%s",
|
||||
request.table_id,
|
||||
request.version_ts,
|
||||
request.action_type,
|
||||
)
|
||||
raise RuntimeError(f"Database operation failed: {exc}") from exc
|
||||
|
||||
updated = rowcount > 1
|
||||
return TableSnippetUpsertResponse(
|
||||
table_id=request.table_id,
|
||||
version_ts=request.version_ts,
|
||||
action_type=request.action_type,
|
||||
status=request.status,
|
||||
updated=updated,
|
||||
)
|
||||
54
table_snippet.sql
Normal file
54
table_snippet.sql
Normal file
@ -0,0 +1,54 @@
|
||||
CREATE TABLE IF NOT EXISTS action_results (
|
||||
id BIGINT NOT NULL AUTO_INCREMENT COMMENT '主键',
|
||||
table_id BIGINT NOT NULL COMMENT '表ID',
|
||||
version_ts BIGINT NOT NULL COMMENT '版本时间戳(版本号)',
|
||||
action_type ENUM('ge_profiling','ge_result_desc','snippet','snippet_alias') NOT NULL COMMENT '动作类型',
|
||||
|
||||
status ENUM('pending','running','success','failed','partial') NOT NULL DEFAULT 'pending' COMMENT '执行状态',
|
||||
error_code VARCHAR(128) NULL,
|
||||
error_message TEXT NULL,
|
||||
|
||||
-- 回调 & 观测
|
||||
callback_url VARCHAR(1024) NOT NULL,
|
||||
started_at DATETIME NULL,
|
||||
finished_at DATETIME NULL,
|
||||
duration_ms INT NULL,
|
||||
|
||||
-- 本次schema信息
|
||||
table_schema_version_id BIGINT NOT NULL,
|
||||
table_schema JSON NOT NULL,
|
||||
|
||||
-- ===== 动作1:GE Profiling =====
|
||||
ge_profiling_full JSON NULL COMMENT 'Profiling完整结果JSON',
|
||||
ge_profiling_full_size_bytes BIGINT NULL,
|
||||
ge_profiling_summary JSON NULL COMMENT 'Profiling摘要(剔除大value_set等)',
|
||||
ge_profiling_summary_size_bytes BIGINT NULL,
|
||||
ge_profiling_total_size_bytes BIGINT NULL COMMENT '上两者合计',
|
||||
ge_profiling_html_report_url VARCHAR(1024) NULL COMMENT 'GE报告HTML路径/URL',
|
||||
|
||||
-- ===== 动作2:GE Result Desc =====
|
||||
ge_result_desc_full JSON NULL COMMENT '表描述结果JSON',
|
||||
ge_result_desc_full_size_bytes BIGINT NULL,
|
||||
|
||||
-- ===== 动作3:Snippet 生成 =====
|
||||
snippet_full JSON NULL COMMENT 'SQL知识片段结果JSON',
|
||||
snippet_full_size_bytes BIGINT NULL,
|
||||
|
||||
-- ===== 动作4:Snippet Alias 改写 =====
|
||||
snippet_alias_full JSON NULL COMMENT 'SQL片段改写/丰富结果JSON',
|
||||
snippet_alias_full_size_bytes BIGINT NULL,
|
||||
|
||||
-- 通用可选指标
|
||||
result_checksum VARBINARY(32) NULL COMMENT '对当前action有效载荷计算的MD5/xxhash',
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
|
||||
PRIMARY KEY (id),
|
||||
UNIQUE KEY uq_table_ver_action (table_id, version_ts, action_type),
|
||||
KEY idx_status (status),
|
||||
KEY idx_table (table_id, updated_at),
|
||||
KEY idx_action_time (action_type, version_ts),
|
||||
KEY idx_schema_version (table_schema_version_id)
|
||||
) ENGINE=InnoDB
|
||||
ROW_FORMAT=DYNAMIC
|
||||
COMMENT='数据分析知识片段表';
|
||||
Reference in New Issue
Block a user