Compare commits

..

14 Commits

Author SHA1 Message Date
7eb3c059a1 数据知识回调入库 2025-11-04 20:28:50 +08:00
0b765e6719 Merge branch 'main' of ssh://git.agentcarrier.cn:2222/zhaoqingliang/data-ge
no conficts
2025-11-03 00:29:42 +08:00
cf2e6aedc7 忽略gx和logs 2025-11-03 00:20:36 +08:00
9a7710a70a file和demo 2025-11-03 00:20:00 +08:00
799b9f8154 增加日志 2025-11-03 00:19:43 +08:00
fe1de87696 部分参数调整 2025-11-03 00:19:23 +08:00
c2a08e4134 table profiling功能开发 2025-11-03 00:18:26 +08:00
557efc4bf1 添加数据知识生成链路的三个prompt 2025-10-31 15:55:07 +08:00
72f9735000 更新 Dockerfile 2025-10-31 09:44:09 +08:00
ae567a996a 更新 docker-compose.yml 2025-10-31 09:43:56 +08:00
d27a003bb0 添加 docker-compose.yml 2025-10-31 09:29:36 +08:00
a199129ada 添加 Dockerfile 2025-10-31 09:28:42 +08:00
d17d850d67 添加导入分析接口内容 2025-10-30 23:25:55 +08:00
4ff3a1f081 安装启动数据分析治理服务指引 2025-10-30 23:19:41 +08:00
23 changed files with 12122 additions and 35 deletions

2
.env
View File

@ -17,7 +17,7 @@ DEFAULT_IMPORT_MODEL=deepseek:deepseek-chat
IMPORT_GATEWAY_BASE_URL=http://localhost:8000 IMPORT_GATEWAY_BASE_URL=http://localhost:8000
# HTTP client configuration # HTTP client configuration
HTTP_CLIENT_TIMEOUT=30 HTTP_CLIENT_TIMEOUT=60
HTTP_CLIENT_TRUST_ENV=false HTTP_CLIENT_TRUST_ENV=false
# HTTP_CLIENT_PROXY= # HTTP_CLIENT_PROXY=

2
.gitignore vendored
View File

@ -4,3 +4,5 @@ gx/uncommitted/
**/__pycache__/ **/__pycache__/
*.pyc *.pyc
.DS_Store .DS_Store
gx/
logs/

17
Dockerfile Normal file
View File

@ -0,0 +1,17 @@
FROM python:3.11-slim
# 设置 pip 全局使用国内源
ENV PIP_INDEX_URL=https://pypi.tuna.tsinghua.edu.cn/simple/
ENV PIP_TRUSTED_HOST=pypi.tuna.tsinghua.edu.cn
WORKDIR /app
COPY requirements.txt .
RUN pip install --no-cache-dir --upgrade pip && \
pip install --no-cache-dir -r requirements.txt
COPY . .
EXPOSE 8000
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]

View File

@ -77,6 +77,9 @@ nohup uvicorn app.main:app --host 0.0.0.0 --port 8000 > server.log 2>&1 &
Or use a process manager such as `pm2`, `supervisor`, or systemd for production deployments. Or use a process manager such as `pm2`, `supervisor`, or systemd for production deployments.
## API List
1. 导入分析schema接口 http://localhost:8000/v1/import/analyze
## Additional Commands ## Additional Commands
- Run the data import analysis example: `python test/data_import_analysis_example.py` - Run the data import analysis example: `python test/data_import_analysis_example.py`

26
app/db.py Normal file
View File

@ -0,0 +1,26 @@
from __future__ import annotations
import os
from functools import lru_cache
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
@lru_cache(maxsize=1)
def get_engine() -> Engine:
"""Return a cached SQLAlchemy engine configured from DATABASE_URL."""
database_url = os.getenv(
"DATABASE_URL",
"mysql+pymysql://root:12345678@localhost:3306/data-ge?charset=utf8mb4",
)
connect_args = {}
if database_url.startswith("sqlite"):
connect_args["check_same_thread"] = False
return create_engine(
database_url,
pool_pre_ping=True,
future=True,
connect_args=connect_args,
)

View File

@ -2,12 +2,17 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import logging.config
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Any from typing import Any
import yaml
import httpx import httpx
from fastapi import Depends, FastAPI, HTTPException, Request from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from app.exceptions import ProviderAPICallError, ProviderConfigurationError from app.exceptions import ProviderAPICallError, ProviderConfigurationError
from app.models import ( from app.models import (
@ -15,30 +20,42 @@ from app.models import (
DataImportAnalysisJobRequest, DataImportAnalysisJobRequest,
LLMRequest, LLMRequest,
LLMResponse, LLMResponse,
TableProfilingJobAck,
TableProfilingJobRequest,
TableSnippetUpsertRequest,
TableSnippetUpsertResponse,
) )
from app.services import LLMGateway from app.services import LLMGateway
from app.services.import_analysis import process_import_analysis_job from app.services.import_analysis import process_import_analysis_job
from app.services.table_profiling import process_table_profiling_job
from app.services.table_snippet import upsert_action_result
def _ensure_log_directories(config: dict[str, Any]) -> None:
handlers = config.get("handlers", {})
for handler_config in handlers.values():
filename = handler_config.get("filename")
if not filename:
continue
directory = os.path.dirname(filename)
if directory and not os.path.exists(directory):
os.makedirs(directory, exist_ok=True)
def _configure_logging() -> None: def _configure_logging() -> None:
level_name = os.getenv("LOG_LEVEL", "INFO").upper() config_path = os.getenv("LOGGING_CONFIG", "logging.yaml")
level = getattr(logging, level_name, logging.INFO) if os.path.exists(config_path):
log_format = os.getenv( with open(config_path, "r", encoding="utf-8") as fh:
"LOG_FORMAT", config = yaml.safe_load(fh)
"%(asctime)s %(levelname)s %(name)s:%(lineno)d %(message)s", if isinstance(config, dict):
_ensure_log_directories(config)
logging.config.dictConfig(config)
return
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s:%(lineno)d %(message)s",
) )
root = logging.getLogger()
if not root.handlers:
logging.basicConfig(level=level, format=log_format)
else:
root.setLevel(level)
formatter = logging.Formatter(log_format)
for handler in root.handlers:
handler.setLevel(level)
handler.setFormatter(formatter)
_configure_logging() _configure_logging()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -119,6 +136,24 @@ def create_app() -> FastAPI:
lifespan=lifespan, lifespan=lifespan,
) )
@application.exception_handler(RequestValidationError)
async def request_validation_exception_handler(
request: Request, exc: RequestValidationError
) -> JSONResponse:
try:
raw_body = await request.body()
except Exception: # pragma: no cover - defensive
raw_body = b"<unavailable>"
truncated_body = raw_body[:4096]
logger.warning(
"Validation error on %s %s: %s | body preview=%s",
request.method,
request.url.path,
exc.errors(),
truncated_body.decode("utf-8", errors="ignore"),
)
return JSONResponse(status_code=422, content={"detail": exc.errors()})
@application.post( @application.post(
"/v1/chat/completions", "/v1/chat/completions",
response_model=LLMResponse, response_model=LLMResponse,
@ -164,6 +199,52 @@ def create_app() -> FastAPI:
return DataImportAnalysisJobAck(import_record_id=payload.import_record_id, status="accepted") return DataImportAnalysisJobAck(import_record_id=payload.import_record_id, status="accepted")
@application.post(
"/v1/table/profiling",
response_model=TableProfilingJobAck,
summary="Run end-to-end GE profiling pipeline and notify via callback per action",
status_code=202,
)
async def run_table_profiling(
payload: TableProfilingJobRequest,
gateway: LLMGateway = Depends(get_gateway),
client: httpx.AsyncClient = Depends(get_http_client),
) -> TableProfilingJobAck:
request_copy = payload.model_copy(deep=True)
async def _runner() -> None:
await process_table_profiling_job(request_copy, gateway, client)
asyncio.create_task(_runner())
return TableProfilingJobAck(
table_id=payload.table_id,
version_ts=payload.version_ts,
status="accepted",
)
@application.post(
"/v1/table/snippet",
response_model=TableSnippetUpsertResponse,
summary="Persist or update action results, such as table snippets.",
)
async def upsert_table_snippet(
payload: TableSnippetUpsertRequest,
) -> TableSnippetUpsertResponse:
request_copy = payload.model_copy(deep=True)
try:
return await asyncio.to_thread(upsert_action_result, request_copy)
except Exception as exc:
logger.error(
"Failed to upsert table snippet: table_id=%s version_ts=%s action_type=%s",
payload.table_id,
payload.version_ts,
payload.action_type,
exc_info=True,
)
raise HTTPException(status_code=500, detail=str(exc)) from exc
@application.post("/__mock__/import-callback") @application.post("/__mock__/import-callback")
async def mock_import_callback(payload: dict[str, Any]) -> dict[str, str]: async def mock_import_callback(payload: dict[str, Any]) -> dict[str, str]:
logger.info("Received import analysis callback: %s", payload) logger.info("Received import analysis callback: %s", payload)

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
@ -76,8 +77,8 @@ class DataImportAnalysisRequest(BaseModel):
description="Ordered list of table headers associated with the data.", description="Ordered list of table headers associated with the data.",
) )
llm_model: str = Field( llm_model: str = Field(
..., None,
description="Model identifier. Accepts 'provider:model' format or plain model name.", description="Model identifier. Accepts 'provider:model_name' format or custom model alias.",
) )
temperature: Optional[float] = Field( temperature: Optional[float] = Field(
None, None,
@ -135,3 +136,186 @@ class DataImportAnalysisJobRequest(BaseModel):
class DataImportAnalysisJobAck(BaseModel): class DataImportAnalysisJobAck(BaseModel):
import_record_id: str = Field(..., description="Echo of the import record identifier") import_record_id: str = Field(..., description="Echo of the import record identifier")
status: str = Field("accepted", description="Processing status acknowledgement.") status: str = Field("accepted", description="Processing status acknowledgement.")
class ActionType(str, Enum):
GE_PROFILING = "ge_profiling"
GE_RESULT_DESC = "ge_result_desc"
SNIPPET = "snippet"
SNIPPET_ALIAS = "snippet_alias"
class ActionStatus(str, Enum):
PENDING = "pending"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
PARTIAL = "partial"
class TableProfilingJobRequest(BaseModel):
table_id: str = Field(..., description="Unique identifier for the table to profile.")
version_ts: str = Field(
...,
pattern=r"^\d{14}$",
description="Version timestamp expressed as fourteen digit string (yyyyMMddHHmmss).",
)
callback_url: HttpUrl = Field(
...,
description="Callback endpoint invoked after each pipeline action completes.",
)
llm_model: Optional[str] = Field(
None,
description="Default LLM model spec applied to prompt-based actions when overrides are omitted.",
)
table_schema: Optional[Any] = Field(
None,
description="Schema structure snapshot for the current table version.",
)
table_schema_version_id: Optional[str] = Field(
None,
description="Identifier for the schema snapshot provided in table_schema.",
)
table_link_info: Optional[Dict[str, Any]] = Field(
None,
description=(
"Information describing how to locate the source table for profiling. "
"For example: {'type': 'sql', 'connection_string': 'mysql+pymysql://user:pass@host/db', "
"'table': 'schema.table_name'}."
),
)
table_access_info: Optional[Dict[str, Any]] = Field(
None,
description=(
"Credentials or supplemental parameters required to access the table described in table_link_info. "
"These values can be merged into the connection string using Python format placeholders."
),
)
ge_batch_request: Optional[Dict[str, Any]] = Field(
None,
description="Optional Great Expectations batch request payload used for profiling.",
)
ge_expectation_suite_name: Optional[str] = Field(
None,
description="Expectation suite name used during profiling. Created automatically when absent.",
)
ge_data_context_root: Optional[str] = Field(
None,
description="Custom root directory for the Great Expectations data context. Defaults to project ./gx.",
)
ge_datasource_name: Optional[str] = Field(
None,
description="Datasource name registered inside the GE context when batch_request is not supplied.",
)
ge_data_asset_name: Optional[str] = Field(
None,
description="Data asset reference used when inferring batch request from datasource configuration.",
)
ge_profiler_type: str = Field(
"user_configurable",
description="Profiler implementation identifier. Currently supports 'user_configurable' or 'data_assistant'.",
)
result_desc_model: Optional[str] = Field(
None,
description="LLM model override used for GE result description (action 2).",
)
snippet_model: Optional[str] = Field(
None,
description="LLM model override used for snippet generation (action 3).",
)
snippet_alias_model: Optional[str] = Field(
None,
description="LLM model override used for snippet alias enrichment (action 4).",
)
extra_options: Optional[Dict[str, Any]] = Field(
None,
description="Miscellaneous execution flags applied across pipeline steps.",
)
class TableProfilingJobAck(BaseModel):
table_id: str = Field(..., description="Echo of the table identifier.")
version_ts: str = Field(..., description="Echo of the profiling version timestamp (yyyyMMddHHmmss).")
status: str = Field("accepted", description="Processing acknowledgement status.")
class TableSnippetUpsertRequest(BaseModel):
table_id: int = Field(..., ge=1, description="Unique identifier for the table.")
version_ts: int = Field(
...,
ge=0,
description="Version timestamp aligned with the pipeline (yyyyMMddHHmmss as integer).",
)
action_type: ActionType = Field(..., description="Pipeline action type for this record.")
status: ActionStatus = Field(
ActionStatus.SUCCESS, description="Execution status for the action."
)
callback_url: HttpUrl = Field(..., description="Callback URL associated with the action run.")
table_schema_version_id: int = Field(..., ge=0, description="Identifier for the schema snapshot.")
table_schema: Any = Field(..., description="Schema snapshot payload for the table.")
llm_usage: Optional[Any] = Field(
None,
description="Optional token usage metrics reported by the LLM provider.",
)
ge_profiling_json: Optional[Any] = Field(
None, description="Full GE profiling result payload for the profiling action."
)
ge_profiling_json_size_bytes: Optional[int] = Field(
None, ge=0, description="Size in bytes of the GE profiling result JSON."
)
ge_profiling_summary: Optional[Any] = Field(
None, description="Sanitised GE profiling summary payload."
)
ge_profiling_summary_size_bytes: Optional[int] = Field(
None, ge=0, description="Size in bytes of the GE profiling summary JSON."
)
ge_profiling_total_size_bytes: Optional[int] = Field(
None, ge=0, description="Combined size (bytes) of profiling result + summary."
)
ge_profiling_html_report_url: Optional[str] = Field(
None, description="Optional URL to the generated GE profiling HTML report."
)
ge_result_desc_json: Optional[Any] = Field(
None, description="Result JSON for the GE result description action."
)
ge_result_desc_json_size_bytes: Optional[int] = Field(
None, ge=0, description="Size in bytes of the GE result description JSON."
)
snippet_json: Optional[Any] = Field(
None, description="Snippet generation action result JSON."
)
snippet_json_size_bytes: Optional[int] = Field(
None, ge=0, description="Size in bytes of the snippet result JSON."
)
snippet_alias_json: Optional[Any] = Field(
None, description="Snippet alias expansion result JSON."
)
snippet_alias_json_size_bytes: Optional[int] = Field(
None, ge=0, description="Size in bytes of the snippet alias result JSON."
)
error_code: Optional[str] = Field(None, description="Optional error code when status indicates a failure.")
error_message: Optional[str] = Field(None, description="Optional error message when status indicates a failure.")
started_at: Optional[datetime] = Field(
None, description="Timestamp when the action started executing."
)
finished_at: Optional[datetime] = Field(
None, description="Timestamp when the action finished executing."
)
duration_ms: Optional[int] = Field(
None,
ge=0,
description="Optional execution duration in milliseconds.",
)
result_checksum: Optional[str] = Field(
None,
description="Optional checksum for the result payload (e.g., MD5).",
)
class TableSnippetUpsertResponse(BaseModel):
table_id: int
version_ts: int
action_type: ActionType
status: ActionStatus
updated: bool

View File

@ -23,6 +23,7 @@ from app.models import (
LLMRole, LLMRole,
) )
from app.settings import DEFAULT_IMPORT_MODEL, get_supported_import_models from app.settings import DEFAULT_IMPORT_MODEL, get_supported_import_models
from app.utils.llm_usage import extract_usage
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -42,7 +43,7 @@ def _env_float(name: str, default: float) -> float:
return default return default
IMPORT_CHAT_TIMEOUT_SECONDS = _env_float("IMPORT_CHAT_TIMEOUT_SECONDS", 90.0) IMPORT_CHAT_TIMEOUT_SECONDS = _env_float("IMPORT_CHAT_TIMEOUT_SECONDS", 120.0)
SUPPORTED_IMPORT_MODELS = get_supported_import_models() SUPPORTED_IMPORT_MODELS = get_supported_import_models()
@ -298,7 +299,7 @@ def parse_llm_analysis_json(llm_response: LLMResponse) -> Dict[str, Any]:
try: try:
return json.loads(json_payload) return json.loads(json_payload)
except json.JSONDecodeError as exc: except json.JSONDecodeError as exc:
preview = json_payload[:2000] preview = json_payload[:10000]
logger.error("Failed to parse JSON from LLM response content: %s", preview, exc_info=True) logger.error("Failed to parse JSON from LLM response content: %s", preview, exc_info=True)
raise ProviderAPICallError("LLM response JSON could not be parsed.") from exc raise ProviderAPICallError("LLM response JSON could not be parsed.") from exc
@ -375,18 +376,6 @@ async def dispatch_import_analysis_job(
return result return result
# 兼容处理多模型的使用量字段提取
def extract_usage(resp_json: dict) -> dict:
usage = resp_json.get("usage") or resp_json.get("usageMetadata") or {}
return {
"prompt_tokens": usage.get("prompt_tokens") or usage.get("input_tokens") or usage.get("promptTokenCount"),
"completion_tokens": usage.get("completion_tokens") or usage.get("output_tokens") or usage.get("candidatesTokenCount"),
"total_tokens": usage.get("total_tokens") or usage.get("totalTokenCount") or (
(usage.get("prompt_tokens") or usage.get("input_tokens") or 0)
+ (usage.get("completion_tokens") or usage.get("output_tokens") or 0)
)
}
async def notify_import_analysis_callback( async def notify_import_analysis_callback(
callback_url: str, callback_url: str,
payload: Dict[str, Any], payload: Dict[str, Any],

View File

@ -0,0 +1,855 @@
from __future__ import annotations
import asyncio
import json
import logging
import os
import re
from datetime import date, datetime
from dataclasses import asdict, dataclass, is_dataclass
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import httpx
import great_expectations as gx
from great_expectations.core.batch import RuntimeBatchRequest
from great_expectations.core.expectation_suite import ExpectationSuite
from great_expectations.data_context import AbstractDataContext
from great_expectations.exceptions import DataContextError, MetricResolutionError
from app.exceptions import ProviderAPICallError
from app.models import TableProfilingJobRequest
from app.services import LLMGateway
from app.settings import DEFAULT_IMPORT_MODEL
from app.services.import_analysis import (
IMPORT_GATEWAY_BASE_URL,
resolve_provider_from_model,
)
from app.utils.llm_usage import extract_usage as extract_llm_usage
logger = logging.getLogger(__name__)
GE_REPORT_RELATIVE_PATH = Path("uncommitted") / "data_docs" / "local_site" / "index.html"
PROMPT_FILENAMES = {
"ge_result_desc": "ge_result_desc_prompt.md",
"snippet_generator": "snippet_generator.md",
"snippet_alias": "snippet_alias_generator.md",
}
DEFAULT_CHAT_TIMEOUT_SECONDS = 180.0
@dataclass
class GEProfilingArtifacts:
profiling_result: Dict[str, Any]
profiling_summary: Dict[str, Any]
docs_path: str
@dataclass
class LLMCallResult:
data: Any
usage: Optional[Dict[str, Any]] = None
class PipelineActionType:
GE_PROFILING = "ge_profiling"
GE_RESULT_DESC = "ge_result_desc"
SNIPPET = "snippet"
SNIPPET_ALIAS = "snippet_alias"
def _project_root() -> Path:
return Path(__file__).resolve().parents[2]
def _prompt_dir() -> Path:
return _project_root() / "prompt"
@lru_cache(maxsize=None)
def _load_prompt_parts(filename: str) -> Tuple[str, str]:
prompt_path = _prompt_dir() / filename
if not prompt_path.exists():
raise FileNotFoundError(f"Prompt template not found: {prompt_path}")
raw = prompt_path.read_text(encoding="utf-8")
splitter = "用户消息User"
if splitter not in raw:
raise ValueError(f"Prompt template '{filename}' missing separator '{splitter}'.")
system_raw, user_raw = raw.split(splitter, maxsplit=1)
system_text = system_raw.replace("系统角色System", "").strip()
user_text = user_raw.strip()
return system_text, user_text
def _render_prompt(template_key: str, replacements: Dict[str, str]) -> Tuple[str, str]:
filename = PROMPT_FILENAMES[template_key]
system_text, user_template = _load_prompt_parts(filename)
rendered_user = user_template
for key, value in replacements.items():
rendered_user = rendered_user.replace(key, value)
return system_text, rendered_user
def _extract_timeout_seconds(options: Optional[Dict[str, Any]]) -> Optional[float]:
if not options:
return None
value = options.get("llm_timeout_seconds")
if value is None:
return None
try:
timeout = float(value)
if timeout <= 0:
raise ValueError
return timeout
except (TypeError, ValueError):
logger.warning(
"Invalid llm_timeout_seconds value in extra_options: %r. Falling back to default.",
value,
)
return DEFAULT_CHAT_TIMEOUT_SECONDS
def _extract_json_payload(content: str) -> str:
fenced = re.search(
r"```(?:json)?\s*([\s\S]+?)```",
content,
flags=re.IGNORECASE,
)
if fenced:
snippet = fenced.group(1).strip()
if snippet:
return snippet
stripped = content.strip()
if not stripped:
raise ValueError("Empty LLM content.")
decoder = json.JSONDecoder()
for idx, char in enumerate(stripped):
if char not in {"{", "["}:
continue
try:
_, end = decoder.raw_decode(stripped[idx:])
except json.JSONDecodeError:
continue
candidate = stripped[idx : idx + end].strip()
if candidate:
return candidate
return stripped
def _parse_completion_payload(response_payload: Dict[str, Any]) -> Any:
choices = response_payload.get("choices") or []
if not choices:
raise ProviderAPICallError("LLM response did not contain choices to parse.")
message = choices[0].get("message") or {}
content = message.get("content") or ""
if not content.strip():
raise ProviderAPICallError("LLM response content is empty.")
json_payload = _extract_json_payload(content)
try:
return json.loads(json_payload)
except json.JSONDecodeError as exc:
preview = json_payload[:800]
logger.error("Failed to parse JSON from LLM response: %s", preview, exc_info=True)
raise ProviderAPICallError("LLM response JSON parsing failed.") from exc
async def _post_callback(callback_url: str, payload: Dict[str, Any], client: httpx.AsyncClient) -> None:
safe_payload = _normalize_for_json(payload)
try:
logger.info(
"Posting pipeline action callback to %s: %s",
callback_url,
json.dumps(safe_payload, ensure_ascii=False),
)
response = await client.post(callback_url, json=safe_payload)
response.raise_for_status()
except httpx.HTTPError as exc:
logger.error("Callback delivery to %s failed: %s", callback_url, exc, exc_info=True)
def _sanitize_value_set(value: Any, max_values: int) -> Tuple[Any, Optional[Dict[str, int]]]:
if not isinstance(value, list):
return value, None
original_len = len(value)
if original_len <= max_values:
return value, None
trimmed = value[:max_values]
return trimmed, {"original_length": original_len, "retained": max_values}
def _sanitize_expectation_suite(suite: ExpectationSuite, max_value_set_values: int = 100) -> Dict[str, Any]:
suite_dict = suite.to_json_dict()
remarks: List[Dict[str, Any]] = []
for expectation in suite_dict.get("expectations", []):
kwargs = expectation.get("kwargs", {})
if "value_set" in kwargs:
sanitized_value, note = _sanitize_value_set(kwargs["value_set"], max_value_set_values)
kwargs["value_set"] = sanitized_value
if note:
expectation.setdefault("meta", {})
expectation["meta"]["value_set_truncated"] = note
remarks.append(
{
"column": kwargs.get("column"),
"expectation": expectation.get("expectation_type"),
"note": note,
}
)
if remarks:
suite_dict.setdefault("meta", {})
suite_dict["meta"]["value_set_truncations"] = remarks
return suite_dict
def _summarize_expectation_suite(suite_dict: Dict[str, Any]) -> Dict[str, Any]:
column_map: Dict[str, Dict[str, Any]] = {}
table_expectations: List[Dict[str, Any]] = []
for expectation in suite_dict.get("expectations", []):
expectation_type = expectation.get("expectation_type")
kwargs = expectation.get("kwargs", {})
column = kwargs.get("column")
summary_entry: Dict[str, Any] = {"expectation": expectation_type}
if "value_set" in kwargs and isinstance(kwargs["value_set"], list):
summary_entry["value_set_size"] = len(kwargs["value_set"])
summary_entry["value_set_preview"] = kwargs["value_set"][:5]
if column:
column_entry = column_map.setdefault(
column,
{"name": column, "expectations": []},
)
column_entry["expectations"].append(summary_entry)
else:
table_expectations.append(summary_entry)
summary = {
"column_profiles": list(column_map.values()),
"table_level_expectations": table_expectations,
"total_expectations": len(suite_dict.get("expectations", [])),
}
return summary
def _sanitize_identifier(raw: Optional[str], fallback: str) -> str:
if not raw:
return fallback
candidate = re.sub(r"[^0-9A-Za-z_]+", "_", raw).strip("_")
return candidate or fallback
def _format_connection_string(template: str, access_info: Dict[str, Any]) -> str:
if not access_info:
return template
try:
return template.format_map({k: v for k, v in access_info.items()})
except KeyError as exc:
missing = exc.args[0]
raise ValueError(f"table_access_info missing key '{missing}' required by connection_string.") from exc
def _ensure_sql_runtime_datasource(
context: AbstractDataContext,
datasource_name: str,
connection_string: str,
) -> None:
try:
datasource = context.get_datasource(datasource_name)
except (DataContextError, ValueError) as exc:
message = str(exc)
if "Could not find a datasource" in message or "Unable to load datasource" in message:
datasource = None
else: # pragma: no cover - defensive
raise RuntimeError(f"Failed to inspect datasource '{datasource_name}'.") from exc
except Exception as exc: # pragma: no cover - defensive
raise RuntimeError(f"Failed to inspect datasource '{datasource_name}'.") from exc
if datasource is not None:
execution_engine = getattr(datasource, "execution_engine", None)
current_conn = getattr(execution_engine, "connection_string", None)
if current_conn and current_conn != connection_string:
logger.info(
"Existing datasource %s uses different connection string; creating dedicated runtime datasource.",
datasource_name,
)
try:
context.delete_datasource(datasource_name)
except Exception as exc: # pragma: no cover - defensive
logger.warning(
"Failed to delete datasource %s before recreation: %s",
datasource_name,
exc,
)
else:
datasource = None
if datasource is not None:
return
runtime_datasource_config = {
"name": datasource_name,
"class_name": "Datasource",
"execution_engine": {
"class_name": "SqlAlchemyExecutionEngine",
"connection_string": connection_string,
},
"data_connectors": {
"runtime_connector": {
"class_name": "RuntimeDataConnector",
"batch_identifiers": ["default_identifier_name"],
}
},
}
try:
context.add_datasource(**runtime_datasource_config)
except Exception as exc: # pragma: no cover - defensive
raise RuntimeError(f"Failed to create runtime datasource '{datasource_name}'.") from exc
def _build_sql_runtime_batch_request(
context: AbstractDataContext,
request: TableProfilingJobRequest,
) -> RuntimeBatchRequest:
link_info = request.table_link_info or {}
access_info = request.table_access_info or {}
connection_template = link_info.get("connection_string")
if not connection_template:
raise ValueError("table_link_info.connection_string is required when using table_link_info.")
connection_string = _format_connection_string(connection_template, access_info)
source_type = (link_info.get("type") or "sql").lower()
if source_type != "sql":
raise ValueError(f"Unsupported table_link_info.type='{source_type}'. Only 'sql' is supported.")
query = link_info.get("query")
table_name = link_info.get("table") or link_info.get("table_name")
schema_name = link_info.get("schema")
if not query and not table_name:
raise ValueError("Either table_link_info.query or table_link_info.table must be provided.")
if not query:
if not table_name:
raise ValueError("table_link_info.table must be provided when query is omitted.")
identifier = re.compile(r"^[A-Za-z_][A-Za-z0-9_$]*$")
def _quote(name: str) -> str:
if identifier.match(name):
return name
return f"`{name.replace('`', '``')}`"
if schema_name:
schema_part = schema_name if "." not in schema_name else schema_name.split(".")[-1]
table_part = table_name if "." not in table_name else table_name.split(".")[-1]
qualified_table = f"{_quote(schema_part)}.{_quote(table_part)}"
else:
qualified_table = _quote(table_name)
query = f"SELECT * FROM {qualified_table}"
limit = link_info.get("limit")
if isinstance(limit, int) and limit > 0:
query = f"{query} LIMIT {limit}"
datasource_name = request.ge_datasource_name or _sanitize_identifier(
f"{request.table_id}_runtime_ds", "runtime_ds"
)
data_asset_name = request.ge_data_asset_name or _sanitize_identifier(
table_name or "runtime_query", "runtime_query"
)
_ensure_sql_runtime_datasource(context, datasource_name, connection_string)
batch_identifiers = {
"default_identifier_name": f"{request.table_id}:{request.version_ts}",
}
return RuntimeBatchRequest(
datasource_name=datasource_name,
data_connector_name="runtime_connector",
data_asset_name=data_asset_name,
runtime_parameters={"query": query},
batch_identifiers=batch_identifiers,
)
def _run_onboarding_assistant(
context: AbstractDataContext,
batch_request: Any,
suite_name: str,
) -> Tuple[ExpectationSuite, Any]:
assistant = context.assistants.onboarding
assistant_result = assistant.run(batch_request=batch_request)
suite = assistant_result.get_expectation_suite(expectation_suite_name=suite_name)
context.save_expectation_suite(suite, expectation_suite_name=suite_name)
validation_getter = getattr(assistant_result, "get_validation_result", None)
if callable(validation_getter):
validation_result = validation_getter()
else:
validation_result = getattr(assistant_result, "validation_result", None)
if validation_result is None:
# Fallback: rerun validation using the freshly generated expectation suite.
validator = context.get_validator(
batch_request=batch_request,
expectation_suite_name=suite_name,
)
validation_result = validator.validate()
return suite, validation_result
def _resolve_context(request: TableProfilingJobRequest) -> AbstractDataContext:
context_kwargs: Dict[str, Any] = {}
if request.ge_data_context_root:
context_kwargs["project_root_dir"] = request.ge_data_context_root
elif os.environ.get("GE_DATA_CONTEXT_ROOT"):
context_kwargs["project_root_dir"] = os.environ["GE_DATA_CONTEXT_ROOT"]
else:
context_kwargs["project_root_dir"] = str(_project_root())
return gx.get_context(**context_kwargs)
def _build_batch_request(
context: AbstractDataContext,
request: TableProfilingJobRequest,
) -> Any:
if request.ge_batch_request:
from great_expectations.core.batch import BatchRequest
return BatchRequest(**request.ge_batch_request)
if request.table_link_info:
return _build_sql_runtime_batch_request(context, request)
if not request.ge_datasource_name or not request.ge_data_asset_name:
raise ValueError(
"ge_batch_request or (ge_datasource_name and ge_data_asset_name) must be provided."
)
datasource = context.get_datasource(request.ge_datasource_name)
data_asset = datasource.get_asset(request.ge_data_asset_name)
return data_asset.build_batch_request()
async def _run_ge_profiling(request: TableProfilingJobRequest) -> GEProfilingArtifacts:
def _execute() -> GEProfilingArtifacts:
context = _resolve_context(request)
suite_name = (
request.ge_expectation_suite_name
or f"{request.table_id}_profiling"
)
batch_request = _build_batch_request(context, request)
try:
context.get_expectation_suite(suite_name)
except DataContextError:
context.add_expectation_suite(suite_name)
validator = context.get_validator(
batch_request=batch_request,
expectation_suite_name=suite_name,
)
profiler_type = (request.ge_profiler_type or "user_configurable").lower()
if profiler_type == "data_assistant":
suite, validation_result = _run_onboarding_assistant(
context,
batch_request,
suite_name,
)
else:
try:
from great_expectations.profile.user_configurable_profiler import (
UserConfigurableProfiler,
)
except ImportError as err: # pragma: no cover - dependency guard
raise RuntimeError(
"UserConfigurableProfiler is unavailable; install great_expectations profiling extra or switch profiler."
) from err
profiler = UserConfigurableProfiler(profile_dataset=validator)
try:
suite = profiler.build_suite()
context.save_expectation_suite(suite, expectation_suite_name=suite_name)
validator.expectation_suite = suite
validation_result = validator.validate()
except MetricResolutionError as exc:
logger.warning(
"UserConfigurableProfiler failed (%s); falling back to data assistant profiling.",
exc,
)
suite, validation_result = _run_onboarding_assistant(
context,
batch_request,
suite_name,
)
sanitized_suite = _sanitize_expectation_suite(suite)
summary = _summarize_expectation_suite(sanitized_suite)
validation_dict = validation_result.to_json_dict()
context.build_data_docs()
docs_path = Path(context.root_directory) / GE_REPORT_RELATIVE_PATH
profiling_result = {
"expectation_suite": sanitized_suite,
"validation_result": validation_dict,
"batch_request": getattr(batch_request, "to_json_dict", lambda: None)() or getattr(batch_request, "dict", lambda: None)(),
}
return GEProfilingArtifacts(
profiling_result=profiling_result,
profiling_summary=summary,
docs_path=str(docs_path),
)
return await asyncio.to_thread(_execute)
async def _call_chat_completions(
*,
model_spec: str,
system_prompt: str,
user_prompt: str,
client: httpx.AsyncClient,
temperature: float = 0.2,
timeout_seconds: Optional[float] = None,
) -> Any:
provider, model_name = resolve_provider_from_model(model_spec)
payload = {
"provider": provider.value,
"model": model_name,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
"temperature": temperature,
}
payload_size_bytes = len(json.dumps(payload, ensure_ascii=False).encode("utf-8"))
url = f"{IMPORT_GATEWAY_BASE_URL.rstrip('/')}/v1/chat/completions"
try:
# log the request whole info
logger.info(
"Calling chat completions API %s with model %s and size %s and payload %s",
url,
model_name,
payload_size_bytes,
payload,
)
response = await client.post(url, json=payload, timeout=timeout_seconds)
response.raise_for_status()
except httpx.HTTPError as exc:
error_name = exc.__class__.__name__
detail = str(exc).strip()
if detail:
message = f"Chat completions request failed ({error_name}): {detail}"
else:
message = f"Chat completions request failed ({error_name})."
raise ProviderAPICallError(message) from exc
try:
response_payload = response.json()
except ValueError as exc:
raise ProviderAPICallError("Chat completions response was not valid JSON.") from exc
parsed_payload = _parse_completion_payload(response_payload)
usage_info = extract_llm_usage(response_payload)
return LLMCallResult(data=parsed_payload, usage=usage_info)
def _normalize_for_json(value: Any) -> Any:
if value is None or isinstance(value, (str, int, float, bool)):
return value
if isinstance(value, (datetime, date)):
return str(value)
if hasattr(value, "model_dump"):
try:
return value.model_dump()
except Exception: # pragma: no cover - defensive
pass
if is_dataclass(value):
return asdict(value)
if isinstance(value, dict):
return {k: _normalize_for_json(v) for k, v in value.items()}
if isinstance(value, (list, tuple, set)):
return [_normalize_for_json(v) for v in value]
if hasattr(value, "to_json_dict"):
try:
return value.to_json_dict()
except Exception: # pragma: no cover - defensive
pass
if hasattr(value, "__dict__"):
return _normalize_for_json(value.__dict__)
return repr(value)
def _json_dumps(data: Any) -> str:
normalised = _normalize_for_json(data)
return json.dumps(normalised, ensure_ascii=False, indent=2)
def _preview_for_log(data: Any) -> str:
try:
serialised = _json_dumps(data)
except Exception:
serialised = repr(data)
return serialised
def _profiling_request_for_log(request: TableProfilingJobRequest) -> Dict[str, Any]:
payload = request.model_dump()
access_info = payload.get("table_access_info")
if isinstance(access_info, dict):
payload["table_access_info"] = {key: "***" for key in access_info.keys()}
return payload
async def _execute_result_desc(
profiling_json: Dict[str, Any],
_request: TableProfilingJobRequest,
llm_model: str,
client: httpx.AsyncClient,
timeout_seconds: Optional[float],
) -> Dict[str, Any]:
system_prompt, user_prompt = _render_prompt(
"ge_result_desc",
{"{{GE_RESULT_JSON}}": _json_dumps(profiling_json)},
)
llm_output = await _call_chat_completions(
model_spec=llm_model,
system_prompt=system_prompt,
user_prompt=user_prompt,
client=client,
timeout_seconds=timeout_seconds,
)
if not isinstance(llm_output.data, dict):
raise ProviderAPICallError("GE result description payload must be a JSON object.")
return llm_output
async def _execute_snippet_generation(
table_desc_json: Dict[str, Any],
_request: TableProfilingJobRequest,
llm_model: str,
client: httpx.AsyncClient,
timeout_seconds: Optional[float],
) -> List[Dict[str, Any]]:
system_prompt, user_prompt = _render_prompt(
"snippet_generator",
{"{{TABLE_PROFILE_JSON}}": _json_dumps(table_desc_json)},
)
llm_output = await _call_chat_completions(
model_spec=llm_model,
system_prompt=system_prompt,
user_prompt=user_prompt,
client=client,
timeout_seconds=timeout_seconds,
)
if not isinstance(llm_output.data, list):
raise ProviderAPICallError("Snippet generator must return a JSON array.")
return llm_output
async def _execute_snippet_alias(
snippets_json: List[Dict[str, Any]],
_request: TableProfilingJobRequest,
llm_model: str,
client: httpx.AsyncClient,
timeout_seconds: Optional[float],
) -> List[Dict[str, Any]]:
system_prompt, user_prompt = _render_prompt(
"snippet_alias",
{"{{SNIPPET_ARRAY}}": _json_dumps(snippets_json)},
)
llm_output = await _call_chat_completions(
model_spec=llm_model,
system_prompt=system_prompt,
user_prompt=user_prompt,
client=client,
timeout_seconds=timeout_seconds,
)
if not isinstance(llm_output.data, list):
raise ProviderAPICallError("Snippet alias generator must return a JSON array.")
return llm_output
async def _run_action_with_callback(
*,
action_type: str,
runner,
callback_base: Dict[str, Any],
client: httpx.AsyncClient,
callback_url: str,
input_payload: Any = None,
model_spec: Optional[str] = None,
) -> Any:
if input_payload is not None:
logger.info(
"Pipeline action %s input: %s",
action_type,
_preview_for_log(input_payload),
)
try:
result = await runner()
except Exception as exc:
failure_payload = dict(callback_base)
failure_payload.update(
{
"status": "failed",
"action_type": action_type,
"error": str(exc),
}
)
if model_spec is not None:
failure_payload["model"] = model_spec
await _post_callback(callback_url, failure_payload, client)
raise
usage_info: Optional[Dict[str, Any]] = None
result_payload = result
if isinstance(result, LLMCallResult):
usage_info = result.usage
result_payload = result.data
success_payload = dict(callback_base)
success_payload.update(
{
"status": "success",
"action_type": action_type,
}
)
if model_spec is not None:
success_payload["model"] = model_spec
logger.info(
"Pipeline action %s output: %s",
action_type,
_preview_for_log(result_payload),
)
if action_type == PipelineActionType.GE_PROFILING:
artifacts: GEProfilingArtifacts = result_payload
success_payload["ge_profiling_json"] = artifacts.profiling_result
success_payload["ge_profiling_summary"] = artifacts.profiling_summary
success_payload["ge_report_path"] = artifacts.docs_path
elif action_type == PipelineActionType.GE_RESULT_DESC:
success_payload["ge_result_desc_json"] = result_payload
elif action_type == PipelineActionType.SNIPPET:
success_payload["snippet_json"] = result_payload
elif action_type == PipelineActionType.SNIPPET_ALIAS:
success_payload["snippet_alias_json"] = result_payload
if usage_info:
success_payload["llm_usage"] = usage_info
await _post_callback(callback_url, success_payload, client)
return result_payload
async def process_table_profiling_job(
request: TableProfilingJobRequest,
_gateway: LLMGateway,
client: httpx.AsyncClient,
) -> None:
"""Sequentially execute the four-step profiling pipeline and emit callbacks per action."""
timeout_seconds = _extract_timeout_seconds(request.extra_options)
if timeout_seconds is None:
timeout_seconds = DEFAULT_CHAT_TIMEOUT_SECONDS
base_payload = {
"table_id": request.table_id,
"version_ts": request.version_ts,
"callback_url": str(request.callback_url),
"table_schema": request.table_schema,
"table_schema_version_id": request.table_schema_version_id,
"llm_model": request.llm_model,
"llm_timeout_seconds": timeout_seconds,
}
logging_request_payload = _profiling_request_for_log(request)
try:
artifacts: GEProfilingArtifacts = await _run_action_with_callback(
action_type=PipelineActionType.GE_PROFILING,
runner=lambda: _run_ge_profiling(request),
callback_base=base_payload,
client=client,
callback_url=str(request.callback_url),
input_payload=logging_request_payload,
model_spec=request.llm_model,
)
table_desc_json: Dict[str, Any] = await _run_action_with_callback(
action_type=PipelineActionType.GE_RESULT_DESC,
runner=lambda: _execute_result_desc(
artifacts.profiling_result,
request,
request.llm_model,
client,
timeout_seconds,
),
callback_base=base_payload,
client=client,
callback_url=str(request.callback_url),
input_payload=artifacts.profiling_result,
model_spec=request.llm_model,
)
snippet_json: List[Dict[str, Any]] = await _run_action_with_callback(
action_type=PipelineActionType.SNIPPET,
runner=lambda: _execute_snippet_generation(
table_desc_json,
request,
request.llm_model,
client,
timeout_seconds,
),
callback_base=base_payload,
client=client,
callback_url=str(request.callback_url),
input_payload=table_desc_json,
model_spec=request.llm_model,
)
await _run_action_with_callback(
action_type=PipelineActionType.SNIPPET_ALIAS,
runner=lambda: _execute_snippet_alias(
snippet_json,
request,
request.llm_model,
client,
timeout_seconds,
),
callback_base=base_payload,
client=client,
callback_url=str(request.callback_url),
input_payload=snippet_json,
model_spec=request.llm_model,
)
except Exception: # pragma: no cover - defensive catch
logger.exception(
"Table profiling pipeline failed for table_id=%s version_ts=%s",
request.table_id,
request.version_ts,
)

View File

@ -0,0 +1,206 @@
from __future__ import annotations
import json
import logging
from typing import Any, Dict, Tuple
from sqlalchemy import text
from sqlalchemy.engine import Engine
from sqlalchemy.exc import SQLAlchemyError
from app.db import get_engine
from app.models import (
ActionType,
TableSnippetUpsertRequest,
TableSnippetUpsertResponse,
)
logger = logging.getLogger(__name__)
def _serialize_json(value: Any) -> Tuple[str | None, int | None]:
logger.debug("Serializing JSON payload: %s", value)
if value is None:
return None, None
if isinstance(value, str):
encoded = value.encode("utf-8")
return value, len(encoded)
serialized = json.dumps(value, ensure_ascii=False)
encoded = serialized.encode("utf-8")
return serialized, len(encoded)
def _prepare_table_schema(value: Any) -> str:
logger.debug("Preparing table_schema payload.")
if isinstance(value, str):
return value
return json.dumps(value, ensure_ascii=False)
def _collect_common_columns(request: TableSnippetUpsertRequest) -> Dict[str, Any]:
logger.debug(
"Collecting common columns for table_id=%s version_ts=%s action_type=%s",
request.table_id,
request.version_ts,
request.action_type,
)
payload: Dict[str, Any] = {
"table_id": request.table_id,
"version_ts": request.version_ts,
"action_type": request.action_type.value,
"status": request.status.value,
"callback_url": str(request.callback_url),
"table_schema_version_id": request.table_schema_version_id,
"table_schema": _prepare_table_schema(request.table_schema),
}
payload.update(
{
"ge_profiling_json": None,
"ge_profiling_json_size_bytes": None,
"ge_profiling_summary": None,
"ge_profiling_summary_size_bytes": None,
"ge_profiling_total_size_bytes": None,
"ge_profiling_html_report_url": None,
"ge_result_desc_json": None,
"ge_result_desc_json_size_bytes": None,
"snippet_json": None,
"snippet_json_size_bytes": None,
"snippet_alias_json": None,
"snippet_alias_json_size_bytes": None,
}
)
if request.llm_usage is not None:
llm_usage_json, _ = _serialize_json(request.llm_usage)
if llm_usage_json is not None:
payload["llm_usage"] = llm_usage_json
if request.error_code is not None:
logger.debug("Adding error_code: %s", request.error_code)
payload["error_code"] = request.error_code
if request.error_message is not None:
logger.debug("Adding error_message: %s", request.error_message)
payload["error_message"] = request.error_message
if request.started_at is not None:
payload["started_at"] = request.started_at
if request.finished_at is not None:
payload["finished_at"] = request.finished_at
if request.duration_ms is not None:
payload["duration_ms"] = request.duration_ms
if request.result_checksum is not None:
payload["result_checksum"] = request.result_checksum
logger.debug("Collected common payload: %s", payload)
return payload
def _apply_action_payload(
request: TableSnippetUpsertRequest,
payload: Dict[str, Any],
) -> None:
logger.debug("Applying action-specific payload for action_type=%s", request.action_type)
if request.action_type == ActionType.GE_PROFILING:
full_json, full_size = _serialize_json(request.ge_profiling_json)
summary_json, summary_size = _serialize_json(request.ge_profiling_summary)
if full_json is not None:
payload["ge_profiling_json"] = full_json
payload["ge_profiling_json_size_bytes"] = full_size
if summary_json is not None:
payload["ge_profiling_summary"] = summary_json
payload["ge_profiling_summary_size_bytes"] = summary_size
if request.ge_profiling_total_size_bytes is not None:
payload["ge_profiling_total_size_bytes"] = request.ge_profiling_total_size_bytes
elif full_size is not None or summary_size is not None:
payload["ge_profiling_total_size_bytes"] = (full_size or 0) + (summary_size or 0)
if request.ge_profiling_html_report_url:
payload["ge_profiling_html_report_url"] = request.ge_profiling_html_report_url
elif request.action_type == ActionType.GE_RESULT_DESC:
full_json, full_size = _serialize_json(request.ge_result_desc_json)
if full_json is not None:
payload["ge_result_desc_json"] = full_json
payload["ge_result_desc_json_size_bytes"] = full_size
elif request.action_type == ActionType.SNIPPET:
full_json, full_size = _serialize_json(request.snippet_json)
if full_json is not None:
payload["snippet_json"] = full_json
payload["snippet_json_size_bytes"] = full_size
elif request.action_type == ActionType.SNIPPET_ALIAS:
full_json, full_size = _serialize_json(request.snippet_alias_json)
if full_json is not None:
payload["snippet_alias_json"] = full_json
payload["snippet_alias_json_size_bytes"] = full_size
else:
logger.error("Unsupported action type encountered: %s", request.action_type)
raise ValueError(f"Unsupported action type '{request.action_type}'.")
logger.debug("Payload after applying action-specific data: %s", payload)
def _build_insert_statement(columns: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
logger.debug("Building insert statement for columns: %s", list(columns.keys()))
column_names = list(columns.keys())
placeholders = [f":{name}" for name in column_names]
update_assignments = [
f"{name}=VALUES({name})"
for name in column_names
if name not in {"table_id", "version_ts", "action_type"}
]
update_assignments.append("updated_at=CURRENT_TIMESTAMP")
sql = (
"INSERT INTO action_results ({cols}) VALUES ({vals}) "
"ON DUPLICATE KEY UPDATE {updates}"
).format(
cols=", ".join(column_names),
vals=", ".join(placeholders),
updates=", ".join(update_assignments),
)
logger.debug("Generated SQL: %s", sql)
return sql, columns
def _execute_upsert(engine: Engine, sql: str, params: Dict[str, Any]) -> int:
logger.info("Executing upsert for table_id=%s version_ts=%s action_type=%s", params.get("table_id"), params.get("version_ts"), params.get("action_type"))
with engine.begin() as conn:
result = conn.execute(text(sql), params)
logger.info("Rows affected: %s", result.rowcount)
return result.rowcount
def upsert_action_result(request: TableSnippetUpsertRequest) -> TableSnippetUpsertResponse:
logger.info(
"Received upsert request: table_id=%s version_ts=%s action_type=%s status=%s",
request.table_id,
request.version_ts,
request.action_type,
request.status,
)
logger.debug("Request payload: %s", request.model_dump())
columns = _collect_common_columns(request)
_apply_action_payload(request, columns)
sql, params = _build_insert_statement(columns)
logger.debug("Final SQL params: %s", params)
engine = get_engine()
try:
rowcount = _execute_upsert(engine, sql, params)
except SQLAlchemyError as exc:
logger.exception(
"Failed to upsert action result: table_id=%s version_ts=%s action_type=%s",
request.table_id,
request.version_ts,
request.action_type,
)
raise RuntimeError(f"Database operation failed: {exc}") from exc
updated = rowcount > 1
return TableSnippetUpsertResponse(
table_id=request.table_id,
version_ts=request.version_ts,
action_type=request.action_type,
status=request.status,
updated=updated,
)

116
app/utils/llm_usage.py Normal file
View File

@ -0,0 +1,116 @@
from __future__ import annotations
from typing import Any, Dict, Iterable, Optional
PROMPT_TOKEN_KEYS: tuple[str, ...] = ("prompt_tokens", "input_tokens", "promptTokenCount")
COMPLETION_TOKEN_KEYS: tuple[str, ...] = (
"completion_tokens",
"output_tokens",
"candidatesTokenCount",
)
TOTAL_TOKEN_KEYS: tuple[str, ...] = ("total_tokens", "totalTokenCount")
USAGE_CONTAINER_KEYS: tuple[str, ...] = ("usage", "usageMetadata", "usage_metadata")
def _normalize_usage_value(value: Any) -> Any:
if isinstance(value, (int, float)):
return int(value)
if isinstance(value, str):
stripped = value.strip()
if not stripped:
return None
try:
numeric = float(stripped)
except ValueError:
return None
return int(numeric)
if isinstance(value, dict):
normalized: Dict[str, Any] = {}
for key, nested_value in value.items():
normalized_value = _normalize_usage_value(nested_value)
if normalized_value is not None:
normalized[key] = normalized_value
return normalized or None
if isinstance(value, (list, tuple, set)):
normalized_list = [
item for item in (_normalize_usage_value(element) for element in value) if item is not None
]
return normalized_list or None
return None
def _first_numeric(payload: Dict[str, Any], keys: Iterable[str]) -> Optional[int]:
for key in keys:
value = payload.get(key)
if isinstance(value, (int, float)):
return int(value)
return None
def _canonicalize_counts(payload: Dict[str, Any]) -> None:
prompt = _first_numeric(payload, PROMPT_TOKEN_KEYS)
completion = _first_numeric(payload, COMPLETION_TOKEN_KEYS)
total = _first_numeric(payload, TOTAL_TOKEN_KEYS)
if prompt is not None:
payload["prompt_tokens"] = prompt
else:
payload.pop("prompt_tokens", None)
if completion is not None:
payload["completion_tokens"] = completion
else:
payload.pop("completion_tokens", None)
if total is not None:
payload["total_tokens"] = total
elif prompt is not None and completion is not None:
payload["total_tokens"] = prompt + completion
else:
payload.pop("total_tokens", None)
for alias in PROMPT_TOKEN_KEYS[1:]:
payload.pop(alias, None)
for alias in COMPLETION_TOKEN_KEYS[1:]:
payload.pop(alias, None)
for alias in TOTAL_TOKEN_KEYS[1:]:
payload.pop(alias, None)
def _extract_usage_container(candidate: Any) -> Optional[Dict[str, Any]]:
if not isinstance(candidate, dict):
return None
for key in USAGE_CONTAINER_KEYS:
value = candidate.get(key)
if isinstance(value, dict):
return value
return None
def extract_usage(payload: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""Unified helper to parse token usage metadata from diverse provider responses."""
if not isinstance(payload, dict):
return None
usage_candidate = _extract_usage_container(payload)
if usage_candidate is None:
raw_section = payload.get("raw")
usage_candidate = _extract_usage_container(raw_section)
if usage_candidate is None:
return None
normalized = _normalize_usage_value(usage_candidate)
if not isinstance(normalized, dict):
return None
_canonicalize_counts(normalized)
return normalized or None
__all__ = ["extract_usage"]

13
docker-compose.yml Normal file
View File

@ -0,0 +1,13 @@
services:
app:
build: .
ports:
- "8060:8000"
volumes:
- .:/app
environment:
- PYTHONUNBUFFERED=1
# 开发模式:启用 --reload
command: uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
# 生产模式:注释上面 command取消注释下面这行
# command: uvicorn app.main:app --host 0.0.0.0 --port 8000 --workers 4

File diff suppressed because it is too large Load Diff

View File

@ -121,7 +121,7 @@ def clean_value(value: Any) -> Any:
if isinstance(value, (np.generic,)): if isinstance(value, (np.generic,)):
return value.item() return value.item()
if isinstance(value, pd.Timestamp): if isinstance(value, pd.Timestamp):
return value.isoformat() return str(value)
if pd.isna(value): if pd.isna(value):
return None return None
return value return value

30
logging.yaml Normal file
View File

@ -0,0 +1,30 @@
version: 1
formatters:
standard:
format: "%(asctime)s %(levelname)s %(name)s:%(lineno)d %(message)s"
handlers:
console:
class: logging.StreamHandler
level: INFO
formatter: standard
stream: ext://sys.stdout
file:
class: logging.handlers.RotatingFileHandler
level: INFO
formatter: standard
filename: logs/app.log
maxBytes: 10485760 # 10 MB
backupCount: 5
encoding: utf-8
loggers:
app:
level: INFO
handlers:
- console
- file
propagate: no
root:
level: INFO
handlers:
- console
- file

View File

@ -0,0 +1,47 @@
系统角色System
你是“数据画像抽取器”。输入是一段 Great Expectations 的 profiling/validation 结果 JSON
可能包含列级期望expect_*)、统计、样例值、类型推断等;也可能带表级/批次元数据。
请将其归一化为一个可被程序消费的“表画像”JSON对不确定项给出置信度与理由。
禁止臆造不存在的列、时间范围或数值。
用户消息User
【输入GE结果JSON】
{{GE_RESULT_JSON}}
【输出要求只输出JSON不要解释文字
{
"table": "<库.表 表名>",
"row_count": <int|null>, // 若未知可为 null
"role": "fact|dimension|unknown", // 依据指标/维度占比与唯一性启发式
"grain": ["<列1>", "<列2>", ...], // 事实粒度猜测(如含 dt/店/类目)
"time": { "column": "<name>|null", "granularity": "day|week|month|unknown", "range": ["YYYY-MM-DD","YYYY-MM-DD"]|null, "has_gaps": true|false|null },
"columns": [
{
"name": "<col>",
"dtype": "<ge推断/物理类型>",
"semantic_type": "dimension|metric|time|text|id|unknown",
"null_rate": <0~1|null>,
"distinct_count": <int|null>,
"distinct_ratio": <0~1|null>,
"stats": { "min": <number|string|null>,"max": <number|string|null>,"mean": <number|null>,"std": <number|null>,"skewness": <number|null> },
"enumish": true|false|null, // 低熵/可枚举
"top_values": [{"value":"<v>","pct":<0~1>}, ...],// 取前K个≤10
"pk_candidate_score": <0~1>, // 唯一性+非空综合评分
"metric_candidate_score": <0~1>, // 数值/偏态/业务词命中
"comment": "<列注释或GE描述|可为空>"
}
],
"primary_key_candidates": [["colA","colB"], ...], // 依据 unique/compound unique 期望
"fk_candidates": [{"from":"<col>","to":"<dim_table(col)>","confidence":<0~1>}],
"quality": {
"failed_expectations": [{"name":"<expect_*>","column":"<col|table>","summary":"<一句话>"}],
"warning_hints": ["空值率>0.2的列: ...", "时间列存在缺口: ..."]
},
"confidence_notes": ["<为什么判定role/grain/time列>"]
}
【判定规则(简要)】
- time列类型为日期/时间 OR 命中 dt/date/day 等命名;若有 min/max 可给出 range若间隔缺口≥1天记 has_gaps=true。
- semantic_type数值+右偏/方差大→更偏 metric高唯一/ID命名→id高基数+文本→text低熵+有限取值→dimension。
- rolemetric列占比高且存在time列→倾向 fact几乎全是枚举/ID且少数值→dimension。
- 置信不高时给出 null 或 unknown并写入 confidence_notes。

View File

@ -0,0 +1,52 @@
系统角色System
你是“SQL片段别名生成器”。
输入为一个或多个 SQL 片段对象(来自 snippet.json输出为针对每个片段生成的多样化别名口语 / 中性 / 专业)、关键词与意图标签。
要求逐个处理所有片段对象,输出同样数量的 JSON 元素。
用户消息User
【上下文】
SQL片段对象数组{{SNIPPET_ARRAY}} // snippet.json中的一个或多个片段
【任务要求】
请针对输入数组中的 每个 SQL 片段,输出一个 JSON 对象,结构如下:
{
"id": "<与输入片段id一致>",
"aliases": [
{"text": "…", "tone": "口语|中性|专业"},
{"text": "…", "tone": "专业"}
],
"keywords": [
"GMV","销售额","TopN","category","类目","趋势","同比","客户","订单","质量","异常检测","join","过滤","sample"
],
"intent_tags": ["aggregate","trend","topn","ratio","quality","join","sample","filter","by_dimension"]
}
生成逻辑规范
1.逐条输出输入数组中每个片段对应一个输出对象id 保持一致)。
2.aliases生成
至少 3 个别名,分别覆盖语气类型:口语 / 中性 / 专业。
≤20字语义需等价不得添加不存在的字段或业务口径。
示例:
GMV趋势分析中性
每天卖多少钱(口语)
按日GMV曲线专业
3.keywords生成
8~15个关键词需涵盖片段核心维度、指标、分析类型和语义近义词。
中英文混合(如 "GMV"/"销售额"、"同比"/"YoY"、"类目"/"category" 等)。
包含用于匹配的分析意图关键词(如 “趋势”、“排行”、“占比”、“质量检查”、“过滤” 等)。
4.intent_tags生成
从以下集合中选取与片段type及用途一致
["aggregate","trend","topn","ratio","quality","join","sample","filter","by_dimension"]
若为条件片段WHERE句型补充 "filter";若含维度分组逻辑,补充 "by_dimension"。
5.语言与内容要求
保持正式书面风格,不添加解释说明。
只输出JSON数组不包含文字描述或额外文本。

View File

@ -0,0 +1,46 @@
系统角色System
你是“SQL片段生成器”。只能基于给定“表画像”生成可复用的分析片段。
为每个片段产出标题、用途描述、片段类型、变量、适用条件、SQL模板mysql方言并注明业务口径与安全限制。
不要发明画像里没有的列。时间/维度/指标须与画像匹配。
用户消息User
【表画像JSON】
{{TABLE_PROFILE_JSON}}
【输出要求只输出JSON数组
[
{
"id": "snpt_<slug>",
"title": "中文标题≤16字",
"desc": "一句话用途",
"type": "aggregate|trend|topn|ratio|quality|join|sample",
"applicability": {
"required_columns": ["<col>", ...],
"time_column": "<dt|nullable>",
"constraints": {
"dim_cardinality_hint": <int|null>, // 用于TopN限制与性能提示
"fk_join_available": true|false,
"notes": ["高基数维度建议LIMIT<=50", "..."]
}
},
"variables": [
{"name":"start_date","type":"date"},
{"name":"end_date","type":"date"},
{"name":"top_n","type":"int","default":10}
],
"dialect_sql": {
"mysql": ""
},
"business_caliber": "清晰口径说明,如 UV以device_id去重粒度=日-类目",
"examples": ["示例问法1","示例问法2"]
}
]
【片段选择建议】
- 若存在 time 列:生成 trend_by_day / yoy_qoq / moving_avg。
- 若存在 enumish 维度distinct 5~200生成 topn_by_dimension / share_of_total。
- 若 metric 列:生成 sum/avg/max、分位数/异常检测3σ/箱线)。
- 有主键/唯一:生成 去重/明细抽样/质量检查。
- 有 fk_candidates同时生成“join维表命名版”和“纯ID版”。
- 高枚举维度:在 constraints.notes 中强调 LIMIT 建议与可能的性能风险。
- 除了完整的sql片段还有sql里部分内容的sql片段比如 where payment_method = 'Credit Card' and delivery_status = 'Deliverd' 的含义是支付方式为信用卡且配送状态是已送达

View File

@ -9,3 +9,5 @@ numpy>=1.24
openpyxl>=3.1 openpyxl>=3.1
httpx==0.27.2 httpx==0.27.2
python-dotenv==1.0.1 python-dotenv==1.0.1
requests>=2.31.0
PyYAML>=6.0.1

View File

@ -0,0 +1,226 @@
import argparse
import logging
import os
from typing import Dict, Iterable, List, Optional
import datasets
from datasets import DownloadConfig
from huggingface_hub import snapshot_download
# 批量下载 Hugging Face 上的数据集和模型
# 支持通过命令行参数配置代理和下载参数如超时和重试次数支持批量循环下载存储到file目录下dataset和model子目录
def _parse_id_list(values: Iterable[str]) -> List[str]:
"""将多次传入以及逗号分隔的标识整理为列表."""
ids: List[str] = []
for value in values:
value = value.strip()
if not value:
continue
if "," in value:
ids.extend(v.strip() for v in value.split(",") if v.strip())
else:
ids.append(value)
return ids
def _parse_proxy_args(proxy_args: Iterable[str]) -> Dict[str, str]:
"""解析命令行传入的代理设置,格式 scheme=url."""
proxies: Dict[str, str] = {}
for item in proxy_args:
raw = item.strip()
if not raw:
continue
if "=" not in raw:
logging.warning("代理参数 %s 缺少 '=' 分隔符,将忽略该项", raw)
continue
key, value = raw.split("=", 1)
key = key.strip()
value = value.strip()
if not key or not value:
logging.warning("代理参数 %s 解析失败,将忽略该项", raw)
continue
proxies[key] = value
return proxies
def _sanitize_dir_name(name: str) -> str:
return name.replace("/", "__")
def _ensure_dirs(root_dir: str) -> Dict[str, str]:
paths = {
"dataset": os.path.join(root_dir, "dataset"),
"model": os.path.join(root_dir, "model"),
}
for path in paths.values():
os.makedirs(path, exist_ok=True)
return paths
def _build_download_config(cache_dir: str, retries: Optional[int], proxies: Dict[str, str]) -> DownloadConfig:
config_kwargs = {"cache_dir": cache_dir}
if retries is not None:
config_kwargs["max_retries"] = retries
if proxies:
config_kwargs["proxies"] = proxies
return DownloadConfig(**config_kwargs)
def _apply_timeout(timeout: Optional[float]) -> None:
if timeout is None:
return
str_timeout = str(timeout)
os.environ.setdefault("HF_DATASETS_HTTP_TIMEOUT", str_timeout)
os.environ.setdefault("HF_HUB_HTTP_TIMEOUT", str_timeout)
def _resolve_log_level(level_name: str) -> int:
if isinstance(level_name, int):
return level_name
upper_name = str(level_name).upper()
return getattr(logging, upper_name, logging.INFO)
def _build_argument_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="批量下载 Hugging Face 数据集和模型并存储到指定目录。"
)
parser.add_argument(
"-d",
"--dataset",
action="append",
default=[],
help="要下载的数据集 ID可重复使用或传入逗号分隔列表。",
)
parser.add_argument(
"-m",
"--model",
action="append",
default=[],
help="要下载的模型 ID可重复使用或传入逗号分隔列表。",
)
parser.add_argument(
"-r",
"--root",
default="file",
help="存储根目录,默认 file。",
)
parser.add_argument(
"--retries",
type=int,
default=None,
help="失败后的重试次数,默认不重试。",
)
parser.add_argument(
"--timeout",
type=float,
default=None,
help="HTTP 超时时间(秒),默认跟随库设置。",
)
parser.add_argument(
"-p",
"--proxy",
action="append",
default=[],
help="代理设置,格式 scheme=url可多次传入例如 --proxy http=http://127.0.0.1:7890",
)
parser.add_argument(
"--log-level",
default="INFO",
help="日志级别,默认 INFO。",
)
return parser
def download_datasets(dataset_ids: Iterable[str], root_dir: str, retries: Optional[int], proxies: Dict[str, str]) -> None:
if not dataset_ids:
return
cache_dir = root_dir
download_config = _build_download_config(cache_dir, retries, proxies)
for dataset_id in dataset_ids:
try:
logging.info("开始下载数据集 %s", dataset_id)
# 使用 load_dataset 触发缓存下载
dataset = datasets.load_dataset(
dataset_id,
cache_dir=cache_dir,
download_config=download_config,
download_mode="reuse_cache_if_exists",
)
target_path = os.path.join(root_dir, _sanitize_dir_name(dataset_id))
dataset.save_to_disk(target_path)
logging.info("数据集 %s 下载完成,存储于 %s", dataset_id, target_path)
except Exception as exc: # pylint: disable=broad-except
logging.error("下载数据集 %s 失败: %s", dataset_id, exc)
def download_models(
model_ids: Iterable[str],
target_dir: str,
retries: Optional[int],
proxies: Dict[str, str],
timeout: Optional[float],
) -> None:
if not model_ids:
return
max_attempts = (retries or 0) + 1
hub_kwargs = {
"local_dir": target_dir,
"local_dir_use_symlinks": False,
"max_workers": os.cpu_count() or 4,
}
if proxies:
hub_kwargs["proxies"] = proxies
if timeout is not None:
hub_kwargs["timeout"] = timeout
for model_id in model_ids:
attempt = 0
while attempt < max_attempts:
attempt += 1
try:
logging.info("开始下载模型 %s (尝试 %s/%s)", model_id, attempt, max_attempts)
snapshot_download(
repo_id=model_id,
**hub_kwargs,
)
logging.info("模型 %s 下载完成,存储于 %s", model_id, target_dir)
break
except Exception as exc: # pylint: disable=broad-except
logging.error("下载模型 %s 失败: %s", model_id, exc)
if attempt >= max_attempts:
logging.error("模型 %s 在重试后仍未成功下载", model_id)
def main() -> None:
parser = _build_argument_parser()
args = parser.parse_args()
logging.basicConfig(
level=_resolve_log_level(args.log_level),
format="%(asctime)s - %(levelname)s - %(message)s",
)
dataset_ids = _parse_id_list(args.dataset)
model_ids = _parse_id_list(args.model)
retries = args.retries
timeout = args.timeout
proxies = _parse_proxy_args(args.proxy)
_apply_timeout(timeout)
if not dataset_ids and not model_ids:
logging.warning(
"未配置任何数据集或模型,"
"请通过参数 --dataset / --model 指定 Hugging Face ID"
)
return
dirs = _ensure_dirs(args.root)
download_datasets(dataset_ids, dirs["dataset"], retries, proxies)
download_models(model_ids, dirs["model"], retries, proxies, timeout)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,80 @@
from __future__ import annotations
import json
import os
import sys
from datetime import datetime
from typing import Any, Dict
import requests
def build_demo_payload() -> Dict[str, Any]:
now = datetime.utcnow()
started_at = now.replace(microsecond=0).isoformat() + "Z"
finished_at = now.replace(microsecond=0).isoformat() + "Z"
return {
"table_id": 42,
"version_ts": 20251101200000,
"action_type": "snippet",
"status": "success",
"callback_url": "http://localhost:9999/dummy-callback",
"table_schema_version_id": 7,
"table_schema": {
"columns": [
{"name": "order_id", "type": "bigint"},
{"name": "order_dt", "type": "date"},
{"name": "gmv", "type": "decimal(18,2)"},
]
},
"result_json": [
{
"id": "snpt_daily_gmv",
"title": "按日GMV",
"desc": "统计每日GMV总额",
"type": "trend",
"dialect_sql": {
"mysql": "SELECT order_dt, SUM(gmv) AS total_gmv FROM orders GROUP BY order_dt ORDER BY order_dt"
},
}
],
"result_summary_json": {"total_snippets": 1},
"html_report_url": None,
"error_code": None,
"error_message": None,
"started_at": started_at,
"finished_at": finished_at,
"duration_ms": 1234,
"result_checksum": "demo-checksum",
}
def main() -> int:
base_url = os.getenv("TABLE_SNIPPET_DEMO_BASE_URL", "http://localhost:8000")
endpoint = f"{base_url.rstrip('/')}/v1/table/snippet"
payload = build_demo_payload()
print(f"POST {endpoint}")
print(json.dumps(payload, ensure_ascii=False, indent=2))
try:
response = requests.post(endpoint, json=payload, timeout=30)
except requests.RequestException as exc:
print(f"Request failed: {exc}", file=sys.stderr)
return 1
print(f"\nStatus: {response.status_code}")
try:
data = response.json()
print("Response JSON:")
print(json.dumps(data, ensure_ascii=False, indent=2))
except ValueError:
print("Response Text:")
print(response.text)
return 0 if response.ok else 1
if __name__ == "__main__":
raise SystemExit(main())

37
table_snippet.sql Normal file
View File

@ -0,0 +1,37 @@
CREATE TABLE `action_results` (
`id` bigint NOT NULL AUTO_INCREMENT COMMENT '主键',
`table_id` bigint NOT NULL COMMENT '表ID',
`version_ts` bigint NOT NULL COMMENT '版本时间戳(版本号)',
`action_type` enum('ge_profiling','ge_result_desc','snippet','snippet_alias') COLLATE utf8mb4_bin NOT NULL COMMENT '动作类型',
`status` enum('pending','running','success','failed','partial') COLLATE utf8mb4_bin NOT NULL DEFAULT 'pending' COMMENT '执行状态',
`llm_usage` json DEFAULT NULL COMMENT 'LLM token usage统计',
`error_code` varchar(128) COLLATE utf8mb4_bin DEFAULT NULL,
`error_message` text COLLATE utf8mb4_bin,
`started_at` datetime DEFAULT NULL,
`finished_at` datetime DEFAULT NULL,
`duration_ms` int DEFAULT NULL,
`table_schema_version_id` varchar(19) COLLATE utf8mb4_bin NOT NULL,
`table_schema` json NOT NULL,
`ge_profiling_json` json DEFAULT NULL COMMENT 'Profiling完整结果JSON',
`ge_profiling_json_size_bytes` bigint DEFAULT NULL,
`ge_profiling_summary` json DEFAULT NULL COMMENT 'Profiling摘要剔除大value_set等',
`ge_profiling_summary_size_bytes` bigint DEFAULT NULL,
`ge_profiling_total_size_bytes` bigint DEFAULT NULL COMMENT '上两者合计',
`ge_profiling_html_report_url` varchar(1024) COLLATE utf8mb4_bin DEFAULT NULL COMMENT 'GE报告HTML路径/URL',
`ge_result_desc_json` json DEFAULT NULL COMMENT '表描述结果JSON',
`ge_result_desc_json_size_bytes` bigint DEFAULT NULL,
`snippet_json` json DEFAULT NULL COMMENT 'SQL知识片段结果JSON',
`snippet_json_size_bytes` bigint DEFAULT NULL,
`snippet_alias_json` json DEFAULT NULL COMMENT 'SQL片段改写/丰富结果JSON',
`snippet_alias_json_size_bytes` bigint DEFAULT NULL,
`callback_url` varchar(1024) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NOT NULL,
`result_checksum` varbinary(32) DEFAULT NULL COMMENT '对当前action有效载荷计算的MD5/xxhash',
`created_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP,
`updated_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
PRIMARY KEY (`id`),
UNIQUE KEY `uq_table_ver_action` (`table_id`,`version_ts`,`action_type`),
KEY `idx_status` (`status`),
KEY `idx_table` (`table_id`,`updated_at`),
KEY `idx_action_time` (`action_type`,`version_ts`),
KEY `idx_schema_version` (`table_schema_version_id`)
) ENGINE=InnoDB AUTO_INCREMENT=53 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin ROW_FORMAT=DYNAMIC COMMENT='数据分析知识片段表';

View File

@ -0,0 +1,74 @@
from __future__ import annotations
from app.services.table_profiling import _parse_completion_payload
from app.utils.llm_usage import extract_usage
def test_parse_completion_payload_handles_array_with_trailing_text() -> None:
response_payload = {
"choices": [
{
"message": {
"content": """
结果如下:
[
{"id": "snpt_a"},
{"id": "snpt_b"}
]
附加说明:模型可能会输出额外文本。
""".strip()
}
}
]
}
parsed = _parse_completion_payload(response_payload)
assert isinstance(parsed, list)
assert [item["id"] for item in parsed] == ["snpt_a", "snpt_b"]
def test_extract_usage_info_normalizes_numeric_fields() -> None:
response_payload = {
"raw": {
"usage": {
"prompt_tokens": 12.7,
"completion_tokens": 3,
"total_tokens": 15.7,
"prompt_tokens_details": {"cached_tokens": 8.9, "other": None},
"non_numeric": "ignored",
}
}
}
usage = extract_usage(response_payload)
assert usage == {
"prompt_tokens": 12,
"completion_tokens": 3,
"total_tokens": 15,
"prompt_tokens_details": {"cached_tokens": 8},
}
def test_extract_usage_handles_alias_keys() -> None:
response_payload = {
"raw": {
"usageMetadata": {
"input_tokens": 20,
"output_tokens": 4,
}
}
}
usage = extract_usage(response_payload)
assert usage == {
"prompt_tokens": 20,
"completion_tokens": 4,
"total_tokens": 24,
}
def test_extract_usage_returns_none_when_missing() -> None:
assert extract_usage({"raw": {}}) is None