from __future__ import annotations import asyncio import json import logging import os import re from datetime import date, datetime from dataclasses import asdict, dataclass, is_dataclass from functools import lru_cache from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import httpx import great_expectations as gx from great_expectations.core.batch import RuntimeBatchRequest from great_expectations.core.expectation_suite import ExpectationSuite from great_expectations.data_context import AbstractDataContext from great_expectations.exceptions import DataContextError, MetricResolutionError from app.exceptions import ProviderAPICallError from app.models import TableProfilingJobRequest from app.services import LLMGateway from app.settings import DEFAULT_IMPORT_MODEL from app.services.import_analysis import ( IMPORT_GATEWAY_BASE_URL, build_import_gateway_headers, resolve_provider_from_model, ) from app.utils.llm_usage import extract_usage as extract_llm_usage logger = logging.getLogger(__name__) GE_REPORT_RELATIVE_PATH = Path("uncommitted") / "data_docs" / "local_site" / "index.html" PROMPT_FILENAMES = { "ge_result_desc": "ge_result_desc_prompt.md", "snippet_generator": "snippet_generator.md", "snippet_alias": "snippet_alias_generator.md", } DEFAULT_CHAT_TIMEOUT_SECONDS = 180.0 @dataclass class GEProfilingArtifacts: profiling_result: Dict[str, Any] profiling_summary: Dict[str, Any] docs_path: str @dataclass class LLMCallResult: data: Any usage: Optional[Dict[str, Any]] = None class PipelineActionType: GE_PROFILING = "ge_profiling" GE_RESULT_DESC = "ge_result_desc" SNIPPET = "snippet" SNIPPET_ALIAS = "snippet_alias" def _project_root() -> Path: return Path(__file__).resolve().parents[2] def _prompt_dir() -> Path: return _project_root() / "prompt" @lru_cache(maxsize=None) def _load_prompt_parts(filename: str) -> Tuple[str, str]: prompt_path = _prompt_dir() / filename if not prompt_path.exists(): raise FileNotFoundError(f"Prompt template not found: {prompt_path}") raw = prompt_path.read_text(encoding="utf-8") splitter = "用户消息(User)" if splitter not in raw: raise ValueError(f"Prompt template '{filename}' missing separator '{splitter}'.") system_raw, user_raw = raw.split(splitter, maxsplit=1) system_text = system_raw.replace("系统角色(System)", "").strip() user_text = user_raw.strip() return system_text, user_text def _render_prompt(template_key: str, replacements: Dict[str, str]) -> Tuple[str, str]: filename = PROMPT_FILENAMES[template_key] system_text, user_template = _load_prompt_parts(filename) rendered_user = user_template for key, value in replacements.items(): rendered_user = rendered_user.replace(key, value) return system_text, rendered_user def _extract_timeout_seconds(options: Optional[Dict[str, Any]]) -> Optional[float]: if not options: return None value = options.get("llm_timeout_seconds") if value is None: return None try: timeout = float(value) if timeout <= 0: raise ValueError return timeout except (TypeError, ValueError): logger.warning( "Invalid llm_timeout_seconds value in extra_options: %r. Falling back to default.", value, ) return DEFAULT_CHAT_TIMEOUT_SECONDS def _extract_json_payload(content: str) -> str: fenced = re.search( r"```(?:json)?\s*([\s\S]+?)```", content, flags=re.IGNORECASE, ) if fenced: snippet = fenced.group(1).strip() if snippet: return snippet stripped = content.strip() if not stripped: raise ValueError("Empty LLM content.") decoder = json.JSONDecoder() for idx, char in enumerate(stripped): if char not in {"{", "["}: continue try: _, end = decoder.raw_decode(stripped[idx:]) except json.JSONDecodeError: continue candidate = stripped[idx : idx + end].strip() if candidate: return candidate return stripped def _parse_completion_payload(response_payload: Dict[str, Any]) -> Any: choices = response_payload.get("choices") or [] if not choices: raise ProviderAPICallError("LLM response did not contain choices to parse.") message = choices[0].get("message") or {} content = message.get("content") or "" if not content.strip(): raise ProviderAPICallError("LLM response content is empty.") json_payload = _extract_json_payload(content) try: return json.loads(json_payload) except json.JSONDecodeError as exc: preview = json_payload[:800] logger.error("Failed to parse JSON from LLM response: %s", preview, exc_info=True) raise ProviderAPICallError("LLM response JSON parsing failed.") from exc async def _post_callback(callback_url: str, payload: Dict[str, Any], client: httpx.AsyncClient) -> None: safe_payload = _normalize_for_json(payload) try: logger.info( "Posting pipeline action callback to %s: %s", callback_url, json.dumps(safe_payload, ensure_ascii=False), ) response = await client.post(callback_url, json=safe_payload) response.raise_for_status() except httpx.HTTPError as exc: logger.error("Callback delivery to %s failed: %s", callback_url, exc, exc_info=True) def _sanitize_value_set(value: Any, max_values: int) -> Tuple[Any, Optional[Dict[str, int]]]: if not isinstance(value, list): return value, None original_len = len(value) if original_len <= max_values: return value, None trimmed = value[:max_values] return trimmed, {"original_length": original_len, "retained": max_values} def _sanitize_expectation_suite(suite: ExpectationSuite, max_value_set_values: int = 100) -> Dict[str, Any]: suite_dict = suite.to_json_dict() remarks: List[Dict[str, Any]] = [] for expectation in suite_dict.get("expectations", []): kwargs = expectation.get("kwargs", {}) if "value_set" in kwargs: sanitized_value, note = _sanitize_value_set(kwargs["value_set"], max_value_set_values) kwargs["value_set"] = sanitized_value if note: expectation.setdefault("meta", {}) expectation["meta"]["value_set_truncated"] = note remarks.append( { "column": kwargs.get("column"), "expectation": expectation.get("expectation_type"), "note": note, } ) if remarks: suite_dict.setdefault("meta", {}) suite_dict["meta"]["value_set_truncations"] = remarks return suite_dict def _summarize_expectation_suite(suite_dict: Dict[str, Any]) -> Dict[str, Any]: column_map: Dict[str, Dict[str, Any]] = {} table_expectations: List[Dict[str, Any]] = [] for expectation in suite_dict.get("expectations", []): expectation_type = expectation.get("expectation_type") kwargs = expectation.get("kwargs", {}) column = kwargs.get("column") summary_entry: Dict[str, Any] = {"expectation": expectation_type} if "value_set" in kwargs and isinstance(kwargs["value_set"], list): summary_entry["value_set_size"] = len(kwargs["value_set"]) summary_entry["value_set_preview"] = kwargs["value_set"][:5] if column: column_entry = column_map.setdefault( column, {"name": column, "expectations": []}, ) column_entry["expectations"].append(summary_entry) else: table_expectations.append(summary_entry) summary = { "column_profiles": list(column_map.values()), "table_level_expectations": table_expectations, "total_expectations": len(suite_dict.get("expectations", [])), } return summary def _sanitize_identifier(raw: Optional[str], fallback: str) -> str: if not raw: return fallback candidate = re.sub(r"[^0-9A-Za-z_]+", "_", raw).strip("_") return candidate or fallback def _format_connection_string(template: str, access_info: Dict[str, Any]) -> str: if not access_info: return template try: return template.format_map({k: v for k, v in access_info.items()}) except KeyError as exc: missing = exc.args[0] raise ValueError(f"table_access_info missing key '{missing}' required by connection_string.") from exc def _ensure_sql_runtime_datasource( context: AbstractDataContext, datasource_name: str, connection_string: str, ) -> None: try: datasource = context.get_datasource(datasource_name) except (DataContextError, ValueError) as exc: message = str(exc) if "Could not find a datasource" in message or "Unable to load datasource" in message: datasource = None else: # pragma: no cover - defensive raise RuntimeError(f"Failed to inspect datasource '{datasource_name}'.") from exc except Exception as exc: # pragma: no cover - defensive raise RuntimeError(f"Failed to inspect datasource '{datasource_name}'.") from exc if datasource is not None: execution_engine = getattr(datasource, "execution_engine", None) current_conn = getattr(execution_engine, "connection_string", None) if current_conn and current_conn != connection_string: logger.info( "Existing datasource %s uses different connection string; creating dedicated runtime datasource.", datasource_name, ) try: context.delete_datasource(datasource_name) except Exception as exc: # pragma: no cover - defensive logger.warning( "Failed to delete datasource %s before recreation: %s", datasource_name, exc, ) else: datasource = None if datasource is not None: return runtime_datasource_config = { "name": datasource_name, "class_name": "Datasource", "execution_engine": { "class_name": "SqlAlchemyExecutionEngine", "connection_string": connection_string, }, "data_connectors": { "runtime_connector": { "class_name": "RuntimeDataConnector", "batch_identifiers": ["default_identifier_name"], } }, } try: context.add_datasource(**runtime_datasource_config) except Exception as exc: # pragma: no cover - defensive raise RuntimeError(f"Failed to create runtime datasource '{datasource_name}'.") from exc def _build_sql_runtime_batch_request( context: AbstractDataContext, request: TableProfilingJobRequest, ) -> RuntimeBatchRequest: link_info = request.table_link_info or {} access_info = request.table_access_info or {} connection_template = link_info.get("connection_string") if not connection_template: raise ValueError("table_link_info.connection_string is required when using table_link_info.") connection_string = _format_connection_string(connection_template, access_info) source_type = (link_info.get("type") or "sql").lower() if source_type != "sql": raise ValueError(f"Unsupported table_link_info.type='{source_type}'. Only 'sql' is supported.") query = link_info.get("query") table_name = link_info.get("table") or link_info.get("table_name") schema_name = link_info.get("schema") if not query and not table_name: raise ValueError("Either table_link_info.query or table_link_info.table must be provided.") if not query: if not table_name: raise ValueError("table_link_info.table must be provided when query is omitted.") identifier = re.compile(r"^[A-Za-z_][A-Za-z0-9_$]*$") def _quote(name: str) -> str: if identifier.match(name): return name return f"`{name.replace('`', '``')}`" if schema_name: schema_part = schema_name if "." not in schema_name else schema_name.split(".")[-1] table_part = table_name if "." not in table_name else table_name.split(".")[-1] qualified_table = f"{_quote(schema_part)}.{_quote(table_part)}" else: qualified_table = _quote(table_name) query = f"SELECT * FROM {qualified_table}" limit = link_info.get("limit") if isinstance(limit, int) and limit > 0: query = f"{query} LIMIT {limit}" datasource_name = request.ge_datasource_name or _sanitize_identifier( f"{request.table_id}_runtime_ds", "runtime_ds" ) data_asset_name = request.ge_data_asset_name or _sanitize_identifier( table_name or "runtime_query", "runtime_query" ) _ensure_sql_runtime_datasource(context, datasource_name, connection_string) batch_identifiers = { "default_identifier_name": f"{request.table_id}:{request.version_ts}", } return RuntimeBatchRequest( datasource_name=datasource_name, data_connector_name="runtime_connector", data_asset_name=data_asset_name, runtime_parameters={"query": query}, batch_identifiers=batch_identifiers, ) def _run_onboarding_assistant( context: AbstractDataContext, batch_request: Any, suite_name: str, ) -> Tuple[ExpectationSuite, Any]: assistant = context.assistants.onboarding assistant_result = assistant.run(batch_request=batch_request) suite = assistant_result.get_expectation_suite(expectation_suite_name=suite_name) context.save_expectation_suite(suite, expectation_suite_name=suite_name) validation_getter = getattr(assistant_result, "get_validation_result", None) if callable(validation_getter): validation_result = validation_getter() else: validation_result = getattr(assistant_result, "validation_result", None) if validation_result is None: # Fallback: rerun validation using the freshly generated expectation suite. validator = context.get_validator( batch_request=batch_request, expectation_suite_name=suite_name, ) validation_result = validator.validate() return suite, validation_result def _resolve_context(request: TableProfilingJobRequest) -> AbstractDataContext: context_kwargs: Dict[str, Any] = {} if request.ge_data_context_root: context_kwargs["project_root_dir"] = request.ge_data_context_root elif os.environ.get("GE_DATA_CONTEXT_ROOT"): context_kwargs["project_root_dir"] = os.environ["GE_DATA_CONTEXT_ROOT"] else: context_kwargs["project_root_dir"] = str(_project_root()) return gx.get_context(**context_kwargs) def _build_batch_request( context: AbstractDataContext, request: TableProfilingJobRequest, ) -> Any: if request.ge_batch_request: from great_expectations.core.batch import BatchRequest return BatchRequest(**request.ge_batch_request) if request.table_link_info: return _build_sql_runtime_batch_request(context, request) if not request.ge_datasource_name or not request.ge_data_asset_name: raise ValueError( "ge_batch_request or (ge_datasource_name and ge_data_asset_name) must be provided." ) datasource = context.get_datasource(request.ge_datasource_name) data_asset = datasource.get_asset(request.ge_data_asset_name) return data_asset.build_batch_request() async def _run_ge_profiling(request: TableProfilingJobRequest) -> GEProfilingArtifacts: def _execute() -> GEProfilingArtifacts: context = _resolve_context(request) suite_name = ( request.ge_expectation_suite_name or f"{request.table_id}_profiling" ) batch_request = _build_batch_request(context, request) try: context.get_expectation_suite(suite_name) except DataContextError: context.add_expectation_suite(suite_name) validator = context.get_validator( batch_request=batch_request, expectation_suite_name=suite_name, ) profiler_type = (request.ge_profiler_type or "user_configurable").lower() if profiler_type == "data_assistant": suite, validation_result = _run_onboarding_assistant( context, batch_request, suite_name, ) else: try: from great_expectations.profile.user_configurable_profiler import ( UserConfigurableProfiler, ) except ImportError as err: # pragma: no cover - dependency guard raise RuntimeError( "UserConfigurableProfiler is unavailable; install great_expectations profiling extra or switch profiler." ) from err profiler = UserConfigurableProfiler(profile_dataset=validator) try: suite = profiler.build_suite() context.save_expectation_suite(suite, expectation_suite_name=suite_name) validator.expectation_suite = suite validation_result = validator.validate() except MetricResolutionError as exc: logger.warning( "UserConfigurableProfiler failed (%s); falling back to data assistant profiling.", exc, ) suite, validation_result = _run_onboarding_assistant( context, batch_request, suite_name, ) sanitized_suite = _sanitize_expectation_suite(suite) summary = _summarize_expectation_suite(sanitized_suite) validation_dict = validation_result.to_json_dict() context.build_data_docs() docs_path = Path(context.root_directory) / GE_REPORT_RELATIVE_PATH profiling_result = { "expectation_suite": sanitized_suite, "validation_result": validation_dict, "batch_request": getattr(batch_request, "to_json_dict", lambda: None)() or getattr(batch_request, "dict", lambda: None)(), } return GEProfilingArtifacts( profiling_result=profiling_result, profiling_summary=summary, docs_path=str(docs_path), ) return await asyncio.to_thread(_execute) async def _call_chat_completions( *, model_spec: str, system_prompt: str, user_prompt: str, client: httpx.AsyncClient, temperature: float = 0.2, timeout_seconds: Optional[float] = None, ) -> Any: # Normalize model spec to provider+model and issue the unified chat call. provider, model_name = resolve_provider_from_model(model_spec) payload = { "provider": provider.value, "model": model_name, "messages": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ], "temperature": temperature, } payload_size_bytes = len(json.dumps(payload, ensure_ascii=False).encode("utf-8")) url = f"{IMPORT_GATEWAY_BASE_URL.rstrip('/')}/v1/chat/completions" headers = build_import_gateway_headers() try: logger.info( "Calling chat completions API %s with model=%s payload_size=%sB", url, model_name, payload_size_bytes, ) response = await client.post( url, json=payload, timeout=timeout_seconds, headers=headers ) response.raise_for_status() except httpx.HTTPError as exc: error_name = exc.__class__.__name__ detail = str(exc).strip() if detail: message = f"Chat completions request failed ({error_name}): {detail}" else: message = f"Chat completions request failed ({error_name})." raise ProviderAPICallError(message) from exc try: response_payload = response.json() except ValueError as exc: raise ProviderAPICallError("Chat completions response was not valid JSON.") from exc parsed_payload = _parse_completion_payload(response_payload) usage_info = extract_llm_usage(response_payload) return LLMCallResult(data=parsed_payload, usage=usage_info) def _normalize_for_json(value: Any) -> Any: if value is None or isinstance(value, (str, int, float, bool)): return value if isinstance(value, (datetime, date)): return str(value) if hasattr(value, "model_dump"): try: return value.model_dump() except Exception: # pragma: no cover - defensive pass if is_dataclass(value): return asdict(value) if isinstance(value, dict): return {k: _normalize_for_json(v) for k, v in value.items()} if isinstance(value, (list, tuple, set)): return [_normalize_for_json(v) for v in value] if hasattr(value, "to_json_dict"): try: return value.to_json_dict() except Exception: # pragma: no cover - defensive pass if hasattr(value, "__dict__"): return _normalize_for_json(value.__dict__) return repr(value) def _json_dumps(data: Any) -> str: normalised = _normalize_for_json(data) return json.dumps(normalised, ensure_ascii=False, indent=2) def _preview_for_log(data: Any) -> str: try: serialised = _json_dumps(data) except Exception: serialised = repr(data) return serialised def _profiling_request_for_log(request: TableProfilingJobRequest) -> Dict[str, Any]: payload = request.model_dump() access_info = payload.get("table_access_info") if isinstance(access_info, dict): payload["table_access_info"] = {key: "***" for key in access_info.keys()} return payload async def _execute_result_desc( profiling_json: Dict[str, Any], _request: TableProfilingJobRequest, llm_model: str, client: httpx.AsyncClient, timeout_seconds: Optional[float], ) -> Dict[str, Any]: system_prompt, user_prompt = _render_prompt( "ge_result_desc", {"{{GE_RESULT_JSON}}": _json_dumps(profiling_json)}, ) llm_output = await _call_chat_completions( model_spec=llm_model, system_prompt=system_prompt, user_prompt=user_prompt, client=client, timeout_seconds=timeout_seconds, ) if not isinstance(llm_output.data, dict): raise ProviderAPICallError("GE result description payload must be a JSON object.") return llm_output async def _execute_snippet_generation( table_desc_json: Dict[str, Any], _request: TableProfilingJobRequest, llm_model: str, client: httpx.AsyncClient, timeout_seconds: Optional[float], ) -> List[Dict[str, Any]]: system_prompt, user_prompt = _render_prompt( "snippet_generator", {"{{TABLE_PROFILE_JSON}}": _json_dumps(table_desc_json)}, ) llm_output = await _call_chat_completions( model_spec=llm_model, system_prompt=system_prompt, user_prompt=user_prompt, client=client, timeout_seconds=timeout_seconds, ) if not isinstance(llm_output.data, list): raise ProviderAPICallError("Snippet generator must return a JSON array.") return llm_output async def _execute_snippet_alias( snippets_json: List[Dict[str, Any]], _request: TableProfilingJobRequest, llm_model: str, client: httpx.AsyncClient, timeout_seconds: Optional[float], ) -> List[Dict[str, Any]]: system_prompt, user_prompt = _render_prompt( "snippet_alias", {"{{SNIPPET_ARRAY}}": _json_dumps(snippets_json)}, ) llm_output = await _call_chat_completions( model_spec=llm_model, system_prompt=system_prompt, user_prompt=user_prompt, client=client, timeout_seconds=timeout_seconds, ) if not isinstance(llm_output.data, list): raise ProviderAPICallError("Snippet alias generator must return a JSON array.") return llm_output async def _run_action_with_callback( *, action_type: str, runner, callback_base: Dict[str, Any], client: httpx.AsyncClient, callback_url: str, input_payload: Any = None, model_spec: Optional[str] = None, ) -> Any: # Execute a pipeline action and always emit a callback capturing success/failure. if input_payload is not None: logger.info( "Pipeline action %s input: %s", action_type, _preview_for_log(input_payload), ) try: result = await runner() except Exception as exc: failure_payload = dict(callback_base) failure_payload.update( { "status": "failed", "action_type": action_type, "error": str(exc), } ) if model_spec is not None: failure_payload["model"] = model_spec await _post_callback(callback_url, failure_payload, client) raise usage_info: Optional[Dict[str, Any]] = None result_payload = result if isinstance(result, LLMCallResult): usage_info = result.usage result_payload = result.data success_payload = dict(callback_base) success_payload.update( { "status": "success", "action_type": action_type, } ) if model_spec is not None: success_payload["model"] = model_spec logger.info( "Pipeline action %s output: %s", action_type, _preview_for_log(result_payload), ) if action_type == PipelineActionType.GE_PROFILING: artifacts: GEProfilingArtifacts = result_payload success_payload["ge_profiling_json"] = artifacts.profiling_result success_payload["ge_profiling_summary"] = artifacts.profiling_summary success_payload["ge_report_path"] = artifacts.docs_path elif action_type == PipelineActionType.GE_RESULT_DESC: success_payload["ge_result_desc_json"] = result_payload elif action_type == PipelineActionType.SNIPPET: success_payload["snippet_json"] = result_payload elif action_type == PipelineActionType.SNIPPET_ALIAS: success_payload["snippet_alias_json"] = result_payload if usage_info: success_payload["llm_usage"] = usage_info await _post_callback(callback_url, success_payload, client) return result_payload async def process_table_profiling_job( request: TableProfilingJobRequest, _gateway: LLMGateway, client: httpx.AsyncClient, ) -> None: """Sequentially execute the four-step profiling pipeline and emit callbacks per action.""" timeout_seconds = _extract_timeout_seconds(request.extra_options) if timeout_seconds is None: timeout_seconds = DEFAULT_CHAT_TIMEOUT_SECONDS base_payload = { "table_id": request.table_id, "version_ts": request.version_ts, "callback_url": str(request.callback_url), "table_schema": request.table_schema, "table_schema_version_id": request.table_schema_version_id, "llm_model": request.llm_model, "llm_timeout_seconds": timeout_seconds, } logging_request_payload = _profiling_request_for_log(request) try: artifacts: GEProfilingArtifacts = await _run_action_with_callback( action_type=PipelineActionType.GE_PROFILING, runner=lambda: _run_ge_profiling(request), callback_base=base_payload, client=client, callback_url=str(request.callback_url), input_payload=logging_request_payload, model_spec=request.llm_model, ) table_desc_json: Dict[str, Any] = await _run_action_with_callback( action_type=PipelineActionType.GE_RESULT_DESC, runner=lambda: _execute_result_desc( artifacts.profiling_result, request, request.llm_model, client, timeout_seconds, ), callback_base=base_payload, client=client, callback_url=str(request.callback_url), input_payload=artifacts.profiling_result, model_spec=request.llm_model, ) snippet_json: List[Dict[str, Any]] = await _run_action_with_callback( action_type=PipelineActionType.SNIPPET, runner=lambda: _execute_snippet_generation( table_desc_json, request, request.llm_model, client, timeout_seconds, ), callback_base=base_payload, client=client, callback_url=str(request.callback_url), input_payload=table_desc_json, model_spec=request.llm_model, ) await _run_action_with_callback( action_type=PipelineActionType.SNIPPET_ALIAS, runner=lambda: _execute_snippet_alias( snippet_json, request, request.llm_model, client, timeout_seconds, ), callback_base=base_payload, client=client, callback_url=str(request.callback_url), input_payload=snippet_json, model_spec=request.llm_model, ) except Exception: # pragma: no cover - defensive catch logger.exception( "Table profiling pipeline failed for table_id=%s version_ts=%s", request.table_id, request.version_ts, )