443 lines
14 KiB
Python
443 lines
14 KiB
Python
from __future__ import annotations
|
|
|
|
import csv
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
from functools import lru_cache
|
|
from io import StringIO
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Sequence, Tuple
|
|
|
|
import httpx
|
|
from pydantic import ValidationError
|
|
|
|
from app.exceptions import ProviderAPICallError
|
|
from app.models import (
|
|
DataImportAnalysisJobRequest,
|
|
DataImportAnalysisRequest,
|
|
LLMMessage,
|
|
LLMProvider,
|
|
LLMResponse,
|
|
LLMRole,
|
|
)
|
|
from app.settings import DEFAULT_IMPORT_MODEL, get_supported_import_models
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
IMPORT_GATEWAY_BASE_URL = os.getenv(
|
|
"IMPORT_GATEWAY_BASE_URL", "http://localhost:8000"
|
|
)
|
|
|
|
|
|
def _env_float(name: str, default: float) -> float:
|
|
raw = os.getenv(name)
|
|
if raw is None:
|
|
return default
|
|
try:
|
|
return float(raw)
|
|
except ValueError:
|
|
logger.warning("Invalid value for %s=%r, falling back to %.2f", name, raw, default)
|
|
return default
|
|
|
|
|
|
IMPORT_CHAT_TIMEOUT_SECONDS = _env_float("IMPORT_CHAT_TIMEOUT_SECONDS", 120.0)
|
|
|
|
SUPPORTED_IMPORT_MODELS = get_supported_import_models()
|
|
|
|
|
|
def resolve_provider_from_model(llm_model: str) -> Tuple[LLMProvider, str]:
|
|
"""Resolve provider based on the llm_model string.
|
|
|
|
The llm_model may be provided as 'provider:model' or 'provider/model'.
|
|
If no provider prefix is present, try an educated guess from common model name patterns.
|
|
"""
|
|
|
|
normalized = llm_model.strip()
|
|
provider_hint: str | None = None
|
|
model_name = normalized
|
|
|
|
for delimiter in (":", "/", "|"):
|
|
if delimiter in normalized:
|
|
provider_hint, model_name = normalized.split(delimiter, 1)
|
|
provider_hint = provider_hint.strip().lower()
|
|
model_name = model_name.strip()
|
|
break
|
|
|
|
provider_map = {provider.value: provider for provider in LLMProvider}
|
|
|
|
if provider_hint:
|
|
if provider_hint not in provider_map:
|
|
raise ValueError(
|
|
f"Unsupported provider '{provider_hint}'. Expected one of: {', '.join(provider_map.keys())}."
|
|
)
|
|
return provider_map[provider_hint], model_name
|
|
|
|
return _guess_provider_from_model(model_name), model_name
|
|
|
|
|
|
def _guess_provider_from_model(model_name: str) -> LLMProvider:
|
|
lowered = model_name.lower()
|
|
|
|
if lowered.startswith(("gpt", "o1", "text-", "dall-e", "whisper")):
|
|
return LLMProvider.OPENAI
|
|
if lowered.startswith(("claude", "anthropic")):
|
|
return LLMProvider.ANTHROPIC
|
|
if lowered.startswith(("gemini", "models/gemini")):
|
|
return LLMProvider.GEMINI
|
|
if lowered.startswith("qwen"):
|
|
return LLMProvider.QWEN
|
|
if lowered.startswith("deepseek"):
|
|
return LLMProvider.DEEPSEEK
|
|
if lowered.startswith(("openrouter", "router-")):
|
|
return LLMProvider.OPENROUTER
|
|
|
|
supported = ", ".join(provider.value for provider in LLMProvider)
|
|
raise ValueError(
|
|
f"Unable to infer provider from model '{model_name}'. "
|
|
f"Please prefix with 'provider:model'. Supported providers: {supported}."
|
|
)
|
|
|
|
|
|
def build_import_messages(
|
|
request: DataImportAnalysisRequest,
|
|
) -> List[LLMMessage]:
|
|
"""Create system and user messages for the import analysis prompt."""
|
|
headers_formatted = "\n".join(f"- {header}" for header in request.table_headers)
|
|
|
|
system_prompt = load_import_template()
|
|
|
|
data_block = (
|
|
f"导入记录ID: {request.import_record_id}\n\n"
|
|
"表头信息:\n"
|
|
f"{headers_formatted}\n\n"
|
|
"示例数据:\n"
|
|
f"{request.example_data}"
|
|
)
|
|
|
|
return [
|
|
LLMMessage(role=LLMRole.SYSTEM, content=system_prompt),
|
|
LLMMessage(role=LLMRole.USER, content=data_block),
|
|
]
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def load_import_template() -> str:
|
|
template_path = (
|
|
Path(__file__).resolve().parents[2] / "prompt" / "data_import_analysis.md"
|
|
)
|
|
if not template_path.exists():
|
|
raise FileNotFoundError(f"Prompt template not found at {template_path}")
|
|
return template_path.read_text(encoding="utf-8").strip()
|
|
|
|
|
|
def derive_headers(
|
|
rows: Sequence[Any], provided_headers: Sequence[str] | None
|
|
) -> List[str]:
|
|
if provided_headers:
|
|
return [str(header) for header in provided_headers]
|
|
|
|
collected: List[str] = []
|
|
list_lengths: List[int] = []
|
|
|
|
for row in rows:
|
|
if isinstance(row, dict):
|
|
for key in row.keys():
|
|
key_str = str(key)
|
|
if key_str not in collected:
|
|
collected.append(key_str)
|
|
elif isinstance(row, (list, tuple)):
|
|
list_lengths.append(len(row))
|
|
|
|
if collected:
|
|
return collected
|
|
|
|
if list_lengths:
|
|
max_len = max(list_lengths)
|
|
return [f"column_{idx + 1}" for idx in range(max_len)]
|
|
|
|
return ["column_1"]
|
|
|
|
|
|
def _stringify_cell(value: Any) -> str:
|
|
if value is None:
|
|
return ""
|
|
if isinstance(value, (str, int, float, bool)):
|
|
return str(value)
|
|
try:
|
|
return json.dumps(value, ensure_ascii=False)
|
|
except (TypeError, ValueError):
|
|
return str(value)
|
|
|
|
|
|
def rows_to_csv_text(
|
|
rows: Sequence[Any], headers: Sequence[str], *, max_rows: int = 50
|
|
) -> str:
|
|
buffer = StringIO()
|
|
writer = csv.writer(buffer)
|
|
|
|
if headers:
|
|
writer.writerow(headers)
|
|
|
|
for idx, row in enumerate(rows):
|
|
if max_rows and idx >= max_rows:
|
|
break
|
|
if isinstance(row, dict):
|
|
writer.writerow([_stringify_cell(row.get(header)) for header in headers])
|
|
elif isinstance(row, (list, tuple)):
|
|
writer.writerow([_stringify_cell(item) for item in row])
|
|
else:
|
|
writer.writerow([_stringify_cell(row)])
|
|
|
|
return buffer.getvalue().strip()
|
|
|
|
|
|
def format_table_schema(schema: Any) -> str:
|
|
if schema is None:
|
|
return ""
|
|
if isinstance(schema, str):
|
|
return schema.strip()
|
|
try:
|
|
return json.dumps(schema, ensure_ascii=False, indent=2)
|
|
except (TypeError, ValueError):
|
|
return str(schema)
|
|
|
|
|
|
def build_analysis_request(
|
|
request: DataImportAnalysisJobRequest,
|
|
) -> DataImportAnalysisRequest:
|
|
headers = derive_headers(request.rows, request.headers)
|
|
|
|
if request.raw_csv:
|
|
csv_text = request.raw_csv.strip()
|
|
else:
|
|
csv_text = rows_to_csv_text(request.rows, headers)
|
|
|
|
sections: List[str] = []
|
|
if csv_text:
|
|
sections.append("CSV样本预览:\n" + csv_text)
|
|
schema_text = format_table_schema(request.table_schema)
|
|
if schema_text:
|
|
sections.append("附加结构信息:\n" + schema_text)
|
|
|
|
example_data = "\n\n".join(sections) if sections else "未提供样本数据。"
|
|
|
|
max_length = 30000
|
|
if len(example_data) > max_length:
|
|
example_data = example_data[: max_length - 3] + "..."
|
|
|
|
return DataImportAnalysisRequest(
|
|
import_record_id=request.import_record_id,
|
|
example_data=example_data,
|
|
table_headers=headers,
|
|
llm_model=request.llm_model or DEFAULT_IMPORT_MODEL,
|
|
)
|
|
|
|
|
|
def build_chat_payload(request: DataImportAnalysisJobRequest) -> Dict[str, Any]:
|
|
llm_input = request.llm_model or DEFAULT_IMPORT_MODEL
|
|
provider, model_name = resolve_provider_from_model(llm_input)
|
|
normalized_model = f"{provider.value}:{model_name}"
|
|
|
|
if SUPPORTED_IMPORT_MODELS and normalized_model not in SUPPORTED_IMPORT_MODELS:
|
|
raise ProviderAPICallError(
|
|
"Model '{model}' is not allowed. Allowed models: {allowed}".format(
|
|
model=normalized_model,
|
|
allowed=", ".join(sorted(SUPPORTED_IMPORT_MODELS)),
|
|
)
|
|
)
|
|
|
|
analysis_request = build_analysis_request(request)
|
|
|
|
messages = build_import_messages(analysis_request)
|
|
|
|
payload: Dict[str, Any] = {
|
|
"provider": provider.value,
|
|
"model": model_name,
|
|
"messages": [message.model_dump() for message in messages],
|
|
"temperature": request.temperature if request.temperature is not None else 0.2,
|
|
}
|
|
|
|
if request.max_output_tokens is not None:
|
|
payload["max_tokens"] = request.max_output_tokens
|
|
|
|
return payload
|
|
|
|
|
|
def _extract_json_payload(content: str) -> str:
|
|
"""Try to pull a JSON object from an LLM content string."""
|
|
# Prefer fenced code blocks such as ```json { ... } ```
|
|
fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", content, flags=re.DOTALL | re.IGNORECASE)
|
|
if fenced:
|
|
return fenced.group(1).strip()
|
|
|
|
stripped = content.strip()
|
|
if stripped.startswith("{") and stripped.endswith("}"):
|
|
return stripped
|
|
|
|
start = stripped.find("{")
|
|
end = stripped.rfind("}")
|
|
if start != -1 and end != -1 and end > start:
|
|
return stripped[start : end + 1].strip()
|
|
|
|
return stripped
|
|
|
|
|
|
def parse_llm_analysis_json(llm_response: LLMResponse) -> Dict[str, Any]:
|
|
"""Extract and parse the structured JSON payload from an LLM response."""
|
|
if not llm_response.choices:
|
|
raise ProviderAPICallError("LLM response did not include any choices to parse.")
|
|
|
|
content = llm_response.choices[0].message.content or ""
|
|
if not content.strip():
|
|
raise ProviderAPICallError("LLM response content is empty.")
|
|
|
|
json_payload = _extract_json_payload(content)
|
|
|
|
try:
|
|
return json.loads(json_payload)
|
|
except json.JSONDecodeError as exc:
|
|
preview = json_payload[:10000]
|
|
logger.error("Failed to parse JSON from LLM response content: %s", preview, exc_info=True)
|
|
raise ProviderAPICallError("LLM response JSON could not be parsed.") from exc
|
|
|
|
|
|
async def dispatch_import_analysis_job(
|
|
request: DataImportAnalysisJobRequest,
|
|
client: httpx.AsyncClient,
|
|
) -> Dict[str, Any]:
|
|
logger.info("Starting import analysis job %s", request.import_record_id)
|
|
|
|
payload = build_chat_payload(request)
|
|
url = f"{IMPORT_GATEWAY_BASE_URL.rstrip('/')}/v1/chat/completions"
|
|
|
|
logger.info(
|
|
"Dispatching import %s to %s: %s",
|
|
request.import_record_id,
|
|
url,
|
|
json.dumps(payload, ensure_ascii=False),
|
|
)
|
|
|
|
timeout = httpx.Timeout(IMPORT_CHAT_TIMEOUT_SECONDS)
|
|
|
|
try:
|
|
response = await client.post(url, json=payload, timeout=timeout)
|
|
response.raise_for_status()
|
|
except httpx.HTTPStatusError as exc:
|
|
body_preview = ""
|
|
if exc.response is not None:
|
|
body_preview = exc.response.text[:400]
|
|
raise ProviderAPICallError(
|
|
f"Failed to invoke chat completions endpoint: {exc}. Response body: {body_preview}"
|
|
) from exc
|
|
except httpx.HTTPError as exc:
|
|
raise ProviderAPICallError(
|
|
f"HTTP call to chat completions endpoint failed: {exc}"
|
|
) from exc
|
|
|
|
try:
|
|
response_data = response.json()
|
|
except ValueError as exc:
|
|
raise ProviderAPICallError("Chat completions endpoint returned invalid JSON") from exc
|
|
|
|
logger.info(
|
|
"LLM HTTP status for %s: %s",
|
|
request.import_record_id,
|
|
response.status_code,
|
|
)
|
|
logger.info(
|
|
"LLM response for %s: %s",
|
|
request.import_record_id,
|
|
json.dumps(response_data, ensure_ascii=False),
|
|
)
|
|
|
|
try:
|
|
llm_response = LLMResponse.model_validate(response_data)
|
|
except ValidationError as exc:
|
|
raise ProviderAPICallError(
|
|
"Chat completions endpoint returned unexpected schema"
|
|
) from exc
|
|
|
|
structured_json = parse_llm_analysis_json(llm_response)
|
|
usage_data = extract_usage(llm_response.raw)
|
|
|
|
logger.info("Completed import analysis job %s", request.import_record_id)
|
|
result: Dict[str, Any] = {
|
|
"import_record_id": request.import_record_id,
|
|
"status": "succeeded",
|
|
"llm_response": llm_response.model_dump(),
|
|
"analysis": structured_json
|
|
}
|
|
|
|
if usage_data:
|
|
result["usage"] = usage_data
|
|
|
|
return result
|
|
|
|
# 兼容处理多模型的使用量字段提取
|
|
def extract_usage(resp_json: dict) -> dict:
|
|
usage = resp_json.get("usage") or resp_json.get("usageMetadata") or {}
|
|
return {
|
|
"prompt_tokens": usage.get("prompt_tokens") or usage.get("input_tokens") or usage.get("promptTokenCount"),
|
|
"completion_tokens": usage.get("completion_tokens") or usage.get("output_tokens") or usage.get("candidatesTokenCount"),
|
|
"total_tokens": usage.get("total_tokens") or usage.get("totalTokenCount") or (
|
|
(usage.get("prompt_tokens") or usage.get("input_tokens") or 0)
|
|
+ (usage.get("completion_tokens") or usage.get("output_tokens") or 0)
|
|
)
|
|
}
|
|
|
|
async def notify_import_analysis_callback(
|
|
callback_url: str,
|
|
payload: Dict[str, Any],
|
|
client: httpx.AsyncClient,
|
|
) -> None:
|
|
callback_target = str(callback_url)
|
|
logger.info(
|
|
"Posting import analysis callback to %s: %s",
|
|
callback_target,
|
|
json.dumps(payload, ensure_ascii=False),
|
|
)
|
|
|
|
|
|
try:
|
|
response = await client.post(callback_target, json=payload)
|
|
response.raise_for_status()
|
|
except httpx.HTTPError as exc:
|
|
logger.error(
|
|
"Failed to deliver import analysis callback to %s: %s",
|
|
callback_target,
|
|
exc,
|
|
)
|
|
|
|
|
|
async def process_import_analysis_job(
|
|
request: DataImportAnalysisJobRequest,
|
|
client: httpx.AsyncClient,
|
|
) -> None:
|
|
try:
|
|
payload = await dispatch_import_analysis_job(request, client)
|
|
except ProviderAPICallError as exc:
|
|
logger.error(
|
|
"LLM call failed for %s: %s",
|
|
request.import_record_id,
|
|
exc,
|
|
)
|
|
payload = {
|
|
"import_record_id": request.import_record_id,
|
|
"status": "failed",
|
|
"error": str(exc),
|
|
}
|
|
except Exception as exc: # pragma: no cover - defensive logging
|
|
logger.exception(
|
|
"Unexpected failure while processing import analysis job %s",
|
|
request.import_record_id,
|
|
)
|
|
payload = {
|
|
"import_record_id": request.import_record_id,
|
|
"status": "failed",
|
|
"error": str(exc),
|
|
}
|
|
|
|
await notify_import_analysis_callback(request.callback_url, payload, client)
|