数据导入分析接口调整
This commit is contained in:
@ -4,6 +4,7 @@ import csv
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from functools import lru_cache
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
@ -29,6 +30,20 @@ IMPORT_GATEWAY_BASE_URL = os.getenv(
|
||||
"IMPORT_GATEWAY_BASE_URL", "http://localhost:8000"
|
||||
)
|
||||
|
||||
|
||||
def _env_float(name: str, default: float) -> float:
|
||||
raw = os.getenv(name)
|
||||
if raw is None:
|
||||
return default
|
||||
try:
|
||||
return float(raw)
|
||||
except ValueError:
|
||||
logger.warning("Invalid value for %s=%r, falling back to %.2f", name, raw, default)
|
||||
return default
|
||||
|
||||
|
||||
IMPORT_CHAT_TIMEOUT_SECONDS = _env_float("IMPORT_CHAT_TIMEOUT_SECONDS", 90.0)
|
||||
|
||||
SUPPORTED_IMPORT_MODELS = get_supported_import_models()
|
||||
|
||||
|
||||
@ -208,7 +223,7 @@ def build_analysis_request(
|
||||
|
||||
example_data = "\n\n".join(sections) if sections else "未提供样本数据。"
|
||||
|
||||
max_length = 30_000
|
||||
max_length = 30000
|
||||
if len(example_data) > max_length:
|
||||
example_data = example_data[: max_length - 3] + "..."
|
||||
|
||||
@ -250,6 +265,44 @@ def build_chat_payload(request: DataImportAnalysisJobRequest) -> Dict[str, Any]:
|
||||
return payload
|
||||
|
||||
|
||||
def _extract_json_payload(content: str) -> str:
|
||||
"""Try to pull a JSON object from an LLM content string."""
|
||||
# Prefer fenced code blocks such as ```json { ... } ```
|
||||
fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", content, flags=re.DOTALL | re.IGNORECASE)
|
||||
if fenced:
|
||||
return fenced.group(1).strip()
|
||||
|
||||
stripped = content.strip()
|
||||
if stripped.startswith("{") and stripped.endswith("}"):
|
||||
return stripped
|
||||
|
||||
start = stripped.find("{")
|
||||
end = stripped.rfind("}")
|
||||
if start != -1 and end != -1 and end > start:
|
||||
return stripped[start : end + 1].strip()
|
||||
|
||||
return stripped
|
||||
|
||||
|
||||
def parse_llm_analysis_json(llm_response: LLMResponse) -> Dict[str, Any]:
|
||||
"""Extract and parse the structured JSON payload from an LLM response."""
|
||||
if not llm_response.choices:
|
||||
raise ProviderAPICallError("LLM response did not include any choices to parse.")
|
||||
|
||||
content = llm_response.choices[0].message.content or ""
|
||||
if not content.strip():
|
||||
raise ProviderAPICallError("LLM response content is empty.")
|
||||
|
||||
json_payload = _extract_json_payload(content)
|
||||
|
||||
try:
|
||||
return json.loads(json_payload)
|
||||
except json.JSONDecodeError as exc:
|
||||
preview = json_payload[:2000]
|
||||
logger.error("Failed to parse JSON from LLM response content: %s", preview, exc_info=True)
|
||||
raise ProviderAPICallError("LLM response JSON could not be parsed.") from exc
|
||||
|
||||
|
||||
async def dispatch_import_analysis_job(
|
||||
request: DataImportAnalysisJobRequest,
|
||||
client: httpx.AsyncClient,
|
||||
@ -259,13 +312,17 @@ async def dispatch_import_analysis_job(
|
||||
payload = build_chat_payload(request)
|
||||
url = f"{IMPORT_GATEWAY_BASE_URL.rstrip('/')}/v1/chat/completions"
|
||||
|
||||
print(
|
||||
f"[ImportAnalysis] Dispatching import {request.import_record_id} to {url}: "
|
||||
f"{json.dumps(payload, ensure_ascii=False)}"
|
||||
logger.info(
|
||||
"Dispatching import %s to %s: %s",
|
||||
request.import_record_id,
|
||||
url,
|
||||
json.dumps(payload, ensure_ascii=False),
|
||||
)
|
||||
|
||||
timeout = httpx.Timeout(IMPORT_CHAT_TIMEOUT_SECONDS)
|
||||
|
||||
try:
|
||||
response = await client.post(url, json=payload)
|
||||
response = await client.post(url, json=payload, timeout=timeout)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
body_preview = ""
|
||||
@ -284,13 +341,15 @@ async def dispatch_import_analysis_job(
|
||||
except ValueError as exc:
|
||||
raise ProviderAPICallError("Chat completions endpoint returned invalid JSON") from exc
|
||||
|
||||
print(
|
||||
f"[ImportAnalysis] LLM HTTP status for {request.import_record_id}: "
|
||||
f"{response.status_code}"
|
||||
logger.info(
|
||||
"LLM HTTP status for %s: %s",
|
||||
request.import_record_id,
|
||||
response.status_code,
|
||||
)
|
||||
print(
|
||||
f"[ImportAnalysis] LLM response for {request.import_record_id}: "
|
||||
f"{json.dumps(response_data, ensure_ascii=False)}"
|
||||
logger.info(
|
||||
"LLM response for %s: %s",
|
||||
request.import_record_id,
|
||||
json.dumps(response_data, ensure_ascii=False),
|
||||
)
|
||||
|
||||
try:
|
||||
@ -300,13 +359,33 @@ async def dispatch_import_analysis_job(
|
||||
"Chat completions endpoint returned unexpected schema"
|
||||
) from exc
|
||||
|
||||
structured_json = parse_llm_analysis_json(llm_response)
|
||||
usage_data = extract_usage(llm_response.raw)
|
||||
|
||||
logger.info("Completed import analysis job %s", request.import_record_id)
|
||||
return {
|
||||
result: Dict[str, Any] = {
|
||||
"import_record_id": request.import_record_id,
|
||||
"status": "succeeded",
|
||||
"llm_response": llm_response.model_dump(),
|
||||
"analysis": structured_json
|
||||
}
|
||||
|
||||
if usage_data:
|
||||
result["usage"] = usage_data
|
||||
|
||||
return result
|
||||
|
||||
# 兼容处理多模型的使用量字段提取
|
||||
def extract_usage(resp_json: dict) -> dict:
|
||||
usage = resp_json.get("usage") or resp_json.get("usageMetadata") or {}
|
||||
return {
|
||||
"prompt_tokens": usage.get("prompt_tokens") or usage.get("input_tokens") or usage.get("promptTokenCount"),
|
||||
"completion_tokens": usage.get("completion_tokens") or usage.get("output_tokens") or usage.get("candidatesTokenCount"),
|
||||
"total_tokens": usage.get("total_tokens") or usage.get("totalTokenCount") or (
|
||||
(usage.get("prompt_tokens") or usage.get("input_tokens") or 0)
|
||||
+ (usage.get("completion_tokens") or usage.get("output_tokens") or 0)
|
||||
)
|
||||
}
|
||||
|
||||
async def notify_import_analysis_callback(
|
||||
callback_url: str,
|
||||
@ -314,6 +393,12 @@ async def notify_import_analysis_callback(
|
||||
client: httpx.AsyncClient,
|
||||
) -> None:
|
||||
callback_target = str(callback_url)
|
||||
logger.info(
|
||||
"Posting import analysis callback to %s: %s",
|
||||
callback_target,
|
||||
json.dumps(payload, ensure_ascii=False),
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
response = await client.post(callback_target, json=payload)
|
||||
@ -333,8 +418,10 @@ async def process_import_analysis_job(
|
||||
try:
|
||||
payload = await dispatch_import_analysis_job(request, client)
|
||||
except ProviderAPICallError as exc:
|
||||
print(
|
||||
f"[ImportAnalysis] LLM call failed for {request.import_record_id}: {exc}"
|
||||
logger.error(
|
||||
"LLM call failed for %s: %s",
|
||||
request.import_record_id,
|
||||
exc,
|
||||
)
|
||||
payload = {
|
||||
"import_record_id": request.import_record_id,
|
||||
@ -346,9 +433,6 @@ async def process_import_analysis_job(
|
||||
"Unexpected failure while processing import analysis job %s",
|
||||
request.import_record_id,
|
||||
)
|
||||
print(
|
||||
f"[ImportAnalysis] Unexpected error for {request.import_record_id}: {exc}"
|
||||
)
|
||||
payload = {
|
||||
"import_record_id": request.import_record_id,
|
||||
"status": "failed",
|
||||
|
||||
Reference in New Issue
Block a user