from __future__ import annotations import base64 import csv import json import logging import os from functools import lru_cache from io import StringIO from pathlib import Path from typing import Any, Dict, List, Sequence, Tuple import httpx from app.exceptions import ProviderAPICallError from app.models import ( DataImportAnalysisJobRequest, DataImportAnalysisRequest, LLMChoice, LLMMessage, LLMProvider, LLMResponse, LLMRole, ) logger = logging.getLogger(__name__) OPENAI_RESPONSES_URL = "https://api.openai.com/v1/responses" def resolve_provider_from_model(llm_model: str) -> Tuple[LLMProvider, str]: """Resolve provider based on the llm_model string. The llm_model may be provided as 'provider:model' or 'provider/model'. If no provider prefix is present, try an educated guess from common model name patterns. """ normalized = llm_model.strip() provider_hint: str | None = None model_name = normalized for delimiter in (":", "/", "|"): if delimiter in normalized: provider_hint, model_name = normalized.split(delimiter, 1) provider_hint = provider_hint.strip().lower() model_name = model_name.strip() break provider_map = {provider.value: provider for provider in LLMProvider} if provider_hint: if provider_hint not in provider_map: raise ValueError( f"Unsupported provider '{provider_hint}'. Expected one of: {', '.join(provider_map.keys())}." ) return provider_map[provider_hint], model_name return _guess_provider_from_model(model_name), model_name def _guess_provider_from_model(model_name: str) -> LLMProvider: lowered = model_name.lower() if lowered.startswith(("gpt", "o1", "text-", "dall-e", "whisper")): return LLMProvider.OPENAI if lowered.startswith(("claude", "anthropic")): return LLMProvider.ANTHROPIC if lowered.startswith(("gemini", "models/gemini")): return LLMProvider.GEMINI if lowered.startswith("qwen"): return LLMProvider.QWEN if lowered.startswith("deepseek"): return LLMProvider.DEEPSEEK if lowered.startswith(("openrouter", "router-")): return LLMProvider.OPENROUTER supported = ", ".join(provider.value for provider in LLMProvider) raise ValueError( f"Unable to infer provider from model '{model_name}'. " f"Please prefix with 'provider:model'. Supported providers: {supported}." ) def build_import_messages( request: DataImportAnalysisRequest, ) -> List[LLMMessage]: """Create system and user messages for the import analysis prompt.""" headers_formatted = "\n".join(f"- {header}" for header in request.table_headers) system_prompt = load_import_template() data_block = ( f"导入记录ID: {request.import_record_id}\n\n" "表头信息:\n" f"{headers_formatted}\n\n" "示例数据:\n" f"{request.example_data}" ) return [ LLMMessage(role=LLMRole.SYSTEM, content=system_prompt), LLMMessage(role=LLMRole.USER, content=data_block), ] @lru_cache(maxsize=1) def load_import_template() -> str: template_path = ( Path(__file__).resolve().parents[2] / "prompt" / "data_import_analysis.md" ) if not template_path.exists(): raise FileNotFoundError(f"Prompt template not found at {template_path}") return template_path.read_text(encoding="utf-8").strip() def derive_headers(rows: Sequence[Any], provided_headers: Sequence[str] | None) -> List[str]: if provided_headers: return list(provided_headers) seen: List[str] = [] for row in rows: if isinstance(row, dict): for key in row.keys(): if key not in seen: seen.append(str(key)) return seen def _stringify_cell(value: Any) -> str: if value is None: return "" if isinstance(value, (str, int, float, bool)): return str(value) try: return json.dumps(value, ensure_ascii=False) except (TypeError, ValueError): return str(value) def rows_to_csv_bytes(rows: Sequence[Any], headers: List[str]) -> bytes: buffer = StringIO() writer = csv.writer(buffer) if headers: writer.writerow(headers) for row in rows: if isinstance(row, dict): writer.writerow([_stringify_cell(row.get(header)) for header in headers]) elif isinstance(row, (list, tuple)): writer.writerow([_stringify_cell(item) for item in row]) else: writer.writerow([_stringify_cell(row)]) return buffer.getvalue().encode("utf-8") def build_file_part(filename: str, payload: bytes, mime_type: str) -> Dict[str, Any]: encoded = base64.b64encode(payload).decode("ascii") return { "type": "input_file", "input_file": { "file_data": { "filename": filename, "file_name": filename, "mime_type": mime_type, "b64_json": encoded, } }, } def build_schema_part( request: DataImportAnalysisJobRequest, headers: List[str] ) -> Dict[str, Any] | None: if request.table_schema is not None: if isinstance(request.table_schema, str): schema_bytes = request.table_schema.encode("utf-8") return build_file_part("table_schema.txt", schema_bytes, "text/plain") try: schema_serialised = json.dumps( request.table_schema, ensure_ascii=False, indent=2 ) except (TypeError, ValueError) as exc: logger.warning("Failed to serialise table_schema for %s: %s", request.import_record_id, exc) schema_serialised = str(request.table_schema) return build_file_part( "table_schema.json", schema_serialised.encode("utf-8"), "application/json", ) if headers: headers_payload = json.dumps({"headers": headers}, ensure_ascii=False, indent=2) return build_file_part( "table_headers.json", headers_payload.encode("utf-8"), "application/json", ) return None def build_openai_input_payload( request: DataImportAnalysisJobRequest, ) -> Dict[str, Any]: headers = derive_headers(request.rows, request.headers) csv_bytes = rows_to_csv_bytes(request.rows, headers) csv_part = build_file_part("sample_rows.csv", csv_bytes, "text/csv") schema_part = build_schema_part(request, headers) prompt = load_import_template() context_lines = [ f"导入记录ID: {request.import_record_id}", f"样本数据行数: {len(request.rows)}", "请参考附件 `sample_rows.csv` 获取原始样本数据。", ] if schema_part: context_lines.append("附加结构信息来自第二个附件,请结合使用。") else: context_lines.append("未提供表头或Schema,请依据数据自行推断字段信息。") user_content = [ {"type": "input_text", "text": "\n".join(context_lines)}, csv_part, ] if schema_part: user_content.append(schema_part) payload: Dict[str, Any] = { "model": request.llm_model, "input": [ {"role": "system", "content": [{"type": "input_text", "text": prompt}]}, {"role": "user", "content": user_content}, ], } if request.temperature is not None: payload["temperature"] = request.temperature if request.max_output_tokens is not None: payload["max_output_tokens"] = request.max_output_tokens return payload def parse_openai_responses_payload( data: Dict[str, Any], fallback_model: str ) -> LLMResponse: output_blocks = data.get("output", []) choices: List[LLMChoice] = [] for idx, block in enumerate(output_blocks): if block.get("type") != "message": continue content_items = block.get("content", []) text_fragments: List[str] = [] for item in content_items: if item.get("type") == "output_text": text_fragments.append(item.get("text", "")) if not text_fragments and data.get("output_text"): text_fragments.append(data.get("output_text", "")) message = LLMMessage(role=LLMRole.ASSISTANT, content="\n".join(text_fragments)) choices.append(LLMChoice(index=idx, message=message)) if not choices and data.get("output_text"): message = LLMMessage(role=LLMRole.ASSISTANT, content=data.get("output_text", "")) choices.append(LLMChoice(index=0, message=message)) return LLMResponse( provider=LLMProvider.OPENAI, model=data.get("model", fallback_model), choices=choices, raw=data, ) async def call_openai_import_analysis( request: DataImportAnalysisJobRequest, client: httpx.AsyncClient, *, api_key: str | None = None, ) -> LLMResponse: openai_api_key = api_key or os.getenv("OPENAI_API_KEY") if not openai_api_key: raise ProviderAPICallError("OPENAI_API_KEY must be set to process import analysis.") payload = build_openai_input_payload(request) headers = { "Authorization": f"Bearer {openai_api_key}", "Content-Type": "application/json", } response = await client.post(OPENAI_RESPONSES_URL, json=payload, headers=headers) try: response.raise_for_status() except httpx.HTTPError as exc: raise ProviderAPICallError(f"OpenAI response API call failed: {exc}") from exc data: Dict[str, Any] = response.json() return parse_openai_responses_payload(data, request.llm_model) async def dispatch_import_analysis_job( request: DataImportAnalysisJobRequest, client: httpx.AsyncClient, *, api_key: str | None = None, ) -> Dict[str, Any]: logger.info("Starting import analysis job %s", request.import_record_id) llm_response = await call_openai_import_analysis(request, client, api_key=api_key) result = { "import_record_id": request.import_record_id, "status": "succeeded", "llm_response": llm_response.model_dump(), } logger.info("Completed import analysis job %s", request.import_record_id) return result async def notify_import_analysis_callback( callback_url: str, payload: Dict[str, Any], client: httpx.AsyncClient, ) -> None: try: response = await client.post(callback_url, json=payload) response.raise_for_status() except httpx.HTTPError as exc: logger.error( "Failed to deliver import analysis callback to %s: %s", callback_url, exc, ) async def process_import_analysis_job( request: DataImportAnalysisJobRequest, client: httpx.AsyncClient, *, api_key: str | None = None, ) -> None: try: payload = await dispatch_import_analysis_job( request, client, api_key=api_key, ) except ProviderAPICallError as exc: payload = { "import_record_id": request.import_record_id, "status": "failed", "error": str(exc), } except Exception as exc: # pragma: no cover - defensive logging logger.exception( "Unexpected failure while processing import analysis job %s", request.import_record_id, ) payload = { "import_record_id": request.import_record_id, "status": "failed", "error": str(exc), } await notify_import_analysis_callback(request.callback_url, payload, client)