数据知识回调入库

This commit is contained in:
zhaoawd
2025-11-04 20:28:50 +08:00
parent 0b765e6719
commit 7eb3c059a1
7 changed files with 346 additions and 111 deletions

View File

@ -26,6 +26,7 @@ from app.services.import_analysis import (
IMPORT_GATEWAY_BASE_URL,
resolve_provider_from_model,
)
from app.utils.llm_usage import extract_usage as extract_llm_usage
logger = logging.getLogger(__name__)
@ -37,7 +38,7 @@ PROMPT_FILENAMES = {
"snippet_generator": "snippet_generator.md",
"snippet_alias": "snippet_alias_generator.md",
}
DEFAULT_CHAT_TIMEOUT_SECONDS = 90.0
DEFAULT_CHAT_TIMEOUT_SECONDS = 180.0
@dataclass
@ -47,6 +48,12 @@ class GEProfilingArtifacts:
docs_path: str
@dataclass
class LLMCallResult:
data: Any
usage: Optional[Dict[str, Any]] = None
class PipelineActionType:
GE_PROFILING = "ge_profiling"
GE_RESULT_DESC = "ge_result_desc"
@ -124,11 +131,16 @@ def _extract_json_payload(content: str) -> str:
if not stripped:
raise ValueError("Empty LLM content.")
for opener, closer in (("{", "}"), ("[", "]")):
start = stripped.find(opener)
end = stripped.rfind(closer)
if start != -1 and end != -1 and end > start:
candidate = stripped[start : end + 1].strip()
decoder = json.JSONDecoder()
for idx, char in enumerate(stripped):
if char not in {"{", "["}:
continue
try:
_, end = decoder.raw_decode(stripped[idx:])
except json.JSONDecodeError:
continue
candidate = stripped[idx : idx + end].strip()
if candidate:
return candidate
return stripped
@ -559,7 +571,9 @@ async def _call_chat_completions(
except ValueError as exc:
raise ProviderAPICallError("Chat completions response was not valid JSON.") from exc
return _parse_completion_payload(response_payload)
parsed_payload = _parse_completion_payload(response_payload)
usage_info = extract_llm_usage(response_payload)
return LLMCallResult(data=parsed_payload, usage=usage_info)
def _normalize_for_json(value: Any) -> Any:
@ -628,7 +642,7 @@ async def _execute_result_desc(
client=client,
timeout_seconds=timeout_seconds,
)
if not isinstance(llm_output, dict):
if not isinstance(llm_output.data, dict):
raise ProviderAPICallError("GE result description payload must be a JSON object.")
return llm_output
@ -651,7 +665,7 @@ async def _execute_snippet_generation(
client=client,
timeout_seconds=timeout_seconds,
)
if not isinstance(llm_output, list):
if not isinstance(llm_output.data, list):
raise ProviderAPICallError("Snippet generator must return a JSON array.")
return llm_output
@ -674,7 +688,7 @@ async def _execute_snippet_alias(
client=client,
timeout_seconds=timeout_seconds,
)
if not isinstance(llm_output, list):
if not isinstance(llm_output.data, list):
raise ProviderAPICallError("Snippet alias generator must return a JSON array.")
return llm_output
@ -711,6 +725,12 @@ async def _run_action_with_callback(
await _post_callback(callback_url, failure_payload, client)
raise
usage_info: Optional[Dict[str, Any]] = None
result_payload = result
if isinstance(result, LLMCallResult):
usage_info = result.usage
result_payload = result.data
success_payload = dict(callback_base)
success_payload.update(
{
@ -724,23 +744,26 @@ async def _run_action_with_callback(
logger.info(
"Pipeline action %s output: %s",
action_type,
_preview_for_log(result),
_preview_for_log(result_payload),
)
if action_type == PipelineActionType.GE_PROFILING:
artifacts: GEProfilingArtifacts = result
success_payload["profiling_json"] = artifacts.profiling_result
success_payload["profiling_summary"] = artifacts.profiling_summary
artifacts: GEProfilingArtifacts = result_payload
success_payload["ge_profiling_json"] = artifacts.profiling_result
success_payload["ge_profiling_summary"] = artifacts.profiling_summary
success_payload["ge_report_path"] = artifacts.docs_path
elif action_type == PipelineActionType.GE_RESULT_DESC:
success_payload["table_desc_json"] = result
success_payload["ge_result_desc_json"] = result_payload
elif action_type == PipelineActionType.SNIPPET:
success_payload["snippet_json"] = result
success_payload["snippet_json"] = result_payload
elif action_type == PipelineActionType.SNIPPET_ALIAS:
success_payload["snippet_alias_json"] = result
success_payload["snippet_alias_json"] = result_payload
if usage_info:
success_payload["llm_usage"] = usage_info
await _post_callback(callback_url, success_payload, client)
return result
return result_payload
async def process_table_profiling_job(