table profiling功能开发

This commit is contained in:
zhaoawd
2025-11-03 00:18:26 +08:00
parent 557efc4bf1
commit c2a08e4134
6 changed files with 1280 additions and 16 deletions

View File

@ -0,0 +1,832 @@
from __future__ import annotations
import asyncio
import json
import logging
import os
import re
from datetime import date, datetime
from dataclasses import asdict, dataclass, is_dataclass
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import httpx
import great_expectations as gx
from great_expectations.core.batch import RuntimeBatchRequest
from great_expectations.core.expectation_suite import ExpectationSuite
from great_expectations.data_context import AbstractDataContext
from great_expectations.exceptions import DataContextError, MetricResolutionError
from app.exceptions import ProviderAPICallError
from app.models import TableProfilingJobRequest
from app.services import LLMGateway
from app.settings import DEFAULT_IMPORT_MODEL
from app.services.import_analysis import (
IMPORT_GATEWAY_BASE_URL,
resolve_provider_from_model,
)
logger = logging.getLogger(__name__)
GE_REPORT_RELATIVE_PATH = Path("uncommitted") / "data_docs" / "local_site" / "index.html"
PROMPT_FILENAMES = {
"ge_result_desc": "ge_result_desc_prompt.md",
"snippet_generator": "snippet_generator.md",
"snippet_alias": "snippet_alias_generator.md",
}
DEFAULT_CHAT_TIMEOUT_SECONDS = 90.0
@dataclass
class GEProfilingArtifacts:
profiling_result: Dict[str, Any]
profiling_summary: Dict[str, Any]
docs_path: str
class PipelineActionType:
GE_PROFILING = "ge_profiling"
GE_RESULT_DESC = "ge_result_desc"
SNIPPET = "snippet"
SNIPPET_ALIAS = "snippet_alias"
def _project_root() -> Path:
return Path(__file__).resolve().parents[2]
def _prompt_dir() -> Path:
return _project_root() / "prompt"
@lru_cache(maxsize=None)
def _load_prompt_parts(filename: str) -> Tuple[str, str]:
prompt_path = _prompt_dir() / filename
if not prompt_path.exists():
raise FileNotFoundError(f"Prompt template not found: {prompt_path}")
raw = prompt_path.read_text(encoding="utf-8")
splitter = "用户消息User"
if splitter not in raw:
raise ValueError(f"Prompt template '{filename}' missing separator '{splitter}'.")
system_raw, user_raw = raw.split(splitter, maxsplit=1)
system_text = system_raw.replace("系统角色System", "").strip()
user_text = user_raw.strip()
return system_text, user_text
def _render_prompt(template_key: str, replacements: Dict[str, str]) -> Tuple[str, str]:
filename = PROMPT_FILENAMES[template_key]
system_text, user_template = _load_prompt_parts(filename)
rendered_user = user_template
for key, value in replacements.items():
rendered_user = rendered_user.replace(key, value)
return system_text, rendered_user
def _extract_timeout_seconds(options: Optional[Dict[str, Any]]) -> Optional[float]:
if not options:
return None
value = options.get("llm_timeout_seconds")
if value is None:
return None
try:
timeout = float(value)
if timeout <= 0:
raise ValueError
return timeout
except (TypeError, ValueError):
logger.warning(
"Invalid llm_timeout_seconds value in extra_options: %r. Falling back to default.",
value,
)
return DEFAULT_CHAT_TIMEOUT_SECONDS
def _extract_json_payload(content: str) -> str:
fenced = re.search(
r"```(?:json)?\s*([\s\S]+?)```",
content,
flags=re.IGNORECASE,
)
if fenced:
snippet = fenced.group(1).strip()
if snippet:
return snippet
stripped = content.strip()
if not stripped:
raise ValueError("Empty LLM content.")
for opener, closer in (("{", "}"), ("[", "]")):
start = stripped.find(opener)
end = stripped.rfind(closer)
if start != -1 and end != -1 and end > start:
candidate = stripped[start : end + 1].strip()
return candidate
return stripped
def _parse_completion_payload(response_payload: Dict[str, Any]) -> Any:
choices = response_payload.get("choices") or []
if not choices:
raise ProviderAPICallError("LLM response did not contain choices to parse.")
message = choices[0].get("message") or {}
content = message.get("content") or ""
if not content.strip():
raise ProviderAPICallError("LLM response content is empty.")
json_payload = _extract_json_payload(content)
try:
return json.loads(json_payload)
except json.JSONDecodeError as exc:
preview = json_payload[:800]
logger.error("Failed to parse JSON from LLM response: %s", preview, exc_info=True)
raise ProviderAPICallError("LLM response JSON parsing failed.") from exc
async def _post_callback(callback_url: str, payload: Dict[str, Any], client: httpx.AsyncClient) -> None:
safe_payload = _normalize_for_json(payload)
try:
logger.info(
"Posting pipeline action callback to %s: %s",
callback_url,
json.dumps(safe_payload, ensure_ascii=False),
)
response = await client.post(callback_url, json=safe_payload)
response.raise_for_status()
except httpx.HTTPError as exc:
logger.error("Callback delivery to %s failed: %s", callback_url, exc, exc_info=True)
def _sanitize_value_set(value: Any, max_values: int) -> Tuple[Any, Optional[Dict[str, int]]]:
if not isinstance(value, list):
return value, None
original_len = len(value)
if original_len <= max_values:
return value, None
trimmed = value[:max_values]
return trimmed, {"original_length": original_len, "retained": max_values}
def _sanitize_expectation_suite(suite: ExpectationSuite, max_value_set_values: int = 100) -> Dict[str, Any]:
suite_dict = suite.to_json_dict()
remarks: List[Dict[str, Any]] = []
for expectation in suite_dict.get("expectations", []):
kwargs = expectation.get("kwargs", {})
if "value_set" in kwargs:
sanitized_value, note = _sanitize_value_set(kwargs["value_set"], max_value_set_values)
kwargs["value_set"] = sanitized_value
if note:
expectation.setdefault("meta", {})
expectation["meta"]["value_set_truncated"] = note
remarks.append(
{
"column": kwargs.get("column"),
"expectation": expectation.get("expectation_type"),
"note": note,
}
)
if remarks:
suite_dict.setdefault("meta", {})
suite_dict["meta"]["value_set_truncations"] = remarks
return suite_dict
def _summarize_expectation_suite(suite_dict: Dict[str, Any]) -> Dict[str, Any]:
column_map: Dict[str, Dict[str, Any]] = {}
table_expectations: List[Dict[str, Any]] = []
for expectation in suite_dict.get("expectations", []):
expectation_type = expectation.get("expectation_type")
kwargs = expectation.get("kwargs", {})
column = kwargs.get("column")
summary_entry: Dict[str, Any] = {"expectation": expectation_type}
if "value_set" in kwargs and isinstance(kwargs["value_set"], list):
summary_entry["value_set_size"] = len(kwargs["value_set"])
summary_entry["value_set_preview"] = kwargs["value_set"][:5]
if column:
column_entry = column_map.setdefault(
column,
{"name": column, "expectations": []},
)
column_entry["expectations"].append(summary_entry)
else:
table_expectations.append(summary_entry)
summary = {
"column_profiles": list(column_map.values()),
"table_level_expectations": table_expectations,
"total_expectations": len(suite_dict.get("expectations", [])),
}
return summary
def _sanitize_identifier(raw: Optional[str], fallback: str) -> str:
if not raw:
return fallback
candidate = re.sub(r"[^0-9A-Za-z_]+", "_", raw).strip("_")
return candidate or fallback
def _format_connection_string(template: str, access_info: Dict[str, Any]) -> str:
if not access_info:
return template
try:
return template.format_map({k: v for k, v in access_info.items()})
except KeyError as exc:
missing = exc.args[0]
raise ValueError(f"table_access_info missing key '{missing}' required by connection_string.") from exc
def _ensure_sql_runtime_datasource(
context: AbstractDataContext,
datasource_name: str,
connection_string: str,
) -> None:
try:
datasource = context.get_datasource(datasource_name)
except (DataContextError, ValueError) as exc:
message = str(exc)
if "Could not find a datasource" in message or "Unable to load datasource" in message:
datasource = None
else: # pragma: no cover - defensive
raise RuntimeError(f"Failed to inspect datasource '{datasource_name}'.") from exc
except Exception as exc: # pragma: no cover - defensive
raise RuntimeError(f"Failed to inspect datasource '{datasource_name}'.") from exc
if datasource is not None:
execution_engine = getattr(datasource, "execution_engine", None)
current_conn = getattr(execution_engine, "connection_string", None)
if current_conn and current_conn != connection_string:
logger.info(
"Existing datasource %s uses different connection string; creating dedicated runtime datasource.",
datasource_name,
)
try:
context.delete_datasource(datasource_name)
except Exception as exc: # pragma: no cover - defensive
logger.warning(
"Failed to delete datasource %s before recreation: %s",
datasource_name,
exc,
)
else:
datasource = None
if datasource is not None:
return
runtime_datasource_config = {
"name": datasource_name,
"class_name": "Datasource",
"execution_engine": {
"class_name": "SqlAlchemyExecutionEngine",
"connection_string": connection_string,
},
"data_connectors": {
"runtime_connector": {
"class_name": "RuntimeDataConnector",
"batch_identifiers": ["default_identifier_name"],
}
},
}
try:
context.add_datasource(**runtime_datasource_config)
except Exception as exc: # pragma: no cover - defensive
raise RuntimeError(f"Failed to create runtime datasource '{datasource_name}'.") from exc
def _build_sql_runtime_batch_request(
context: AbstractDataContext,
request: TableProfilingJobRequest,
) -> RuntimeBatchRequest:
link_info = request.table_link_info or {}
access_info = request.table_access_info or {}
connection_template = link_info.get("connection_string")
if not connection_template:
raise ValueError("table_link_info.connection_string is required when using table_link_info.")
connection_string = _format_connection_string(connection_template, access_info)
source_type = (link_info.get("type") or "sql").lower()
if source_type != "sql":
raise ValueError(f"Unsupported table_link_info.type='{source_type}'. Only 'sql' is supported.")
query = link_info.get("query")
table_name = link_info.get("table") or link_info.get("table_name")
schema_name = link_info.get("schema")
if not query and not table_name:
raise ValueError("Either table_link_info.query or table_link_info.table must be provided.")
if not query:
if not table_name:
raise ValueError("table_link_info.table must be provided when query is omitted.")
identifier = re.compile(r"^[A-Za-z_][A-Za-z0-9_$]*$")
def _quote(name: str) -> str:
if identifier.match(name):
return name
return f"`{name.replace('`', '``')}`"
if schema_name:
schema_part = schema_name if "." not in schema_name else schema_name.split(".")[-1]
table_part = table_name if "." not in table_name else table_name.split(".")[-1]
qualified_table = f"{_quote(schema_part)}.{_quote(table_part)}"
else:
qualified_table = _quote(table_name)
query = f"SELECT * FROM {qualified_table}"
limit = link_info.get("limit")
if isinstance(limit, int) and limit > 0:
query = f"{query} LIMIT {limit}"
datasource_name = request.ge_datasource_name or _sanitize_identifier(
f"{request.table_id}_runtime_ds", "runtime_ds"
)
data_asset_name = request.ge_data_asset_name or _sanitize_identifier(
table_name or "runtime_query", "runtime_query"
)
_ensure_sql_runtime_datasource(context, datasource_name, connection_string)
batch_identifiers = {
"default_identifier_name": f"{request.table_id}:{request.version_ts}",
}
return RuntimeBatchRequest(
datasource_name=datasource_name,
data_connector_name="runtime_connector",
data_asset_name=data_asset_name,
runtime_parameters={"query": query},
batch_identifiers=batch_identifiers,
)
def _run_onboarding_assistant(
context: AbstractDataContext,
batch_request: Any,
suite_name: str,
) -> Tuple[ExpectationSuite, Any]:
assistant = context.assistants.onboarding
assistant_result = assistant.run(batch_request=batch_request)
suite = assistant_result.get_expectation_suite(expectation_suite_name=suite_name)
context.save_expectation_suite(suite, expectation_suite_name=suite_name)
validation_getter = getattr(assistant_result, "get_validation_result", None)
if callable(validation_getter):
validation_result = validation_getter()
else:
validation_result = getattr(assistant_result, "validation_result", None)
if validation_result is None:
# Fallback: rerun validation using the freshly generated expectation suite.
validator = context.get_validator(
batch_request=batch_request,
expectation_suite_name=suite_name,
)
validation_result = validator.validate()
return suite, validation_result
def _resolve_context(request: TableProfilingJobRequest) -> AbstractDataContext:
context_kwargs: Dict[str, Any] = {}
if request.ge_data_context_root:
context_kwargs["project_root_dir"] = request.ge_data_context_root
elif os.environ.get("GE_DATA_CONTEXT_ROOT"):
context_kwargs["project_root_dir"] = os.environ["GE_DATA_CONTEXT_ROOT"]
else:
context_kwargs["project_root_dir"] = str(_project_root())
return gx.get_context(**context_kwargs)
def _build_batch_request(
context: AbstractDataContext,
request: TableProfilingJobRequest,
) -> Any:
if request.ge_batch_request:
from great_expectations.core.batch import BatchRequest
return BatchRequest(**request.ge_batch_request)
if request.table_link_info:
return _build_sql_runtime_batch_request(context, request)
if not request.ge_datasource_name or not request.ge_data_asset_name:
raise ValueError(
"ge_batch_request or (ge_datasource_name and ge_data_asset_name) must be provided."
)
datasource = context.get_datasource(request.ge_datasource_name)
data_asset = datasource.get_asset(request.ge_data_asset_name)
return data_asset.build_batch_request()
async def _run_ge_profiling(request: TableProfilingJobRequest) -> GEProfilingArtifacts:
def _execute() -> GEProfilingArtifacts:
context = _resolve_context(request)
suite_name = (
request.ge_expectation_suite_name
or f"{request.table_id}_profiling"
)
batch_request = _build_batch_request(context, request)
try:
context.get_expectation_suite(suite_name)
except DataContextError:
context.add_expectation_suite(suite_name)
validator = context.get_validator(
batch_request=batch_request,
expectation_suite_name=suite_name,
)
profiler_type = (request.ge_profiler_type or "user_configurable").lower()
if profiler_type == "data_assistant":
suite, validation_result = _run_onboarding_assistant(
context,
batch_request,
suite_name,
)
else:
try:
from great_expectations.profile.user_configurable_profiler import (
UserConfigurableProfiler,
)
except ImportError as err: # pragma: no cover - dependency guard
raise RuntimeError(
"UserConfigurableProfiler is unavailable; install great_expectations profiling extra or switch profiler."
) from err
profiler = UserConfigurableProfiler(profile_dataset=validator)
try:
suite = profiler.build_suite()
context.save_expectation_suite(suite, expectation_suite_name=suite_name)
validator.expectation_suite = suite
validation_result = validator.validate()
except MetricResolutionError as exc:
logger.warning(
"UserConfigurableProfiler failed (%s); falling back to data assistant profiling.",
exc,
)
suite, validation_result = _run_onboarding_assistant(
context,
batch_request,
suite_name,
)
sanitized_suite = _sanitize_expectation_suite(suite)
summary = _summarize_expectation_suite(sanitized_suite)
validation_dict = validation_result.to_json_dict()
context.build_data_docs()
docs_path = Path(context.root_directory) / GE_REPORT_RELATIVE_PATH
profiling_result = {
"expectation_suite": sanitized_suite,
"validation_result": validation_dict,
"batch_request": getattr(batch_request, "to_json_dict", lambda: None)() or getattr(batch_request, "dict", lambda: None)(),
}
return GEProfilingArtifacts(
profiling_result=profiling_result,
profiling_summary=summary,
docs_path=str(docs_path),
)
return await asyncio.to_thread(_execute)
async def _call_chat_completions(
*,
model_spec: str,
system_prompt: str,
user_prompt: str,
client: httpx.AsyncClient,
temperature: float = 0.2,
timeout_seconds: Optional[float] = None,
) -> Any:
provider, model_name = resolve_provider_from_model(model_spec)
payload = {
"provider": provider.value,
"model": model_name,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
"temperature": temperature,
}
payload_size_bytes = len(json.dumps(payload, ensure_ascii=False).encode("utf-8"))
url = f"{IMPORT_GATEWAY_BASE_URL.rstrip('/')}/v1/chat/completions"
try:
# log the request whole info
logger.info(
"Calling chat completions API %s with model %s and size %s and payload %s",
url,
model_name,
payload_size_bytes,
payload,
)
response = await client.post(url, json=payload, timeout=timeout_seconds)
response.raise_for_status()
except httpx.HTTPError as exc:
error_name = exc.__class__.__name__
detail = str(exc).strip()
if detail:
message = f"Chat completions request failed ({error_name}): {detail}"
else:
message = f"Chat completions request failed ({error_name})."
raise ProviderAPICallError(message) from exc
try:
response_payload = response.json()
except ValueError as exc:
raise ProviderAPICallError("Chat completions response was not valid JSON.") from exc
return _parse_completion_payload(response_payload)
def _normalize_for_json(value: Any) -> Any:
if value is None or isinstance(value, (str, int, float, bool)):
return value
if isinstance(value, (datetime, date)):
return str(value)
if hasattr(value, "model_dump"):
try:
return value.model_dump()
except Exception: # pragma: no cover - defensive
pass
if is_dataclass(value):
return asdict(value)
if isinstance(value, dict):
return {k: _normalize_for_json(v) for k, v in value.items()}
if isinstance(value, (list, tuple, set)):
return [_normalize_for_json(v) for v in value]
if hasattr(value, "to_json_dict"):
try:
return value.to_json_dict()
except Exception: # pragma: no cover - defensive
pass
if hasattr(value, "__dict__"):
return _normalize_for_json(value.__dict__)
return repr(value)
def _json_dumps(data: Any) -> str:
normalised = _normalize_for_json(data)
return json.dumps(normalised, ensure_ascii=False, indent=2)
def _preview_for_log(data: Any) -> str:
try:
serialised = _json_dumps(data)
except Exception:
serialised = repr(data)
return serialised
def _profiling_request_for_log(request: TableProfilingJobRequest) -> Dict[str, Any]:
payload = request.model_dump()
access_info = payload.get("table_access_info")
if isinstance(access_info, dict):
payload["table_access_info"] = {key: "***" for key in access_info.keys()}
return payload
async def _execute_result_desc(
profiling_json: Dict[str, Any],
_request: TableProfilingJobRequest,
llm_model: str,
client: httpx.AsyncClient,
timeout_seconds: Optional[float],
) -> Dict[str, Any]:
system_prompt, user_prompt = _render_prompt(
"ge_result_desc",
{"{{GE_RESULT_JSON}}": _json_dumps(profiling_json)},
)
llm_output = await _call_chat_completions(
model_spec=llm_model,
system_prompt=system_prompt,
user_prompt=user_prompt,
client=client,
timeout_seconds=timeout_seconds,
)
if not isinstance(llm_output, dict):
raise ProviderAPICallError("GE result description payload must be a JSON object.")
return llm_output
async def _execute_snippet_generation(
table_desc_json: Dict[str, Any],
_request: TableProfilingJobRequest,
llm_model: str,
client: httpx.AsyncClient,
timeout_seconds: Optional[float],
) -> List[Dict[str, Any]]:
system_prompt, user_prompt = _render_prompt(
"snippet_generator",
{"{{TABLE_PROFILE_JSON}}": _json_dumps(table_desc_json)},
)
llm_output = await _call_chat_completions(
model_spec=llm_model,
system_prompt=system_prompt,
user_prompt=user_prompt,
client=client,
timeout_seconds=timeout_seconds,
)
if not isinstance(llm_output, list):
raise ProviderAPICallError("Snippet generator must return a JSON array.")
return llm_output
async def _execute_snippet_alias(
snippets_json: List[Dict[str, Any]],
_request: TableProfilingJobRequest,
llm_model: str,
client: httpx.AsyncClient,
timeout_seconds: Optional[float],
) -> List[Dict[str, Any]]:
system_prompt, user_prompt = _render_prompt(
"snippet_alias",
{"{{SNIPPET_ARRAY}}": _json_dumps(snippets_json)},
)
llm_output = await _call_chat_completions(
model_spec=llm_model,
system_prompt=system_prompt,
user_prompt=user_prompt,
client=client,
timeout_seconds=timeout_seconds,
)
if not isinstance(llm_output, list):
raise ProviderAPICallError("Snippet alias generator must return a JSON array.")
return llm_output
async def _run_action_with_callback(
*,
action_type: str,
runner,
callback_base: Dict[str, Any],
client: httpx.AsyncClient,
callback_url: str,
input_payload: Any = None,
model_spec: Optional[str] = None,
) -> Any:
if input_payload is not None:
logger.info(
"Pipeline action %s input: %s",
action_type,
_preview_for_log(input_payload),
)
try:
result = await runner()
except Exception as exc:
failure_payload = dict(callback_base)
failure_payload.update(
{
"status": "failed",
"action_type": action_type,
"error": str(exc),
}
)
if model_spec is not None:
failure_payload["model"] = model_spec
await _post_callback(callback_url, failure_payload, client)
raise
success_payload = dict(callback_base)
success_payload.update(
{
"status": "success",
"action_type": action_type,
}
)
if model_spec is not None:
success_payload["model"] = model_spec
logger.info(
"Pipeline action %s output: %s",
action_type,
_preview_for_log(result),
)
if action_type == PipelineActionType.GE_PROFILING:
artifacts: GEProfilingArtifacts = result
success_payload["profiling_json"] = artifacts.profiling_result
success_payload["profiling_summary"] = artifacts.profiling_summary
success_payload["ge_report_path"] = artifacts.docs_path
elif action_type == PipelineActionType.GE_RESULT_DESC:
success_payload["table_desc_json"] = result
elif action_type == PipelineActionType.SNIPPET:
success_payload["snippet_json"] = result
elif action_type == PipelineActionType.SNIPPET_ALIAS:
success_payload["snippet_alias_json"] = result
await _post_callback(callback_url, success_payload, client)
return result
async def process_table_profiling_job(
request: TableProfilingJobRequest,
_gateway: LLMGateway,
client: httpx.AsyncClient,
) -> None:
"""Sequentially execute the four-step profiling pipeline and emit callbacks per action."""
timeout_seconds = _extract_timeout_seconds(request.extra_options)
if timeout_seconds is None:
timeout_seconds = DEFAULT_CHAT_TIMEOUT_SECONDS
base_payload = {
"table_id": request.table_id,
"version_ts": request.version_ts,
"callback_url": str(request.callback_url),
"table_schema": request.table_schema,
"table_schema_version_id": request.table_schema_version_id,
"llm_model": request.llm_model,
"llm_timeout_seconds": timeout_seconds,
}
logging_request_payload = _profiling_request_for_log(request)
try:
artifacts: GEProfilingArtifacts = await _run_action_with_callback(
action_type=PipelineActionType.GE_PROFILING,
runner=lambda: _run_ge_profiling(request),
callback_base=base_payload,
client=client,
callback_url=str(request.callback_url),
input_payload=logging_request_payload,
model_spec=request.llm_model,
)
table_desc_json: Dict[str, Any] = await _run_action_with_callback(
action_type=PipelineActionType.GE_RESULT_DESC,
runner=lambda: _execute_result_desc(
artifacts.profiling_result,
request,
request.llm_model,
client,
timeout_seconds,
),
callback_base=base_payload,
client=client,
callback_url=str(request.callback_url),
input_payload=artifacts.profiling_result,
model_spec=request.llm_model,
)
snippet_json: List[Dict[str, Any]] = await _run_action_with_callback(
action_type=PipelineActionType.SNIPPET,
runner=lambda: _execute_snippet_generation(
table_desc_json,
request,
request.llm_model,
client,
timeout_seconds,
),
callback_base=base_payload,
client=client,
callback_url=str(request.callback_url),
input_payload=table_desc_json,
model_spec=request.llm_model,
)
await _run_action_with_callback(
action_type=PipelineActionType.SNIPPET_ALIAS,
runner=lambda: _execute_snippet_alias(
snippet_json,
request,
request.llm_model,
client,
timeout_seconds,
),
callback_base=base_payload,
client=client,
callback_url=str(request.callback_url),
input_payload=snippet_json,
model_spec=request.llm_model,
)
except Exception: # pragma: no cover - defensive catch
logger.exception(
"Table profiling pipeline failed for table_id=%s version_ts=%s",
request.table_id,
request.version_ts,
)