Compare commits

...

7 Commits

Author SHA1 Message Date
0b765e6719 Merge branch 'main' of ssh://git.agentcarrier.cn:2222/zhaoqingliang/data-ge
no conficts
2025-11-03 00:29:42 +08:00
cf2e6aedc7 忽略gx和logs 2025-11-03 00:20:36 +08:00
9a7710a70a file和demo 2025-11-03 00:20:00 +08:00
799b9f8154 增加日志 2025-11-03 00:19:43 +08:00
fe1de87696 部分参数调整 2025-11-03 00:19:23 +08:00
c2a08e4134 table profiling功能开发 2025-11-03 00:18:26 +08:00
557efc4bf1 添加数据知识生成链路的三个prompt 2025-10-31 15:55:07 +08:00
18 changed files with 11842 additions and 23 deletions

2
.env
View File

@ -17,7 +17,7 @@ DEFAULT_IMPORT_MODEL=deepseek:deepseek-chat
IMPORT_GATEWAY_BASE_URL=http://localhost:8000
# HTTP client configuration
HTTP_CLIENT_TIMEOUT=30
HTTP_CLIENT_TIMEOUT=60
HTTP_CLIENT_TRUST_ENV=false
# HTTP_CLIENT_PROXY=

4
.gitignore vendored
View File

@ -3,4 +3,6 @@ gx/uncommitted/
.vscode/
**/__pycache__/
*.pyc
.DS_Store
.DS_Store
gx/
logs/

26
app/db.py Normal file
View File

@ -0,0 +1,26 @@
from __future__ import annotations
import os
from functools import lru_cache
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
@lru_cache(maxsize=1)
def get_engine() -> Engine:
"""Return a cached SQLAlchemy engine configured from DATABASE_URL."""
database_url = os.getenv(
"DATABASE_URL",
"mysql+pymysql://root:12345678@localhost:3306/data-ge?charset=utf8mb4",
)
connect_args = {}
if database_url.startswith("sqlite"):
connect_args["check_same_thread"] = False
return create_engine(
database_url,
pool_pre_ping=True,
future=True,
connect_args=connect_args,
)

View File

@ -2,12 +2,17 @@ from __future__ import annotations
import asyncio
import logging
import logging.config
import os
from contextlib import asynccontextmanager
from typing import Any
import yaml
import httpx
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from app.exceptions import ProviderAPICallError, ProviderConfigurationError
from app.models import (
@ -15,30 +20,42 @@ from app.models import (
DataImportAnalysisJobRequest,
LLMRequest,
LLMResponse,
TableProfilingJobAck,
TableProfilingJobRequest,
TableSnippetUpsertRequest,
TableSnippetUpsertResponse,
)
from app.services import LLMGateway
from app.services.import_analysis import process_import_analysis_job
from app.services.table_profiling import process_table_profiling_job
from app.services.table_snippet import upsert_action_result
def _ensure_log_directories(config: dict[str, Any]) -> None:
handlers = config.get("handlers", {})
for handler_config in handlers.values():
filename = handler_config.get("filename")
if not filename:
continue
directory = os.path.dirname(filename)
if directory and not os.path.exists(directory):
os.makedirs(directory, exist_ok=True)
def _configure_logging() -> None:
level_name = os.getenv("LOG_LEVEL", "INFO").upper()
level = getattr(logging, level_name, logging.INFO)
log_format = os.getenv(
"LOG_FORMAT",
"%(asctime)s %(levelname)s %(name)s:%(lineno)d %(message)s",
config_path = os.getenv("LOGGING_CONFIG", "logging.yaml")
if os.path.exists(config_path):
with open(config_path, "r", encoding="utf-8") as fh:
config = yaml.safe_load(fh)
if isinstance(config, dict):
_ensure_log_directories(config)
logging.config.dictConfig(config)
return
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s:%(lineno)d %(message)s",
)
root = logging.getLogger()
if not root.handlers:
logging.basicConfig(level=level, format=log_format)
else:
root.setLevel(level)
formatter = logging.Formatter(log_format)
for handler in root.handlers:
handler.setLevel(level)
handler.setFormatter(formatter)
_configure_logging()
logger = logging.getLogger(__name__)
@ -119,6 +136,24 @@ def create_app() -> FastAPI:
lifespan=lifespan,
)
@application.exception_handler(RequestValidationError)
async def request_validation_exception_handler(
request: Request, exc: RequestValidationError
) -> JSONResponse:
try:
raw_body = await request.body()
except Exception: # pragma: no cover - defensive
raw_body = b"<unavailable>"
truncated_body = raw_body[:4096]
logger.warning(
"Validation error on %s %s: %s | body preview=%s",
request.method,
request.url.path,
exc.errors(),
truncated_body.decode("utf-8", errors="ignore"),
)
return JSONResponse(status_code=422, content={"detail": exc.errors()})
@application.post(
"/v1/chat/completions",
response_model=LLMResponse,
@ -164,6 +199,52 @@ def create_app() -> FastAPI:
return DataImportAnalysisJobAck(import_record_id=payload.import_record_id, status="accepted")
@application.post(
"/v1/table/profiling",
response_model=TableProfilingJobAck,
summary="Run end-to-end GE profiling pipeline and notify via callback per action",
status_code=202,
)
async def run_table_profiling(
payload: TableProfilingJobRequest,
gateway: LLMGateway = Depends(get_gateway),
client: httpx.AsyncClient = Depends(get_http_client),
) -> TableProfilingJobAck:
request_copy = payload.model_copy(deep=True)
async def _runner() -> None:
await process_table_profiling_job(request_copy, gateway, client)
asyncio.create_task(_runner())
return TableProfilingJobAck(
table_id=payload.table_id,
version_ts=payload.version_ts,
status="accepted",
)
@application.post(
"/v1/table/snippet",
response_model=TableSnippetUpsertResponse,
summary="Persist or update action results, such as table snippets.",
)
async def upsert_table_snippet(
payload: TableSnippetUpsertRequest,
) -> TableSnippetUpsertResponse:
request_copy = payload.model_copy(deep=True)
try:
return await asyncio.to_thread(upsert_action_result, request_copy)
except Exception as exc:
logger.error(
"Failed to upsert table snippet: table_id=%s version_ts=%s action_type=%s",
payload.table_id,
payload.version_ts,
payload.action_type,
exc_info=True,
)
raise HTTPException(status_code=500, detail=str(exc)) from exc
@application.post("/__mock__/import-callback")
async def mock_import_callback(payload: dict[str, Any]) -> dict[str, str]:
logger.info("Received import analysis callback: %s", payload)

View File

@ -1,5 +1,6 @@
from __future__ import annotations
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Union
@ -76,8 +77,8 @@ class DataImportAnalysisRequest(BaseModel):
description="Ordered list of table headers associated with the data.",
)
llm_model: str = Field(
...,
description="Model identifier. Accepts 'provider:model' format or plain model name.",
None,
description="Model identifier. Accepts 'provider:model_name' format or custom model alias.",
)
temperature: Optional[float] = Field(
None,
@ -135,3 +136,158 @@ class DataImportAnalysisJobRequest(BaseModel):
class DataImportAnalysisJobAck(BaseModel):
import_record_id: str = Field(..., description="Echo of the import record identifier")
status: str = Field("accepted", description="Processing status acknowledgement.")
class ActionType(str, Enum):
GE_PROFILING = "ge_profiling"
GE_RESULT_DESC = "ge_result_desc"
SNIPPET = "snippet"
SNIPPET_ALIAS = "snippet_alias"
class ActionStatus(str, Enum):
PENDING = "pending"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
PARTIAL = "partial"
class TableProfilingJobRequest(BaseModel):
table_id: str = Field(..., description="Unique identifier for the table to profile.")
version_ts: str = Field(
...,
pattern=r"^\d{14}$",
description="Version timestamp expressed as fourteen digit string (yyyyMMddHHmmss).",
)
callback_url: HttpUrl = Field(
...,
description="Callback endpoint invoked after each pipeline action completes.",
)
llm_model: Optional[str] = Field(
None,
description="Default LLM model spec applied to prompt-based actions when overrides are omitted.",
)
table_schema: Optional[Any] = Field(
None,
description="Schema structure snapshot for the current table version.",
)
table_schema_version_id: Optional[str] = Field(
None,
description="Identifier for the schema snapshot provided in table_schema.",
)
table_link_info: Optional[Dict[str, Any]] = Field(
None,
description=(
"Information describing how to locate the source table for profiling. "
"For example: {'type': 'sql', 'connection_string': 'mysql+pymysql://user:pass@host/db', "
"'table': 'schema.table_name'}."
),
)
table_access_info: Optional[Dict[str, Any]] = Field(
None,
description=(
"Credentials or supplemental parameters required to access the table described in table_link_info. "
"These values can be merged into the connection string using Python format placeholders."
),
)
ge_batch_request: Optional[Dict[str, Any]] = Field(
None,
description="Optional Great Expectations batch request payload used for profiling.",
)
ge_expectation_suite_name: Optional[str] = Field(
None,
description="Expectation suite name used during profiling. Created automatically when absent.",
)
ge_data_context_root: Optional[str] = Field(
None,
description="Custom root directory for the Great Expectations data context. Defaults to project ./gx.",
)
ge_datasource_name: Optional[str] = Field(
None,
description="Datasource name registered inside the GE context when batch_request is not supplied.",
)
ge_data_asset_name: Optional[str] = Field(
None,
description="Data asset reference used when inferring batch request from datasource configuration.",
)
ge_profiler_type: str = Field(
"user_configurable",
description="Profiler implementation identifier. Currently supports 'user_configurable' or 'data_assistant'.",
)
result_desc_model: Optional[str] = Field(
None,
description="LLM model override used for GE result description (action 2).",
)
snippet_model: Optional[str] = Field(
None,
description="LLM model override used for snippet generation (action 3).",
)
snippet_alias_model: Optional[str] = Field(
None,
description="LLM model override used for snippet alias enrichment (action 4).",
)
extra_options: Optional[Dict[str, Any]] = Field(
None,
description="Miscellaneous execution flags applied across pipeline steps.",
)
class TableProfilingJobAck(BaseModel):
table_id: str = Field(..., description="Echo of the table identifier.")
version_ts: str = Field(..., description="Echo of the profiling version timestamp (yyyyMMddHHmmss).")
status: str = Field("accepted", description="Processing acknowledgement status.")
class TableSnippetUpsertRequest(BaseModel):
table_id: int = Field(..., ge=1, description="Unique identifier for the table.")
version_ts: int = Field(
...,
ge=0,
description="Version timestamp aligned with the pipeline (yyyyMMddHHmmss as integer).",
)
action_type: ActionType = Field(..., description="Pipeline action type for this record.")
status: ActionStatus = Field(
ActionStatus.SUCCESS, description="Execution status for the action."
)
callback_url: HttpUrl = Field(..., description="Callback URL associated with the action run.")
table_schema_version_id: int = Field(..., ge=0, description="Identifier for the schema snapshot.")
table_schema: Any = Field(..., description="Schema snapshot payload for the table.")
result_json: Optional[Any] = Field(
None,
description="Primary result payload for the action (e.g., profiling output, snippet array).",
)
result_summary_json: Optional[Any] = Field(
None,
description="Optional summary payload (e.g., profiling summary) for the action.",
)
html_report_url: Optional[str] = Field(
None,
description="Optional HTML report URL generated by the action.",
)
error_code: Optional[str] = Field(None, description="Optional error code when status indicates a failure.")
error_message: Optional[str] = Field(None, description="Optional error message when status indicates a failure.")
started_at: Optional[datetime] = Field(
None, description="Timestamp when the action started executing."
)
finished_at: Optional[datetime] = Field(
None, description="Timestamp when the action finished executing."
)
duration_ms: Optional[int] = Field(
None,
ge=0,
description="Optional execution duration in milliseconds.",
)
result_checksum: Optional[str] = Field(
None,
description="Optional checksum for the result payload (e.g., MD5).",
)
class TableSnippetUpsertResponse(BaseModel):
table_id: int
version_ts: int
action_type: ActionType
status: ActionStatus
updated: bool

View File

@ -42,7 +42,7 @@ def _env_float(name: str, default: float) -> float:
return default
IMPORT_CHAT_TIMEOUT_SECONDS = _env_float("IMPORT_CHAT_TIMEOUT_SECONDS", 90.0)
IMPORT_CHAT_TIMEOUT_SECONDS = _env_float("IMPORT_CHAT_TIMEOUT_SECONDS", 120.0)
SUPPORTED_IMPORT_MODELS = get_supported_import_models()
@ -298,7 +298,7 @@ def parse_llm_analysis_json(llm_response: LLMResponse) -> Dict[str, Any]:
try:
return json.loads(json_payload)
except json.JSONDecodeError as exc:
preview = json_payload[:2000]
preview = json_payload[:10000]
logger.error("Failed to parse JSON from LLM response content: %s", preview, exc_info=True)
raise ProviderAPICallError("LLM response JSON could not be parsed.") from exc

View File

@ -0,0 +1,832 @@
from __future__ import annotations
import asyncio
import json
import logging
import os
import re
from datetime import date, datetime
from dataclasses import asdict, dataclass, is_dataclass
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import httpx
import great_expectations as gx
from great_expectations.core.batch import RuntimeBatchRequest
from great_expectations.core.expectation_suite import ExpectationSuite
from great_expectations.data_context import AbstractDataContext
from great_expectations.exceptions import DataContextError, MetricResolutionError
from app.exceptions import ProviderAPICallError
from app.models import TableProfilingJobRequest
from app.services import LLMGateway
from app.settings import DEFAULT_IMPORT_MODEL
from app.services.import_analysis import (
IMPORT_GATEWAY_BASE_URL,
resolve_provider_from_model,
)
logger = logging.getLogger(__name__)
GE_REPORT_RELATIVE_PATH = Path("uncommitted") / "data_docs" / "local_site" / "index.html"
PROMPT_FILENAMES = {
"ge_result_desc": "ge_result_desc_prompt.md",
"snippet_generator": "snippet_generator.md",
"snippet_alias": "snippet_alias_generator.md",
}
DEFAULT_CHAT_TIMEOUT_SECONDS = 90.0
@dataclass
class GEProfilingArtifacts:
profiling_result: Dict[str, Any]
profiling_summary: Dict[str, Any]
docs_path: str
class PipelineActionType:
GE_PROFILING = "ge_profiling"
GE_RESULT_DESC = "ge_result_desc"
SNIPPET = "snippet"
SNIPPET_ALIAS = "snippet_alias"
def _project_root() -> Path:
return Path(__file__).resolve().parents[2]
def _prompt_dir() -> Path:
return _project_root() / "prompt"
@lru_cache(maxsize=None)
def _load_prompt_parts(filename: str) -> Tuple[str, str]:
prompt_path = _prompt_dir() / filename
if not prompt_path.exists():
raise FileNotFoundError(f"Prompt template not found: {prompt_path}")
raw = prompt_path.read_text(encoding="utf-8")
splitter = "用户消息User"
if splitter not in raw:
raise ValueError(f"Prompt template '{filename}' missing separator '{splitter}'.")
system_raw, user_raw = raw.split(splitter, maxsplit=1)
system_text = system_raw.replace("系统角色System", "").strip()
user_text = user_raw.strip()
return system_text, user_text
def _render_prompt(template_key: str, replacements: Dict[str, str]) -> Tuple[str, str]:
filename = PROMPT_FILENAMES[template_key]
system_text, user_template = _load_prompt_parts(filename)
rendered_user = user_template
for key, value in replacements.items():
rendered_user = rendered_user.replace(key, value)
return system_text, rendered_user
def _extract_timeout_seconds(options: Optional[Dict[str, Any]]) -> Optional[float]:
if not options:
return None
value = options.get("llm_timeout_seconds")
if value is None:
return None
try:
timeout = float(value)
if timeout <= 0:
raise ValueError
return timeout
except (TypeError, ValueError):
logger.warning(
"Invalid llm_timeout_seconds value in extra_options: %r. Falling back to default.",
value,
)
return DEFAULT_CHAT_TIMEOUT_SECONDS
def _extract_json_payload(content: str) -> str:
fenced = re.search(
r"```(?:json)?\s*([\s\S]+?)```",
content,
flags=re.IGNORECASE,
)
if fenced:
snippet = fenced.group(1).strip()
if snippet:
return snippet
stripped = content.strip()
if not stripped:
raise ValueError("Empty LLM content.")
for opener, closer in (("{", "}"), ("[", "]")):
start = stripped.find(opener)
end = stripped.rfind(closer)
if start != -1 and end != -1 and end > start:
candidate = stripped[start : end + 1].strip()
return candidate
return stripped
def _parse_completion_payload(response_payload: Dict[str, Any]) -> Any:
choices = response_payload.get("choices") or []
if not choices:
raise ProviderAPICallError("LLM response did not contain choices to parse.")
message = choices[0].get("message") or {}
content = message.get("content") or ""
if not content.strip():
raise ProviderAPICallError("LLM response content is empty.")
json_payload = _extract_json_payload(content)
try:
return json.loads(json_payload)
except json.JSONDecodeError as exc:
preview = json_payload[:800]
logger.error("Failed to parse JSON from LLM response: %s", preview, exc_info=True)
raise ProviderAPICallError("LLM response JSON parsing failed.") from exc
async def _post_callback(callback_url: str, payload: Dict[str, Any], client: httpx.AsyncClient) -> None:
safe_payload = _normalize_for_json(payload)
try:
logger.info(
"Posting pipeline action callback to %s: %s",
callback_url,
json.dumps(safe_payload, ensure_ascii=False),
)
response = await client.post(callback_url, json=safe_payload)
response.raise_for_status()
except httpx.HTTPError as exc:
logger.error("Callback delivery to %s failed: %s", callback_url, exc, exc_info=True)
def _sanitize_value_set(value: Any, max_values: int) -> Tuple[Any, Optional[Dict[str, int]]]:
if not isinstance(value, list):
return value, None
original_len = len(value)
if original_len <= max_values:
return value, None
trimmed = value[:max_values]
return trimmed, {"original_length": original_len, "retained": max_values}
def _sanitize_expectation_suite(suite: ExpectationSuite, max_value_set_values: int = 100) -> Dict[str, Any]:
suite_dict = suite.to_json_dict()
remarks: List[Dict[str, Any]] = []
for expectation in suite_dict.get("expectations", []):
kwargs = expectation.get("kwargs", {})
if "value_set" in kwargs:
sanitized_value, note = _sanitize_value_set(kwargs["value_set"], max_value_set_values)
kwargs["value_set"] = sanitized_value
if note:
expectation.setdefault("meta", {})
expectation["meta"]["value_set_truncated"] = note
remarks.append(
{
"column": kwargs.get("column"),
"expectation": expectation.get("expectation_type"),
"note": note,
}
)
if remarks:
suite_dict.setdefault("meta", {})
suite_dict["meta"]["value_set_truncations"] = remarks
return suite_dict
def _summarize_expectation_suite(suite_dict: Dict[str, Any]) -> Dict[str, Any]:
column_map: Dict[str, Dict[str, Any]] = {}
table_expectations: List[Dict[str, Any]] = []
for expectation in suite_dict.get("expectations", []):
expectation_type = expectation.get("expectation_type")
kwargs = expectation.get("kwargs", {})
column = kwargs.get("column")
summary_entry: Dict[str, Any] = {"expectation": expectation_type}
if "value_set" in kwargs and isinstance(kwargs["value_set"], list):
summary_entry["value_set_size"] = len(kwargs["value_set"])
summary_entry["value_set_preview"] = kwargs["value_set"][:5]
if column:
column_entry = column_map.setdefault(
column,
{"name": column, "expectations": []},
)
column_entry["expectations"].append(summary_entry)
else:
table_expectations.append(summary_entry)
summary = {
"column_profiles": list(column_map.values()),
"table_level_expectations": table_expectations,
"total_expectations": len(suite_dict.get("expectations", [])),
}
return summary
def _sanitize_identifier(raw: Optional[str], fallback: str) -> str:
if not raw:
return fallback
candidate = re.sub(r"[^0-9A-Za-z_]+", "_", raw).strip("_")
return candidate or fallback
def _format_connection_string(template: str, access_info: Dict[str, Any]) -> str:
if not access_info:
return template
try:
return template.format_map({k: v for k, v in access_info.items()})
except KeyError as exc:
missing = exc.args[0]
raise ValueError(f"table_access_info missing key '{missing}' required by connection_string.") from exc
def _ensure_sql_runtime_datasource(
context: AbstractDataContext,
datasource_name: str,
connection_string: str,
) -> None:
try:
datasource = context.get_datasource(datasource_name)
except (DataContextError, ValueError) as exc:
message = str(exc)
if "Could not find a datasource" in message or "Unable to load datasource" in message:
datasource = None
else: # pragma: no cover - defensive
raise RuntimeError(f"Failed to inspect datasource '{datasource_name}'.") from exc
except Exception as exc: # pragma: no cover - defensive
raise RuntimeError(f"Failed to inspect datasource '{datasource_name}'.") from exc
if datasource is not None:
execution_engine = getattr(datasource, "execution_engine", None)
current_conn = getattr(execution_engine, "connection_string", None)
if current_conn and current_conn != connection_string:
logger.info(
"Existing datasource %s uses different connection string; creating dedicated runtime datasource.",
datasource_name,
)
try:
context.delete_datasource(datasource_name)
except Exception as exc: # pragma: no cover - defensive
logger.warning(
"Failed to delete datasource %s before recreation: %s",
datasource_name,
exc,
)
else:
datasource = None
if datasource is not None:
return
runtime_datasource_config = {
"name": datasource_name,
"class_name": "Datasource",
"execution_engine": {
"class_name": "SqlAlchemyExecutionEngine",
"connection_string": connection_string,
},
"data_connectors": {
"runtime_connector": {
"class_name": "RuntimeDataConnector",
"batch_identifiers": ["default_identifier_name"],
}
},
}
try:
context.add_datasource(**runtime_datasource_config)
except Exception as exc: # pragma: no cover - defensive
raise RuntimeError(f"Failed to create runtime datasource '{datasource_name}'.") from exc
def _build_sql_runtime_batch_request(
context: AbstractDataContext,
request: TableProfilingJobRequest,
) -> RuntimeBatchRequest:
link_info = request.table_link_info or {}
access_info = request.table_access_info or {}
connection_template = link_info.get("connection_string")
if not connection_template:
raise ValueError("table_link_info.connection_string is required when using table_link_info.")
connection_string = _format_connection_string(connection_template, access_info)
source_type = (link_info.get("type") or "sql").lower()
if source_type != "sql":
raise ValueError(f"Unsupported table_link_info.type='{source_type}'. Only 'sql' is supported.")
query = link_info.get("query")
table_name = link_info.get("table") or link_info.get("table_name")
schema_name = link_info.get("schema")
if not query and not table_name:
raise ValueError("Either table_link_info.query or table_link_info.table must be provided.")
if not query:
if not table_name:
raise ValueError("table_link_info.table must be provided when query is omitted.")
identifier = re.compile(r"^[A-Za-z_][A-Za-z0-9_$]*$")
def _quote(name: str) -> str:
if identifier.match(name):
return name
return f"`{name.replace('`', '``')}`"
if schema_name:
schema_part = schema_name if "." not in schema_name else schema_name.split(".")[-1]
table_part = table_name if "." not in table_name else table_name.split(".")[-1]
qualified_table = f"{_quote(schema_part)}.{_quote(table_part)}"
else:
qualified_table = _quote(table_name)
query = f"SELECT * FROM {qualified_table}"
limit = link_info.get("limit")
if isinstance(limit, int) and limit > 0:
query = f"{query} LIMIT {limit}"
datasource_name = request.ge_datasource_name or _sanitize_identifier(
f"{request.table_id}_runtime_ds", "runtime_ds"
)
data_asset_name = request.ge_data_asset_name or _sanitize_identifier(
table_name or "runtime_query", "runtime_query"
)
_ensure_sql_runtime_datasource(context, datasource_name, connection_string)
batch_identifiers = {
"default_identifier_name": f"{request.table_id}:{request.version_ts}",
}
return RuntimeBatchRequest(
datasource_name=datasource_name,
data_connector_name="runtime_connector",
data_asset_name=data_asset_name,
runtime_parameters={"query": query},
batch_identifiers=batch_identifiers,
)
def _run_onboarding_assistant(
context: AbstractDataContext,
batch_request: Any,
suite_name: str,
) -> Tuple[ExpectationSuite, Any]:
assistant = context.assistants.onboarding
assistant_result = assistant.run(batch_request=batch_request)
suite = assistant_result.get_expectation_suite(expectation_suite_name=suite_name)
context.save_expectation_suite(suite, expectation_suite_name=suite_name)
validation_getter = getattr(assistant_result, "get_validation_result", None)
if callable(validation_getter):
validation_result = validation_getter()
else:
validation_result = getattr(assistant_result, "validation_result", None)
if validation_result is None:
# Fallback: rerun validation using the freshly generated expectation suite.
validator = context.get_validator(
batch_request=batch_request,
expectation_suite_name=suite_name,
)
validation_result = validator.validate()
return suite, validation_result
def _resolve_context(request: TableProfilingJobRequest) -> AbstractDataContext:
context_kwargs: Dict[str, Any] = {}
if request.ge_data_context_root:
context_kwargs["project_root_dir"] = request.ge_data_context_root
elif os.environ.get("GE_DATA_CONTEXT_ROOT"):
context_kwargs["project_root_dir"] = os.environ["GE_DATA_CONTEXT_ROOT"]
else:
context_kwargs["project_root_dir"] = str(_project_root())
return gx.get_context(**context_kwargs)
def _build_batch_request(
context: AbstractDataContext,
request: TableProfilingJobRequest,
) -> Any:
if request.ge_batch_request:
from great_expectations.core.batch import BatchRequest
return BatchRequest(**request.ge_batch_request)
if request.table_link_info:
return _build_sql_runtime_batch_request(context, request)
if not request.ge_datasource_name or not request.ge_data_asset_name:
raise ValueError(
"ge_batch_request or (ge_datasource_name and ge_data_asset_name) must be provided."
)
datasource = context.get_datasource(request.ge_datasource_name)
data_asset = datasource.get_asset(request.ge_data_asset_name)
return data_asset.build_batch_request()
async def _run_ge_profiling(request: TableProfilingJobRequest) -> GEProfilingArtifacts:
def _execute() -> GEProfilingArtifacts:
context = _resolve_context(request)
suite_name = (
request.ge_expectation_suite_name
or f"{request.table_id}_profiling"
)
batch_request = _build_batch_request(context, request)
try:
context.get_expectation_suite(suite_name)
except DataContextError:
context.add_expectation_suite(suite_name)
validator = context.get_validator(
batch_request=batch_request,
expectation_suite_name=suite_name,
)
profiler_type = (request.ge_profiler_type or "user_configurable").lower()
if profiler_type == "data_assistant":
suite, validation_result = _run_onboarding_assistant(
context,
batch_request,
suite_name,
)
else:
try:
from great_expectations.profile.user_configurable_profiler import (
UserConfigurableProfiler,
)
except ImportError as err: # pragma: no cover - dependency guard
raise RuntimeError(
"UserConfigurableProfiler is unavailable; install great_expectations profiling extra or switch profiler."
) from err
profiler = UserConfigurableProfiler(profile_dataset=validator)
try:
suite = profiler.build_suite()
context.save_expectation_suite(suite, expectation_suite_name=suite_name)
validator.expectation_suite = suite
validation_result = validator.validate()
except MetricResolutionError as exc:
logger.warning(
"UserConfigurableProfiler failed (%s); falling back to data assistant profiling.",
exc,
)
suite, validation_result = _run_onboarding_assistant(
context,
batch_request,
suite_name,
)
sanitized_suite = _sanitize_expectation_suite(suite)
summary = _summarize_expectation_suite(sanitized_suite)
validation_dict = validation_result.to_json_dict()
context.build_data_docs()
docs_path = Path(context.root_directory) / GE_REPORT_RELATIVE_PATH
profiling_result = {
"expectation_suite": sanitized_suite,
"validation_result": validation_dict,
"batch_request": getattr(batch_request, "to_json_dict", lambda: None)() or getattr(batch_request, "dict", lambda: None)(),
}
return GEProfilingArtifacts(
profiling_result=profiling_result,
profiling_summary=summary,
docs_path=str(docs_path),
)
return await asyncio.to_thread(_execute)
async def _call_chat_completions(
*,
model_spec: str,
system_prompt: str,
user_prompt: str,
client: httpx.AsyncClient,
temperature: float = 0.2,
timeout_seconds: Optional[float] = None,
) -> Any:
provider, model_name = resolve_provider_from_model(model_spec)
payload = {
"provider": provider.value,
"model": model_name,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
"temperature": temperature,
}
payload_size_bytes = len(json.dumps(payload, ensure_ascii=False).encode("utf-8"))
url = f"{IMPORT_GATEWAY_BASE_URL.rstrip('/')}/v1/chat/completions"
try:
# log the request whole info
logger.info(
"Calling chat completions API %s with model %s and size %s and payload %s",
url,
model_name,
payload_size_bytes,
payload,
)
response = await client.post(url, json=payload, timeout=timeout_seconds)
response.raise_for_status()
except httpx.HTTPError as exc:
error_name = exc.__class__.__name__
detail = str(exc).strip()
if detail:
message = f"Chat completions request failed ({error_name}): {detail}"
else:
message = f"Chat completions request failed ({error_name})."
raise ProviderAPICallError(message) from exc
try:
response_payload = response.json()
except ValueError as exc:
raise ProviderAPICallError("Chat completions response was not valid JSON.") from exc
return _parse_completion_payload(response_payload)
def _normalize_for_json(value: Any) -> Any:
if value is None or isinstance(value, (str, int, float, bool)):
return value
if isinstance(value, (datetime, date)):
return str(value)
if hasattr(value, "model_dump"):
try:
return value.model_dump()
except Exception: # pragma: no cover - defensive
pass
if is_dataclass(value):
return asdict(value)
if isinstance(value, dict):
return {k: _normalize_for_json(v) for k, v in value.items()}
if isinstance(value, (list, tuple, set)):
return [_normalize_for_json(v) for v in value]
if hasattr(value, "to_json_dict"):
try:
return value.to_json_dict()
except Exception: # pragma: no cover - defensive
pass
if hasattr(value, "__dict__"):
return _normalize_for_json(value.__dict__)
return repr(value)
def _json_dumps(data: Any) -> str:
normalised = _normalize_for_json(data)
return json.dumps(normalised, ensure_ascii=False, indent=2)
def _preview_for_log(data: Any) -> str:
try:
serialised = _json_dumps(data)
except Exception:
serialised = repr(data)
return serialised
def _profiling_request_for_log(request: TableProfilingJobRequest) -> Dict[str, Any]:
payload = request.model_dump()
access_info = payload.get("table_access_info")
if isinstance(access_info, dict):
payload["table_access_info"] = {key: "***" for key in access_info.keys()}
return payload
async def _execute_result_desc(
profiling_json: Dict[str, Any],
_request: TableProfilingJobRequest,
llm_model: str,
client: httpx.AsyncClient,
timeout_seconds: Optional[float],
) -> Dict[str, Any]:
system_prompt, user_prompt = _render_prompt(
"ge_result_desc",
{"{{GE_RESULT_JSON}}": _json_dumps(profiling_json)},
)
llm_output = await _call_chat_completions(
model_spec=llm_model,
system_prompt=system_prompt,
user_prompt=user_prompt,
client=client,
timeout_seconds=timeout_seconds,
)
if not isinstance(llm_output, dict):
raise ProviderAPICallError("GE result description payload must be a JSON object.")
return llm_output
async def _execute_snippet_generation(
table_desc_json: Dict[str, Any],
_request: TableProfilingJobRequest,
llm_model: str,
client: httpx.AsyncClient,
timeout_seconds: Optional[float],
) -> List[Dict[str, Any]]:
system_prompt, user_prompt = _render_prompt(
"snippet_generator",
{"{{TABLE_PROFILE_JSON}}": _json_dumps(table_desc_json)},
)
llm_output = await _call_chat_completions(
model_spec=llm_model,
system_prompt=system_prompt,
user_prompt=user_prompt,
client=client,
timeout_seconds=timeout_seconds,
)
if not isinstance(llm_output, list):
raise ProviderAPICallError("Snippet generator must return a JSON array.")
return llm_output
async def _execute_snippet_alias(
snippets_json: List[Dict[str, Any]],
_request: TableProfilingJobRequest,
llm_model: str,
client: httpx.AsyncClient,
timeout_seconds: Optional[float],
) -> List[Dict[str, Any]]:
system_prompt, user_prompt = _render_prompt(
"snippet_alias",
{"{{SNIPPET_ARRAY}}": _json_dumps(snippets_json)},
)
llm_output = await _call_chat_completions(
model_spec=llm_model,
system_prompt=system_prompt,
user_prompt=user_prompt,
client=client,
timeout_seconds=timeout_seconds,
)
if not isinstance(llm_output, list):
raise ProviderAPICallError("Snippet alias generator must return a JSON array.")
return llm_output
async def _run_action_with_callback(
*,
action_type: str,
runner,
callback_base: Dict[str, Any],
client: httpx.AsyncClient,
callback_url: str,
input_payload: Any = None,
model_spec: Optional[str] = None,
) -> Any:
if input_payload is not None:
logger.info(
"Pipeline action %s input: %s",
action_type,
_preview_for_log(input_payload),
)
try:
result = await runner()
except Exception as exc:
failure_payload = dict(callback_base)
failure_payload.update(
{
"status": "failed",
"action_type": action_type,
"error": str(exc),
}
)
if model_spec is not None:
failure_payload["model"] = model_spec
await _post_callback(callback_url, failure_payload, client)
raise
success_payload = dict(callback_base)
success_payload.update(
{
"status": "success",
"action_type": action_type,
}
)
if model_spec is not None:
success_payload["model"] = model_spec
logger.info(
"Pipeline action %s output: %s",
action_type,
_preview_for_log(result),
)
if action_type == PipelineActionType.GE_PROFILING:
artifacts: GEProfilingArtifacts = result
success_payload["profiling_json"] = artifacts.profiling_result
success_payload["profiling_summary"] = artifacts.profiling_summary
success_payload["ge_report_path"] = artifacts.docs_path
elif action_type == PipelineActionType.GE_RESULT_DESC:
success_payload["table_desc_json"] = result
elif action_type == PipelineActionType.SNIPPET:
success_payload["snippet_json"] = result
elif action_type == PipelineActionType.SNIPPET_ALIAS:
success_payload["snippet_alias_json"] = result
await _post_callback(callback_url, success_payload, client)
return result
async def process_table_profiling_job(
request: TableProfilingJobRequest,
_gateway: LLMGateway,
client: httpx.AsyncClient,
) -> None:
"""Sequentially execute the four-step profiling pipeline and emit callbacks per action."""
timeout_seconds = _extract_timeout_seconds(request.extra_options)
if timeout_seconds is None:
timeout_seconds = DEFAULT_CHAT_TIMEOUT_SECONDS
base_payload = {
"table_id": request.table_id,
"version_ts": request.version_ts,
"callback_url": str(request.callback_url),
"table_schema": request.table_schema,
"table_schema_version_id": request.table_schema_version_id,
"llm_model": request.llm_model,
"llm_timeout_seconds": timeout_seconds,
}
logging_request_payload = _profiling_request_for_log(request)
try:
artifacts: GEProfilingArtifacts = await _run_action_with_callback(
action_type=PipelineActionType.GE_PROFILING,
runner=lambda: _run_ge_profiling(request),
callback_base=base_payload,
client=client,
callback_url=str(request.callback_url),
input_payload=logging_request_payload,
model_spec=request.llm_model,
)
table_desc_json: Dict[str, Any] = await _run_action_with_callback(
action_type=PipelineActionType.GE_RESULT_DESC,
runner=lambda: _execute_result_desc(
artifacts.profiling_result,
request,
request.llm_model,
client,
timeout_seconds,
),
callback_base=base_payload,
client=client,
callback_url=str(request.callback_url),
input_payload=artifacts.profiling_result,
model_spec=request.llm_model,
)
snippet_json: List[Dict[str, Any]] = await _run_action_with_callback(
action_type=PipelineActionType.SNIPPET,
runner=lambda: _execute_snippet_generation(
table_desc_json,
request,
request.llm_model,
client,
timeout_seconds,
),
callback_base=base_payload,
client=client,
callback_url=str(request.callback_url),
input_payload=table_desc_json,
model_spec=request.llm_model,
)
await _run_action_with_callback(
action_type=PipelineActionType.SNIPPET_ALIAS,
runner=lambda: _execute_snippet_alias(
snippet_json,
request,
request.llm_model,
client,
timeout_seconds,
),
callback_base=base_payload,
client=client,
callback_url=str(request.callback_url),
input_payload=snippet_json,
model_spec=request.llm_model,
)
except Exception: # pragma: no cover - defensive catch
logger.exception(
"Table profiling pipeline failed for table_id=%s version_ts=%s",
request.table_id,
request.version_ts,
)

View File

@ -0,0 +1,184 @@
from __future__ import annotations
import json
import logging
from typing import Any, Dict, Tuple
from sqlalchemy import text
from sqlalchemy.engine import Engine
from sqlalchemy.exc import SQLAlchemyError
from app.db import get_engine
from app.models import (
ActionType,
TableSnippetUpsertRequest,
TableSnippetUpsertResponse,
)
logger = logging.getLogger(__name__)
def _serialize_json(value: Any) -> Tuple[str | None, int | None]:
logger.debug("Serializing JSON payload: %s", value)
if value is None:
return None, None
if isinstance(value, str):
encoded = value.encode("utf-8")
return value, len(encoded)
serialized = json.dumps(value, ensure_ascii=False)
encoded = serialized.encode("utf-8")
return serialized, len(encoded)
def _prepare_table_schema(value: Any) -> str:
logger.debug("Preparing table_schema payload.")
if isinstance(value, str):
return value
return json.dumps(value, ensure_ascii=False)
def _collect_common_columns(request: TableSnippetUpsertRequest) -> Dict[str, Any]:
logger.debug(
"Collecting common columns for table_id=%s version_ts=%s action_type=%s",
request.table_id,
request.version_ts,
request.action_type,
)
payload: Dict[str, Any] = {
"table_id": request.table_id,
"version_ts": request.version_ts,
"action_type": request.action_type.value,
"status": request.status.value,
"callback_url": str(request.callback_url),
"table_schema_version_id": request.table_schema_version_id,
"table_schema": _prepare_table_schema(request.table_schema),
}
if request.error_code is not None:
logger.debug("Adding error_code: %s", request.error_code)
payload["error_code"] = request.error_code
if request.error_message is not None:
logger.debug("Adding error_message: %s", request.error_message)
payload["error_message"] = request.error_message
if request.started_at is not None:
payload["started_at"] = request.started_at
if request.finished_at is not None:
payload["finished_at"] = request.finished_at
if request.duration_ms is not None:
payload["duration_ms"] = request.duration_ms
if request.result_checksum is not None:
payload["result_checksum"] = request.result_checksum
logger.debug("Collected common payload: %s", payload)
return payload
def _apply_action_payload(
request: TableSnippetUpsertRequest,
payload: Dict[str, Any],
) -> None:
logger.debug("Applying action-specific payload for action_type=%s", request.action_type)
if request.action_type == ActionType.GE_PROFILING:
full_json, full_size = _serialize_json(request.result_json)
summary_json, summary_size = _serialize_json(request.result_summary_json)
if full_json is not None:
payload["ge_profiling_full"] = full_json
payload["ge_profiling_full_size_bytes"] = full_size
if summary_json is not None:
payload["ge_profiling_summary"] = summary_json
payload["ge_profiling_summary_size_bytes"] = summary_size
if full_size is not None or summary_size is not None:
payload["ge_profiling_total_size_bytes"] = (full_size or 0) + (
summary_size or 0
)
if request.html_report_url:
payload["ge_profiling_html_report_url"] = request.html_report_url
elif request.action_type == ActionType.GE_RESULT_DESC:
full_json, full_size = _serialize_json(request.result_json)
if full_json is not None:
payload["ge_result_desc_full"] = full_json
payload["ge_result_desc_full_size_bytes"] = full_size
elif request.action_type == ActionType.SNIPPET:
full_json, full_size = _serialize_json(request.result_json)
if full_json is not None:
payload["snippet_full"] = full_json
payload["snippet_full_size_bytes"] = full_size
elif request.action_type == ActionType.SNIPPET_ALIAS:
full_json, full_size = _serialize_json(request.result_json)
if full_json is not None:
payload["snippet_alias_full"] = full_json
payload["snippet_alias_full_size_bytes"] = full_size
else:
logger.error("Unsupported action type encountered: %s", request.action_type)
raise ValueError(f"Unsupported action type '{request.action_type}'.")
logger.debug("Payload after applying action-specific data: %s", payload)
def _build_insert_statement(columns: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
logger.debug("Building insert statement for columns: %s", list(columns.keys()))
column_names = list(columns.keys())
placeholders = [f":{name}" for name in column_names]
update_assignments = [
f"{name}=VALUES({name})"
for name in column_names
if name not in {"table_id", "version_ts", "action_type"}
]
update_assignments.append("updated_at=CURRENT_TIMESTAMP")
sql = (
"INSERT INTO action_results ({cols}) VALUES ({vals}) "
"ON DUPLICATE KEY UPDATE {updates}"
).format(
cols=", ".join(column_names),
vals=", ".join(placeholders),
updates=", ".join(update_assignments),
)
logger.debug("Generated SQL: %s", sql)
return sql, columns
def _execute_upsert(engine: Engine, sql: str, params: Dict[str, Any]) -> int:
logger.info("Executing upsert for table_id=%s version_ts=%s action_type=%s", params.get("table_id"), params.get("version_ts"), params.get("action_type"))
with engine.begin() as conn:
result = conn.execute(text(sql), params)
logger.info("Rows affected: %s", result.rowcount)
return result.rowcount
def upsert_action_result(request: TableSnippetUpsertRequest) -> TableSnippetUpsertResponse:
logger.info(
"Received upsert request: table_id=%s version_ts=%s action_type=%s status=%s",
request.table_id,
request.version_ts,
request.action_type,
request.status,
)
logger.debug("Request payload: %s", request.model_dump())
columns = _collect_common_columns(request)
_apply_action_payload(request, columns)
sql, params = _build_insert_statement(columns)
logger.debug("Final SQL params: %s", params)
engine = get_engine()
try:
rowcount = _execute_upsert(engine, sql, params)
except SQLAlchemyError as exc:
logger.exception(
"Failed to upsert action result: table_id=%s version_ts=%s action_type=%s",
request.table_id,
request.version_ts,
request.action_type,
)
raise RuntimeError(f"Database operation failed: {exc}") from exc
updated = rowcount > 1
return TableSnippetUpsertResponse(
table_id=request.table_id,
version_ts=request.version_ts,
action_type=request.action_type,
status=request.status,
updated=updated,
)

File diff suppressed because it is too large Load Diff

View File

@ -121,7 +121,7 @@ def clean_value(value: Any) -> Any:
if isinstance(value, (np.generic,)):
return value.item()
if isinstance(value, pd.Timestamp):
return value.isoformat()
return str(value)
if pd.isna(value):
return None
return value

30
logging.yaml Normal file
View File

@ -0,0 +1,30 @@
version: 1
formatters:
standard:
format: "%(asctime)s %(levelname)s %(name)s:%(lineno)d %(message)s"
handlers:
console:
class: logging.StreamHandler
level: INFO
formatter: standard
stream: ext://sys.stdout
file:
class: logging.handlers.RotatingFileHandler
level: INFO
formatter: standard
filename: logs/app.log
maxBytes: 10485760 # 10 MB
backupCount: 5
encoding: utf-8
loggers:
app:
level: INFO
handlers:
- console
- file
propagate: no
root:
level: INFO
handlers:
- console
- file

View File

@ -0,0 +1,47 @@
系统角色System
你是“数据画像抽取器”。输入是一段 Great Expectations 的 profiling/validation 结果 JSON
可能包含列级期望expect_*)、统计、样例值、类型推断等;也可能带表级/批次元数据。
请将其归一化为一个可被程序消费的“表画像”JSON对不确定项给出置信度与理由。
禁止臆造不存在的列、时间范围或数值。
用户消息User
【输入GE结果JSON】
{{GE_RESULT_JSON}}
【输出要求只输出JSON不要解释文字
{
"table": "<库.表 表名>",
"row_count": <int|null>, // 若未知可为 null
"role": "fact|dimension|unknown", // 依据指标/维度占比与唯一性启发式
"grain": ["<列1>", "<列2>", ...], // 事实粒度猜测(如含 dt/店/类目)
"time": { "column": "<name>|null", "granularity": "day|week|month|unknown", "range": ["YYYY-MM-DD","YYYY-MM-DD"]|null, "has_gaps": true|false|null },
"columns": [
{
"name": "<col>",
"dtype": "<ge推断/物理类型>",
"semantic_type": "dimension|metric|time|text|id|unknown",
"null_rate": <0~1|null>,
"distinct_count": <int|null>,
"distinct_ratio": <0~1|null>,
"stats": { "min": <number|string|null>,"max": <number|string|null>,"mean": <number|null>,"std": <number|null>,"skewness": <number|null> },
"enumish": true|false|null, // 低熵/可枚举
"top_values": [{"value":"<v>","pct":<0~1>}, ...],// 取前K个≤10
"pk_candidate_score": <0~1>, // 唯一性+非空综合评分
"metric_candidate_score": <0~1>, // 数值/偏态/业务词命中
"comment": "<列注释或GE描述|可为空>"
}
],
"primary_key_candidates": [["colA","colB"], ...], // 依据 unique/compound unique 期望
"fk_candidates": [{"from":"<col>","to":"<dim_table(col)>","confidence":<0~1>}],
"quality": {
"failed_expectations": [{"name":"<expect_*>","column":"<col|table>","summary":"<一句话>"}],
"warning_hints": ["空值率>0.2的列: ...", "时间列存在缺口: ..."]
},
"confidence_notes": ["<为什么判定role/grain/time列>"]
}
【判定规则(简要)】
- time列类型为日期/时间 OR 命中 dt/date/day 等命名;若有 min/max 可给出 range若间隔缺口≥1天记 has_gaps=true。
- semantic_type数值+右偏/方差大→更偏 metric高唯一/ID命名→id高基数+文本→text低熵+有限取值→dimension。
- rolemetric列占比高且存在time列→倾向 fact几乎全是枚举/ID且少数值→dimension。
- 置信不高时给出 null 或 unknown并写入 confidence_notes。

View File

@ -0,0 +1,52 @@
系统角色System
你是“SQL片段别名生成器”。
输入为一个或多个 SQL 片段对象(来自 snippet.json输出为针对每个片段生成的多样化别名口语 / 中性 / 专业)、关键词与意图标签。
要求逐个处理所有片段对象,输出同样数量的 JSON 元素。
用户消息User
【上下文】
SQL片段对象数组{{SNIPPET_ARRAY}} // snippet.json中的一个或多个片段
【任务要求】
请针对输入数组中的 每个 SQL 片段,输出一个 JSON 对象,结构如下:
{
"id": "<与输入片段id一致>",
"aliases": [
{"text": "…", "tone": "口语|中性|专业"},
{"text": "…", "tone": "专业"}
],
"keywords": [
"GMV","销售额","TopN","category","类目","趋势","同比","客户","订单","质量","异常检测","join","过滤","sample"
],
"intent_tags": ["aggregate","trend","topn","ratio","quality","join","sample","filter","by_dimension"]
}
生成逻辑规范
1.逐条输出输入数组中每个片段对应一个输出对象id 保持一致)。
2.aliases生成
至少 3 个别名,分别覆盖语气类型:口语 / 中性 / 专业。
≤20字语义需等价不得添加不存在的字段或业务口径。
示例:
GMV趋势分析中性
每天卖多少钱(口语)
按日GMV曲线专业
3.keywords生成
8~15个关键词需涵盖片段核心维度、指标、分析类型和语义近义词。
中英文混合(如 "GMV"/"销售额"、"同比"/"YoY"、"类目"/"category" 等)。
包含用于匹配的分析意图关键词(如 “趋势”、“排行”、“占比”、“质量检查”、“过滤” 等)。
4.intent_tags生成
从以下集合中选取与片段type及用途一致
["aggregate","trend","topn","ratio","quality","join","sample","filter","by_dimension"]
若为条件片段WHERE句型补充 "filter";若含维度分组逻辑,补充 "by_dimension"。
5.语言与内容要求
保持正式书面风格,不添加解释说明。
只输出JSON数组不包含文字描述或额外文本。

View File

@ -0,0 +1,46 @@
系统角色System
你是“SQL片段生成器”。只能基于给定“表画像”生成可复用的分析片段。
为每个片段产出标题、用途描述、片段类型、变量、适用条件、SQL模板mysql方言并注明业务口径与安全限制。
不要发明画像里没有的列。时间/维度/指标须与画像匹配。
用户消息User
【表画像JSON】
{{TABLE_PROFILE_JSON}}
【输出要求只输出JSON数组
[
{
"id": "snpt_<slug>",
"title": "中文标题≤16字",
"desc": "一句话用途",
"type": "aggregate|trend|topn|ratio|quality|join|sample",
"applicability": {
"required_columns": ["<col>", ...],
"time_column": "<dt|nullable>",
"constraints": {
"dim_cardinality_hint": <int|null>, // 用于TopN限制与性能提示
"fk_join_available": true|false,
"notes": ["高基数维度建议LIMIT<=50", "..."]
}
},
"variables": [
{"name":"start_date","type":"date"},
{"name":"end_date","type":"date"},
{"name":"top_n","type":"int","default":10}
],
"dialect_sql": {
"mysql": ""
},
"business_caliber": "清晰口径说明,如 UV以device_id去重粒度=日-类目",
"examples": ["示例问法1","示例问法2"]
}
]
【片段选择建议】
- 若存在 time 列:生成 trend_by_day / yoy_qoq / moving_avg。
- 若存在 enumish 维度distinct 5~200生成 topn_by_dimension / share_of_total。
- 若 metric 列:生成 sum/avg/max、分位数/异常检测3σ/箱线)。
- 有主键/唯一:生成 去重/明细抽样/质量检查。
- 有 fk_candidates同时生成“join维表命名版”和“纯ID版”。
- 高枚举维度:在 constraints.notes 中强调 LIMIT 建议与可能的性能风险。
- 除了完整的sql片段还有sql里部分内容的sql片段比如 where payment_method = 'Credit Card' and delivery_status = 'Deliverd' 的含义是支付方式为信用卡且配送状态是已送达

View File

@ -9,3 +9,5 @@ numpy>=1.24
openpyxl>=3.1
httpx==0.27.2
python-dotenv==1.0.1
requests>=2.31.0
PyYAML>=6.0.1

View File

@ -0,0 +1,226 @@
import argparse
import logging
import os
from typing import Dict, Iterable, List, Optional
import datasets
from datasets import DownloadConfig
from huggingface_hub import snapshot_download
# 批量下载 Hugging Face 上的数据集和模型
# 支持通过命令行参数配置代理和下载参数如超时和重试次数支持批量循环下载存储到file目录下dataset和model子目录
def _parse_id_list(values: Iterable[str]) -> List[str]:
"""将多次传入以及逗号分隔的标识整理为列表."""
ids: List[str] = []
for value in values:
value = value.strip()
if not value:
continue
if "," in value:
ids.extend(v.strip() for v in value.split(",") if v.strip())
else:
ids.append(value)
return ids
def _parse_proxy_args(proxy_args: Iterable[str]) -> Dict[str, str]:
"""解析命令行传入的代理设置,格式 scheme=url."""
proxies: Dict[str, str] = {}
for item in proxy_args:
raw = item.strip()
if not raw:
continue
if "=" not in raw:
logging.warning("代理参数 %s 缺少 '=' 分隔符,将忽略该项", raw)
continue
key, value = raw.split("=", 1)
key = key.strip()
value = value.strip()
if not key or not value:
logging.warning("代理参数 %s 解析失败,将忽略该项", raw)
continue
proxies[key] = value
return proxies
def _sanitize_dir_name(name: str) -> str:
return name.replace("/", "__")
def _ensure_dirs(root_dir: str) -> Dict[str, str]:
paths = {
"dataset": os.path.join(root_dir, "dataset"),
"model": os.path.join(root_dir, "model"),
}
for path in paths.values():
os.makedirs(path, exist_ok=True)
return paths
def _build_download_config(cache_dir: str, retries: Optional[int], proxies: Dict[str, str]) -> DownloadConfig:
config_kwargs = {"cache_dir": cache_dir}
if retries is not None:
config_kwargs["max_retries"] = retries
if proxies:
config_kwargs["proxies"] = proxies
return DownloadConfig(**config_kwargs)
def _apply_timeout(timeout: Optional[float]) -> None:
if timeout is None:
return
str_timeout = str(timeout)
os.environ.setdefault("HF_DATASETS_HTTP_TIMEOUT", str_timeout)
os.environ.setdefault("HF_HUB_HTTP_TIMEOUT", str_timeout)
def _resolve_log_level(level_name: str) -> int:
if isinstance(level_name, int):
return level_name
upper_name = str(level_name).upper()
return getattr(logging, upper_name, logging.INFO)
def _build_argument_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="批量下载 Hugging Face 数据集和模型并存储到指定目录。"
)
parser.add_argument(
"-d",
"--dataset",
action="append",
default=[],
help="要下载的数据集 ID可重复使用或传入逗号分隔列表。",
)
parser.add_argument(
"-m",
"--model",
action="append",
default=[],
help="要下载的模型 ID可重复使用或传入逗号分隔列表。",
)
parser.add_argument(
"-r",
"--root",
default="file",
help="存储根目录,默认 file。",
)
parser.add_argument(
"--retries",
type=int,
default=None,
help="失败后的重试次数,默认不重试。",
)
parser.add_argument(
"--timeout",
type=float,
default=None,
help="HTTP 超时时间(秒),默认跟随库设置。",
)
parser.add_argument(
"-p",
"--proxy",
action="append",
default=[],
help="代理设置,格式 scheme=url可多次传入例如 --proxy http=http://127.0.0.1:7890",
)
parser.add_argument(
"--log-level",
default="INFO",
help="日志级别,默认 INFO。",
)
return parser
def download_datasets(dataset_ids: Iterable[str], root_dir: str, retries: Optional[int], proxies: Dict[str, str]) -> None:
if not dataset_ids:
return
cache_dir = root_dir
download_config = _build_download_config(cache_dir, retries, proxies)
for dataset_id in dataset_ids:
try:
logging.info("开始下载数据集 %s", dataset_id)
# 使用 load_dataset 触发缓存下载
dataset = datasets.load_dataset(
dataset_id,
cache_dir=cache_dir,
download_config=download_config,
download_mode="reuse_cache_if_exists",
)
target_path = os.path.join(root_dir, _sanitize_dir_name(dataset_id))
dataset.save_to_disk(target_path)
logging.info("数据集 %s 下载完成,存储于 %s", dataset_id, target_path)
except Exception as exc: # pylint: disable=broad-except
logging.error("下载数据集 %s 失败: %s", dataset_id, exc)
def download_models(
model_ids: Iterable[str],
target_dir: str,
retries: Optional[int],
proxies: Dict[str, str],
timeout: Optional[float],
) -> None:
if not model_ids:
return
max_attempts = (retries or 0) + 1
hub_kwargs = {
"local_dir": target_dir,
"local_dir_use_symlinks": False,
"max_workers": os.cpu_count() or 4,
}
if proxies:
hub_kwargs["proxies"] = proxies
if timeout is not None:
hub_kwargs["timeout"] = timeout
for model_id in model_ids:
attempt = 0
while attempt < max_attempts:
attempt += 1
try:
logging.info("开始下载模型 %s (尝试 %s/%s)", model_id, attempt, max_attempts)
snapshot_download(
repo_id=model_id,
**hub_kwargs,
)
logging.info("模型 %s 下载完成,存储于 %s", model_id, target_dir)
break
except Exception as exc: # pylint: disable=broad-except
logging.error("下载模型 %s 失败: %s", model_id, exc)
if attempt >= max_attempts:
logging.error("模型 %s 在重试后仍未成功下载", model_id)
def main() -> None:
parser = _build_argument_parser()
args = parser.parse_args()
logging.basicConfig(
level=_resolve_log_level(args.log_level),
format="%(asctime)s - %(levelname)s - %(message)s",
)
dataset_ids = _parse_id_list(args.dataset)
model_ids = _parse_id_list(args.model)
retries = args.retries
timeout = args.timeout
proxies = _parse_proxy_args(args.proxy)
_apply_timeout(timeout)
if not dataset_ids and not model_ids:
logging.warning(
"未配置任何数据集或模型,"
"请通过参数 --dataset / --model 指定 Hugging Face ID"
)
return
dirs = _ensure_dirs(args.root)
download_datasets(dataset_ids, dirs["dataset"], retries, proxies)
download_models(model_ids, dirs["model"], retries, proxies, timeout)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,80 @@
from __future__ import annotations
import json
import os
import sys
from datetime import datetime
from typing import Any, Dict
import requests
def build_demo_payload() -> Dict[str, Any]:
now = datetime.utcnow()
started_at = now.replace(microsecond=0).isoformat() + "Z"
finished_at = now.replace(microsecond=0).isoformat() + "Z"
return {
"table_id": 42,
"version_ts": 20251101200000,
"action_type": "snippet",
"status": "success",
"callback_url": "http://localhost:9999/dummy-callback",
"table_schema_version_id": 7,
"table_schema": {
"columns": [
{"name": "order_id", "type": "bigint"},
{"name": "order_dt", "type": "date"},
{"name": "gmv", "type": "decimal(18,2)"},
]
},
"result_json": [
{
"id": "snpt_daily_gmv",
"title": "按日GMV",
"desc": "统计每日GMV总额",
"type": "trend",
"dialect_sql": {
"mysql": "SELECT order_dt, SUM(gmv) AS total_gmv FROM orders GROUP BY order_dt ORDER BY order_dt"
},
}
],
"result_summary_json": {"total_snippets": 1},
"html_report_url": None,
"error_code": None,
"error_message": None,
"started_at": started_at,
"finished_at": finished_at,
"duration_ms": 1234,
"result_checksum": "demo-checksum",
}
def main() -> int:
base_url = os.getenv("TABLE_SNIPPET_DEMO_BASE_URL", "http://localhost:8000")
endpoint = f"{base_url.rstrip('/')}/v1/table/snippet"
payload = build_demo_payload()
print(f"POST {endpoint}")
print(json.dumps(payload, ensure_ascii=False, indent=2))
try:
response = requests.post(endpoint, json=payload, timeout=30)
except requests.RequestException as exc:
print(f"Request failed: {exc}", file=sys.stderr)
return 1
print(f"\nStatus: {response.status_code}")
try:
data = response.json()
print("Response JSON:")
print(json.dumps(data, ensure_ascii=False, indent=2))
except ValueError:
print("Response Text:")
print(response.text)
return 0 if response.ok else 1
if __name__ == "__main__":
raise SystemExit(main())

54
table_snippet.sql Normal file
View File

@ -0,0 +1,54 @@
CREATE TABLE IF NOT EXISTS action_results (
id BIGINT NOT NULL AUTO_INCREMENT COMMENT '主键',
table_id BIGINT NOT NULL COMMENT '表ID',
version_ts BIGINT NOT NULL COMMENT '版本时间戳(版本号)',
action_type ENUM('ge_profiling','ge_result_desc','snippet','snippet_alias') NOT NULL COMMENT '动作类型',
status ENUM('pending','running','success','failed','partial') NOT NULL DEFAULT 'pending' COMMENT '执行状态',
error_code VARCHAR(128) NULL,
error_message TEXT NULL,
-- 回调 & 观测
callback_url VARCHAR(1024) NOT NULL,
started_at DATETIME NULL,
finished_at DATETIME NULL,
duration_ms INT NULL,
-- 本次schema信息
table_schema_version_id BIGINT NOT NULL,
table_schema JSON NOT NULL,
-- ===== 动作1GE Profiling =====
ge_profiling_full JSON NULL COMMENT 'Profiling完整结果JSON',
ge_profiling_full_size_bytes BIGINT NULL,
ge_profiling_summary JSON NULL COMMENT 'Profiling摘要剔除大value_set等',
ge_profiling_summary_size_bytes BIGINT NULL,
ge_profiling_total_size_bytes BIGINT NULL COMMENT '上两者合计',
ge_profiling_html_report_url VARCHAR(1024) NULL COMMENT 'GE报告HTML路径/URL',
-- ===== 动作2GE Result Desc =====
ge_result_desc_full JSON NULL COMMENT '表描述结果JSON',
ge_result_desc_full_size_bytes BIGINT NULL,
-- ===== 动作3Snippet 生成 =====
snippet_full JSON NULL COMMENT 'SQL知识片段结果JSON',
snippet_full_size_bytes BIGINT NULL,
-- ===== 动作4Snippet Alias 改写 =====
snippet_alias_full JSON NULL COMMENT 'SQL片段改写/丰富结果JSON',
snippet_alias_full_size_bytes BIGINT NULL,
-- 通用可选指标
result_checksum VARBINARY(32) NULL COMMENT '对当前action有效载荷计算的MD5/xxhash',
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
PRIMARY KEY (id),
UNIQUE KEY uq_table_ver_action (table_id, version_ts, action_type),
KEY idx_status (status),
KEY idx_table (table_id, updated_at),
KEY idx_action_time (action_type, version_ts),
KEY idx_schema_version (table_schema_version_id)
) ENGINE=InnoDB
ROW_FORMAT=DYNAMIC
COMMENT='数据分析知识片段表';