Compare commits
7 Commits
72f9735000
...
0b765e6719
| Author | SHA1 | Date | |
|---|---|---|---|
| 0b765e6719 | |||
| cf2e6aedc7 | |||
| 9a7710a70a | |||
| 799b9f8154 | |||
| fe1de87696 | |||
| c2a08e4134 | |||
| 557efc4bf1 |
2
.env
2
.env
@ -17,7 +17,7 @@ DEFAULT_IMPORT_MODEL=deepseek:deepseek-chat
|
||||
IMPORT_GATEWAY_BASE_URL=http://localhost:8000
|
||||
|
||||
# HTTP client configuration
|
||||
HTTP_CLIENT_TIMEOUT=30
|
||||
HTTP_CLIENT_TIMEOUT=60
|
||||
HTTP_CLIENT_TRUST_ENV=false
|
||||
# HTTP_CLIENT_PROXY=
|
||||
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@ -3,4 +3,6 @@ gx/uncommitted/
|
||||
.vscode/
|
||||
**/__pycache__/
|
||||
*.pyc
|
||||
.DS_Store
|
||||
.DS_Store
|
||||
gx/
|
||||
logs/
|
||||
26
app/db.py
Normal file
26
app/db.py
Normal file
@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from functools import lru_cache
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_engine() -> Engine:
|
||||
"""Return a cached SQLAlchemy engine configured from DATABASE_URL."""
|
||||
database_url = os.getenv(
|
||||
"DATABASE_URL",
|
||||
"mysql+pymysql://root:12345678@localhost:3306/data-ge?charset=utf8mb4",
|
||||
)
|
||||
connect_args = {}
|
||||
if database_url.startswith("sqlite"):
|
||||
connect_args["check_same_thread"] = False
|
||||
|
||||
return create_engine(
|
||||
database_url,
|
||||
pool_pre_ping=True,
|
||||
future=True,
|
||||
connect_args=connect_args,
|
||||
)
|
||||
113
app/main.py
113
app/main.py
@ -2,12 +2,17 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import logging.config
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
import httpx
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.exceptions import ProviderAPICallError, ProviderConfigurationError
|
||||
from app.models import (
|
||||
@ -15,30 +20,42 @@ from app.models import (
|
||||
DataImportAnalysisJobRequest,
|
||||
LLMRequest,
|
||||
LLMResponse,
|
||||
TableProfilingJobAck,
|
||||
TableProfilingJobRequest,
|
||||
TableSnippetUpsertRequest,
|
||||
TableSnippetUpsertResponse,
|
||||
)
|
||||
from app.services import LLMGateway
|
||||
from app.services.import_analysis import process_import_analysis_job
|
||||
from app.services.table_profiling import process_table_profiling_job
|
||||
from app.services.table_snippet import upsert_action_result
|
||||
|
||||
|
||||
def _ensure_log_directories(config: dict[str, Any]) -> None:
|
||||
handlers = config.get("handlers", {})
|
||||
for handler_config in handlers.values():
|
||||
filename = handler_config.get("filename")
|
||||
if not filename:
|
||||
continue
|
||||
directory = os.path.dirname(filename)
|
||||
if directory and not os.path.exists(directory):
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
|
||||
def _configure_logging() -> None:
|
||||
level_name = os.getenv("LOG_LEVEL", "INFO").upper()
|
||||
level = getattr(logging, level_name, logging.INFO)
|
||||
log_format = os.getenv(
|
||||
"LOG_FORMAT",
|
||||
"%(asctime)s %(levelname)s %(name)s:%(lineno)d %(message)s",
|
||||
config_path = os.getenv("LOGGING_CONFIG", "logging.yaml")
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, "r", encoding="utf-8") as fh:
|
||||
config = yaml.safe_load(fh)
|
||||
if isinstance(config, dict):
|
||||
_ensure_log_directories(config)
|
||||
logging.config.dictConfig(config)
|
||||
return
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s:%(lineno)d %(message)s",
|
||||
)
|
||||
|
||||
root = logging.getLogger()
|
||||
|
||||
if not root.handlers:
|
||||
logging.basicConfig(level=level, format=log_format)
|
||||
else:
|
||||
root.setLevel(level)
|
||||
formatter = logging.Formatter(log_format)
|
||||
for handler in root.handlers:
|
||||
handler.setLevel(level)
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
|
||||
_configure_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -119,6 +136,24 @@ def create_app() -> FastAPI:
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
@application.exception_handler(RequestValidationError)
|
||||
async def request_validation_exception_handler(
|
||||
request: Request, exc: RequestValidationError
|
||||
) -> JSONResponse:
|
||||
try:
|
||||
raw_body = await request.body()
|
||||
except Exception: # pragma: no cover - defensive
|
||||
raw_body = b"<unavailable>"
|
||||
truncated_body = raw_body[:4096]
|
||||
logger.warning(
|
||||
"Validation error on %s %s: %s | body preview=%s",
|
||||
request.method,
|
||||
request.url.path,
|
||||
exc.errors(),
|
||||
truncated_body.decode("utf-8", errors="ignore"),
|
||||
)
|
||||
return JSONResponse(status_code=422, content={"detail": exc.errors()})
|
||||
|
||||
@application.post(
|
||||
"/v1/chat/completions",
|
||||
response_model=LLMResponse,
|
||||
@ -164,6 +199,52 @@ def create_app() -> FastAPI:
|
||||
|
||||
return DataImportAnalysisJobAck(import_record_id=payload.import_record_id, status="accepted")
|
||||
|
||||
@application.post(
|
||||
"/v1/table/profiling",
|
||||
response_model=TableProfilingJobAck,
|
||||
summary="Run end-to-end GE profiling pipeline and notify via callback per action",
|
||||
status_code=202,
|
||||
)
|
||||
async def run_table_profiling(
|
||||
payload: TableProfilingJobRequest,
|
||||
gateway: LLMGateway = Depends(get_gateway),
|
||||
client: httpx.AsyncClient = Depends(get_http_client),
|
||||
) -> TableProfilingJobAck:
|
||||
request_copy = payload.model_copy(deep=True)
|
||||
|
||||
async def _runner() -> None:
|
||||
await process_table_profiling_job(request_copy, gateway, client)
|
||||
|
||||
asyncio.create_task(_runner())
|
||||
|
||||
return TableProfilingJobAck(
|
||||
table_id=payload.table_id,
|
||||
version_ts=payload.version_ts,
|
||||
status="accepted",
|
||||
)
|
||||
|
||||
@application.post(
|
||||
"/v1/table/snippet",
|
||||
response_model=TableSnippetUpsertResponse,
|
||||
summary="Persist or update action results, such as table snippets.",
|
||||
)
|
||||
async def upsert_table_snippet(
|
||||
payload: TableSnippetUpsertRequest,
|
||||
) -> TableSnippetUpsertResponse:
|
||||
request_copy = payload.model_copy(deep=True)
|
||||
|
||||
try:
|
||||
return await asyncio.to_thread(upsert_action_result, request_copy)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to upsert table snippet: table_id=%s version_ts=%s action_type=%s",
|
||||
payload.table_id,
|
||||
payload.version_ts,
|
||||
payload.action_type,
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
@application.post("/__mock__/import-callback")
|
||||
async def mock_import_callback(payload: dict[str, Any]) -> dict[str, str]:
|
||||
logger.info("Received import analysis callback: %s", payload)
|
||||
|
||||
160
app/models.py
160
app/models.py
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
@ -76,8 +77,8 @@ class DataImportAnalysisRequest(BaseModel):
|
||||
description="Ordered list of table headers associated with the data.",
|
||||
)
|
||||
llm_model: str = Field(
|
||||
...,
|
||||
description="Model identifier. Accepts 'provider:model' format or plain model name.",
|
||||
None,
|
||||
description="Model identifier. Accepts 'provider:model_name' format or custom model alias.",
|
||||
)
|
||||
temperature: Optional[float] = Field(
|
||||
None,
|
||||
@ -135,3 +136,158 @@ class DataImportAnalysisJobRequest(BaseModel):
|
||||
class DataImportAnalysisJobAck(BaseModel):
|
||||
import_record_id: str = Field(..., description="Echo of the import record identifier")
|
||||
status: str = Field("accepted", description="Processing status acknowledgement.")
|
||||
|
||||
|
||||
class ActionType(str, Enum):
|
||||
GE_PROFILING = "ge_profiling"
|
||||
GE_RESULT_DESC = "ge_result_desc"
|
||||
SNIPPET = "snippet"
|
||||
SNIPPET_ALIAS = "snippet_alias"
|
||||
|
||||
|
||||
class ActionStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
PARTIAL = "partial"
|
||||
|
||||
|
||||
class TableProfilingJobRequest(BaseModel):
|
||||
table_id: str = Field(..., description="Unique identifier for the table to profile.")
|
||||
version_ts: str = Field(
|
||||
...,
|
||||
pattern=r"^\d{14}$",
|
||||
description="Version timestamp expressed as fourteen digit string (yyyyMMddHHmmss).",
|
||||
)
|
||||
callback_url: HttpUrl = Field(
|
||||
...,
|
||||
description="Callback endpoint invoked after each pipeline action completes.",
|
||||
)
|
||||
llm_model: Optional[str] = Field(
|
||||
None,
|
||||
description="Default LLM model spec applied to prompt-based actions when overrides are omitted.",
|
||||
)
|
||||
table_schema: Optional[Any] = Field(
|
||||
None,
|
||||
description="Schema structure snapshot for the current table version.",
|
||||
)
|
||||
table_schema_version_id: Optional[str] = Field(
|
||||
None,
|
||||
description="Identifier for the schema snapshot provided in table_schema.",
|
||||
)
|
||||
table_link_info: Optional[Dict[str, Any]] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Information describing how to locate the source table for profiling. "
|
||||
"For example: {'type': 'sql', 'connection_string': 'mysql+pymysql://user:pass@host/db', "
|
||||
"'table': 'schema.table_name'}."
|
||||
),
|
||||
)
|
||||
table_access_info: Optional[Dict[str, Any]] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Credentials or supplemental parameters required to access the table described in table_link_info. "
|
||||
"These values can be merged into the connection string using Python format placeholders."
|
||||
),
|
||||
)
|
||||
ge_batch_request: Optional[Dict[str, Any]] = Field(
|
||||
None,
|
||||
description="Optional Great Expectations batch request payload used for profiling.",
|
||||
)
|
||||
ge_expectation_suite_name: Optional[str] = Field(
|
||||
None,
|
||||
description="Expectation suite name used during profiling. Created automatically when absent.",
|
||||
)
|
||||
ge_data_context_root: Optional[str] = Field(
|
||||
None,
|
||||
description="Custom root directory for the Great Expectations data context. Defaults to project ./gx.",
|
||||
)
|
||||
ge_datasource_name: Optional[str] = Field(
|
||||
None,
|
||||
description="Datasource name registered inside the GE context when batch_request is not supplied.",
|
||||
)
|
||||
ge_data_asset_name: Optional[str] = Field(
|
||||
None,
|
||||
description="Data asset reference used when inferring batch request from datasource configuration.",
|
||||
)
|
||||
ge_profiler_type: str = Field(
|
||||
"user_configurable",
|
||||
description="Profiler implementation identifier. Currently supports 'user_configurable' or 'data_assistant'.",
|
||||
)
|
||||
|
||||
result_desc_model: Optional[str] = Field(
|
||||
None,
|
||||
description="LLM model override used for GE result description (action 2).",
|
||||
)
|
||||
snippet_model: Optional[str] = Field(
|
||||
None,
|
||||
description="LLM model override used for snippet generation (action 3).",
|
||||
)
|
||||
snippet_alias_model: Optional[str] = Field(
|
||||
None,
|
||||
description="LLM model override used for snippet alias enrichment (action 4).",
|
||||
)
|
||||
extra_options: Optional[Dict[str, Any]] = Field(
|
||||
None,
|
||||
description="Miscellaneous execution flags applied across pipeline steps.",
|
||||
)
|
||||
|
||||
|
||||
class TableProfilingJobAck(BaseModel):
|
||||
table_id: str = Field(..., description="Echo of the table identifier.")
|
||||
version_ts: str = Field(..., description="Echo of the profiling version timestamp (yyyyMMddHHmmss).")
|
||||
status: str = Field("accepted", description="Processing acknowledgement status.")
|
||||
|
||||
|
||||
class TableSnippetUpsertRequest(BaseModel):
|
||||
table_id: int = Field(..., ge=1, description="Unique identifier for the table.")
|
||||
version_ts: int = Field(
|
||||
...,
|
||||
ge=0,
|
||||
description="Version timestamp aligned with the pipeline (yyyyMMddHHmmss as integer).",
|
||||
)
|
||||
action_type: ActionType = Field(..., description="Pipeline action type for this record.")
|
||||
status: ActionStatus = Field(
|
||||
ActionStatus.SUCCESS, description="Execution status for the action."
|
||||
)
|
||||
callback_url: HttpUrl = Field(..., description="Callback URL associated with the action run.")
|
||||
table_schema_version_id: int = Field(..., ge=0, description="Identifier for the schema snapshot.")
|
||||
table_schema: Any = Field(..., description="Schema snapshot payload for the table.")
|
||||
result_json: Optional[Any] = Field(
|
||||
None,
|
||||
description="Primary result payload for the action (e.g., profiling output, snippet array).",
|
||||
)
|
||||
result_summary_json: Optional[Any] = Field(
|
||||
None,
|
||||
description="Optional summary payload (e.g., profiling summary) for the action.",
|
||||
)
|
||||
html_report_url: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional HTML report URL generated by the action.",
|
||||
)
|
||||
error_code: Optional[str] = Field(None, description="Optional error code when status indicates a failure.")
|
||||
error_message: Optional[str] = Field(None, description="Optional error message when status indicates a failure.")
|
||||
started_at: Optional[datetime] = Field(
|
||||
None, description="Timestamp when the action started executing."
|
||||
)
|
||||
finished_at: Optional[datetime] = Field(
|
||||
None, description="Timestamp when the action finished executing."
|
||||
)
|
||||
duration_ms: Optional[int] = Field(
|
||||
None,
|
||||
ge=0,
|
||||
description="Optional execution duration in milliseconds.",
|
||||
)
|
||||
result_checksum: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional checksum for the result payload (e.g., MD5).",
|
||||
)
|
||||
|
||||
|
||||
class TableSnippetUpsertResponse(BaseModel):
|
||||
table_id: int
|
||||
version_ts: int
|
||||
action_type: ActionType
|
||||
status: ActionStatus
|
||||
updated: bool
|
||||
|
||||
@ -42,7 +42,7 @@ def _env_float(name: str, default: float) -> float:
|
||||
return default
|
||||
|
||||
|
||||
IMPORT_CHAT_TIMEOUT_SECONDS = _env_float("IMPORT_CHAT_TIMEOUT_SECONDS", 90.0)
|
||||
IMPORT_CHAT_TIMEOUT_SECONDS = _env_float("IMPORT_CHAT_TIMEOUT_SECONDS", 120.0)
|
||||
|
||||
SUPPORTED_IMPORT_MODELS = get_supported_import_models()
|
||||
|
||||
@ -298,7 +298,7 @@ def parse_llm_analysis_json(llm_response: LLMResponse) -> Dict[str, Any]:
|
||||
try:
|
||||
return json.loads(json_payload)
|
||||
except json.JSONDecodeError as exc:
|
||||
preview = json_payload[:2000]
|
||||
preview = json_payload[:10000]
|
||||
logger.error("Failed to parse JSON from LLM response content: %s", preview, exc_info=True)
|
||||
raise ProviderAPICallError("LLM response JSON could not be parsed.") from exc
|
||||
|
||||
|
||||
832
app/services/table_profiling.py
Normal file
832
app/services/table_profiling.py
Normal file
@ -0,0 +1,832 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from datetime import date, datetime
|
||||
from dataclasses import asdict, dataclass, is_dataclass
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
import great_expectations as gx
|
||||
from great_expectations.core.batch import RuntimeBatchRequest
|
||||
from great_expectations.core.expectation_suite import ExpectationSuite
|
||||
from great_expectations.data_context import AbstractDataContext
|
||||
from great_expectations.exceptions import DataContextError, MetricResolutionError
|
||||
|
||||
from app.exceptions import ProviderAPICallError
|
||||
from app.models import TableProfilingJobRequest
|
||||
from app.services import LLMGateway
|
||||
from app.settings import DEFAULT_IMPORT_MODEL
|
||||
from app.services.import_analysis import (
|
||||
IMPORT_GATEWAY_BASE_URL,
|
||||
resolve_provider_from_model,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
GE_REPORT_RELATIVE_PATH = Path("uncommitted") / "data_docs" / "local_site" / "index.html"
|
||||
PROMPT_FILENAMES = {
|
||||
"ge_result_desc": "ge_result_desc_prompt.md",
|
||||
"snippet_generator": "snippet_generator.md",
|
||||
"snippet_alias": "snippet_alias_generator.md",
|
||||
}
|
||||
DEFAULT_CHAT_TIMEOUT_SECONDS = 90.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class GEProfilingArtifacts:
|
||||
profiling_result: Dict[str, Any]
|
||||
profiling_summary: Dict[str, Any]
|
||||
docs_path: str
|
||||
|
||||
|
||||
class PipelineActionType:
|
||||
GE_PROFILING = "ge_profiling"
|
||||
GE_RESULT_DESC = "ge_result_desc"
|
||||
SNIPPET = "snippet"
|
||||
SNIPPET_ALIAS = "snippet_alias"
|
||||
|
||||
|
||||
def _project_root() -> Path:
|
||||
return Path(__file__).resolve().parents[2]
|
||||
|
||||
|
||||
def _prompt_dir() -> Path:
|
||||
return _project_root() / "prompt"
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def _load_prompt_parts(filename: str) -> Tuple[str, str]:
|
||||
prompt_path = _prompt_dir() / filename
|
||||
if not prompt_path.exists():
|
||||
raise FileNotFoundError(f"Prompt template not found: {prompt_path}")
|
||||
|
||||
raw = prompt_path.read_text(encoding="utf-8")
|
||||
splitter = "用户消息(User)"
|
||||
if splitter not in raw:
|
||||
raise ValueError(f"Prompt template '{filename}' missing separator '{splitter}'.")
|
||||
|
||||
system_raw, user_raw = raw.split(splitter, maxsplit=1)
|
||||
system_text = system_raw.replace("系统角色(System)", "").strip()
|
||||
user_text = user_raw.strip()
|
||||
return system_text, user_text
|
||||
|
||||
|
||||
def _render_prompt(template_key: str, replacements: Dict[str, str]) -> Tuple[str, str]:
|
||||
filename = PROMPT_FILENAMES[template_key]
|
||||
system_text, user_template = _load_prompt_parts(filename)
|
||||
|
||||
rendered_user = user_template
|
||||
for key, value in replacements.items():
|
||||
rendered_user = rendered_user.replace(key, value)
|
||||
|
||||
return system_text, rendered_user
|
||||
|
||||
|
||||
def _extract_timeout_seconds(options: Optional[Dict[str, Any]]) -> Optional[float]:
|
||||
if not options:
|
||||
return None
|
||||
value = options.get("llm_timeout_seconds")
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
timeout = float(value)
|
||||
if timeout <= 0:
|
||||
raise ValueError
|
||||
return timeout
|
||||
except (TypeError, ValueError):
|
||||
logger.warning(
|
||||
"Invalid llm_timeout_seconds value in extra_options: %r. Falling back to default.",
|
||||
value,
|
||||
)
|
||||
return DEFAULT_CHAT_TIMEOUT_SECONDS
|
||||
|
||||
|
||||
def _extract_json_payload(content: str) -> str:
|
||||
fenced = re.search(
|
||||
r"```(?:json)?\s*([\s\S]+?)```",
|
||||
content,
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
if fenced:
|
||||
snippet = fenced.group(1).strip()
|
||||
if snippet:
|
||||
return snippet
|
||||
|
||||
stripped = content.strip()
|
||||
if not stripped:
|
||||
raise ValueError("Empty LLM content.")
|
||||
|
||||
for opener, closer in (("{", "}"), ("[", "]")):
|
||||
start = stripped.find(opener)
|
||||
end = stripped.rfind(closer)
|
||||
if start != -1 and end != -1 and end > start:
|
||||
candidate = stripped[start : end + 1].strip()
|
||||
return candidate
|
||||
|
||||
return stripped
|
||||
|
||||
|
||||
def _parse_completion_payload(response_payload: Dict[str, Any]) -> Any:
|
||||
choices = response_payload.get("choices") or []
|
||||
if not choices:
|
||||
raise ProviderAPICallError("LLM response did not contain choices to parse.")
|
||||
message = choices[0].get("message") or {}
|
||||
content = message.get("content") or ""
|
||||
if not content.strip():
|
||||
raise ProviderAPICallError("LLM response content is empty.")
|
||||
json_payload = _extract_json_payload(content)
|
||||
try:
|
||||
return json.loads(json_payload)
|
||||
except json.JSONDecodeError as exc:
|
||||
preview = json_payload[:800]
|
||||
logger.error("Failed to parse JSON from LLM response: %s", preview, exc_info=True)
|
||||
raise ProviderAPICallError("LLM response JSON parsing failed.") from exc
|
||||
|
||||
|
||||
async def _post_callback(callback_url: str, payload: Dict[str, Any], client: httpx.AsyncClient) -> None:
|
||||
safe_payload = _normalize_for_json(payload)
|
||||
try:
|
||||
logger.info(
|
||||
"Posting pipeline action callback to %s: %s",
|
||||
callback_url,
|
||||
json.dumps(safe_payload, ensure_ascii=False),
|
||||
)
|
||||
response = await client.post(callback_url, json=safe_payload)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPError as exc:
|
||||
logger.error("Callback delivery to %s failed: %s", callback_url, exc, exc_info=True)
|
||||
|
||||
|
||||
def _sanitize_value_set(value: Any, max_values: int) -> Tuple[Any, Optional[Dict[str, int]]]:
|
||||
if not isinstance(value, list):
|
||||
return value, None
|
||||
original_len = len(value)
|
||||
if original_len <= max_values:
|
||||
return value, None
|
||||
trimmed = value[:max_values]
|
||||
return trimmed, {"original_length": original_len, "retained": max_values}
|
||||
|
||||
|
||||
def _sanitize_expectation_suite(suite: ExpectationSuite, max_value_set_values: int = 100) -> Dict[str, Any]:
|
||||
suite_dict = suite.to_json_dict()
|
||||
remarks: List[Dict[str, Any]] = []
|
||||
|
||||
for expectation in suite_dict.get("expectations", []):
|
||||
kwargs = expectation.get("kwargs", {})
|
||||
if "value_set" in kwargs:
|
||||
sanitized_value, note = _sanitize_value_set(kwargs["value_set"], max_value_set_values)
|
||||
kwargs["value_set"] = sanitized_value
|
||||
if note:
|
||||
expectation.setdefault("meta", {})
|
||||
expectation["meta"]["value_set_truncated"] = note
|
||||
remarks.append(
|
||||
{
|
||||
"column": kwargs.get("column"),
|
||||
"expectation": expectation.get("expectation_type"),
|
||||
"note": note,
|
||||
}
|
||||
)
|
||||
|
||||
if remarks:
|
||||
suite_dict.setdefault("meta", {})
|
||||
suite_dict["meta"]["value_set_truncations"] = remarks
|
||||
|
||||
return suite_dict
|
||||
|
||||
|
||||
def _summarize_expectation_suite(suite_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
column_map: Dict[str, Dict[str, Any]] = {}
|
||||
table_expectations: List[Dict[str, Any]] = []
|
||||
|
||||
for expectation in suite_dict.get("expectations", []):
|
||||
expectation_type = expectation.get("expectation_type")
|
||||
kwargs = expectation.get("kwargs", {})
|
||||
column = kwargs.get("column")
|
||||
summary_entry: Dict[str, Any] = {"expectation": expectation_type}
|
||||
|
||||
if "value_set" in kwargs and isinstance(kwargs["value_set"], list):
|
||||
summary_entry["value_set_size"] = len(kwargs["value_set"])
|
||||
summary_entry["value_set_preview"] = kwargs["value_set"][:5]
|
||||
|
||||
if column:
|
||||
column_entry = column_map.setdefault(
|
||||
column,
|
||||
{"name": column, "expectations": []},
|
||||
)
|
||||
column_entry["expectations"].append(summary_entry)
|
||||
else:
|
||||
table_expectations.append(summary_entry)
|
||||
|
||||
summary = {
|
||||
"column_profiles": list(column_map.values()),
|
||||
"table_level_expectations": table_expectations,
|
||||
"total_expectations": len(suite_dict.get("expectations", [])),
|
||||
}
|
||||
return summary
|
||||
|
||||
|
||||
def _sanitize_identifier(raw: Optional[str], fallback: str) -> str:
|
||||
if not raw:
|
||||
return fallback
|
||||
candidate = re.sub(r"[^0-9A-Za-z_]+", "_", raw).strip("_")
|
||||
return candidate or fallback
|
||||
|
||||
|
||||
def _format_connection_string(template: str, access_info: Dict[str, Any]) -> str:
|
||||
if not access_info:
|
||||
return template
|
||||
try:
|
||||
return template.format_map({k: v for k, v in access_info.items()})
|
||||
except KeyError as exc:
|
||||
missing = exc.args[0]
|
||||
raise ValueError(f"table_access_info missing key '{missing}' required by connection_string.") from exc
|
||||
|
||||
|
||||
def _ensure_sql_runtime_datasource(
|
||||
context: AbstractDataContext,
|
||||
datasource_name: str,
|
||||
connection_string: str,
|
||||
) -> None:
|
||||
try:
|
||||
datasource = context.get_datasource(datasource_name)
|
||||
except (DataContextError, ValueError) as exc:
|
||||
message = str(exc)
|
||||
if "Could not find a datasource" in message or "Unable to load datasource" in message:
|
||||
datasource = None
|
||||
else: # pragma: no cover - defensive
|
||||
raise RuntimeError(f"Failed to inspect datasource '{datasource_name}'.") from exc
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
raise RuntimeError(f"Failed to inspect datasource '{datasource_name}'.") from exc
|
||||
|
||||
if datasource is not None:
|
||||
execution_engine = getattr(datasource, "execution_engine", None)
|
||||
current_conn = getattr(execution_engine, "connection_string", None)
|
||||
if current_conn and current_conn != connection_string:
|
||||
logger.info(
|
||||
"Existing datasource %s uses different connection string; creating dedicated runtime datasource.",
|
||||
datasource_name,
|
||||
)
|
||||
try:
|
||||
context.delete_datasource(datasource_name)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.warning(
|
||||
"Failed to delete datasource %s before recreation: %s",
|
||||
datasource_name,
|
||||
exc,
|
||||
)
|
||||
else:
|
||||
datasource = None
|
||||
|
||||
if datasource is not None:
|
||||
return
|
||||
|
||||
runtime_datasource_config = {
|
||||
"name": datasource_name,
|
||||
"class_name": "Datasource",
|
||||
"execution_engine": {
|
||||
"class_name": "SqlAlchemyExecutionEngine",
|
||||
"connection_string": connection_string,
|
||||
},
|
||||
"data_connectors": {
|
||||
"runtime_connector": {
|
||||
"class_name": "RuntimeDataConnector",
|
||||
"batch_identifiers": ["default_identifier_name"],
|
||||
}
|
||||
},
|
||||
}
|
||||
try:
|
||||
context.add_datasource(**runtime_datasource_config)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
raise RuntimeError(f"Failed to create runtime datasource '{datasource_name}'.") from exc
|
||||
|
||||
|
||||
def _build_sql_runtime_batch_request(
|
||||
context: AbstractDataContext,
|
||||
request: TableProfilingJobRequest,
|
||||
) -> RuntimeBatchRequest:
|
||||
link_info = request.table_link_info or {}
|
||||
access_info = request.table_access_info or {}
|
||||
|
||||
connection_template = link_info.get("connection_string")
|
||||
if not connection_template:
|
||||
raise ValueError("table_link_info.connection_string is required when using table_link_info.")
|
||||
|
||||
connection_string = _format_connection_string(connection_template, access_info)
|
||||
|
||||
source_type = (link_info.get("type") or "sql").lower()
|
||||
if source_type != "sql":
|
||||
raise ValueError(f"Unsupported table_link_info.type='{source_type}'. Only 'sql' is supported.")
|
||||
|
||||
query = link_info.get("query")
|
||||
table_name = link_info.get("table") or link_info.get("table_name")
|
||||
schema_name = link_info.get("schema")
|
||||
|
||||
if not query and not table_name:
|
||||
raise ValueError("Either table_link_info.query or table_link_info.table must be provided.")
|
||||
|
||||
if not query:
|
||||
if not table_name:
|
||||
raise ValueError("table_link_info.table must be provided when query is omitted.")
|
||||
|
||||
identifier = re.compile(r"^[A-Za-z_][A-Za-z0-9_$]*$")
|
||||
|
||||
def _quote(name: str) -> str:
|
||||
if identifier.match(name):
|
||||
return name
|
||||
return f"`{name.replace('`', '``')}`"
|
||||
|
||||
if schema_name:
|
||||
schema_part = schema_name if "." not in schema_name else schema_name.split(".")[-1]
|
||||
table_part = table_name if "." not in table_name else table_name.split(".")[-1]
|
||||
qualified_table = f"{_quote(schema_part)}.{_quote(table_part)}"
|
||||
else:
|
||||
qualified_table = _quote(table_name)
|
||||
|
||||
query = f"SELECT * FROM {qualified_table}"
|
||||
limit = link_info.get("limit")
|
||||
if isinstance(limit, int) and limit > 0:
|
||||
query = f"{query} LIMIT {limit}"
|
||||
|
||||
datasource_name = request.ge_datasource_name or _sanitize_identifier(
|
||||
f"{request.table_id}_runtime_ds", "runtime_ds"
|
||||
)
|
||||
data_asset_name = request.ge_data_asset_name or _sanitize_identifier(
|
||||
table_name or "runtime_query", "runtime_query"
|
||||
)
|
||||
|
||||
_ensure_sql_runtime_datasource(context, datasource_name, connection_string)
|
||||
|
||||
batch_identifiers = {
|
||||
"default_identifier_name": f"{request.table_id}:{request.version_ts}",
|
||||
}
|
||||
|
||||
return RuntimeBatchRequest(
|
||||
datasource_name=datasource_name,
|
||||
data_connector_name="runtime_connector",
|
||||
data_asset_name=data_asset_name,
|
||||
runtime_parameters={"query": query},
|
||||
batch_identifiers=batch_identifiers,
|
||||
)
|
||||
|
||||
|
||||
def _run_onboarding_assistant(
|
||||
context: AbstractDataContext,
|
||||
batch_request: Any,
|
||||
suite_name: str,
|
||||
) -> Tuple[ExpectationSuite, Any]:
|
||||
assistant = context.assistants.onboarding
|
||||
assistant_result = assistant.run(batch_request=batch_request)
|
||||
suite = assistant_result.get_expectation_suite(expectation_suite_name=suite_name)
|
||||
context.save_expectation_suite(suite, expectation_suite_name=suite_name)
|
||||
validation_getter = getattr(assistant_result, "get_validation_result", None)
|
||||
if callable(validation_getter):
|
||||
validation_result = validation_getter()
|
||||
else:
|
||||
validation_result = getattr(assistant_result, "validation_result", None)
|
||||
if validation_result is None:
|
||||
# Fallback: rerun validation using the freshly generated expectation suite.
|
||||
validator = context.get_validator(
|
||||
batch_request=batch_request,
|
||||
expectation_suite_name=suite_name,
|
||||
)
|
||||
validation_result = validator.validate()
|
||||
return suite, validation_result
|
||||
|
||||
|
||||
def _resolve_context(request: TableProfilingJobRequest) -> AbstractDataContext:
|
||||
context_kwargs: Dict[str, Any] = {}
|
||||
if request.ge_data_context_root:
|
||||
context_kwargs["project_root_dir"] = request.ge_data_context_root
|
||||
elif os.environ.get("GE_DATA_CONTEXT_ROOT"):
|
||||
context_kwargs["project_root_dir"] = os.environ["GE_DATA_CONTEXT_ROOT"]
|
||||
else:
|
||||
context_kwargs["project_root_dir"] = str(_project_root())
|
||||
|
||||
return gx.get_context(**context_kwargs)
|
||||
|
||||
|
||||
def _build_batch_request(
|
||||
context: AbstractDataContext,
|
||||
request: TableProfilingJobRequest,
|
||||
) -> Any:
|
||||
if request.ge_batch_request:
|
||||
from great_expectations.core.batch import BatchRequest
|
||||
|
||||
return BatchRequest(**request.ge_batch_request)
|
||||
|
||||
if request.table_link_info:
|
||||
return _build_sql_runtime_batch_request(context, request)
|
||||
|
||||
if not request.ge_datasource_name or not request.ge_data_asset_name:
|
||||
raise ValueError(
|
||||
"ge_batch_request or (ge_datasource_name and ge_data_asset_name) must be provided."
|
||||
)
|
||||
|
||||
datasource = context.get_datasource(request.ge_datasource_name)
|
||||
data_asset = datasource.get_asset(request.ge_data_asset_name)
|
||||
return data_asset.build_batch_request()
|
||||
|
||||
|
||||
async def _run_ge_profiling(request: TableProfilingJobRequest) -> GEProfilingArtifacts:
|
||||
def _execute() -> GEProfilingArtifacts:
|
||||
context = _resolve_context(request)
|
||||
suite_name = (
|
||||
request.ge_expectation_suite_name
|
||||
or f"{request.table_id}_profiling"
|
||||
)
|
||||
|
||||
batch_request = _build_batch_request(context, request)
|
||||
try:
|
||||
context.get_expectation_suite(suite_name)
|
||||
except DataContextError:
|
||||
context.add_expectation_suite(suite_name)
|
||||
|
||||
validator = context.get_validator(
|
||||
batch_request=batch_request,
|
||||
expectation_suite_name=suite_name,
|
||||
)
|
||||
|
||||
profiler_type = (request.ge_profiler_type or "user_configurable").lower()
|
||||
|
||||
if profiler_type == "data_assistant":
|
||||
suite, validation_result = _run_onboarding_assistant(
|
||||
context,
|
||||
batch_request,
|
||||
suite_name,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
from great_expectations.profile.user_configurable_profiler import (
|
||||
UserConfigurableProfiler,
|
||||
)
|
||||
except ImportError as err: # pragma: no cover - dependency guard
|
||||
raise RuntimeError(
|
||||
"UserConfigurableProfiler is unavailable; install great_expectations profiling extra or switch profiler."
|
||||
) from err
|
||||
|
||||
profiler = UserConfigurableProfiler(profile_dataset=validator)
|
||||
try:
|
||||
suite = profiler.build_suite()
|
||||
context.save_expectation_suite(suite, expectation_suite_name=suite_name)
|
||||
validator.expectation_suite = suite
|
||||
validation_result = validator.validate()
|
||||
except MetricResolutionError as exc:
|
||||
logger.warning(
|
||||
"UserConfigurableProfiler failed (%s); falling back to data assistant profiling.",
|
||||
exc,
|
||||
)
|
||||
suite, validation_result = _run_onboarding_assistant(
|
||||
context,
|
||||
batch_request,
|
||||
suite_name,
|
||||
)
|
||||
|
||||
sanitized_suite = _sanitize_expectation_suite(suite)
|
||||
summary = _summarize_expectation_suite(sanitized_suite)
|
||||
validation_dict = validation_result.to_json_dict()
|
||||
|
||||
context.build_data_docs()
|
||||
docs_path = Path(context.root_directory) / GE_REPORT_RELATIVE_PATH
|
||||
|
||||
profiling_result = {
|
||||
"expectation_suite": sanitized_suite,
|
||||
"validation_result": validation_dict,
|
||||
"batch_request": getattr(batch_request, "to_json_dict", lambda: None)() or getattr(batch_request, "dict", lambda: None)(),
|
||||
}
|
||||
|
||||
return GEProfilingArtifacts(
|
||||
profiling_result=profiling_result,
|
||||
profiling_summary=summary,
|
||||
docs_path=str(docs_path),
|
||||
)
|
||||
|
||||
return await asyncio.to_thread(_execute)
|
||||
|
||||
|
||||
async def _call_chat_completions(
|
||||
*,
|
||||
model_spec: str,
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
client: httpx.AsyncClient,
|
||||
temperature: float = 0.2,
|
||||
timeout_seconds: Optional[float] = None,
|
||||
) -> Any:
|
||||
provider, model_name = resolve_provider_from_model(model_spec)
|
||||
payload = {
|
||||
"provider": provider.value,
|
||||
"model": model_name,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
"temperature": temperature,
|
||||
}
|
||||
payload_size_bytes = len(json.dumps(payload, ensure_ascii=False).encode("utf-8"))
|
||||
|
||||
url = f"{IMPORT_GATEWAY_BASE_URL.rstrip('/')}/v1/chat/completions"
|
||||
try:
|
||||
# log the request whole info
|
||||
logger.info(
|
||||
"Calling chat completions API %s with model %s and size %s and payload %s",
|
||||
url,
|
||||
model_name,
|
||||
payload_size_bytes,
|
||||
payload,
|
||||
)
|
||||
response = await client.post(url, json=payload, timeout=timeout_seconds)
|
||||
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPError as exc:
|
||||
error_name = exc.__class__.__name__
|
||||
detail = str(exc).strip()
|
||||
if detail:
|
||||
message = f"Chat completions request failed ({error_name}): {detail}"
|
||||
else:
|
||||
message = f"Chat completions request failed ({error_name})."
|
||||
raise ProviderAPICallError(message) from exc
|
||||
|
||||
try:
|
||||
response_payload = response.json()
|
||||
except ValueError as exc:
|
||||
raise ProviderAPICallError("Chat completions response was not valid JSON.") from exc
|
||||
|
||||
return _parse_completion_payload(response_payload)
|
||||
|
||||
|
||||
def _normalize_for_json(value: Any) -> Any:
|
||||
if value is None or isinstance(value, (str, int, float, bool)):
|
||||
return value
|
||||
if isinstance(value, (datetime, date)):
|
||||
return str(value)
|
||||
if hasattr(value, "model_dump"):
|
||||
try:
|
||||
return value.model_dump()
|
||||
except Exception: # pragma: no cover - defensive
|
||||
pass
|
||||
if is_dataclass(value):
|
||||
return asdict(value)
|
||||
if isinstance(value, dict):
|
||||
return {k: _normalize_for_json(v) for k, v in value.items()}
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
return [_normalize_for_json(v) for v in value]
|
||||
if hasattr(value, "to_json_dict"):
|
||||
try:
|
||||
return value.to_json_dict()
|
||||
except Exception: # pragma: no cover - defensive
|
||||
pass
|
||||
if hasattr(value, "__dict__"):
|
||||
return _normalize_for_json(value.__dict__)
|
||||
return repr(value)
|
||||
|
||||
|
||||
def _json_dumps(data: Any) -> str:
|
||||
normalised = _normalize_for_json(data)
|
||||
return json.dumps(normalised, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
def _preview_for_log(data: Any) -> str:
|
||||
try:
|
||||
serialised = _json_dumps(data)
|
||||
except Exception:
|
||||
serialised = repr(data)
|
||||
|
||||
return serialised
|
||||
|
||||
|
||||
def _profiling_request_for_log(request: TableProfilingJobRequest) -> Dict[str, Any]:
|
||||
payload = request.model_dump()
|
||||
access_info = payload.get("table_access_info")
|
||||
if isinstance(access_info, dict):
|
||||
payload["table_access_info"] = {key: "***" for key in access_info.keys()}
|
||||
return payload
|
||||
|
||||
|
||||
async def _execute_result_desc(
|
||||
profiling_json: Dict[str, Any],
|
||||
_request: TableProfilingJobRequest,
|
||||
llm_model: str,
|
||||
client: httpx.AsyncClient,
|
||||
timeout_seconds: Optional[float],
|
||||
) -> Dict[str, Any]:
|
||||
system_prompt, user_prompt = _render_prompt(
|
||||
"ge_result_desc",
|
||||
{"{{GE_RESULT_JSON}}": _json_dumps(profiling_json)},
|
||||
)
|
||||
llm_output = await _call_chat_completions(
|
||||
model_spec=llm_model,
|
||||
system_prompt=system_prompt,
|
||||
user_prompt=user_prompt,
|
||||
client=client,
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
if not isinstance(llm_output, dict):
|
||||
raise ProviderAPICallError("GE result description payload must be a JSON object.")
|
||||
return llm_output
|
||||
|
||||
|
||||
async def _execute_snippet_generation(
|
||||
table_desc_json: Dict[str, Any],
|
||||
_request: TableProfilingJobRequest,
|
||||
llm_model: str,
|
||||
client: httpx.AsyncClient,
|
||||
timeout_seconds: Optional[float],
|
||||
) -> List[Dict[str, Any]]:
|
||||
system_prompt, user_prompt = _render_prompt(
|
||||
"snippet_generator",
|
||||
{"{{TABLE_PROFILE_JSON}}": _json_dumps(table_desc_json)},
|
||||
)
|
||||
llm_output = await _call_chat_completions(
|
||||
model_spec=llm_model,
|
||||
system_prompt=system_prompt,
|
||||
user_prompt=user_prompt,
|
||||
client=client,
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
if not isinstance(llm_output, list):
|
||||
raise ProviderAPICallError("Snippet generator must return a JSON array.")
|
||||
return llm_output
|
||||
|
||||
|
||||
async def _execute_snippet_alias(
|
||||
snippets_json: List[Dict[str, Any]],
|
||||
_request: TableProfilingJobRequest,
|
||||
llm_model: str,
|
||||
client: httpx.AsyncClient,
|
||||
timeout_seconds: Optional[float],
|
||||
) -> List[Dict[str, Any]]:
|
||||
system_prompt, user_prompt = _render_prompt(
|
||||
"snippet_alias",
|
||||
{"{{SNIPPET_ARRAY}}": _json_dumps(snippets_json)},
|
||||
)
|
||||
llm_output = await _call_chat_completions(
|
||||
model_spec=llm_model,
|
||||
system_prompt=system_prompt,
|
||||
user_prompt=user_prompt,
|
||||
client=client,
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
if not isinstance(llm_output, list):
|
||||
raise ProviderAPICallError("Snippet alias generator must return a JSON array.")
|
||||
return llm_output
|
||||
|
||||
|
||||
async def _run_action_with_callback(
|
||||
*,
|
||||
action_type: str,
|
||||
runner,
|
||||
callback_base: Dict[str, Any],
|
||||
client: httpx.AsyncClient,
|
||||
callback_url: str,
|
||||
input_payload: Any = None,
|
||||
model_spec: Optional[str] = None,
|
||||
) -> Any:
|
||||
if input_payload is not None:
|
||||
logger.info(
|
||||
"Pipeline action %s input: %s",
|
||||
action_type,
|
||||
_preview_for_log(input_payload),
|
||||
)
|
||||
try:
|
||||
result = await runner()
|
||||
except Exception as exc:
|
||||
failure_payload = dict(callback_base)
|
||||
failure_payload.update(
|
||||
{
|
||||
"status": "failed",
|
||||
"action_type": action_type,
|
||||
"error": str(exc),
|
||||
}
|
||||
)
|
||||
if model_spec is not None:
|
||||
failure_payload["model"] = model_spec
|
||||
await _post_callback(callback_url, failure_payload, client)
|
||||
raise
|
||||
|
||||
success_payload = dict(callback_base)
|
||||
success_payload.update(
|
||||
{
|
||||
"status": "success",
|
||||
"action_type": action_type,
|
||||
}
|
||||
)
|
||||
if model_spec is not None:
|
||||
success_payload["model"] = model_spec
|
||||
|
||||
logger.info(
|
||||
"Pipeline action %s output: %s",
|
||||
action_type,
|
||||
_preview_for_log(result),
|
||||
)
|
||||
|
||||
if action_type == PipelineActionType.GE_PROFILING:
|
||||
artifacts: GEProfilingArtifacts = result
|
||||
success_payload["profiling_json"] = artifacts.profiling_result
|
||||
success_payload["profiling_summary"] = artifacts.profiling_summary
|
||||
success_payload["ge_report_path"] = artifacts.docs_path
|
||||
elif action_type == PipelineActionType.GE_RESULT_DESC:
|
||||
success_payload["table_desc_json"] = result
|
||||
elif action_type == PipelineActionType.SNIPPET:
|
||||
success_payload["snippet_json"] = result
|
||||
elif action_type == PipelineActionType.SNIPPET_ALIAS:
|
||||
success_payload["snippet_alias_json"] = result
|
||||
|
||||
await _post_callback(callback_url, success_payload, client)
|
||||
return result
|
||||
|
||||
|
||||
async def process_table_profiling_job(
|
||||
request: TableProfilingJobRequest,
|
||||
_gateway: LLMGateway,
|
||||
client: httpx.AsyncClient,
|
||||
) -> None:
|
||||
"""Sequentially execute the four-step profiling pipeline and emit callbacks per action."""
|
||||
|
||||
timeout_seconds = _extract_timeout_seconds(request.extra_options)
|
||||
if timeout_seconds is None:
|
||||
timeout_seconds = DEFAULT_CHAT_TIMEOUT_SECONDS
|
||||
|
||||
base_payload = {
|
||||
"table_id": request.table_id,
|
||||
"version_ts": request.version_ts,
|
||||
"callback_url": str(request.callback_url),
|
||||
"table_schema": request.table_schema,
|
||||
"table_schema_version_id": request.table_schema_version_id,
|
||||
"llm_model": request.llm_model,
|
||||
"llm_timeout_seconds": timeout_seconds,
|
||||
}
|
||||
|
||||
logging_request_payload = _profiling_request_for_log(request)
|
||||
|
||||
try:
|
||||
artifacts: GEProfilingArtifacts = await _run_action_with_callback(
|
||||
action_type=PipelineActionType.GE_PROFILING,
|
||||
runner=lambda: _run_ge_profiling(request),
|
||||
callback_base=base_payload,
|
||||
client=client,
|
||||
callback_url=str(request.callback_url),
|
||||
input_payload=logging_request_payload,
|
||||
model_spec=request.llm_model,
|
||||
)
|
||||
|
||||
table_desc_json: Dict[str, Any] = await _run_action_with_callback(
|
||||
action_type=PipelineActionType.GE_RESULT_DESC,
|
||||
runner=lambda: _execute_result_desc(
|
||||
artifacts.profiling_result,
|
||||
request,
|
||||
request.llm_model,
|
||||
client,
|
||||
timeout_seconds,
|
||||
),
|
||||
callback_base=base_payload,
|
||||
client=client,
|
||||
callback_url=str(request.callback_url),
|
||||
input_payload=artifacts.profiling_result,
|
||||
model_spec=request.llm_model,
|
||||
)
|
||||
|
||||
snippet_json: List[Dict[str, Any]] = await _run_action_with_callback(
|
||||
action_type=PipelineActionType.SNIPPET,
|
||||
runner=lambda: _execute_snippet_generation(
|
||||
table_desc_json,
|
||||
request,
|
||||
request.llm_model,
|
||||
client,
|
||||
timeout_seconds,
|
||||
),
|
||||
callback_base=base_payload,
|
||||
client=client,
|
||||
callback_url=str(request.callback_url),
|
||||
input_payload=table_desc_json,
|
||||
model_spec=request.llm_model,
|
||||
)
|
||||
|
||||
await _run_action_with_callback(
|
||||
action_type=PipelineActionType.SNIPPET_ALIAS,
|
||||
runner=lambda: _execute_snippet_alias(
|
||||
snippet_json,
|
||||
request,
|
||||
request.llm_model,
|
||||
client,
|
||||
timeout_seconds,
|
||||
),
|
||||
callback_base=base_payload,
|
||||
client=client,
|
||||
callback_url=str(request.callback_url),
|
||||
input_payload=snippet_json,
|
||||
model_spec=request.llm_model,
|
||||
)
|
||||
except Exception: # pragma: no cover - defensive catch
|
||||
logger.exception(
|
||||
"Table profiling pipeline failed for table_id=%s version_ts=%s",
|
||||
request.table_id,
|
||||
request.version_ts,
|
||||
)
|
||||
184
app/services/table_snippet.py
Normal file
184
app/services/table_snippet.py
Normal file
@ -0,0 +1,184 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from app.db import get_engine
|
||||
from app.models import (
|
||||
ActionType,
|
||||
TableSnippetUpsertRequest,
|
||||
TableSnippetUpsertResponse,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _serialize_json(value: Any) -> Tuple[str | None, int | None]:
|
||||
logger.debug("Serializing JSON payload: %s", value)
|
||||
if value is None:
|
||||
return None, None
|
||||
if isinstance(value, str):
|
||||
encoded = value.encode("utf-8")
|
||||
return value, len(encoded)
|
||||
serialized = json.dumps(value, ensure_ascii=False)
|
||||
encoded = serialized.encode("utf-8")
|
||||
return serialized, len(encoded)
|
||||
|
||||
|
||||
def _prepare_table_schema(value: Any) -> str:
|
||||
logger.debug("Preparing table_schema payload.")
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
|
||||
|
||||
def _collect_common_columns(request: TableSnippetUpsertRequest) -> Dict[str, Any]:
|
||||
logger.debug(
|
||||
"Collecting common columns for table_id=%s version_ts=%s action_type=%s",
|
||||
request.table_id,
|
||||
request.version_ts,
|
||||
request.action_type,
|
||||
)
|
||||
payload: Dict[str, Any] = {
|
||||
"table_id": request.table_id,
|
||||
"version_ts": request.version_ts,
|
||||
"action_type": request.action_type.value,
|
||||
"status": request.status.value,
|
||||
"callback_url": str(request.callback_url),
|
||||
"table_schema_version_id": request.table_schema_version_id,
|
||||
"table_schema": _prepare_table_schema(request.table_schema),
|
||||
}
|
||||
|
||||
if request.error_code is not None:
|
||||
logger.debug("Adding error_code: %s", request.error_code)
|
||||
payload["error_code"] = request.error_code
|
||||
if request.error_message is not None:
|
||||
logger.debug("Adding error_message: %s", request.error_message)
|
||||
payload["error_message"] = request.error_message
|
||||
if request.started_at is not None:
|
||||
payload["started_at"] = request.started_at
|
||||
if request.finished_at is not None:
|
||||
payload["finished_at"] = request.finished_at
|
||||
if request.duration_ms is not None:
|
||||
payload["duration_ms"] = request.duration_ms
|
||||
if request.result_checksum is not None:
|
||||
payload["result_checksum"] = request.result_checksum
|
||||
|
||||
logger.debug("Collected common payload: %s", payload)
|
||||
return payload
|
||||
|
||||
|
||||
def _apply_action_payload(
|
||||
request: TableSnippetUpsertRequest,
|
||||
payload: Dict[str, Any],
|
||||
) -> None:
|
||||
logger.debug("Applying action-specific payload for action_type=%s", request.action_type)
|
||||
if request.action_type == ActionType.GE_PROFILING:
|
||||
full_json, full_size = _serialize_json(request.result_json)
|
||||
summary_json, summary_size = _serialize_json(request.result_summary_json)
|
||||
if full_json is not None:
|
||||
payload["ge_profiling_full"] = full_json
|
||||
payload["ge_profiling_full_size_bytes"] = full_size
|
||||
if summary_json is not None:
|
||||
payload["ge_profiling_summary"] = summary_json
|
||||
payload["ge_profiling_summary_size_bytes"] = summary_size
|
||||
if full_size is not None or summary_size is not None:
|
||||
payload["ge_profiling_total_size_bytes"] = (full_size or 0) + (
|
||||
summary_size or 0
|
||||
)
|
||||
if request.html_report_url:
|
||||
payload["ge_profiling_html_report_url"] = request.html_report_url
|
||||
elif request.action_type == ActionType.GE_RESULT_DESC:
|
||||
full_json, full_size = _serialize_json(request.result_json)
|
||||
if full_json is not None:
|
||||
payload["ge_result_desc_full"] = full_json
|
||||
payload["ge_result_desc_full_size_bytes"] = full_size
|
||||
elif request.action_type == ActionType.SNIPPET:
|
||||
full_json, full_size = _serialize_json(request.result_json)
|
||||
if full_json is not None:
|
||||
payload["snippet_full"] = full_json
|
||||
payload["snippet_full_size_bytes"] = full_size
|
||||
elif request.action_type == ActionType.SNIPPET_ALIAS:
|
||||
full_json, full_size = _serialize_json(request.result_json)
|
||||
if full_json is not None:
|
||||
payload["snippet_alias_full"] = full_json
|
||||
payload["snippet_alias_full_size_bytes"] = full_size
|
||||
else:
|
||||
logger.error("Unsupported action type encountered: %s", request.action_type)
|
||||
raise ValueError(f"Unsupported action type '{request.action_type}'.")
|
||||
|
||||
logger.debug("Payload after applying action-specific data: %s", payload)
|
||||
|
||||
|
||||
def _build_insert_statement(columns: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
|
||||
logger.debug("Building insert statement for columns: %s", list(columns.keys()))
|
||||
column_names = list(columns.keys())
|
||||
placeholders = [f":{name}" for name in column_names]
|
||||
update_assignments = [
|
||||
f"{name}=VALUES({name})"
|
||||
for name in column_names
|
||||
if name not in {"table_id", "version_ts", "action_type"}
|
||||
]
|
||||
update_assignments.append("updated_at=CURRENT_TIMESTAMP")
|
||||
|
||||
sql = (
|
||||
"INSERT INTO action_results ({cols}) VALUES ({vals}) "
|
||||
"ON DUPLICATE KEY UPDATE {updates}"
|
||||
).format(
|
||||
cols=", ".join(column_names),
|
||||
vals=", ".join(placeholders),
|
||||
updates=", ".join(update_assignments),
|
||||
)
|
||||
logger.debug("Generated SQL: %s", sql)
|
||||
return sql, columns
|
||||
|
||||
|
||||
def _execute_upsert(engine: Engine, sql: str, params: Dict[str, Any]) -> int:
|
||||
logger.info("Executing upsert for table_id=%s version_ts=%s action_type=%s", params.get("table_id"), params.get("version_ts"), params.get("action_type"))
|
||||
with engine.begin() as conn:
|
||||
result = conn.execute(text(sql), params)
|
||||
logger.info("Rows affected: %s", result.rowcount)
|
||||
return result.rowcount
|
||||
|
||||
|
||||
def upsert_action_result(request: TableSnippetUpsertRequest) -> TableSnippetUpsertResponse:
|
||||
logger.info(
|
||||
"Received upsert request: table_id=%s version_ts=%s action_type=%s status=%s",
|
||||
request.table_id,
|
||||
request.version_ts,
|
||||
request.action_type,
|
||||
request.status,
|
||||
)
|
||||
logger.debug("Request payload: %s", request.model_dump())
|
||||
columns = _collect_common_columns(request)
|
||||
_apply_action_payload(request, columns)
|
||||
|
||||
sql, params = _build_insert_statement(columns)
|
||||
logger.debug("Final SQL params: %s", params)
|
||||
|
||||
engine = get_engine()
|
||||
try:
|
||||
rowcount = _execute_upsert(engine, sql, params)
|
||||
except SQLAlchemyError as exc:
|
||||
logger.exception(
|
||||
"Failed to upsert action result: table_id=%s version_ts=%s action_type=%s",
|
||||
request.table_id,
|
||||
request.version_ts,
|
||||
request.action_type,
|
||||
)
|
||||
raise RuntimeError(f"Database operation failed: {exc}") from exc
|
||||
|
||||
updated = rowcount > 1
|
||||
return TableSnippetUpsertResponse(
|
||||
table_id=request.table_id,
|
||||
version_ts=request.version_ts,
|
||||
action_type=request.action_type,
|
||||
status=request.status,
|
||||
updated=updated,
|
||||
)
|
||||
10001
file/dataset/ecommerce_orders_clean.csv
Normal file
10001
file/dataset/ecommerce_orders_clean.csv
Normal file
File diff suppressed because it is too large
Load Diff
2
ge_v1.py
2
ge_v1.py
@ -121,7 +121,7 @@ def clean_value(value: Any) -> Any:
|
||||
if isinstance(value, (np.generic,)):
|
||||
return value.item()
|
||||
if isinstance(value, pd.Timestamp):
|
||||
return value.isoformat()
|
||||
return str(value)
|
||||
if pd.isna(value):
|
||||
return None
|
||||
return value
|
||||
|
||||
30
logging.yaml
Normal file
30
logging.yaml
Normal file
@ -0,0 +1,30 @@
|
||||
version: 1
|
||||
formatters:
|
||||
standard:
|
||||
format: "%(asctime)s %(levelname)s %(name)s:%(lineno)d %(message)s"
|
||||
handlers:
|
||||
console:
|
||||
class: logging.StreamHandler
|
||||
level: INFO
|
||||
formatter: standard
|
||||
stream: ext://sys.stdout
|
||||
file:
|
||||
class: logging.handlers.RotatingFileHandler
|
||||
level: INFO
|
||||
formatter: standard
|
||||
filename: logs/app.log
|
||||
maxBytes: 10485760 # 10 MB
|
||||
backupCount: 5
|
||||
encoding: utf-8
|
||||
loggers:
|
||||
app:
|
||||
level: INFO
|
||||
handlers:
|
||||
- console
|
||||
- file
|
||||
propagate: no
|
||||
root:
|
||||
level: INFO
|
||||
handlers:
|
||||
- console
|
||||
- file
|
||||
47
prompt/ge_result_desc_prompt.md
Normal file
47
prompt/ge_result_desc_prompt.md
Normal file
@ -0,0 +1,47 @@
|
||||
系统角色(System)
|
||||
你是“数据画像抽取器”。输入是一段 Great Expectations 的 profiling/validation 结果 JSON,
|
||||
可能包含:列级期望(expect_*)、统计、样例值、类型推断等;也可能带表级/批次元数据。
|
||||
请将其归一化为一个可被程序消费的“表画像”JSON,对不确定项给出置信度与理由。
|
||||
禁止臆造不存在的列、时间范围或数值。
|
||||
|
||||
用户消息(User)
|
||||
【输入:GE结果JSON】
|
||||
{{GE_RESULT_JSON}}
|
||||
|
||||
【输出要求(只输出JSON,不要解释文字)】
|
||||
{
|
||||
"table": "<库.表 或 表名>",
|
||||
"row_count": <int|null>, // 若未知可为 null
|
||||
"role": "fact|dimension|unknown", // 依据指标/维度占比与唯一性启发式
|
||||
"grain": ["<列1>", "<列2>", ...], // 事实粒度猜测(如含 dt/店/类目)
|
||||
"time": { "column": "<name>|null", "granularity": "day|week|month|unknown", "range": ["YYYY-MM-DD","YYYY-MM-DD"]|null, "has_gaps": true|false|null },
|
||||
"columns": [
|
||||
{
|
||||
"name": "<col>",
|
||||
"dtype": "<ge推断/物理类型>",
|
||||
"semantic_type": "dimension|metric|time|text|id|unknown",
|
||||
"null_rate": <0~1|null>,
|
||||
"distinct_count": <int|null>,
|
||||
"distinct_ratio": <0~1|null>,
|
||||
"stats": { "min": <number|string|null>,"max": <number|string|null>,"mean": <number|null>,"std": <number|null>,"skewness": <number|null> },
|
||||
"enumish": true|false|null, // 低熵/可枚举
|
||||
"top_values": [{"value":"<v>","pct":<0~1>}, ...],// 取前K个(≤10)
|
||||
"pk_candidate_score": <0~1>, // 唯一性+非空综合评分
|
||||
"metric_candidate_score": <0~1>, // 数值/偏态/业务词命中
|
||||
"comment": "<列注释或GE描述|可为空>"
|
||||
}
|
||||
],
|
||||
"primary_key_candidates": [["colA","colB"], ...], // 依据 unique/compound unique 期望
|
||||
"fk_candidates": [{"from":"<col>","to":"<dim_table(col)>","confidence":<0~1>}],
|
||||
"quality": {
|
||||
"failed_expectations": [{"name":"<expect_*>","column":"<col|table>","summary":"<一句话>"}],
|
||||
"warning_hints": ["空值率>0.2的列: ...", "时间列存在缺口: ..."]
|
||||
},
|
||||
"confidence_notes": ["<为什么判定role/grain/time列>"]
|
||||
}
|
||||
|
||||
【判定规则(简要)】
|
||||
- time列:类型为日期/时间 OR 命中 dt/date/day 等命名;若有 min/max 可给出 range;若间隔缺口≥1天记 has_gaps=true。
|
||||
- semantic_type:数值+右偏/方差大→更偏 metric;高唯一/ID命名→id;高基数+文本→text;低熵+有限取值→dimension。
|
||||
- role:metric列占比高且存在time列→倾向 fact;几乎全是枚举/ID且少数值→dimension。
|
||||
- 置信不高时给出 null 或 unknown,并写入 confidence_notes。
|
||||
52
prompt/snippet_alias_generator.md
Normal file
52
prompt/snippet_alias_generator.md
Normal file
@ -0,0 +1,52 @@
|
||||
系统角色(System)
|
||||
你是“SQL片段别名生成器”。
|
||||
输入为一个或多个 SQL 片段对象(来自 snippet.json),输出为针对每个片段生成的多样化别名(口语 / 中性 / 专业)、关键词与意图标签。
|
||||
要求逐个处理所有片段对象,输出同样数量的 JSON 元素。
|
||||
|
||||
用户消息(User)
|
||||
【上下文】
|
||||
|
||||
SQL片段对象数组:{{SNIPPET_ARRAY}} // snippet.json中的一个或多个片段
|
||||
|
||||
【任务要求】
|
||||
请针对输入数组中的 每个 SQL 片段,输出一个 JSON 对象,结构如下:
|
||||
|
||||
{
|
||||
"id": "<与输入片段id一致>",
|
||||
"aliases": [
|
||||
{"text": "…", "tone": "口语|中性|专业"},
|
||||
{"text": "…", "tone": "专业"}
|
||||
],
|
||||
"keywords": [
|
||||
"GMV","销售额","TopN","category","类目","趋势","同比","客户","订单","质量","异常检测","join","过滤","sample"
|
||||
],
|
||||
"intent_tags": ["aggregate","trend","topn","ratio","quality","join","sample","filter","by_dimension"]
|
||||
}
|
||||
|
||||
生成逻辑规范
|
||||
1.逐条输出:输入数组中每个片段对应一个输出对象(id 保持一致)。
|
||||
|
||||
2.aliases生成
|
||||
至少 3 个别名,分别覆盖语气类型:口语 / 中性 / 专业。
|
||||
≤20字,语义需等价,不得添加不存在的字段或业务口径。
|
||||
示例:
|
||||
GMV趋势分析(中性)
|
||||
每天卖多少钱(口语)
|
||||
按日GMV曲线(专业)
|
||||
3.keywords生成
|
||||
8~15个关键词,需涵盖片段核心维度、指标、分析类型和语义近义词。
|
||||
中英文混合(如 "GMV"/"销售额"、"同比"/"YoY"、"类目"/"category" 等)。
|
||||
包含用于匹配的分析意图关键词(如 “趋势”、“排行”、“占比”、“质量检查”、“过滤” 等)。
|
||||
|
||||
4.intent_tags生成
|
||||
|
||||
从以下集合中选取,与片段type及用途一致:
|
||||
["aggregate","trend","topn","ratio","quality","join","sample","filter","by_dimension"]
|
||||
|
||||
若为条件片段(WHERE句型),补充 "filter";若含维度分组逻辑,补充 "by_dimension"。
|
||||
|
||||
5.语言与内容要求
|
||||
|
||||
保持正式书面风格,不添加解释说明。
|
||||
|
||||
只输出JSON数组,不包含文字描述或额外文本。
|
||||
46
prompt/snippet_generator.md
Normal file
46
prompt/snippet_generator.md
Normal file
@ -0,0 +1,46 @@
|
||||
系统角色(System)
|
||||
你是“SQL片段生成器”。只能基于给定“表画像”生成可复用的分析片段。
|
||||
为每个片段产出:标题、用途描述、片段类型、变量、适用条件、SQL模板(mysql方言),并注明业务口径与安全限制。
|
||||
不要发明画像里没有的列。时间/维度/指标须与画像匹配。
|
||||
|
||||
用户消息(User)
|
||||
【表画像JSON】
|
||||
{{TABLE_PROFILE_JSON}}
|
||||
|
||||
【输出要求(只输出JSON数组)】
|
||||
[
|
||||
{
|
||||
"id": "snpt_<slug>",
|
||||
"title": "中文标题(≤16字)",
|
||||
"desc": "一句话用途",
|
||||
"type": "aggregate|trend|topn|ratio|quality|join|sample",
|
||||
"applicability": {
|
||||
"required_columns": ["<col>", ...],
|
||||
"time_column": "<dt|nullable>",
|
||||
"constraints": {
|
||||
"dim_cardinality_hint": <int|null>, // 用于TopN限制与性能提示
|
||||
"fk_join_available": true|false,
|
||||
"notes": ["高基数维度建议LIMIT<=50", "..."]
|
||||
}
|
||||
},
|
||||
"variables": [
|
||||
{"name":"start_date","type":"date"},
|
||||
{"name":"end_date","type":"date"},
|
||||
{"name":"top_n","type":"int","default":10}
|
||||
],
|
||||
"dialect_sql": {
|
||||
"mysql": ""
|
||||
},
|
||||
"business_caliber": "清晰口径说明,如 UV以device_id去重;粒度=日-类目",
|
||||
"examples": ["示例问法1","示例问法2"]
|
||||
}
|
||||
]
|
||||
|
||||
【片段选择建议】
|
||||
- 若存在 time 列:生成 trend_by_day / yoy_qoq / moving_avg。
|
||||
- 若存在 enumish 维度(distinct 5~200):生成 topn_by_dimension / share_of_total。
|
||||
- 若 metric 列:生成 sum/avg/max、分位数/异常检测(3σ/箱线)。
|
||||
- 有主键/唯一:生成 去重/明细抽样/质量检查。
|
||||
- 有 fk_candidates:同时生成“join维表命名版”和“纯ID版”。
|
||||
- 高枚举维度:在 constraints.notes 中强调 LIMIT 建议与可能的性能风险。
|
||||
- 除了完整的sql片段,还有sql里部分内容的sql片段,比如 where payment_method = 'Credit Card' and delivery_status = 'Deliverd' 的含义是支付方式为信用卡且配送状态是已送达
|
||||
@ -9,3 +9,5 @@ numpy>=1.24
|
||||
openpyxl>=3.1
|
||||
httpx==0.27.2
|
||||
python-dotenv==1.0.1
|
||||
requests>=2.31.0
|
||||
PyYAML>=6.0.1
|
||||
|
||||
226
scripts/huggingface_download.py
Normal file
226
scripts/huggingface_download.py
Normal file
@ -0,0 +1,226 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
|
||||
import datasets
|
||||
from datasets import DownloadConfig
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# 批量下载 Hugging Face 上的数据集和模型
|
||||
# 支持通过命令行参数配置代理和下载参数,如超时和重试次数,支持批量循环下载,存储到file目录下dataset和model子目录
|
||||
|
||||
|
||||
def _parse_id_list(values: Iterable[str]) -> List[str]:
|
||||
"""将多次传入以及逗号分隔的标识整理为列表."""
|
||||
ids: List[str] = []
|
||||
for value in values:
|
||||
value = value.strip()
|
||||
if not value:
|
||||
continue
|
||||
if "," in value:
|
||||
ids.extend(v.strip() for v in value.split(",") if v.strip())
|
||||
else:
|
||||
ids.append(value)
|
||||
return ids
|
||||
|
||||
|
||||
def _parse_proxy_args(proxy_args: Iterable[str]) -> Dict[str, str]:
|
||||
"""解析命令行传入的代理设置,格式 scheme=url."""
|
||||
proxies: Dict[str, str] = {}
|
||||
for item in proxy_args:
|
||||
raw = item.strip()
|
||||
if not raw:
|
||||
continue
|
||||
if "=" not in raw:
|
||||
logging.warning("代理参数 %s 缺少 '=' 分隔符,将忽略该项", raw)
|
||||
continue
|
||||
key, value = raw.split("=", 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
if not key or not value:
|
||||
logging.warning("代理参数 %s 解析失败,将忽略该项", raw)
|
||||
continue
|
||||
proxies[key] = value
|
||||
return proxies
|
||||
|
||||
|
||||
def _sanitize_dir_name(name: str) -> str:
|
||||
return name.replace("/", "__")
|
||||
|
||||
|
||||
def _ensure_dirs(root_dir: str) -> Dict[str, str]:
|
||||
paths = {
|
||||
"dataset": os.path.join(root_dir, "dataset"),
|
||||
"model": os.path.join(root_dir, "model"),
|
||||
}
|
||||
for path in paths.values():
|
||||
os.makedirs(path, exist_ok=True)
|
||||
return paths
|
||||
|
||||
|
||||
def _build_download_config(cache_dir: str, retries: Optional[int], proxies: Dict[str, str]) -> DownloadConfig:
|
||||
config_kwargs = {"cache_dir": cache_dir}
|
||||
if retries is not None:
|
||||
config_kwargs["max_retries"] = retries
|
||||
if proxies:
|
||||
config_kwargs["proxies"] = proxies
|
||||
return DownloadConfig(**config_kwargs)
|
||||
|
||||
|
||||
def _apply_timeout(timeout: Optional[float]) -> None:
|
||||
if timeout is None:
|
||||
return
|
||||
str_timeout = str(timeout)
|
||||
os.environ.setdefault("HF_DATASETS_HTTP_TIMEOUT", str_timeout)
|
||||
os.environ.setdefault("HF_HUB_HTTP_TIMEOUT", str_timeout)
|
||||
|
||||
|
||||
def _resolve_log_level(level_name: str) -> int:
|
||||
if isinstance(level_name, int):
|
||||
return level_name
|
||||
upper_name = str(level_name).upper()
|
||||
return getattr(logging, upper_name, logging.INFO)
|
||||
|
||||
|
||||
def _build_argument_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="批量下载 Hugging Face 数据集和模型并存储到指定目录。"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--dataset",
|
||||
action="append",
|
||||
default=[],
|
||||
help="要下载的数据集 ID,可重复使用或传入逗号分隔列表。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--model",
|
||||
action="append",
|
||||
default=[],
|
||||
help="要下载的模型 ID,可重复使用或传入逗号分隔列表。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--root",
|
||||
default="file",
|
||||
help="存储根目录,默认 file。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--retries",
|
||||
type=int,
|
||||
default=None,
|
||||
help="失败后的重试次数,默认不重试。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=float,
|
||||
default=None,
|
||||
help="HTTP 超时时间(秒),默认跟随库设置。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--proxy",
|
||||
action="append",
|
||||
default=[],
|
||||
help="代理设置,格式 scheme=url,可多次传入,例如 --proxy http=http://127.0.0.1:7890",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
default="INFO",
|
||||
help="日志级别,默认 INFO。",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def download_datasets(dataset_ids: Iterable[str], root_dir: str, retries: Optional[int], proxies: Dict[str, str]) -> None:
|
||||
if not dataset_ids:
|
||||
return
|
||||
cache_dir = root_dir
|
||||
download_config = _build_download_config(cache_dir, retries, proxies)
|
||||
for dataset_id in dataset_ids:
|
||||
try:
|
||||
logging.info("开始下载数据集 %s", dataset_id)
|
||||
# 使用 load_dataset 触发缓存下载
|
||||
dataset = datasets.load_dataset(
|
||||
dataset_id,
|
||||
cache_dir=cache_dir,
|
||||
download_config=download_config,
|
||||
download_mode="reuse_cache_if_exists",
|
||||
)
|
||||
target_path = os.path.join(root_dir, _sanitize_dir_name(dataset_id))
|
||||
dataset.save_to_disk(target_path)
|
||||
logging.info("数据集 %s 下载完成,存储于 %s", dataset_id, target_path)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
logging.error("下载数据集 %s 失败: %s", dataset_id, exc)
|
||||
|
||||
|
||||
def download_models(
|
||||
model_ids: Iterable[str],
|
||||
target_dir: str,
|
||||
retries: Optional[int],
|
||||
proxies: Dict[str, str],
|
||||
timeout: Optional[float],
|
||||
) -> None:
|
||||
if not model_ids:
|
||||
return
|
||||
max_attempts = (retries or 0) + 1
|
||||
hub_kwargs = {
|
||||
"local_dir": target_dir,
|
||||
"local_dir_use_symlinks": False,
|
||||
"max_workers": os.cpu_count() or 4,
|
||||
}
|
||||
if proxies:
|
||||
hub_kwargs["proxies"] = proxies
|
||||
if timeout is not None:
|
||||
hub_kwargs["timeout"] = timeout
|
||||
for model_id in model_ids:
|
||||
attempt = 0
|
||||
while attempt < max_attempts:
|
||||
attempt += 1
|
||||
try:
|
||||
logging.info("开始下载模型 %s (尝试 %s/%s)", model_id, attempt, max_attempts)
|
||||
snapshot_download(
|
||||
repo_id=model_id,
|
||||
**hub_kwargs,
|
||||
)
|
||||
logging.info("模型 %s 下载完成,存储于 %s", model_id, target_dir)
|
||||
break
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
logging.error("下载模型 %s 失败: %s", model_id, exc)
|
||||
if attempt >= max_attempts:
|
||||
logging.error("模型 %s 在重试后仍未成功下载", model_id)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = _build_argument_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
level=_resolve_log_level(args.log_level),
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
dataset_ids = _parse_id_list(args.dataset)
|
||||
model_ids = _parse_id_list(args.model)
|
||||
retries = args.retries
|
||||
timeout = args.timeout
|
||||
proxies = _parse_proxy_args(args.proxy)
|
||||
_apply_timeout(timeout)
|
||||
|
||||
if not dataset_ids and not model_ids:
|
||||
logging.warning(
|
||||
"未配置任何数据集或模型,"
|
||||
"请通过参数 --dataset / --model 指定 Hugging Face ID"
|
||||
)
|
||||
return
|
||||
|
||||
dirs = _ensure_dirs(args.root)
|
||||
|
||||
download_datasets(dataset_ids, dirs["dataset"], retries, proxies)
|
||||
download_models(model_ids, dirs["model"], retries, proxies, timeout)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
80
scripts/table_snippet_demo.py
Normal file
80
scripts/table_snippet_demo.py
Normal file
@ -0,0 +1,80 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def build_demo_payload() -> Dict[str, Any]:
|
||||
now = datetime.utcnow()
|
||||
started_at = now.replace(microsecond=0).isoformat() + "Z"
|
||||
finished_at = now.replace(microsecond=0).isoformat() + "Z"
|
||||
return {
|
||||
"table_id": 42,
|
||||
"version_ts": 20251101200000,
|
||||
"action_type": "snippet",
|
||||
"status": "success",
|
||||
"callback_url": "http://localhost:9999/dummy-callback",
|
||||
"table_schema_version_id": 7,
|
||||
"table_schema": {
|
||||
"columns": [
|
||||
{"name": "order_id", "type": "bigint"},
|
||||
{"name": "order_dt", "type": "date"},
|
||||
{"name": "gmv", "type": "decimal(18,2)"},
|
||||
]
|
||||
},
|
||||
"result_json": [
|
||||
{
|
||||
"id": "snpt_daily_gmv",
|
||||
"title": "按日GMV",
|
||||
"desc": "统计每日GMV总额",
|
||||
"type": "trend",
|
||||
"dialect_sql": {
|
||||
"mysql": "SELECT order_dt, SUM(gmv) AS total_gmv FROM orders GROUP BY order_dt ORDER BY order_dt"
|
||||
},
|
||||
}
|
||||
],
|
||||
"result_summary_json": {"total_snippets": 1},
|
||||
"html_report_url": None,
|
||||
"error_code": None,
|
||||
"error_message": None,
|
||||
"started_at": started_at,
|
||||
"finished_at": finished_at,
|
||||
"duration_ms": 1234,
|
||||
"result_checksum": "demo-checksum",
|
||||
}
|
||||
|
||||
|
||||
def main() -> int:
|
||||
base_url = os.getenv("TABLE_SNIPPET_DEMO_BASE_URL", "http://localhost:8000")
|
||||
endpoint = f"{base_url.rstrip('/')}/v1/table/snippet"
|
||||
payload = build_demo_payload()
|
||||
|
||||
print(f"POST {endpoint}")
|
||||
print(json.dumps(payload, ensure_ascii=False, indent=2))
|
||||
|
||||
try:
|
||||
response = requests.post(endpoint, json=payload, timeout=30)
|
||||
except requests.RequestException as exc:
|
||||
print(f"Request failed: {exc}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
print(f"\nStatus: {response.status_code}")
|
||||
|
||||
try:
|
||||
data = response.json()
|
||||
print("Response JSON:")
|
||||
print(json.dumps(data, ensure_ascii=False, indent=2))
|
||||
except ValueError:
|
||||
print("Response Text:")
|
||||
print(response.text)
|
||||
|
||||
return 0 if response.ok else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
54
table_snippet.sql
Normal file
54
table_snippet.sql
Normal file
@ -0,0 +1,54 @@
|
||||
CREATE TABLE IF NOT EXISTS action_results (
|
||||
id BIGINT NOT NULL AUTO_INCREMENT COMMENT '主键',
|
||||
table_id BIGINT NOT NULL COMMENT '表ID',
|
||||
version_ts BIGINT NOT NULL COMMENT '版本时间戳(版本号)',
|
||||
action_type ENUM('ge_profiling','ge_result_desc','snippet','snippet_alias') NOT NULL COMMENT '动作类型',
|
||||
|
||||
status ENUM('pending','running','success','failed','partial') NOT NULL DEFAULT 'pending' COMMENT '执行状态',
|
||||
error_code VARCHAR(128) NULL,
|
||||
error_message TEXT NULL,
|
||||
|
||||
-- 回调 & 观测
|
||||
callback_url VARCHAR(1024) NOT NULL,
|
||||
started_at DATETIME NULL,
|
||||
finished_at DATETIME NULL,
|
||||
duration_ms INT NULL,
|
||||
|
||||
-- 本次schema信息
|
||||
table_schema_version_id BIGINT NOT NULL,
|
||||
table_schema JSON NOT NULL,
|
||||
|
||||
-- ===== 动作1:GE Profiling =====
|
||||
ge_profiling_full JSON NULL COMMENT 'Profiling完整结果JSON',
|
||||
ge_profiling_full_size_bytes BIGINT NULL,
|
||||
ge_profiling_summary JSON NULL COMMENT 'Profiling摘要(剔除大value_set等)',
|
||||
ge_profiling_summary_size_bytes BIGINT NULL,
|
||||
ge_profiling_total_size_bytes BIGINT NULL COMMENT '上两者合计',
|
||||
ge_profiling_html_report_url VARCHAR(1024) NULL COMMENT 'GE报告HTML路径/URL',
|
||||
|
||||
-- ===== 动作2:GE Result Desc =====
|
||||
ge_result_desc_full JSON NULL COMMENT '表描述结果JSON',
|
||||
ge_result_desc_full_size_bytes BIGINT NULL,
|
||||
|
||||
-- ===== 动作3:Snippet 生成 =====
|
||||
snippet_full JSON NULL COMMENT 'SQL知识片段结果JSON',
|
||||
snippet_full_size_bytes BIGINT NULL,
|
||||
|
||||
-- ===== 动作4:Snippet Alias 改写 =====
|
||||
snippet_alias_full JSON NULL COMMENT 'SQL片段改写/丰富结果JSON',
|
||||
snippet_alias_full_size_bytes BIGINT NULL,
|
||||
|
||||
-- 通用可选指标
|
||||
result_checksum VARBINARY(32) NULL COMMENT '对当前action有效载荷计算的MD5/xxhash',
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
|
||||
PRIMARY KEY (id),
|
||||
UNIQUE KEY uq_table_ver_action (table_id, version_ts, action_type),
|
||||
KEY idx_status (status),
|
||||
KEY idx_table (table_id, updated_at),
|
||||
KEY idx_action_time (action_type, version_ts),
|
||||
KEY idx_schema_version (table_schema_version_id)
|
||||
) ENGINE=InnoDB
|
||||
ROW_FORMAT=DYNAMIC
|
||||
COMMENT='数据分析知识片段表';
|
||||
Reference in New Issue
Block a user