Files
data-ge/app/services/import_analysis.py
2025-11-03 00:19:23 +08:00

443 lines
14 KiB
Python

from __future__ import annotations
import csv
import json
import logging
import os
import re
from functools import lru_cache
from io import StringIO
from pathlib import Path
from typing import Any, Dict, List, Sequence, Tuple
import httpx
from pydantic import ValidationError
from app.exceptions import ProviderAPICallError
from app.models import (
DataImportAnalysisJobRequest,
DataImportAnalysisRequest,
LLMMessage,
LLMProvider,
LLMResponse,
LLMRole,
)
from app.settings import DEFAULT_IMPORT_MODEL, get_supported_import_models
logger = logging.getLogger(__name__)
IMPORT_GATEWAY_BASE_URL = os.getenv(
"IMPORT_GATEWAY_BASE_URL", "http://localhost:8000"
)
def _env_float(name: str, default: float) -> float:
raw = os.getenv(name)
if raw is None:
return default
try:
return float(raw)
except ValueError:
logger.warning("Invalid value for %s=%r, falling back to %.2f", name, raw, default)
return default
IMPORT_CHAT_TIMEOUT_SECONDS = _env_float("IMPORT_CHAT_TIMEOUT_SECONDS", 120.0)
SUPPORTED_IMPORT_MODELS = get_supported_import_models()
def resolve_provider_from_model(llm_model: str) -> Tuple[LLMProvider, str]:
"""Resolve provider based on the llm_model string.
The llm_model may be provided as 'provider:model' or 'provider/model'.
If no provider prefix is present, try an educated guess from common model name patterns.
"""
normalized = llm_model.strip()
provider_hint: str | None = None
model_name = normalized
for delimiter in (":", "/", "|"):
if delimiter in normalized:
provider_hint, model_name = normalized.split(delimiter, 1)
provider_hint = provider_hint.strip().lower()
model_name = model_name.strip()
break
provider_map = {provider.value: provider for provider in LLMProvider}
if provider_hint:
if provider_hint not in provider_map:
raise ValueError(
f"Unsupported provider '{provider_hint}'. Expected one of: {', '.join(provider_map.keys())}."
)
return provider_map[provider_hint], model_name
return _guess_provider_from_model(model_name), model_name
def _guess_provider_from_model(model_name: str) -> LLMProvider:
lowered = model_name.lower()
if lowered.startswith(("gpt", "o1", "text-", "dall-e", "whisper")):
return LLMProvider.OPENAI
if lowered.startswith(("claude", "anthropic")):
return LLMProvider.ANTHROPIC
if lowered.startswith(("gemini", "models/gemini")):
return LLMProvider.GEMINI
if lowered.startswith("qwen"):
return LLMProvider.QWEN
if lowered.startswith("deepseek"):
return LLMProvider.DEEPSEEK
if lowered.startswith(("openrouter", "router-")):
return LLMProvider.OPENROUTER
supported = ", ".join(provider.value for provider in LLMProvider)
raise ValueError(
f"Unable to infer provider from model '{model_name}'. "
f"Please prefix with 'provider:model'. Supported providers: {supported}."
)
def build_import_messages(
request: DataImportAnalysisRequest,
) -> List[LLMMessage]:
"""Create system and user messages for the import analysis prompt."""
headers_formatted = "\n".join(f"- {header}" for header in request.table_headers)
system_prompt = load_import_template()
data_block = (
f"导入记录ID: {request.import_record_id}\n\n"
"表头信息:\n"
f"{headers_formatted}\n\n"
"示例数据:\n"
f"{request.example_data}"
)
return [
LLMMessage(role=LLMRole.SYSTEM, content=system_prompt),
LLMMessage(role=LLMRole.USER, content=data_block),
]
@lru_cache(maxsize=1)
def load_import_template() -> str:
template_path = (
Path(__file__).resolve().parents[2] / "prompt" / "data_import_analysis.md"
)
if not template_path.exists():
raise FileNotFoundError(f"Prompt template not found at {template_path}")
return template_path.read_text(encoding="utf-8").strip()
def derive_headers(
rows: Sequence[Any], provided_headers: Sequence[str] | None
) -> List[str]:
if provided_headers:
return [str(header) for header in provided_headers]
collected: List[str] = []
list_lengths: List[int] = []
for row in rows:
if isinstance(row, dict):
for key in row.keys():
key_str = str(key)
if key_str not in collected:
collected.append(key_str)
elif isinstance(row, (list, tuple)):
list_lengths.append(len(row))
if collected:
return collected
if list_lengths:
max_len = max(list_lengths)
return [f"column_{idx + 1}" for idx in range(max_len)]
return ["column_1"]
def _stringify_cell(value: Any) -> str:
if value is None:
return ""
if isinstance(value, (str, int, float, bool)):
return str(value)
try:
return json.dumps(value, ensure_ascii=False)
except (TypeError, ValueError):
return str(value)
def rows_to_csv_text(
rows: Sequence[Any], headers: Sequence[str], *, max_rows: int = 50
) -> str:
buffer = StringIO()
writer = csv.writer(buffer)
if headers:
writer.writerow(headers)
for idx, row in enumerate(rows):
if max_rows and idx >= max_rows:
break
if isinstance(row, dict):
writer.writerow([_stringify_cell(row.get(header)) for header in headers])
elif isinstance(row, (list, tuple)):
writer.writerow([_stringify_cell(item) for item in row])
else:
writer.writerow([_stringify_cell(row)])
return buffer.getvalue().strip()
def format_table_schema(schema: Any) -> str:
if schema is None:
return ""
if isinstance(schema, str):
return schema.strip()
try:
return json.dumps(schema, ensure_ascii=False, indent=2)
except (TypeError, ValueError):
return str(schema)
def build_analysis_request(
request: DataImportAnalysisJobRequest,
) -> DataImportAnalysisRequest:
headers = derive_headers(request.rows, request.headers)
if request.raw_csv:
csv_text = request.raw_csv.strip()
else:
csv_text = rows_to_csv_text(request.rows, headers)
sections: List[str] = []
if csv_text:
sections.append("CSV样本预览:\n" + csv_text)
schema_text = format_table_schema(request.table_schema)
if schema_text:
sections.append("附加结构信息:\n" + schema_text)
example_data = "\n\n".join(sections) if sections else "未提供样本数据。"
max_length = 30000
if len(example_data) > max_length:
example_data = example_data[: max_length - 3] + "..."
return DataImportAnalysisRequest(
import_record_id=request.import_record_id,
example_data=example_data,
table_headers=headers,
llm_model=request.llm_model or DEFAULT_IMPORT_MODEL,
)
def build_chat_payload(request: DataImportAnalysisJobRequest) -> Dict[str, Any]:
llm_input = request.llm_model or DEFAULT_IMPORT_MODEL
provider, model_name = resolve_provider_from_model(llm_input)
normalized_model = f"{provider.value}:{model_name}"
if SUPPORTED_IMPORT_MODELS and normalized_model not in SUPPORTED_IMPORT_MODELS:
raise ProviderAPICallError(
"Model '{model}' is not allowed. Allowed models: {allowed}".format(
model=normalized_model,
allowed=", ".join(sorted(SUPPORTED_IMPORT_MODELS)),
)
)
analysis_request = build_analysis_request(request)
messages = build_import_messages(analysis_request)
payload: Dict[str, Any] = {
"provider": provider.value,
"model": model_name,
"messages": [message.model_dump() for message in messages],
"temperature": request.temperature if request.temperature is not None else 0.2,
}
if request.max_output_tokens is not None:
payload["max_tokens"] = request.max_output_tokens
return payload
def _extract_json_payload(content: str) -> str:
"""Try to pull a JSON object from an LLM content string."""
# Prefer fenced code blocks such as ```json { ... } ```
fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", content, flags=re.DOTALL | re.IGNORECASE)
if fenced:
return fenced.group(1).strip()
stripped = content.strip()
if stripped.startswith("{") and stripped.endswith("}"):
return stripped
start = stripped.find("{")
end = stripped.rfind("}")
if start != -1 and end != -1 and end > start:
return stripped[start : end + 1].strip()
return stripped
def parse_llm_analysis_json(llm_response: LLMResponse) -> Dict[str, Any]:
"""Extract and parse the structured JSON payload from an LLM response."""
if not llm_response.choices:
raise ProviderAPICallError("LLM response did not include any choices to parse.")
content = llm_response.choices[0].message.content or ""
if not content.strip():
raise ProviderAPICallError("LLM response content is empty.")
json_payload = _extract_json_payload(content)
try:
return json.loads(json_payload)
except json.JSONDecodeError as exc:
preview = json_payload[:10000]
logger.error("Failed to parse JSON from LLM response content: %s", preview, exc_info=True)
raise ProviderAPICallError("LLM response JSON could not be parsed.") from exc
async def dispatch_import_analysis_job(
request: DataImportAnalysisJobRequest,
client: httpx.AsyncClient,
) -> Dict[str, Any]:
logger.info("Starting import analysis job %s", request.import_record_id)
payload = build_chat_payload(request)
url = f"{IMPORT_GATEWAY_BASE_URL.rstrip('/')}/v1/chat/completions"
logger.info(
"Dispatching import %s to %s: %s",
request.import_record_id,
url,
json.dumps(payload, ensure_ascii=False),
)
timeout = httpx.Timeout(IMPORT_CHAT_TIMEOUT_SECONDS)
try:
response = await client.post(url, json=payload, timeout=timeout)
response.raise_for_status()
except httpx.HTTPStatusError as exc:
body_preview = ""
if exc.response is not None:
body_preview = exc.response.text[:400]
raise ProviderAPICallError(
f"Failed to invoke chat completions endpoint: {exc}. Response body: {body_preview}"
) from exc
except httpx.HTTPError as exc:
raise ProviderAPICallError(
f"HTTP call to chat completions endpoint failed: {exc}"
) from exc
try:
response_data = response.json()
except ValueError as exc:
raise ProviderAPICallError("Chat completions endpoint returned invalid JSON") from exc
logger.info(
"LLM HTTP status for %s: %s",
request.import_record_id,
response.status_code,
)
logger.info(
"LLM response for %s: %s",
request.import_record_id,
json.dumps(response_data, ensure_ascii=False),
)
try:
llm_response = LLMResponse.model_validate(response_data)
except ValidationError as exc:
raise ProviderAPICallError(
"Chat completions endpoint returned unexpected schema"
) from exc
structured_json = parse_llm_analysis_json(llm_response)
usage_data = extract_usage(llm_response.raw)
logger.info("Completed import analysis job %s", request.import_record_id)
result: Dict[str, Any] = {
"import_record_id": request.import_record_id,
"status": "succeeded",
"llm_response": llm_response.model_dump(),
"analysis": structured_json
}
if usage_data:
result["usage"] = usage_data
return result
# 兼容处理多模型的使用量字段提取
def extract_usage(resp_json: dict) -> dict:
usage = resp_json.get("usage") or resp_json.get("usageMetadata") or {}
return {
"prompt_tokens": usage.get("prompt_tokens") or usage.get("input_tokens") or usage.get("promptTokenCount"),
"completion_tokens": usage.get("completion_tokens") or usage.get("output_tokens") or usage.get("candidatesTokenCount"),
"total_tokens": usage.get("total_tokens") or usage.get("totalTokenCount") or (
(usage.get("prompt_tokens") or usage.get("input_tokens") or 0)
+ (usage.get("completion_tokens") or usage.get("output_tokens") or 0)
)
}
async def notify_import_analysis_callback(
callback_url: str,
payload: Dict[str, Any],
client: httpx.AsyncClient,
) -> None:
callback_target = str(callback_url)
logger.info(
"Posting import analysis callback to %s: %s",
callback_target,
json.dumps(payload, ensure_ascii=False),
)
try:
response = await client.post(callback_target, json=payload)
response.raise_for_status()
except httpx.HTTPError as exc:
logger.error(
"Failed to deliver import analysis callback to %s: %s",
callback_target,
exc,
)
async def process_import_analysis_job(
request: DataImportAnalysisJobRequest,
client: httpx.AsyncClient,
) -> None:
try:
payload = await dispatch_import_analysis_job(request, client)
except ProviderAPICallError as exc:
logger.error(
"LLM call failed for %s: %s",
request.import_record_id,
exc,
)
payload = {
"import_record_id": request.import_record_id,
"status": "failed",
"error": str(exc),
}
except Exception as exc: # pragma: no cover - defensive logging
logger.exception(
"Unexpected failure while processing import analysis job %s",
request.import_record_id,
)
payload = {
"import_record_id": request.import_record_id,
"status": "failed",
"error": str(exc),
}
await notify_import_analysis_callback(request.callback_url, payload, client)