860 lines
29 KiB
Python
860 lines
29 KiB
Python
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import os
|
||
import re
|
||
from datetime import date, datetime
|
||
from dataclasses import asdict, dataclass, is_dataclass
|
||
from functools import lru_cache
|
||
from pathlib import Path
|
||
from typing import Any, Dict, List, Optional, Tuple
|
||
|
||
import httpx
|
||
import great_expectations as gx
|
||
from great_expectations.core.batch import RuntimeBatchRequest
|
||
from great_expectations.core.expectation_suite import ExpectationSuite
|
||
from great_expectations.data_context import AbstractDataContext
|
||
from great_expectations.exceptions import DataContextError, MetricResolutionError
|
||
|
||
from app.exceptions import ProviderAPICallError
|
||
from app.models import TableProfilingJobRequest
|
||
from app.services import LLMGateway
|
||
from app.settings import DEFAULT_IMPORT_MODEL
|
||
from app.services.import_analysis import (
|
||
IMPORT_GATEWAY_BASE_URL,
|
||
build_import_gateway_headers,
|
||
resolve_provider_from_model,
|
||
)
|
||
from app.utils.llm_usage import extract_usage as extract_llm_usage
|
||
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
GE_REPORT_RELATIVE_PATH = Path("uncommitted") / "data_docs" / "local_site" / "index.html"
|
||
PROMPT_FILENAMES = {
|
||
"ge_result_desc": "ge_result_desc_prompt.md",
|
||
"snippet_generator": "snippet_generator.md",
|
||
"snippet_alias": "snippet_alias_generator.md",
|
||
}
|
||
DEFAULT_CHAT_TIMEOUT_SECONDS = 180.0
|
||
|
||
|
||
@dataclass
|
||
class GEProfilingArtifacts:
|
||
profiling_result: Dict[str, Any]
|
||
profiling_summary: Dict[str, Any]
|
||
docs_path: str
|
||
|
||
|
||
@dataclass
|
||
class LLMCallResult:
|
||
data: Any
|
||
usage: Optional[Dict[str, Any]] = None
|
||
|
||
|
||
class PipelineActionType:
|
||
GE_PROFILING = "ge_profiling"
|
||
GE_RESULT_DESC = "ge_result_desc"
|
||
SNIPPET = "snippet"
|
||
SNIPPET_ALIAS = "snippet_alias"
|
||
|
||
|
||
def _project_root() -> Path:
|
||
return Path(__file__).resolve().parents[2]
|
||
|
||
|
||
def _prompt_dir() -> Path:
|
||
return _project_root() / "prompt"
|
||
|
||
|
||
@lru_cache(maxsize=None)
|
||
def _load_prompt_parts(filename: str) -> Tuple[str, str]:
|
||
prompt_path = _prompt_dir() / filename
|
||
if not prompt_path.exists():
|
||
raise FileNotFoundError(f"Prompt template not found: {prompt_path}")
|
||
|
||
raw = prompt_path.read_text(encoding="utf-8")
|
||
splitter = "用户消息(User)"
|
||
if splitter not in raw:
|
||
raise ValueError(f"Prompt template '{filename}' missing separator '{splitter}'.")
|
||
|
||
system_raw, user_raw = raw.split(splitter, maxsplit=1)
|
||
system_text = system_raw.replace("系统角色(System)", "").strip()
|
||
user_text = user_raw.strip()
|
||
return system_text, user_text
|
||
|
||
|
||
def _render_prompt(template_key: str, replacements: Dict[str, str]) -> Tuple[str, str]:
|
||
filename = PROMPT_FILENAMES[template_key]
|
||
system_text, user_template = _load_prompt_parts(filename)
|
||
|
||
rendered_user = user_template
|
||
for key, value in replacements.items():
|
||
rendered_user = rendered_user.replace(key, value)
|
||
|
||
return system_text, rendered_user
|
||
|
||
|
||
def _extract_timeout_seconds(options: Optional[Dict[str, Any]]) -> Optional[float]:
|
||
if not options:
|
||
return None
|
||
value = options.get("llm_timeout_seconds")
|
||
if value is None:
|
||
return None
|
||
try:
|
||
timeout = float(value)
|
||
if timeout <= 0:
|
||
raise ValueError
|
||
return timeout
|
||
except (TypeError, ValueError):
|
||
logger.warning(
|
||
"Invalid llm_timeout_seconds value in extra_options: %r. Falling back to default.",
|
||
value,
|
||
)
|
||
return DEFAULT_CHAT_TIMEOUT_SECONDS
|
||
|
||
|
||
def _extract_json_payload(content: str) -> str:
|
||
fenced = re.search(
|
||
r"```(?:json)?\s*([\s\S]+?)```",
|
||
content,
|
||
flags=re.IGNORECASE,
|
||
)
|
||
if fenced:
|
||
snippet = fenced.group(1).strip()
|
||
if snippet:
|
||
return snippet
|
||
|
||
stripped = content.strip()
|
||
if not stripped:
|
||
raise ValueError("Empty LLM content.")
|
||
|
||
decoder = json.JSONDecoder()
|
||
for idx, char in enumerate(stripped):
|
||
if char not in {"{", "["}:
|
||
continue
|
||
try:
|
||
_, end = decoder.raw_decode(stripped[idx:])
|
||
except json.JSONDecodeError:
|
||
continue
|
||
candidate = stripped[idx : idx + end].strip()
|
||
if candidate:
|
||
return candidate
|
||
|
||
return stripped
|
||
|
||
|
||
def _parse_completion_payload(response_payload: Dict[str, Any]) -> Any:
|
||
choices = response_payload.get("choices") or []
|
||
if not choices:
|
||
raise ProviderAPICallError("LLM response did not contain choices to parse.")
|
||
message = choices[0].get("message") or {}
|
||
content = message.get("content") or ""
|
||
if not content.strip():
|
||
raise ProviderAPICallError("LLM response content is empty.")
|
||
json_payload = _extract_json_payload(content)
|
||
try:
|
||
return json.loads(json_payload)
|
||
except json.JSONDecodeError as exc:
|
||
preview = json_payload[:800]
|
||
logger.error("Failed to parse JSON from LLM response: %s", preview, exc_info=True)
|
||
raise ProviderAPICallError("LLM response JSON parsing failed.") from exc
|
||
|
||
|
||
async def _post_callback(callback_url: str, payload: Dict[str, Any], client: httpx.AsyncClient) -> None:
|
||
safe_payload = _normalize_for_json(payload)
|
||
try:
|
||
logger.info(
|
||
"Posting pipeline action callback to %s: %s",
|
||
callback_url,
|
||
json.dumps(safe_payload, ensure_ascii=False),
|
||
)
|
||
response = await client.post(callback_url, json=safe_payload)
|
||
response.raise_for_status()
|
||
except httpx.HTTPError as exc:
|
||
logger.error("Callback delivery to %s failed: %s", callback_url, exc, exc_info=True)
|
||
|
||
|
||
def _sanitize_value_set(value: Any, max_values: int) -> Tuple[Any, Optional[Dict[str, int]]]:
|
||
if not isinstance(value, list):
|
||
return value, None
|
||
original_len = len(value)
|
||
if original_len <= max_values:
|
||
return value, None
|
||
trimmed = value[:max_values]
|
||
return trimmed, {"original_length": original_len, "retained": max_values}
|
||
|
||
|
||
def _sanitize_expectation_suite(suite: ExpectationSuite, max_value_set_values: int = 100) -> Dict[str, Any]:
|
||
suite_dict = suite.to_json_dict()
|
||
remarks: List[Dict[str, Any]] = []
|
||
|
||
for expectation in suite_dict.get("expectations", []):
|
||
kwargs = expectation.get("kwargs", {})
|
||
if "value_set" in kwargs:
|
||
sanitized_value, note = _sanitize_value_set(kwargs["value_set"], max_value_set_values)
|
||
kwargs["value_set"] = sanitized_value
|
||
if note:
|
||
expectation.setdefault("meta", {})
|
||
expectation["meta"]["value_set_truncated"] = note
|
||
remarks.append(
|
||
{
|
||
"column": kwargs.get("column"),
|
||
"expectation": expectation.get("expectation_type"),
|
||
"note": note,
|
||
}
|
||
)
|
||
|
||
if remarks:
|
||
suite_dict.setdefault("meta", {})
|
||
suite_dict["meta"]["value_set_truncations"] = remarks
|
||
|
||
return suite_dict
|
||
|
||
|
||
def _summarize_expectation_suite(suite_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||
column_map: Dict[str, Dict[str, Any]] = {}
|
||
table_expectations: List[Dict[str, Any]] = []
|
||
|
||
for expectation in suite_dict.get("expectations", []):
|
||
expectation_type = expectation.get("expectation_type")
|
||
kwargs = expectation.get("kwargs", {})
|
||
column = kwargs.get("column")
|
||
summary_entry: Dict[str, Any] = {"expectation": expectation_type}
|
||
|
||
if "value_set" in kwargs and isinstance(kwargs["value_set"], list):
|
||
summary_entry["value_set_size"] = len(kwargs["value_set"])
|
||
summary_entry["value_set_preview"] = kwargs["value_set"][:5]
|
||
|
||
if column:
|
||
column_entry = column_map.setdefault(
|
||
column,
|
||
{"name": column, "expectations": []},
|
||
)
|
||
column_entry["expectations"].append(summary_entry)
|
||
else:
|
||
table_expectations.append(summary_entry)
|
||
|
||
summary = {
|
||
"column_profiles": list(column_map.values()),
|
||
"table_level_expectations": table_expectations,
|
||
"total_expectations": len(suite_dict.get("expectations", [])),
|
||
}
|
||
return summary
|
||
|
||
|
||
def _sanitize_identifier(raw: Optional[str], fallback: str) -> str:
|
||
if not raw:
|
||
return fallback
|
||
candidate = re.sub(r"[^0-9A-Za-z_]+", "_", raw).strip("_")
|
||
return candidate or fallback
|
||
|
||
|
||
def _format_connection_string(template: str, access_info: Dict[str, Any]) -> str:
|
||
if not access_info:
|
||
return template
|
||
try:
|
||
return template.format_map({k: v for k, v in access_info.items()})
|
||
except KeyError as exc:
|
||
missing = exc.args[0]
|
||
raise ValueError(f"table_access_info missing key '{missing}' required by connection_string.") from exc
|
||
|
||
|
||
def _ensure_sql_runtime_datasource(
|
||
context: AbstractDataContext,
|
||
datasource_name: str,
|
||
connection_string: str,
|
||
) -> None:
|
||
try:
|
||
datasource = context.get_datasource(datasource_name)
|
||
except (DataContextError, ValueError) as exc:
|
||
message = str(exc)
|
||
if "Could not find a datasource" in message or "Unable to load datasource" in message:
|
||
datasource = None
|
||
else: # pragma: no cover - defensive
|
||
raise RuntimeError(f"Failed to inspect datasource '{datasource_name}'.") from exc
|
||
except Exception as exc: # pragma: no cover - defensive
|
||
raise RuntimeError(f"Failed to inspect datasource '{datasource_name}'.") from exc
|
||
|
||
if datasource is not None:
|
||
execution_engine = getattr(datasource, "execution_engine", None)
|
||
current_conn = getattr(execution_engine, "connection_string", None)
|
||
if current_conn and current_conn != connection_string:
|
||
logger.info(
|
||
"Existing datasource %s uses different connection string; creating dedicated runtime datasource.",
|
||
datasource_name,
|
||
)
|
||
try:
|
||
context.delete_datasource(datasource_name)
|
||
except Exception as exc: # pragma: no cover - defensive
|
||
logger.warning(
|
||
"Failed to delete datasource %s before recreation: %s",
|
||
datasource_name,
|
||
exc,
|
||
)
|
||
else:
|
||
datasource = None
|
||
|
||
if datasource is not None:
|
||
return
|
||
|
||
runtime_datasource_config = {
|
||
"name": datasource_name,
|
||
"class_name": "Datasource",
|
||
"execution_engine": {
|
||
"class_name": "SqlAlchemyExecutionEngine",
|
||
"connection_string": connection_string,
|
||
},
|
||
"data_connectors": {
|
||
"runtime_connector": {
|
||
"class_name": "RuntimeDataConnector",
|
||
"batch_identifiers": ["default_identifier_name"],
|
||
}
|
||
},
|
||
}
|
||
try:
|
||
context.add_datasource(**runtime_datasource_config)
|
||
except Exception as exc: # pragma: no cover - defensive
|
||
raise RuntimeError(f"Failed to create runtime datasource '{datasource_name}'.") from exc
|
||
|
||
|
||
def _build_sql_runtime_batch_request(
|
||
context: AbstractDataContext,
|
||
request: TableProfilingJobRequest,
|
||
) -> RuntimeBatchRequest:
|
||
link_info = request.table_link_info or {}
|
||
access_info = request.table_access_info or {}
|
||
|
||
connection_template = link_info.get("connection_string")
|
||
if not connection_template:
|
||
raise ValueError("table_link_info.connection_string is required when using table_link_info.")
|
||
|
||
connection_string = _format_connection_string(connection_template, access_info)
|
||
|
||
source_type = (link_info.get("type") or "sql").lower()
|
||
if source_type != "sql":
|
||
raise ValueError(f"Unsupported table_link_info.type='{source_type}'. Only 'sql' is supported.")
|
||
|
||
query = link_info.get("query")
|
||
table_name = link_info.get("table") or link_info.get("table_name")
|
||
schema_name = link_info.get("schema")
|
||
|
||
if not query and not table_name:
|
||
raise ValueError("Either table_link_info.query or table_link_info.table must be provided.")
|
||
|
||
if not query:
|
||
if not table_name:
|
||
raise ValueError("table_link_info.table must be provided when query is omitted.")
|
||
|
||
identifier = re.compile(r"^[A-Za-z_][A-Za-z0-9_$]*$")
|
||
|
||
def _quote(name: str) -> str:
|
||
if identifier.match(name):
|
||
return name
|
||
return f"`{name.replace('`', '``')}`"
|
||
|
||
if schema_name:
|
||
schema_part = schema_name if "." not in schema_name else schema_name.split(".")[-1]
|
||
table_part = table_name if "." not in table_name else table_name.split(".")[-1]
|
||
qualified_table = f"{_quote(schema_part)}.{_quote(table_part)}"
|
||
else:
|
||
qualified_table = _quote(table_name)
|
||
|
||
query = f"SELECT * FROM {qualified_table}"
|
||
limit = link_info.get("limit")
|
||
if isinstance(limit, int) and limit > 0:
|
||
query = f"{query} LIMIT {limit}"
|
||
|
||
datasource_name = request.ge_datasource_name or _sanitize_identifier(
|
||
f"{request.table_id}_runtime_ds", "runtime_ds"
|
||
)
|
||
data_asset_name = request.ge_data_asset_name or _sanitize_identifier(
|
||
table_name or "runtime_query", "runtime_query"
|
||
)
|
||
|
||
_ensure_sql_runtime_datasource(context, datasource_name, connection_string)
|
||
|
||
batch_identifiers = {
|
||
"default_identifier_name": f"{request.table_id}:{request.version_ts}",
|
||
}
|
||
|
||
return RuntimeBatchRequest(
|
||
datasource_name=datasource_name,
|
||
data_connector_name="runtime_connector",
|
||
data_asset_name=data_asset_name,
|
||
runtime_parameters={"query": query},
|
||
batch_identifiers=batch_identifiers,
|
||
)
|
||
|
||
|
||
def _run_onboarding_assistant(
|
||
context: AbstractDataContext,
|
||
batch_request: Any,
|
||
suite_name: str,
|
||
) -> Tuple[ExpectationSuite, Any]:
|
||
assistant = context.assistants.onboarding
|
||
assistant_result = assistant.run(batch_request=batch_request)
|
||
suite = assistant_result.get_expectation_suite(expectation_suite_name=suite_name)
|
||
context.save_expectation_suite(suite, expectation_suite_name=suite_name)
|
||
validation_getter = getattr(assistant_result, "get_validation_result", None)
|
||
if callable(validation_getter):
|
||
validation_result = validation_getter()
|
||
else:
|
||
validation_result = getattr(assistant_result, "validation_result", None)
|
||
if validation_result is None:
|
||
# Fallback: rerun validation using the freshly generated expectation suite.
|
||
validator = context.get_validator(
|
||
batch_request=batch_request,
|
||
expectation_suite_name=suite_name,
|
||
)
|
||
validation_result = validator.validate()
|
||
return suite, validation_result
|
||
|
||
|
||
def _resolve_context(request: TableProfilingJobRequest) -> AbstractDataContext:
|
||
context_kwargs: Dict[str, Any] = {}
|
||
if request.ge_data_context_root:
|
||
context_kwargs["project_root_dir"] = request.ge_data_context_root
|
||
elif os.environ.get("GE_DATA_CONTEXT_ROOT"):
|
||
context_kwargs["project_root_dir"] = os.environ["GE_DATA_CONTEXT_ROOT"]
|
||
else:
|
||
context_kwargs["project_root_dir"] = str(_project_root())
|
||
|
||
return gx.get_context(**context_kwargs)
|
||
|
||
|
||
def _build_batch_request(
|
||
context: AbstractDataContext,
|
||
request: TableProfilingJobRequest,
|
||
) -> Any:
|
||
if request.ge_batch_request:
|
||
from great_expectations.core.batch import BatchRequest
|
||
|
||
return BatchRequest(**request.ge_batch_request)
|
||
|
||
if request.table_link_info:
|
||
return _build_sql_runtime_batch_request(context, request)
|
||
|
||
if not request.ge_datasource_name or not request.ge_data_asset_name:
|
||
raise ValueError(
|
||
"ge_batch_request or (ge_datasource_name and ge_data_asset_name) must be provided."
|
||
)
|
||
|
||
datasource = context.get_datasource(request.ge_datasource_name)
|
||
data_asset = datasource.get_asset(request.ge_data_asset_name)
|
||
return data_asset.build_batch_request()
|
||
|
||
|
||
async def _run_ge_profiling(request: TableProfilingJobRequest) -> GEProfilingArtifacts:
|
||
def _execute() -> GEProfilingArtifacts:
|
||
context = _resolve_context(request)
|
||
suite_name = (
|
||
request.ge_expectation_suite_name
|
||
or f"{request.table_id}_profiling"
|
||
)
|
||
|
||
batch_request = _build_batch_request(context, request)
|
||
try:
|
||
context.get_expectation_suite(suite_name)
|
||
except DataContextError:
|
||
context.add_expectation_suite(suite_name)
|
||
|
||
validator = context.get_validator(
|
||
batch_request=batch_request,
|
||
expectation_suite_name=suite_name,
|
||
)
|
||
|
||
profiler_type = (request.ge_profiler_type or "user_configurable").lower()
|
||
|
||
if profiler_type == "data_assistant":
|
||
suite, validation_result = _run_onboarding_assistant(
|
||
context,
|
||
batch_request,
|
||
suite_name,
|
||
)
|
||
else:
|
||
try:
|
||
from great_expectations.profile.user_configurable_profiler import (
|
||
UserConfigurableProfiler,
|
||
)
|
||
except ImportError as err: # pragma: no cover - dependency guard
|
||
raise RuntimeError(
|
||
"UserConfigurableProfiler is unavailable; install great_expectations profiling extra or switch profiler."
|
||
) from err
|
||
|
||
profiler = UserConfigurableProfiler(profile_dataset=validator)
|
||
try:
|
||
suite = profiler.build_suite()
|
||
context.save_expectation_suite(suite, expectation_suite_name=suite_name)
|
||
validator.expectation_suite = suite
|
||
validation_result = validator.validate()
|
||
except MetricResolutionError as exc:
|
||
logger.warning(
|
||
"UserConfigurableProfiler failed (%s); falling back to data assistant profiling.",
|
||
exc,
|
||
)
|
||
suite, validation_result = _run_onboarding_assistant(
|
||
context,
|
||
batch_request,
|
||
suite_name,
|
||
)
|
||
|
||
sanitized_suite = _sanitize_expectation_suite(suite)
|
||
summary = _summarize_expectation_suite(sanitized_suite)
|
||
validation_dict = validation_result.to_json_dict()
|
||
|
||
context.build_data_docs()
|
||
docs_path = Path(context.root_directory) / GE_REPORT_RELATIVE_PATH
|
||
|
||
profiling_result = {
|
||
"expectation_suite": sanitized_suite,
|
||
"validation_result": validation_dict,
|
||
"batch_request": getattr(batch_request, "to_json_dict", lambda: None)() or getattr(batch_request, "dict", lambda: None)(),
|
||
}
|
||
|
||
return GEProfilingArtifacts(
|
||
profiling_result=profiling_result,
|
||
profiling_summary=summary,
|
||
docs_path=str(docs_path),
|
||
)
|
||
|
||
return await asyncio.to_thread(_execute)
|
||
|
||
|
||
async def _call_chat_completions(
|
||
*,
|
||
model_spec: str,
|
||
system_prompt: str,
|
||
user_prompt: str,
|
||
client: httpx.AsyncClient,
|
||
temperature: float = 0.2,
|
||
timeout_seconds: Optional[float] = None,
|
||
) -> Any:
|
||
# Normalize model spec to provider+model and issue the unified chat call.
|
||
provider, model_name = resolve_provider_from_model(model_spec)
|
||
payload = {
|
||
"provider": provider.value,
|
||
"model": model_name,
|
||
"messages": [
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_prompt},
|
||
],
|
||
"temperature": temperature,
|
||
}
|
||
payload_size_bytes = len(json.dumps(payload, ensure_ascii=False).encode("utf-8"))
|
||
|
||
url = f"{IMPORT_GATEWAY_BASE_URL.rstrip('/')}/v1/chat/completions"
|
||
headers = build_import_gateway_headers()
|
||
try:
|
||
logger.info(
|
||
"Calling chat completions API %s with model=%s payload_size=%sB",
|
||
url,
|
||
model_name,
|
||
payload_size_bytes,
|
||
)
|
||
response = await client.post(
|
||
url, json=payload, timeout=timeout_seconds, headers=headers
|
||
)
|
||
|
||
response.raise_for_status()
|
||
except httpx.HTTPError as exc:
|
||
error_name = exc.__class__.__name__
|
||
detail = str(exc).strip()
|
||
if detail:
|
||
message = f"Chat completions request failed ({error_name}): {detail}"
|
||
else:
|
||
message = f"Chat completions request failed ({error_name})."
|
||
raise ProviderAPICallError(message) from exc
|
||
|
||
try:
|
||
response_payload = response.json()
|
||
except ValueError as exc:
|
||
raise ProviderAPICallError("Chat completions response was not valid JSON.") from exc
|
||
|
||
parsed_payload = _parse_completion_payload(response_payload)
|
||
usage_info = extract_llm_usage(response_payload)
|
||
return LLMCallResult(data=parsed_payload, usage=usage_info)
|
||
|
||
|
||
def _normalize_for_json(value: Any) -> Any:
|
||
if value is None or isinstance(value, (str, int, float, bool)):
|
||
return value
|
||
if isinstance(value, (datetime, date)):
|
||
return str(value)
|
||
if hasattr(value, "model_dump"):
|
||
try:
|
||
return value.model_dump()
|
||
except Exception: # pragma: no cover - defensive
|
||
pass
|
||
if is_dataclass(value):
|
||
return asdict(value)
|
||
if isinstance(value, dict):
|
||
return {k: _normalize_for_json(v) for k, v in value.items()}
|
||
if isinstance(value, (list, tuple, set)):
|
||
return [_normalize_for_json(v) for v in value]
|
||
if hasattr(value, "to_json_dict"):
|
||
try:
|
||
return value.to_json_dict()
|
||
except Exception: # pragma: no cover - defensive
|
||
pass
|
||
if hasattr(value, "__dict__"):
|
||
return _normalize_for_json(value.__dict__)
|
||
return repr(value)
|
||
|
||
|
||
def _json_dumps(data: Any) -> str:
|
||
normalised = _normalize_for_json(data)
|
||
return json.dumps(normalised, ensure_ascii=False, indent=2)
|
||
|
||
|
||
def _preview_for_log(data: Any) -> str:
|
||
try:
|
||
serialised = _json_dumps(data)
|
||
except Exception:
|
||
serialised = repr(data)
|
||
|
||
return serialised
|
||
|
||
|
||
def _profiling_request_for_log(request: TableProfilingJobRequest) -> Dict[str, Any]:
|
||
payload = request.model_dump()
|
||
access_info = payload.get("table_access_info")
|
||
if isinstance(access_info, dict):
|
||
payload["table_access_info"] = {key: "***" for key in access_info.keys()}
|
||
return payload
|
||
|
||
|
||
async def _execute_result_desc(
|
||
profiling_json: Dict[str, Any],
|
||
_request: TableProfilingJobRequest,
|
||
llm_model: str,
|
||
client: httpx.AsyncClient,
|
||
timeout_seconds: Optional[float],
|
||
) -> Dict[str, Any]:
|
||
system_prompt, user_prompt = _render_prompt(
|
||
"ge_result_desc",
|
||
{"{{GE_RESULT_JSON}}": _json_dumps(profiling_json)},
|
||
)
|
||
llm_output = await _call_chat_completions(
|
||
model_spec=llm_model,
|
||
system_prompt=system_prompt,
|
||
user_prompt=user_prompt,
|
||
client=client,
|
||
timeout_seconds=timeout_seconds,
|
||
)
|
||
if not isinstance(llm_output.data, dict):
|
||
raise ProviderAPICallError("GE result description payload must be a JSON object.")
|
||
return llm_output
|
||
|
||
|
||
async def _execute_snippet_generation(
|
||
table_desc_json: Dict[str, Any],
|
||
_request: TableProfilingJobRequest,
|
||
llm_model: str,
|
||
client: httpx.AsyncClient,
|
||
timeout_seconds: Optional[float],
|
||
) -> List[Dict[str, Any]]:
|
||
system_prompt, user_prompt = _render_prompt(
|
||
"snippet_generator",
|
||
{"{{TABLE_PROFILE_JSON}}": _json_dumps(table_desc_json)},
|
||
)
|
||
llm_output = await _call_chat_completions(
|
||
model_spec=llm_model,
|
||
system_prompt=system_prompt,
|
||
user_prompt=user_prompt,
|
||
client=client,
|
||
timeout_seconds=timeout_seconds,
|
||
)
|
||
if not isinstance(llm_output.data, list):
|
||
raise ProviderAPICallError("Snippet generator must return a JSON array.")
|
||
return llm_output
|
||
|
||
|
||
async def _execute_snippet_alias(
|
||
snippets_json: List[Dict[str, Any]],
|
||
_request: TableProfilingJobRequest,
|
||
llm_model: str,
|
||
client: httpx.AsyncClient,
|
||
timeout_seconds: Optional[float],
|
||
) -> List[Dict[str, Any]]:
|
||
system_prompt, user_prompt = _render_prompt(
|
||
"snippet_alias",
|
||
{"{{SNIPPET_ARRAY}}": _json_dumps(snippets_json)},
|
||
)
|
||
llm_output = await _call_chat_completions(
|
||
model_spec=llm_model,
|
||
system_prompt=system_prompt,
|
||
user_prompt=user_prompt,
|
||
client=client,
|
||
timeout_seconds=timeout_seconds,
|
||
)
|
||
if not isinstance(llm_output.data, list):
|
||
raise ProviderAPICallError("Snippet alias generator must return a JSON array.")
|
||
return llm_output
|
||
|
||
|
||
async def _run_action_with_callback(
|
||
*,
|
||
action_type: str,
|
||
runner,
|
||
callback_base: Dict[str, Any],
|
||
client: httpx.AsyncClient,
|
||
callback_url: str,
|
||
input_payload: Any = None,
|
||
model_spec: Optional[str] = None,
|
||
) -> Any:
|
||
# Execute a pipeline action and always emit a callback capturing success/failure.
|
||
if input_payload is not None:
|
||
logger.info(
|
||
"Pipeline action %s input: %s",
|
||
action_type,
|
||
_preview_for_log(input_payload),
|
||
)
|
||
try:
|
||
result = await runner()
|
||
except Exception as exc:
|
||
failure_payload = dict(callback_base)
|
||
failure_payload.update(
|
||
{
|
||
"status": "failed",
|
||
"action_type": action_type,
|
||
"error": str(exc),
|
||
}
|
||
)
|
||
if model_spec is not None:
|
||
failure_payload["model"] = model_spec
|
||
await _post_callback(callback_url, failure_payload, client)
|
||
raise
|
||
|
||
usage_info: Optional[Dict[str, Any]] = None
|
||
result_payload = result
|
||
if isinstance(result, LLMCallResult):
|
||
usage_info = result.usage
|
||
result_payload = result.data
|
||
|
||
success_payload = dict(callback_base)
|
||
success_payload.update(
|
||
{
|
||
"status": "success",
|
||
"action_type": action_type,
|
||
}
|
||
)
|
||
if model_spec is not None:
|
||
success_payload["model"] = model_spec
|
||
|
||
logger.info(
|
||
"Pipeline action %s output: %s",
|
||
action_type,
|
||
_preview_for_log(result_payload),
|
||
)
|
||
|
||
if action_type == PipelineActionType.GE_PROFILING:
|
||
artifacts: GEProfilingArtifacts = result_payload
|
||
success_payload["ge_profiling_json"] = artifacts.profiling_result
|
||
success_payload["ge_profiling_summary"] = artifacts.profiling_summary
|
||
success_payload["ge_report_path"] = artifacts.docs_path
|
||
elif action_type == PipelineActionType.GE_RESULT_DESC:
|
||
success_payload["ge_result_desc_json"] = result_payload
|
||
elif action_type == PipelineActionType.SNIPPET:
|
||
success_payload["snippet_json"] = result_payload
|
||
elif action_type == PipelineActionType.SNIPPET_ALIAS:
|
||
success_payload["snippet_alias_json"] = result_payload
|
||
|
||
if usage_info:
|
||
success_payload["llm_usage"] = usage_info
|
||
|
||
await _post_callback(callback_url, success_payload, client)
|
||
return result_payload
|
||
|
||
|
||
async def process_table_profiling_job(
|
||
request: TableProfilingJobRequest,
|
||
_gateway: LLMGateway,
|
||
client: httpx.AsyncClient,
|
||
) -> None:
|
||
"""Sequentially execute the four-step profiling pipeline and emit callbacks per action."""
|
||
|
||
timeout_seconds = _extract_timeout_seconds(request.extra_options)
|
||
if timeout_seconds is None:
|
||
timeout_seconds = DEFAULT_CHAT_TIMEOUT_SECONDS
|
||
|
||
base_payload = {
|
||
"table_id": request.table_id,
|
||
"version_ts": request.version_ts,
|
||
"callback_url": str(request.callback_url),
|
||
"table_schema": request.table_schema,
|
||
"table_schema_version_id": request.table_schema_version_id,
|
||
"llm_model": request.llm_model,
|
||
"llm_timeout_seconds": timeout_seconds,
|
||
}
|
||
|
||
logging_request_payload = _profiling_request_for_log(request)
|
||
|
||
try:
|
||
artifacts: GEProfilingArtifacts = await _run_action_with_callback(
|
||
action_type=PipelineActionType.GE_PROFILING,
|
||
runner=lambda: _run_ge_profiling(request),
|
||
callback_base=base_payload,
|
||
client=client,
|
||
callback_url=str(request.callback_url),
|
||
input_payload=logging_request_payload,
|
||
model_spec=request.llm_model,
|
||
)
|
||
|
||
table_desc_json: Dict[str, Any] = await _run_action_with_callback(
|
||
action_type=PipelineActionType.GE_RESULT_DESC,
|
||
runner=lambda: _execute_result_desc(
|
||
artifacts.profiling_result,
|
||
request,
|
||
request.llm_model,
|
||
client,
|
||
timeout_seconds,
|
||
),
|
||
callback_base=base_payload,
|
||
client=client,
|
||
callback_url=str(request.callback_url),
|
||
input_payload=artifacts.profiling_result,
|
||
model_spec=request.llm_model,
|
||
)
|
||
|
||
snippet_json: List[Dict[str, Any]] = await _run_action_with_callback(
|
||
action_type=PipelineActionType.SNIPPET,
|
||
runner=lambda: _execute_snippet_generation(
|
||
table_desc_json,
|
||
request,
|
||
request.llm_model,
|
||
client,
|
||
timeout_seconds,
|
||
),
|
||
callback_base=base_payload,
|
||
client=client,
|
||
callback_url=str(request.callback_url),
|
||
input_payload=table_desc_json,
|
||
model_spec=request.llm_model,
|
||
)
|
||
|
||
await _run_action_with_callback(
|
||
action_type=PipelineActionType.SNIPPET_ALIAS,
|
||
runner=lambda: _execute_snippet_alias(
|
||
snippet_json,
|
||
request,
|
||
request.llm_model,
|
||
client,
|
||
timeout_seconds,
|
||
),
|
||
callback_base=base_payload,
|
||
client=client,
|
||
callback_url=str(request.callback_url),
|
||
input_payload=snippet_json,
|
||
model_spec=request.llm_model,
|
||
)
|
||
except Exception: # pragma: no cover - defensive catch
|
||
logger.exception(
|
||
"Table profiling pipeline failed for table_id=%s version_ts=%s",
|
||
request.table_id,
|
||
request.version_ts,
|
||
)
|