diff --git a/app/models.py b/app/models.py index f5f9280..9bae3d0 100644 --- a/app/models.py +++ b/app/models.py @@ -254,17 +254,45 @@ class TableSnippetUpsertRequest(BaseModel): callback_url: HttpUrl = Field(..., description="Callback URL associated with the action run.") table_schema_version_id: int = Field(..., ge=0, description="Identifier for the schema snapshot.") table_schema: Any = Field(..., description="Schema snapshot payload for the table.") - result_json: Optional[Any] = Field( + llm_usage: Optional[Any] = Field( None, - description="Primary result payload for the action (e.g., profiling output, snippet array).", + description="Optional token usage metrics reported by the LLM provider.", ) - result_summary_json: Optional[Any] = Field( - None, - description="Optional summary payload (e.g., profiling summary) for the action.", + ge_profiling_json: Optional[Any] = Field( + None, description="Full GE profiling result payload for the profiling action." ) - html_report_url: Optional[str] = Field( - None, - description="Optional HTML report URL generated by the action.", + ge_profiling_json_size_bytes: Optional[int] = Field( + None, ge=0, description="Size in bytes of the GE profiling result JSON." + ) + ge_profiling_summary: Optional[Any] = Field( + None, description="Sanitised GE profiling summary payload." + ) + ge_profiling_summary_size_bytes: Optional[int] = Field( + None, ge=0, description="Size in bytes of the GE profiling summary JSON." + ) + ge_profiling_total_size_bytes: Optional[int] = Field( + None, ge=0, description="Combined size (bytes) of profiling result + summary." + ) + ge_profiling_html_report_url: Optional[str] = Field( + None, description="Optional URL to the generated GE profiling HTML report." + ) + ge_result_desc_json: Optional[Any] = Field( + None, description="Result JSON for the GE result description action." + ) + ge_result_desc_json_size_bytes: Optional[int] = Field( + None, ge=0, description="Size in bytes of the GE result description JSON." + ) + snippet_json: Optional[Any] = Field( + None, description="Snippet generation action result JSON." + ) + snippet_json_size_bytes: Optional[int] = Field( + None, ge=0, description="Size in bytes of the snippet result JSON." + ) + snippet_alias_json: Optional[Any] = Field( + None, description="Snippet alias expansion result JSON." + ) + snippet_alias_json_size_bytes: Optional[int] = Field( + None, ge=0, description="Size in bytes of the snippet alias result JSON." ) error_code: Optional[str] = Field(None, description="Optional error code when status indicates a failure.") error_message: Optional[str] = Field(None, description="Optional error message when status indicates a failure.") diff --git a/app/services/import_analysis.py b/app/services/import_analysis.py index 0aef6fc..c9a55f2 100644 --- a/app/services/import_analysis.py +++ b/app/services/import_analysis.py @@ -23,6 +23,7 @@ from app.models import ( LLMRole, ) from app.settings import DEFAULT_IMPORT_MODEL, get_supported_import_models +from app.utils.llm_usage import extract_usage logger = logging.getLogger(__name__) @@ -375,18 +376,6 @@ async def dispatch_import_analysis_job( return result -# 兼容处理多模型的使用量字段提取 -def extract_usage(resp_json: dict) -> dict: - usage = resp_json.get("usage") or resp_json.get("usageMetadata") or {} - return { - "prompt_tokens": usage.get("prompt_tokens") or usage.get("input_tokens") or usage.get("promptTokenCount"), - "completion_tokens": usage.get("completion_tokens") or usage.get("output_tokens") or usage.get("candidatesTokenCount"), - "total_tokens": usage.get("total_tokens") or usage.get("totalTokenCount") or ( - (usage.get("prompt_tokens") or usage.get("input_tokens") or 0) - + (usage.get("completion_tokens") or usage.get("output_tokens") or 0) - ) - } - async def notify_import_analysis_callback( callback_url: str, payload: Dict[str, Any], diff --git a/app/services/table_profiling.py b/app/services/table_profiling.py index ea6cc04..f412238 100644 --- a/app/services/table_profiling.py +++ b/app/services/table_profiling.py @@ -26,6 +26,7 @@ from app.services.import_analysis import ( IMPORT_GATEWAY_BASE_URL, resolve_provider_from_model, ) +from app.utils.llm_usage import extract_usage as extract_llm_usage logger = logging.getLogger(__name__) @@ -37,7 +38,7 @@ PROMPT_FILENAMES = { "snippet_generator": "snippet_generator.md", "snippet_alias": "snippet_alias_generator.md", } -DEFAULT_CHAT_TIMEOUT_SECONDS = 90.0 +DEFAULT_CHAT_TIMEOUT_SECONDS = 180.0 @dataclass @@ -47,6 +48,12 @@ class GEProfilingArtifacts: docs_path: str +@dataclass +class LLMCallResult: + data: Any + usage: Optional[Dict[str, Any]] = None + + class PipelineActionType: GE_PROFILING = "ge_profiling" GE_RESULT_DESC = "ge_result_desc" @@ -124,11 +131,16 @@ def _extract_json_payload(content: str) -> str: if not stripped: raise ValueError("Empty LLM content.") - for opener, closer in (("{", "}"), ("[", "]")): - start = stripped.find(opener) - end = stripped.rfind(closer) - if start != -1 and end != -1 and end > start: - candidate = stripped[start : end + 1].strip() + decoder = json.JSONDecoder() + for idx, char in enumerate(stripped): + if char not in {"{", "["}: + continue + try: + _, end = decoder.raw_decode(stripped[idx:]) + except json.JSONDecodeError: + continue + candidate = stripped[idx : idx + end].strip() + if candidate: return candidate return stripped @@ -559,7 +571,9 @@ async def _call_chat_completions( except ValueError as exc: raise ProviderAPICallError("Chat completions response was not valid JSON.") from exc - return _parse_completion_payload(response_payload) + parsed_payload = _parse_completion_payload(response_payload) + usage_info = extract_llm_usage(response_payload) + return LLMCallResult(data=parsed_payload, usage=usage_info) def _normalize_for_json(value: Any) -> Any: @@ -628,7 +642,7 @@ async def _execute_result_desc( client=client, timeout_seconds=timeout_seconds, ) - if not isinstance(llm_output, dict): + if not isinstance(llm_output.data, dict): raise ProviderAPICallError("GE result description payload must be a JSON object.") return llm_output @@ -651,7 +665,7 @@ async def _execute_snippet_generation( client=client, timeout_seconds=timeout_seconds, ) - if not isinstance(llm_output, list): + if not isinstance(llm_output.data, list): raise ProviderAPICallError("Snippet generator must return a JSON array.") return llm_output @@ -674,7 +688,7 @@ async def _execute_snippet_alias( client=client, timeout_seconds=timeout_seconds, ) - if not isinstance(llm_output, list): + if not isinstance(llm_output.data, list): raise ProviderAPICallError("Snippet alias generator must return a JSON array.") return llm_output @@ -711,6 +725,12 @@ async def _run_action_with_callback( await _post_callback(callback_url, failure_payload, client) raise + usage_info: Optional[Dict[str, Any]] = None + result_payload = result + if isinstance(result, LLMCallResult): + usage_info = result.usage + result_payload = result.data + success_payload = dict(callback_base) success_payload.update( { @@ -724,23 +744,26 @@ async def _run_action_with_callback( logger.info( "Pipeline action %s output: %s", action_type, - _preview_for_log(result), + _preview_for_log(result_payload), ) if action_type == PipelineActionType.GE_PROFILING: - artifacts: GEProfilingArtifacts = result - success_payload["profiling_json"] = artifacts.profiling_result - success_payload["profiling_summary"] = artifacts.profiling_summary + artifacts: GEProfilingArtifacts = result_payload + success_payload["ge_profiling_json"] = artifacts.profiling_result + success_payload["ge_profiling_summary"] = artifacts.profiling_summary success_payload["ge_report_path"] = artifacts.docs_path elif action_type == PipelineActionType.GE_RESULT_DESC: - success_payload["table_desc_json"] = result + success_payload["ge_result_desc_json"] = result_payload elif action_type == PipelineActionType.SNIPPET: - success_payload["snippet_json"] = result + success_payload["snippet_json"] = result_payload elif action_type == PipelineActionType.SNIPPET_ALIAS: - success_payload["snippet_alias_json"] = result + success_payload["snippet_alias_json"] = result_payload + + if usage_info: + success_payload["llm_usage"] = usage_info await _post_callback(callback_url, success_payload, client) - return result + return result_payload async def process_table_profiling_job( diff --git a/app/services/table_snippet.py b/app/services/table_snippet.py index e7e2c95..be72b0e 100644 --- a/app/services/table_snippet.py +++ b/app/services/table_snippet.py @@ -55,6 +55,28 @@ def _collect_common_columns(request: TableSnippetUpsertRequest) -> Dict[str, Any "table_schema": _prepare_table_schema(request.table_schema), } + payload.update( + { + "ge_profiling_json": None, + "ge_profiling_json_size_bytes": None, + "ge_profiling_summary": None, + "ge_profiling_summary_size_bytes": None, + "ge_profiling_total_size_bytes": None, + "ge_profiling_html_report_url": None, + "ge_result_desc_json": None, + "ge_result_desc_json_size_bytes": None, + "snippet_json": None, + "snippet_json_size_bytes": None, + "snippet_alias_json": None, + "snippet_alias_json_size_bytes": None, + } + ) + + if request.llm_usage is not None: + llm_usage_json, _ = _serialize_json(request.llm_usage) + if llm_usage_json is not None: + payload["llm_usage"] = llm_usage_json + if request.error_code is not None: logger.debug("Adding error_code: %s", request.error_code) payload["error_code"] = request.error_code @@ -80,35 +102,35 @@ def _apply_action_payload( ) -> None: logger.debug("Applying action-specific payload for action_type=%s", request.action_type) if request.action_type == ActionType.GE_PROFILING: - full_json, full_size = _serialize_json(request.result_json) - summary_json, summary_size = _serialize_json(request.result_summary_json) + full_json, full_size = _serialize_json(request.ge_profiling_json) + summary_json, summary_size = _serialize_json(request.ge_profiling_summary) if full_json is not None: - payload["ge_profiling_full"] = full_json - payload["ge_profiling_full_size_bytes"] = full_size + payload["ge_profiling_json"] = full_json + payload["ge_profiling_json_size_bytes"] = full_size if summary_json is not None: payload["ge_profiling_summary"] = summary_json payload["ge_profiling_summary_size_bytes"] = summary_size - if full_size is not None or summary_size is not None: - payload["ge_profiling_total_size_bytes"] = (full_size or 0) + ( - summary_size or 0 - ) - if request.html_report_url: - payload["ge_profiling_html_report_url"] = request.html_report_url + if request.ge_profiling_total_size_bytes is not None: + payload["ge_profiling_total_size_bytes"] = request.ge_profiling_total_size_bytes + elif full_size is not None or summary_size is not None: + payload["ge_profiling_total_size_bytes"] = (full_size or 0) + (summary_size or 0) + if request.ge_profiling_html_report_url: + payload["ge_profiling_html_report_url"] = request.ge_profiling_html_report_url elif request.action_type == ActionType.GE_RESULT_DESC: - full_json, full_size = _serialize_json(request.result_json) + full_json, full_size = _serialize_json(request.ge_result_desc_json) if full_json is not None: - payload["ge_result_desc_full"] = full_json - payload["ge_result_desc_full_size_bytes"] = full_size + payload["ge_result_desc_json"] = full_json + payload["ge_result_desc_json_size_bytes"] = full_size elif request.action_type == ActionType.SNIPPET: - full_json, full_size = _serialize_json(request.result_json) + full_json, full_size = _serialize_json(request.snippet_json) if full_json is not None: - payload["snippet_full"] = full_json - payload["snippet_full_size_bytes"] = full_size + payload["snippet_json"] = full_json + payload["snippet_json_size_bytes"] = full_size elif request.action_type == ActionType.SNIPPET_ALIAS: - full_json, full_size = _serialize_json(request.result_json) + full_json, full_size = _serialize_json(request.snippet_alias_json) if full_json is not None: - payload["snippet_alias_full"] = full_json - payload["snippet_alias_full_size_bytes"] = full_size + payload["snippet_alias_json"] = full_json + payload["snippet_alias_json_size_bytes"] = full_size else: logger.error("Unsupported action type encountered: %s", request.action_type) raise ValueError(f"Unsupported action type '{request.action_type}'.") diff --git a/app/utils/llm_usage.py b/app/utils/llm_usage.py new file mode 100644 index 0000000..4424e8f --- /dev/null +++ b/app/utils/llm_usage.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +from typing import Any, Dict, Iterable, Optional + + +PROMPT_TOKEN_KEYS: tuple[str, ...] = ("prompt_tokens", "input_tokens", "promptTokenCount") +COMPLETION_TOKEN_KEYS: tuple[str, ...] = ( + "completion_tokens", + "output_tokens", + "candidatesTokenCount", +) +TOTAL_TOKEN_KEYS: tuple[str, ...] = ("total_tokens", "totalTokenCount") +USAGE_CONTAINER_KEYS: tuple[str, ...] = ("usage", "usageMetadata", "usage_metadata") + + +def _normalize_usage_value(value: Any) -> Any: + if isinstance(value, (int, float)): + return int(value) + + if isinstance(value, str): + stripped = value.strip() + if not stripped: + return None + try: + numeric = float(stripped) + except ValueError: + return None + return int(numeric) + + if isinstance(value, dict): + normalized: Dict[str, Any] = {} + for key, nested_value in value.items(): + normalized_value = _normalize_usage_value(nested_value) + if normalized_value is not None: + normalized[key] = normalized_value + return normalized or None + + if isinstance(value, (list, tuple, set)): + normalized_list = [ + item for item in (_normalize_usage_value(element) for element in value) if item is not None + ] + return normalized_list or None + + return None + + +def _first_numeric(payload: Dict[str, Any], keys: Iterable[str]) -> Optional[int]: + for key in keys: + value = payload.get(key) + if isinstance(value, (int, float)): + return int(value) + return None + + +def _canonicalize_counts(payload: Dict[str, Any]) -> None: + prompt = _first_numeric(payload, PROMPT_TOKEN_KEYS) + completion = _first_numeric(payload, COMPLETION_TOKEN_KEYS) + total = _first_numeric(payload, TOTAL_TOKEN_KEYS) + + if prompt is not None: + payload["prompt_tokens"] = prompt + else: + payload.pop("prompt_tokens", None) + + if completion is not None: + payload["completion_tokens"] = completion + else: + payload.pop("completion_tokens", None) + + if total is not None: + payload["total_tokens"] = total + elif prompt is not None and completion is not None: + payload["total_tokens"] = prompt + completion + else: + payload.pop("total_tokens", None) + + for alias in PROMPT_TOKEN_KEYS[1:]: + payload.pop(alias, None) + for alias in COMPLETION_TOKEN_KEYS[1:]: + payload.pop(alias, None) + for alias in TOTAL_TOKEN_KEYS[1:]: + payload.pop(alias, None) + + +def _extract_usage_container(candidate: Any) -> Optional[Dict[str, Any]]: + if not isinstance(candidate, dict): + return None + for key in USAGE_CONTAINER_KEYS: + value = candidate.get(key) + if isinstance(value, dict): + return value + return None + + +def extract_usage(payload: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Unified helper to parse token usage metadata from diverse provider responses.""" + if not isinstance(payload, dict): + return None + + usage_candidate = _extract_usage_container(payload) + if usage_candidate is None: + raw_section = payload.get("raw") + usage_candidate = _extract_usage_container(raw_section) + + if usage_candidate is None: + return None + + normalized = _normalize_usage_value(usage_candidate) + if not isinstance(normalized, dict): + return None + + _canonicalize_counts(normalized) + return normalized or None + + +__all__ = ["extract_usage"] diff --git a/table_snippet.sql b/table_snippet.sql index b9fb19b..a4540ea 100644 --- a/table_snippet.sql +++ b/table_snippet.sql @@ -1,54 +1,37 @@ -CREATE TABLE IF NOT EXISTS action_results ( - id BIGINT NOT NULL AUTO_INCREMENT COMMENT '主键', - table_id BIGINT NOT NULL COMMENT '表ID', - version_ts BIGINT NOT NULL COMMENT '版本时间戳(版本号)', - action_type ENUM('ge_profiling','ge_result_desc','snippet','snippet_alias') NOT NULL COMMENT '动作类型', - - status ENUM('pending','running','success','failed','partial') NOT NULL DEFAULT 'pending' COMMENT '执行状态', - error_code VARCHAR(128) NULL, - error_message TEXT NULL, - - -- 回调 & 观测 - callback_url VARCHAR(1024) NOT NULL, - started_at DATETIME NULL, - finished_at DATETIME NULL, - duration_ms INT NULL, - - -- 本次schema信息 - table_schema_version_id BIGINT NOT NULL, - table_schema JSON NOT NULL, - - -- ===== 动作1:GE Profiling ===== - ge_profiling_full JSON NULL COMMENT 'Profiling完整结果JSON', - ge_profiling_full_size_bytes BIGINT NULL, - ge_profiling_summary JSON NULL COMMENT 'Profiling摘要(剔除大value_set等)', - ge_profiling_summary_size_bytes BIGINT NULL, - ge_profiling_total_size_bytes BIGINT NULL COMMENT '上两者合计', - ge_profiling_html_report_url VARCHAR(1024) NULL COMMENT 'GE报告HTML路径/URL', - - -- ===== 动作2:GE Result Desc ===== - ge_result_desc_full JSON NULL COMMENT '表描述结果JSON', - ge_result_desc_full_size_bytes BIGINT NULL, - - -- ===== 动作3:Snippet 生成 ===== - snippet_full JSON NULL COMMENT 'SQL知识片段结果JSON', - snippet_full_size_bytes BIGINT NULL, - - -- ===== 动作4:Snippet Alias 改写 ===== - snippet_alias_full JSON NULL COMMENT 'SQL片段改写/丰富结果JSON', - snippet_alias_full_size_bytes BIGINT NULL, - - -- 通用可选指标 - result_checksum VARBINARY(32) NULL COMMENT '对当前action有效载荷计算的MD5/xxhash', - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, - - PRIMARY KEY (id), - UNIQUE KEY uq_table_ver_action (table_id, version_ts, action_type), - KEY idx_status (status), - KEY idx_table (table_id, updated_at), - KEY idx_action_time (action_type, version_ts), - KEY idx_schema_version (table_schema_version_id) -) ENGINE=InnoDB - ROW_FORMAT=DYNAMIC - COMMENT='数据分析知识片段表'; +CREATE TABLE `action_results` ( + `id` bigint NOT NULL AUTO_INCREMENT COMMENT '主键', + `table_id` bigint NOT NULL COMMENT '表ID', + `version_ts` bigint NOT NULL COMMENT '版本时间戳(版本号)', + `action_type` enum('ge_profiling','ge_result_desc','snippet','snippet_alias') COLLATE utf8mb4_bin NOT NULL COMMENT '动作类型', + `status` enum('pending','running','success','failed','partial') COLLATE utf8mb4_bin NOT NULL DEFAULT 'pending' COMMENT '执行状态', + `llm_usage` json DEFAULT NULL COMMENT 'LLM token usage统计', + `error_code` varchar(128) COLLATE utf8mb4_bin DEFAULT NULL, + `error_message` text COLLATE utf8mb4_bin, + `started_at` datetime DEFAULT NULL, + `finished_at` datetime DEFAULT NULL, + `duration_ms` int DEFAULT NULL, + `table_schema_version_id` varchar(19) COLLATE utf8mb4_bin NOT NULL, + `table_schema` json NOT NULL, + `ge_profiling_json` json DEFAULT NULL COMMENT 'Profiling完整结果JSON', + `ge_profiling_json_size_bytes` bigint DEFAULT NULL, + `ge_profiling_summary` json DEFAULT NULL COMMENT 'Profiling摘要(剔除大value_set等)', + `ge_profiling_summary_size_bytes` bigint DEFAULT NULL, + `ge_profiling_total_size_bytes` bigint DEFAULT NULL COMMENT '上两者合计', + `ge_profiling_html_report_url` varchar(1024) COLLATE utf8mb4_bin DEFAULT NULL COMMENT 'GE报告HTML路径/URL', + `ge_result_desc_json` json DEFAULT NULL COMMENT '表描述结果JSON', + `ge_result_desc_json_size_bytes` bigint DEFAULT NULL, + `snippet_json` json DEFAULT NULL COMMENT 'SQL知识片段结果JSON', + `snippet_json_size_bytes` bigint DEFAULT NULL, + `snippet_alias_json` json DEFAULT NULL COMMENT 'SQL片段改写/丰富结果JSON', + `snippet_alias_json_size_bytes` bigint DEFAULT NULL, + `callback_url` varchar(1024) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NOT NULL, + `result_checksum` varbinary(32) DEFAULT NULL COMMENT '对当前action有效载荷计算的MD5/xxhash', + `created_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, + `updated_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `uq_table_ver_action` (`table_id`,`version_ts`,`action_type`), + KEY `idx_status` (`status`), + KEY `idx_table` (`table_id`,`updated_at`), + KEY `idx_action_time` (`action_type`,`version_ts`), + KEY `idx_schema_version` (`table_schema_version_id`) +) ENGINE=InnoDB AUTO_INCREMENT=53 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin ROW_FORMAT=DYNAMIC COMMENT='数据分析知识片段表'; \ No newline at end of file diff --git a/test/test_table_profiling_parsing.py b/test/test_table_profiling_parsing.py new file mode 100644 index 0000000..19a6758 --- /dev/null +++ b/test/test_table_profiling_parsing.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from app.services.table_profiling import _parse_completion_payload +from app.utils.llm_usage import extract_usage + + +def test_parse_completion_payload_handles_array_with_trailing_text() -> None: + response_payload = { + "choices": [ + { + "message": { + "content": """ +结果如下: +[ + {"id": "snpt_a"}, + {"id": "snpt_b"} +] +附加说明:模型可能会输出额外文本。 +""".strip() + } + } + ] + } + + parsed = _parse_completion_payload(response_payload) + + assert isinstance(parsed, list) + assert [item["id"] for item in parsed] == ["snpt_a", "snpt_b"] + + +def test_extract_usage_info_normalizes_numeric_fields() -> None: + response_payload = { + "raw": { + "usage": { + "prompt_tokens": 12.7, + "completion_tokens": 3, + "total_tokens": 15.7, + "prompt_tokens_details": {"cached_tokens": 8.9, "other": None}, + "non_numeric": "ignored", + } + } + } + + usage = extract_usage(response_payload) + + assert usage == { + "prompt_tokens": 12, + "completion_tokens": 3, + "total_tokens": 15, + "prompt_tokens_details": {"cached_tokens": 8}, + } + + +def test_extract_usage_handles_alias_keys() -> None: + response_payload = { + "raw": { + "usageMetadata": { + "input_tokens": 20, + "output_tokens": 4, + } + } + } + + usage = extract_usage(response_payload) + + assert usage == { + "prompt_tokens": 20, + "completion_tokens": 4, + "total_tokens": 24, + } + + +def test_extract_usage_returns_none_when_missing() -> None: + assert extract_usage({"raw": {}}) is None