Files
data-ge/app/services/import_analysis.py

375 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import base64
import csv
import json
import logging
import os
from functools import lru_cache
from io import StringIO
from pathlib import Path
from typing import Any, Dict, List, Sequence, Tuple
import httpx
from app.exceptions import ProviderAPICallError
from app.models import (
DataImportAnalysisJobRequest,
DataImportAnalysisRequest,
LLMChoice,
LLMMessage,
LLMProvider,
LLMResponse,
LLMRole,
)
logger = logging.getLogger(__name__)
OPENAI_RESPONSES_URL = "https://api.openai.com/v1/responses"
def resolve_provider_from_model(llm_model: str) -> Tuple[LLMProvider, str]:
"""Resolve provider based on the llm_model string.
The llm_model may be provided as 'provider:model' or 'provider/model'.
If no provider prefix is present, try an educated guess from common model name patterns.
"""
normalized = llm_model.strip()
provider_hint: str | None = None
model_name = normalized
for delimiter in (":", "/", "|"):
if delimiter in normalized:
provider_hint, model_name = normalized.split(delimiter, 1)
provider_hint = provider_hint.strip().lower()
model_name = model_name.strip()
break
provider_map = {provider.value: provider for provider in LLMProvider}
if provider_hint:
if provider_hint not in provider_map:
raise ValueError(
f"Unsupported provider '{provider_hint}'. Expected one of: {', '.join(provider_map.keys())}."
)
return provider_map[provider_hint], model_name
return _guess_provider_from_model(model_name), model_name
def _guess_provider_from_model(model_name: str) -> LLMProvider:
lowered = model_name.lower()
if lowered.startswith(("gpt", "o1", "text-", "dall-e", "whisper")):
return LLMProvider.OPENAI
if lowered.startswith(("claude", "anthropic")):
return LLMProvider.ANTHROPIC
if lowered.startswith(("gemini", "models/gemini")):
return LLMProvider.GEMINI
if lowered.startswith("qwen"):
return LLMProvider.QWEN
if lowered.startswith("deepseek"):
return LLMProvider.DEEPSEEK
if lowered.startswith(("openrouter", "router-")):
return LLMProvider.OPENROUTER
supported = ", ".join(provider.value for provider in LLMProvider)
raise ValueError(
f"Unable to infer provider from model '{model_name}'. "
f"Please prefix with 'provider:model'. Supported providers: {supported}."
)
def build_import_messages(
request: DataImportAnalysisRequest,
) -> List[LLMMessage]:
"""Create system and user messages for the import analysis prompt."""
headers_formatted = "\n".join(f"- {header}" for header in request.table_headers)
system_prompt = load_import_template()
data_block = (
f"导入记录ID: {request.import_record_id}\n\n"
"表头信息:\n"
f"{headers_formatted}\n\n"
"示例数据:\n"
f"{request.example_data}"
)
return [
LLMMessage(role=LLMRole.SYSTEM, content=system_prompt),
LLMMessage(role=LLMRole.USER, content=data_block),
]
@lru_cache(maxsize=1)
def load_import_template() -> str:
template_path = (
Path(__file__).resolve().parents[2] / "prompt" / "data_import_analysis.md"
)
if not template_path.exists():
raise FileNotFoundError(f"Prompt template not found at {template_path}")
return template_path.read_text(encoding="utf-8").strip()
def derive_headers(rows: Sequence[Any], provided_headers: Sequence[str] | None) -> List[str]:
if provided_headers:
return list(provided_headers)
seen: List[str] = []
for row in rows:
if isinstance(row, dict):
for key in row.keys():
if key not in seen:
seen.append(str(key))
return seen
def _stringify_cell(value: Any) -> str:
if value is None:
return ""
if isinstance(value, (str, int, float, bool)):
return str(value)
try:
return json.dumps(value, ensure_ascii=False)
except (TypeError, ValueError):
return str(value)
def rows_to_csv_bytes(rows: Sequence[Any], headers: List[str]) -> bytes:
buffer = StringIO()
writer = csv.writer(buffer)
if headers:
writer.writerow(headers)
for row in rows:
if isinstance(row, dict):
writer.writerow([_stringify_cell(row.get(header)) for header in headers])
elif isinstance(row, (list, tuple)):
writer.writerow([_stringify_cell(item) for item in row])
else:
writer.writerow([_stringify_cell(row)])
return buffer.getvalue().encode("utf-8")
def build_file_part(filename: str, payload: bytes, mime_type: str) -> Dict[str, Any]:
encoded = base64.b64encode(payload).decode("ascii")
return {
"type": "input_file",
"input_file": {
"file_data": {
"filename": filename,
"file_name": filename,
"mime_type": mime_type,
"b64_json": encoded,
}
},
}
def build_schema_part(
request: DataImportAnalysisJobRequest, headers: List[str]
) -> Dict[str, Any] | None:
if request.table_schema is not None:
if isinstance(request.table_schema, str):
schema_bytes = request.table_schema.encode("utf-8")
return build_file_part("table_schema.txt", schema_bytes, "text/plain")
try:
schema_serialised = json.dumps(
request.table_schema, ensure_ascii=False, indent=2
)
except (TypeError, ValueError) as exc:
logger.warning("Failed to serialise table_schema for %s: %s", request.import_record_id, exc)
schema_serialised = str(request.table_schema)
return build_file_part(
"table_schema.json",
schema_serialised.encode("utf-8"),
"application/json",
)
if headers:
headers_payload = json.dumps({"headers": headers}, ensure_ascii=False, indent=2)
return build_file_part(
"table_headers.json",
headers_payload.encode("utf-8"),
"application/json",
)
return None
def build_openai_input_payload(
request: DataImportAnalysisJobRequest,
) -> Dict[str, Any]:
headers = derive_headers(request.rows, request.headers)
csv_bytes = rows_to_csv_bytes(request.rows, headers)
csv_part = build_file_part("sample_rows.csv", csv_bytes, "text/csv")
schema_part = build_schema_part(request, headers)
prompt = load_import_template()
context_lines = [
f"导入记录ID: {request.import_record_id}",
f"样本数据行数: {len(request.rows)}",
"请参考附件 `sample_rows.csv` 获取原始样本数据。",
]
if schema_part:
context_lines.append("附加结构信息来自第二个附件,请结合使用。")
else:
context_lines.append("未提供表头或Schema请依据数据自行推断字段信息。")
user_content = [
{"type": "input_text", "text": "\n".join(context_lines)},
csv_part,
]
if schema_part:
user_content.append(schema_part)
payload: Dict[str, Any] = {
"model": request.llm_model,
"input": [
{"role": "system", "content": [{"type": "input_text", "text": prompt}]},
{"role": "user", "content": user_content},
],
}
if request.temperature is not None:
payload["temperature"] = request.temperature
if request.max_output_tokens is not None:
payload["max_output_tokens"] = request.max_output_tokens
return payload
def parse_openai_responses_payload(
data: Dict[str, Any], fallback_model: str
) -> LLMResponse:
output_blocks = data.get("output", [])
choices: List[LLMChoice] = []
for idx, block in enumerate(output_blocks):
if block.get("type") != "message":
continue
content_items = block.get("content", [])
text_fragments: List[str] = []
for item in content_items:
if item.get("type") == "output_text":
text_fragments.append(item.get("text", ""))
if not text_fragments and data.get("output_text"):
text_fragments.append(data.get("output_text", ""))
message = LLMMessage(role=LLMRole.ASSISTANT, content="\n".join(text_fragments))
choices.append(LLMChoice(index=idx, message=message))
if not choices and data.get("output_text"):
message = LLMMessage(role=LLMRole.ASSISTANT, content=data.get("output_text", ""))
choices.append(LLMChoice(index=0, message=message))
return LLMResponse(
provider=LLMProvider.OPENAI,
model=data.get("model", fallback_model),
choices=choices,
raw=data,
)
async def call_openai_import_analysis(
request: DataImportAnalysisJobRequest,
client: httpx.AsyncClient,
*,
api_key: str | None = None,
) -> LLMResponse:
openai_api_key = api_key or os.getenv("OPENAI_API_KEY")
if not openai_api_key:
raise ProviderAPICallError("OPENAI_API_KEY must be set to process import analysis.")
payload = build_openai_input_payload(request)
headers = {
"Authorization": f"Bearer {openai_api_key}",
"Content-Type": "application/json",
}
response = await client.post(OPENAI_RESPONSES_URL, json=payload, headers=headers)
try:
response.raise_for_status()
except httpx.HTTPError as exc:
raise ProviderAPICallError(f"OpenAI response API call failed: {exc}") from exc
data: Dict[str, Any] = response.json()
return parse_openai_responses_payload(data, request.llm_model)
async def dispatch_import_analysis_job(
request: DataImportAnalysisJobRequest,
client: httpx.AsyncClient,
*,
api_key: str | None = None,
) -> Dict[str, Any]:
logger.info("Starting import analysis job %s", request.import_record_id)
llm_response = await call_openai_import_analysis(request, client, api_key=api_key)
result = {
"import_record_id": request.import_record_id,
"status": "succeeded",
"llm_response": llm_response.model_dump(),
}
logger.info("Completed import analysis job %s", request.import_record_id)
return result
async def notify_import_analysis_callback(
callback_url: str,
payload: Dict[str, Any],
client: httpx.AsyncClient,
) -> None:
try:
response = await client.post(callback_url, json=payload)
response.raise_for_status()
except httpx.HTTPError as exc:
logger.error(
"Failed to deliver import analysis callback to %s: %s",
callback_url,
exc,
)
async def process_import_analysis_job(
request: DataImportAnalysisJobRequest,
client: httpx.AsyncClient,
*,
api_key: str | None = None,
) -> None:
try:
payload = await dispatch_import_analysis_job(
request,
client,
api_key=api_key,
)
except ProviderAPICallError as exc:
payload = {
"import_record_id": request.import_record_id,
"status": "failed",
"error": str(exc),
}
except Exception as exc: # pragma: no cover - defensive logging
logger.exception(
"Unexpected failure while processing import analysis job %s",
request.import_record_id,
)
payload = {
"import_record_id": request.import_record_id,
"status": "failed",
"error": str(exc),
}
await notify_import_analysis_callback(request.callback_url, payload, client)