数据导入分析接口调整

This commit is contained in:
zhaoawd
2025-10-30 22:38:05 +08:00
parent 39911d78ab
commit 455b884551
6 changed files with 141 additions and 29 deletions

15
.env
View File

@ -1,12 +1,13 @@
# LLM provider API keys
OPENAI_API_KEY=
ANTHROPIC_API_KEY=
OPENROUTER_API_KEY=
OPENROUTER_API_KEY="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"
OPENROUTER_SITE_URL=
OPENROUTER_APP_NAME=
GEMINI_API_KEY=
QWEN_API_KEY=
DEEPSEEK_API_KEY="sk-657f0752a1564563be7ce35b6a0a7b46"
DEEPSEEK_TIMEOUT_SECONDS=120
# Data import analysis defaults
IMPORT_SUPPORTED_MODELS=openai:gpt-5,deepseek:deepseek-chat,openrouter:anthropic/claude-4.0-sonnet
@ -14,3 +15,15 @@ DEFAULT_IMPORT_MODEL=deepseek:deepseek-chat
# Service configuration
IMPORT_GATEWAY_BASE_URL=http://localhost:8000
# HTTP client configuration
HTTP_CLIENT_TIMEOUT=30
HTTP_CLIENT_TRUST_ENV=false
# HTTP_CLIENT_PROXY=
# Import analysis configuration
IMPORT_CHAT_TIMEOUT_SECONDS=120
# Logging
LOG_LEVEL=INFO
# LOG_FORMAT=%(asctime)s %(levelname)s %(name)s:%(lineno)d %(message)s

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import logging
import os
from typing import Any, Dict, List
import httpx
@ -13,6 +14,23 @@ from app.providers.base import LLMProviderClient
logger = logging.getLogger(__name__)
def _resolve_timeout_seconds() -> float:
raw = os.getenv("DEEPSEEK_TIMEOUT_SECONDS")
if raw is None:
return 60.0
try:
return float(raw)
except ValueError:
logger.warning(
"Invalid value for DEEPSEEK_TIMEOUT_SECONDS=%r, falling back to 60 seconds",
raw,
)
return 60.0
DEEPSEEK_TIMEOUT_SECONDS = _resolve_timeout_seconds()
class DeepSeekProvider(LLMProviderClient):
name = LLMProvider.DEEPSEEK.value
api_key_env = "DEEPSEEK_API_KEY"
@ -40,9 +58,12 @@ class DeepSeekProvider(LLMProviderClient):
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
timeout = httpx.Timeout(DEEPSEEK_TIMEOUT_SECONDS)
try:
response = await client.post(self.base_url, json=payload, headers=headers)
response = await client.post(
self.base_url, json=payload, headers=headers, timeout=timeout
)
response.raise_for_status()
except httpx.HTTPStatusError as exc:
status_code = exc.response.status_code

View File

@ -4,6 +4,7 @@ import csv
import json
import logging
import os
import re
from functools import lru_cache
from io import StringIO
from pathlib import Path
@ -29,6 +30,20 @@ IMPORT_GATEWAY_BASE_URL = os.getenv(
"IMPORT_GATEWAY_BASE_URL", "http://localhost:8000"
)
def _env_float(name: str, default: float) -> float:
raw = os.getenv(name)
if raw is None:
return default
try:
return float(raw)
except ValueError:
logger.warning("Invalid value for %s=%r, falling back to %.2f", name, raw, default)
return default
IMPORT_CHAT_TIMEOUT_SECONDS = _env_float("IMPORT_CHAT_TIMEOUT_SECONDS", 90.0)
SUPPORTED_IMPORT_MODELS = get_supported_import_models()
@ -208,7 +223,7 @@ def build_analysis_request(
example_data = "\n\n".join(sections) if sections else "未提供样本数据。"
max_length = 30_000
max_length = 30000
if len(example_data) > max_length:
example_data = example_data[: max_length - 3] + "..."
@ -250,6 +265,44 @@ def build_chat_payload(request: DataImportAnalysisJobRequest) -> Dict[str, Any]:
return payload
def _extract_json_payload(content: str) -> str:
"""Try to pull a JSON object from an LLM content string."""
# Prefer fenced code blocks such as ```json { ... } ```
fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", content, flags=re.DOTALL | re.IGNORECASE)
if fenced:
return fenced.group(1).strip()
stripped = content.strip()
if stripped.startswith("{") and stripped.endswith("}"):
return stripped
start = stripped.find("{")
end = stripped.rfind("}")
if start != -1 and end != -1 and end > start:
return stripped[start : end + 1].strip()
return stripped
def parse_llm_analysis_json(llm_response: LLMResponse) -> Dict[str, Any]:
"""Extract and parse the structured JSON payload from an LLM response."""
if not llm_response.choices:
raise ProviderAPICallError("LLM response did not include any choices to parse.")
content = llm_response.choices[0].message.content or ""
if not content.strip():
raise ProviderAPICallError("LLM response content is empty.")
json_payload = _extract_json_payload(content)
try:
return json.loads(json_payload)
except json.JSONDecodeError as exc:
preview = json_payload[:2000]
logger.error("Failed to parse JSON from LLM response content: %s", preview, exc_info=True)
raise ProviderAPICallError("LLM response JSON could not be parsed.") from exc
async def dispatch_import_analysis_job(
request: DataImportAnalysisJobRequest,
client: httpx.AsyncClient,
@ -259,13 +312,17 @@ async def dispatch_import_analysis_job(
payload = build_chat_payload(request)
url = f"{IMPORT_GATEWAY_BASE_URL.rstrip('/')}/v1/chat/completions"
print(
f"[ImportAnalysis] Dispatching import {request.import_record_id} to {url}: "
f"{json.dumps(payload, ensure_ascii=False)}"
logger.info(
"Dispatching import %s to %s: %s",
request.import_record_id,
url,
json.dumps(payload, ensure_ascii=False),
)
timeout = httpx.Timeout(IMPORT_CHAT_TIMEOUT_SECONDS)
try:
response = await client.post(url, json=payload)
response = await client.post(url, json=payload, timeout=timeout)
response.raise_for_status()
except httpx.HTTPStatusError as exc:
body_preview = ""
@ -284,13 +341,15 @@ async def dispatch_import_analysis_job(
except ValueError as exc:
raise ProviderAPICallError("Chat completions endpoint returned invalid JSON") from exc
print(
f"[ImportAnalysis] LLM HTTP status for {request.import_record_id}: "
f"{response.status_code}"
logger.info(
"LLM HTTP status for %s: %s",
request.import_record_id,
response.status_code,
)
print(
f"[ImportAnalysis] LLM response for {request.import_record_id}: "
f"{json.dumps(response_data, ensure_ascii=False)}"
logger.info(
"LLM response for %s: %s",
request.import_record_id,
json.dumps(response_data, ensure_ascii=False),
)
try:
@ -300,13 +359,33 @@ async def dispatch_import_analysis_job(
"Chat completions endpoint returned unexpected schema"
) from exc
structured_json = parse_llm_analysis_json(llm_response)
usage_data = extract_usage(llm_response.raw)
logger.info("Completed import analysis job %s", request.import_record_id)
return {
result: Dict[str, Any] = {
"import_record_id": request.import_record_id,
"status": "succeeded",
"llm_response": llm_response.model_dump(),
"analysis": structured_json
}
if usage_data:
result["usage"] = usage_data
return result
# 兼容处理多模型的使用量字段提取
def extract_usage(resp_json: dict) -> dict:
usage = resp_json.get("usage") or resp_json.get("usageMetadata") or {}
return {
"prompt_tokens": usage.get("prompt_tokens") or usage.get("input_tokens") or usage.get("promptTokenCount"),
"completion_tokens": usage.get("completion_tokens") or usage.get("output_tokens") or usage.get("candidatesTokenCount"),
"total_tokens": usage.get("total_tokens") or usage.get("totalTokenCount") or (
(usage.get("prompt_tokens") or usage.get("input_tokens") or 0)
+ (usage.get("completion_tokens") or usage.get("output_tokens") or 0)
)
}
async def notify_import_analysis_callback(
callback_url: str,
@ -314,6 +393,12 @@ async def notify_import_analysis_callback(
client: httpx.AsyncClient,
) -> None:
callback_target = str(callback_url)
logger.info(
"Posting import analysis callback to %s: %s",
callback_target,
json.dumps(payload, ensure_ascii=False),
)
try:
response = await client.post(callback_target, json=payload)
@ -333,8 +418,10 @@ async def process_import_analysis_job(
try:
payload = await dispatch_import_analysis_job(request, client)
except ProviderAPICallError as exc:
print(
f"[ImportAnalysis] LLM call failed for {request.import_record_id}: {exc}"
logger.error(
"LLM call failed for %s: %s",
request.import_record_id,
exc,
)
payload = {
"import_record_id": request.import_record_id,
@ -346,9 +433,6 @@ async def process_import_analysis_job(
"Unexpected failure while processing import analysis job %s",
request.import_record_id,
)
print(
f"[ImportAnalysis] Unexpected error for {request.import_record_id}: {exc}"
)
payload = {
"import_record_id": request.import_record_id,
"status": "failed",

View File

@ -12,7 +12,6 @@
方向 2字段数据类型与格式推断
针对每列:输出推断数据类型(如 varchar(n) / int / bigint / tinyint / float / double / decimal(p,s) / date / datetime / text
说明推断依据:样本值分布、长度范围、格式正则、是否存在空值、是否数值但含前导零等。
指出数据质量初步观察:缺失率、是否有异常/离群值(简单规则即可)、是否需标准化(如去空格、去重、枚举值归一)。
给出“建议处理动作”:如 trim、cast_float、cast_int、cast_double、cast_date、cast_time、cast_datetime适用于将样本数据转换成数据库表字段兼容的格式。
若为“可能是枚举”的字段,列出候选枚举值及占比。
@ -23,12 +22,8 @@
"columns": [{
"original_name": "原始名称",
"standard_name": "标准化后的名称: 下划线命名,大小写字母、数字、下划线",
"data_type": "数据类型限制为number/string/datetime",
"db_type": "数据库字段类型",
"java_type": "java字段类型限制为: int/long/double/string/date",
"data_type": "",
"nullable": true/false,
"distinct_count_sample": number,
"null_ratio_sample": 0.x,
"is_enum_candidate": true/false,
"description": "字段简短描述",
"date_format": "转换成Date类型的pattern"

View File

@ -36,14 +36,13 @@ async def main() -> None:
},
"llm_model": "deepseek:deepseek-chat",
"temperature": 0.2,
"max_output_tokens": 256,
"callback_url": CALLBACK_URL,
}
async with httpx.AsyncClient(timeout=httpx.Timeout(15.0)) as client:
async with httpx.AsyncClient(timeout=httpx.Timeout(60.0)) as client:
response = await client.post(API_URL, json=payload)
print("Status:", response.status_code)
print("Body:", response.json())
if __name__ == "__main__":

View File

@ -35,7 +35,7 @@ async def main() -> None:
}
],
"temperature": 0.1,
"max_tokens": 1024,
"max_tokens": 2048,
}
async with httpx.AsyncClient(timeout=httpx.Timeout(15.0)) as client: