from __future__ import annotations import csv import json import logging import os import re from functools import lru_cache from io import StringIO from pathlib import Path from typing import Any, Dict, List, Sequence, Tuple import httpx from pydantic import ValidationError from app.exceptions import ProviderAPICallError from app.models import ( DataImportAnalysisJobRequest, DataImportAnalysisRequest, LLMMessage, LLMProvider, LLMResponse, LLMRole, ) from app.settings import DEFAULT_IMPORT_MODEL, get_supported_import_models logger = logging.getLogger(__name__) IMPORT_GATEWAY_BASE_URL = os.getenv( "IMPORT_GATEWAY_BASE_URL", "http://localhost:8000" ) def _env_float(name: str, default: float) -> float: raw = os.getenv(name) if raw is None: return default try: return float(raw) except ValueError: logger.warning("Invalid value for %s=%r, falling back to %.2f", name, raw, default) return default IMPORT_CHAT_TIMEOUT_SECONDS = _env_float("IMPORT_CHAT_TIMEOUT_SECONDS", 120.0) SUPPORTED_IMPORT_MODELS = get_supported_import_models() def resolve_provider_from_model(llm_model: str) -> Tuple[LLMProvider, str]: """Resolve provider based on the llm_model string. The llm_model may be provided as 'provider:model' or 'provider/model'. If no provider prefix is present, try an educated guess from common model name patterns. """ normalized = llm_model.strip() provider_hint: str | None = None model_name = normalized for delimiter in (":", "/", "|"): if delimiter in normalized: provider_hint, model_name = normalized.split(delimiter, 1) provider_hint = provider_hint.strip().lower() model_name = model_name.strip() break provider_map = {provider.value: provider for provider in LLMProvider} if provider_hint: if provider_hint not in provider_map: raise ValueError( f"Unsupported provider '{provider_hint}'. Expected one of: {', '.join(provider_map.keys())}." ) return provider_map[provider_hint], model_name return _guess_provider_from_model(model_name), model_name def _guess_provider_from_model(model_name: str) -> LLMProvider: lowered = model_name.lower() if lowered.startswith(("gpt", "o1", "text-", "dall-e", "whisper")): return LLMProvider.OPENAI if lowered.startswith(("claude", "anthropic")): return LLMProvider.ANTHROPIC if lowered.startswith(("gemini", "models/gemini")): return LLMProvider.GEMINI if lowered.startswith("qwen"): return LLMProvider.QWEN if lowered.startswith("deepseek"): return LLMProvider.DEEPSEEK if lowered.startswith(("openrouter", "router-")): return LLMProvider.OPENROUTER supported = ", ".join(provider.value for provider in LLMProvider) raise ValueError( f"Unable to infer provider from model '{model_name}'. " f"Please prefix with 'provider:model'. Supported providers: {supported}." ) def build_import_messages( request: DataImportAnalysisRequest, ) -> List[LLMMessage]: """Create system and user messages for the import analysis prompt.""" headers_formatted = "\n".join(f"- {header}" for header in request.table_headers) system_prompt = load_import_template() data_block = ( f"导入记录ID: {request.import_record_id}\n\n" "表头信息:\n" f"{headers_formatted}\n\n" "示例数据:\n" f"{request.example_data}" ) return [ LLMMessage(role=LLMRole.SYSTEM, content=system_prompt), LLMMessage(role=LLMRole.USER, content=data_block), ] @lru_cache(maxsize=1) def load_import_template() -> str: template_path = ( Path(__file__).resolve().parents[2] / "prompt" / "data_import_analysis.md" ) if not template_path.exists(): raise FileNotFoundError(f"Prompt template not found at {template_path}") return template_path.read_text(encoding="utf-8").strip() def derive_headers( rows: Sequence[Any], provided_headers: Sequence[str] | None ) -> List[str]: if provided_headers: return [str(header) for header in provided_headers] collected: List[str] = [] list_lengths: List[int] = [] for row in rows: if isinstance(row, dict): for key in row.keys(): key_str = str(key) if key_str not in collected: collected.append(key_str) elif isinstance(row, (list, tuple)): list_lengths.append(len(row)) if collected: return collected if list_lengths: max_len = max(list_lengths) return [f"column_{idx + 1}" for idx in range(max_len)] return ["column_1"] def _stringify_cell(value: Any) -> str: if value is None: return "" if isinstance(value, (str, int, float, bool)): return str(value) try: return json.dumps(value, ensure_ascii=False) except (TypeError, ValueError): return str(value) def rows_to_csv_text( rows: Sequence[Any], headers: Sequence[str], *, max_rows: int = 50 ) -> str: buffer = StringIO() writer = csv.writer(buffer) if headers: writer.writerow(headers) for idx, row in enumerate(rows): if max_rows and idx >= max_rows: break if isinstance(row, dict): writer.writerow([_stringify_cell(row.get(header)) for header in headers]) elif isinstance(row, (list, tuple)): writer.writerow([_stringify_cell(item) for item in row]) else: writer.writerow([_stringify_cell(row)]) return buffer.getvalue().strip() def format_table_schema(schema: Any) -> str: if schema is None: return "" if isinstance(schema, str): return schema.strip() try: return json.dumps(schema, ensure_ascii=False, indent=2) except (TypeError, ValueError): return str(schema) def build_analysis_request( request: DataImportAnalysisJobRequest, ) -> DataImportAnalysisRequest: headers = derive_headers(request.rows, request.headers) if request.raw_csv: csv_text = request.raw_csv.strip() else: csv_text = rows_to_csv_text(request.rows, headers) sections: List[str] = [] if csv_text: sections.append("CSV样本预览:\n" + csv_text) schema_text = format_table_schema(request.table_schema) if schema_text: sections.append("附加结构信息:\n" + schema_text) example_data = "\n\n".join(sections) if sections else "未提供样本数据。" max_length = 30000 if len(example_data) > max_length: example_data = example_data[: max_length - 3] + "..." return DataImportAnalysisRequest( import_record_id=request.import_record_id, example_data=example_data, table_headers=headers, llm_model=request.llm_model or DEFAULT_IMPORT_MODEL, ) def build_chat_payload(request: DataImportAnalysisJobRequest) -> Dict[str, Any]: llm_input = request.llm_model or DEFAULT_IMPORT_MODEL provider, model_name = resolve_provider_from_model(llm_input) normalized_model = f"{provider.value}:{model_name}" if SUPPORTED_IMPORT_MODELS and normalized_model not in SUPPORTED_IMPORT_MODELS: raise ProviderAPICallError( "Model '{model}' is not allowed. Allowed models: {allowed}".format( model=normalized_model, allowed=", ".join(sorted(SUPPORTED_IMPORT_MODELS)), ) ) analysis_request = build_analysis_request(request) messages = build_import_messages(analysis_request) payload: Dict[str, Any] = { "provider": provider.value, "model": model_name, "messages": [message.model_dump() for message in messages], "temperature": request.temperature if request.temperature is not None else 0.2, } if request.max_output_tokens is not None: payload["max_tokens"] = request.max_output_tokens return payload def _extract_json_payload(content: str) -> str: """Try to pull a JSON object from an LLM content string.""" # Prefer fenced code blocks such as ```json { ... } ``` fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", content, flags=re.DOTALL | re.IGNORECASE) if fenced: return fenced.group(1).strip() stripped = content.strip() if stripped.startswith("{") and stripped.endswith("}"): return stripped start = stripped.find("{") end = stripped.rfind("}") if start != -1 and end != -1 and end > start: return stripped[start : end + 1].strip() return stripped def parse_llm_analysis_json(llm_response: LLMResponse) -> Dict[str, Any]: """Extract and parse the structured JSON payload from an LLM response.""" if not llm_response.choices: raise ProviderAPICallError("LLM response did not include any choices to parse.") content = llm_response.choices[0].message.content or "" if not content.strip(): raise ProviderAPICallError("LLM response content is empty.") json_payload = _extract_json_payload(content) try: return json.loads(json_payload) except json.JSONDecodeError as exc: preview = json_payload[:10000] logger.error("Failed to parse JSON from LLM response content: %s", preview, exc_info=True) raise ProviderAPICallError("LLM response JSON could not be parsed.") from exc async def dispatch_import_analysis_job( request: DataImportAnalysisJobRequest, client: httpx.AsyncClient, ) -> Dict[str, Any]: logger.info("Starting import analysis job %s", request.import_record_id) payload = build_chat_payload(request) url = f"{IMPORT_GATEWAY_BASE_URL.rstrip('/')}/v1/chat/completions" logger.info( "Dispatching import %s to %s: %s", request.import_record_id, url, json.dumps(payload, ensure_ascii=False), ) timeout = httpx.Timeout(IMPORT_CHAT_TIMEOUT_SECONDS) try: response = await client.post(url, json=payload, timeout=timeout) response.raise_for_status() except httpx.HTTPStatusError as exc: body_preview = "" if exc.response is not None: body_preview = exc.response.text[:400] raise ProviderAPICallError( f"Failed to invoke chat completions endpoint: {exc}. Response body: {body_preview}" ) from exc except httpx.HTTPError as exc: raise ProviderAPICallError( f"HTTP call to chat completions endpoint failed: {exc}" ) from exc try: response_data = response.json() except ValueError as exc: raise ProviderAPICallError("Chat completions endpoint returned invalid JSON") from exc logger.info( "LLM HTTP status for %s: %s", request.import_record_id, response.status_code, ) logger.info( "LLM response for %s: %s", request.import_record_id, json.dumps(response_data, ensure_ascii=False), ) try: llm_response = LLMResponse.model_validate(response_data) except ValidationError as exc: raise ProviderAPICallError( "Chat completions endpoint returned unexpected schema" ) from exc structured_json = parse_llm_analysis_json(llm_response) usage_data = extract_usage(llm_response.raw) logger.info("Completed import analysis job %s", request.import_record_id) result: Dict[str, Any] = { "import_record_id": request.import_record_id, "status": "succeeded", "llm_response": llm_response.model_dump(), "analysis": structured_json } if usage_data: result["usage"] = usage_data return result # 兼容处理多模型的使用量字段提取 def extract_usage(resp_json: dict) -> dict: usage = resp_json.get("usage") or resp_json.get("usageMetadata") or {} return { "prompt_tokens": usage.get("prompt_tokens") or usage.get("input_tokens") or usage.get("promptTokenCount"), "completion_tokens": usage.get("completion_tokens") or usage.get("output_tokens") or usage.get("candidatesTokenCount"), "total_tokens": usage.get("total_tokens") or usage.get("totalTokenCount") or ( (usage.get("prompt_tokens") or usage.get("input_tokens") or 0) + (usage.get("completion_tokens") or usage.get("output_tokens") or 0) ) } async def notify_import_analysis_callback( callback_url: str, payload: Dict[str, Any], client: httpx.AsyncClient, ) -> None: callback_target = str(callback_url) logger.info( "Posting import analysis callback to %s: %s", callback_target, json.dumps(payload, ensure_ascii=False), ) try: response = await client.post(callback_target, json=payload) response.raise_for_status() except httpx.HTTPError as exc: logger.error( "Failed to deliver import analysis callback to %s: %s", callback_target, exc, ) async def process_import_analysis_job( request: DataImportAnalysisJobRequest, client: httpx.AsyncClient, ) -> None: try: payload = await dispatch_import_analysis_job(request, client) except ProviderAPICallError as exc: logger.error( "LLM call failed for %s: %s", request.import_record_id, exc, ) payload = { "import_record_id": request.import_record_id, "status": "failed", "error": str(exc), } except Exception as exc: # pragma: no cover - defensive logging logger.exception( "Unexpected failure while processing import analysis job %s", request.import_record_id, ) payload = { "import_record_id": request.import_record_id, "status": "failed", "error": str(exc), } await notify_import_analysis_callback(request.callback_url, payload, client)