375 lines
12 KiB
Python
375 lines
12 KiB
Python
from __future__ import annotations
|
||
|
||
import base64
|
||
import csv
|
||
import json
|
||
import logging
|
||
import os
|
||
from functools import lru_cache
|
||
from io import StringIO
|
||
from pathlib import Path
|
||
from typing import Any, Dict, List, Sequence, Tuple
|
||
|
||
import httpx
|
||
|
||
from app.exceptions import ProviderAPICallError
|
||
from app.models import (
|
||
DataImportAnalysisJobRequest,
|
||
DataImportAnalysisRequest,
|
||
LLMChoice,
|
||
LLMMessage,
|
||
LLMProvider,
|
||
LLMResponse,
|
||
LLMRole,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
OPENAI_RESPONSES_URL = "https://api.openai.com/v1/responses"
|
||
|
||
|
||
def resolve_provider_from_model(llm_model: str) -> Tuple[LLMProvider, str]:
|
||
"""Resolve provider based on the llm_model string.
|
||
|
||
The llm_model may be provided as 'provider:model' or 'provider/model'.
|
||
If no provider prefix is present, try an educated guess from common model name patterns.
|
||
"""
|
||
|
||
normalized = llm_model.strip()
|
||
provider_hint: str | None = None
|
||
model_name = normalized
|
||
|
||
for delimiter in (":", "/", "|"):
|
||
if delimiter in normalized:
|
||
provider_hint, model_name = normalized.split(delimiter, 1)
|
||
provider_hint = provider_hint.strip().lower()
|
||
model_name = model_name.strip()
|
||
break
|
||
|
||
provider_map = {provider.value: provider for provider in LLMProvider}
|
||
|
||
if provider_hint:
|
||
if provider_hint not in provider_map:
|
||
raise ValueError(
|
||
f"Unsupported provider '{provider_hint}'. Expected one of: {', '.join(provider_map.keys())}."
|
||
)
|
||
return provider_map[provider_hint], model_name
|
||
|
||
return _guess_provider_from_model(model_name), model_name
|
||
|
||
|
||
def _guess_provider_from_model(model_name: str) -> LLMProvider:
|
||
lowered = model_name.lower()
|
||
|
||
if lowered.startswith(("gpt", "o1", "text-", "dall-e", "whisper")):
|
||
return LLMProvider.OPENAI
|
||
if lowered.startswith(("claude", "anthropic")):
|
||
return LLMProvider.ANTHROPIC
|
||
if lowered.startswith(("gemini", "models/gemini")):
|
||
return LLMProvider.GEMINI
|
||
if lowered.startswith("qwen"):
|
||
return LLMProvider.QWEN
|
||
if lowered.startswith("deepseek"):
|
||
return LLMProvider.DEEPSEEK
|
||
if lowered.startswith(("openrouter", "router-")):
|
||
return LLMProvider.OPENROUTER
|
||
|
||
supported = ", ".join(provider.value for provider in LLMProvider)
|
||
raise ValueError(
|
||
f"Unable to infer provider from model '{model_name}'. "
|
||
f"Please prefix with 'provider:model'. Supported providers: {supported}."
|
||
)
|
||
|
||
|
||
def build_import_messages(
|
||
request: DataImportAnalysisRequest,
|
||
) -> List[LLMMessage]:
|
||
"""Create system and user messages for the import analysis prompt."""
|
||
headers_formatted = "\n".join(f"- {header}" for header in request.table_headers)
|
||
|
||
system_prompt = load_import_template()
|
||
|
||
data_block = (
|
||
f"导入记录ID: {request.import_record_id}\n\n"
|
||
"表头信息:\n"
|
||
f"{headers_formatted}\n\n"
|
||
"示例数据:\n"
|
||
f"{request.example_data}"
|
||
)
|
||
|
||
return [
|
||
LLMMessage(role=LLMRole.SYSTEM, content=system_prompt),
|
||
LLMMessage(role=LLMRole.USER, content=data_block),
|
||
]
|
||
|
||
|
||
@lru_cache(maxsize=1)
|
||
def load_import_template() -> str:
|
||
template_path = (
|
||
Path(__file__).resolve().parents[2] / "prompt" / "data_import_analysis.md"
|
||
)
|
||
if not template_path.exists():
|
||
raise FileNotFoundError(f"Prompt template not found at {template_path}")
|
||
return template_path.read_text(encoding="utf-8").strip()
|
||
|
||
|
||
def derive_headers(rows: Sequence[Any], provided_headers: Sequence[str] | None) -> List[str]:
|
||
if provided_headers:
|
||
return list(provided_headers)
|
||
|
||
seen: List[str] = []
|
||
for row in rows:
|
||
if isinstance(row, dict):
|
||
for key in row.keys():
|
||
if key not in seen:
|
||
seen.append(str(key))
|
||
return seen
|
||
|
||
|
||
def _stringify_cell(value: Any) -> str:
|
||
if value is None:
|
||
return ""
|
||
if isinstance(value, (str, int, float, bool)):
|
||
return str(value)
|
||
try:
|
||
return json.dumps(value, ensure_ascii=False)
|
||
except (TypeError, ValueError):
|
||
return str(value)
|
||
|
||
|
||
def rows_to_csv_bytes(rows: Sequence[Any], headers: List[str]) -> bytes:
|
||
buffer = StringIO()
|
||
writer = csv.writer(buffer)
|
||
|
||
if headers:
|
||
writer.writerow(headers)
|
||
|
||
for row in rows:
|
||
if isinstance(row, dict):
|
||
writer.writerow([_stringify_cell(row.get(header)) for header in headers])
|
||
elif isinstance(row, (list, tuple)):
|
||
writer.writerow([_stringify_cell(item) for item in row])
|
||
else:
|
||
writer.writerow([_stringify_cell(row)])
|
||
|
||
return buffer.getvalue().encode("utf-8")
|
||
|
||
|
||
def build_file_part(filename: str, payload: bytes, mime_type: str) -> Dict[str, Any]:
|
||
encoded = base64.b64encode(payload).decode("ascii")
|
||
return {
|
||
"type": "input_file",
|
||
"input_file": {
|
||
"file_data": {
|
||
"filename": filename,
|
||
"file_name": filename,
|
||
"mime_type": mime_type,
|
||
"b64_json": encoded,
|
||
}
|
||
},
|
||
}
|
||
|
||
|
||
def build_schema_part(
|
||
request: DataImportAnalysisJobRequest, headers: List[str]
|
||
) -> Dict[str, Any] | None:
|
||
if request.table_schema is not None:
|
||
if isinstance(request.table_schema, str):
|
||
schema_bytes = request.table_schema.encode("utf-8")
|
||
return build_file_part("table_schema.txt", schema_bytes, "text/plain")
|
||
|
||
try:
|
||
schema_serialised = json.dumps(
|
||
request.table_schema, ensure_ascii=False, indent=2
|
||
)
|
||
except (TypeError, ValueError) as exc:
|
||
logger.warning("Failed to serialise table_schema for %s: %s", request.import_record_id, exc)
|
||
schema_serialised = str(request.table_schema)
|
||
|
||
return build_file_part(
|
||
"table_schema.json",
|
||
schema_serialised.encode("utf-8"),
|
||
"application/json",
|
||
)
|
||
|
||
if headers:
|
||
headers_payload = json.dumps({"headers": headers}, ensure_ascii=False, indent=2)
|
||
return build_file_part(
|
||
"table_headers.json",
|
||
headers_payload.encode("utf-8"),
|
||
"application/json",
|
||
)
|
||
|
||
return None
|
||
|
||
|
||
def build_openai_input_payload(
|
||
request: DataImportAnalysisJobRequest,
|
||
) -> Dict[str, Any]:
|
||
headers = derive_headers(request.rows, request.headers)
|
||
csv_bytes = rows_to_csv_bytes(request.rows, headers)
|
||
csv_part = build_file_part("sample_rows.csv", csv_bytes, "text/csv")
|
||
schema_part = build_schema_part(request, headers)
|
||
|
||
prompt = load_import_template()
|
||
|
||
context_lines = [
|
||
f"导入记录ID: {request.import_record_id}",
|
||
f"样本数据行数: {len(request.rows)}",
|
||
"请参考附件 `sample_rows.csv` 获取原始样本数据。",
|
||
]
|
||
|
||
if schema_part:
|
||
context_lines.append("附加结构信息来自第二个附件,请结合使用。")
|
||
else:
|
||
context_lines.append("未提供表头或Schema,请依据数据自行推断字段信息。")
|
||
|
||
user_content = [
|
||
{"type": "input_text", "text": "\n".join(context_lines)},
|
||
csv_part,
|
||
]
|
||
|
||
if schema_part:
|
||
user_content.append(schema_part)
|
||
|
||
payload: Dict[str, Any] = {
|
||
"model": request.llm_model,
|
||
"input": [
|
||
{"role": "system", "content": [{"type": "input_text", "text": prompt}]},
|
||
{"role": "user", "content": user_content},
|
||
],
|
||
}
|
||
|
||
if request.temperature is not None:
|
||
payload["temperature"] = request.temperature
|
||
if request.max_output_tokens is not None:
|
||
payload["max_output_tokens"] = request.max_output_tokens
|
||
|
||
return payload
|
||
|
||
|
||
def parse_openai_responses_payload(
|
||
data: Dict[str, Any], fallback_model: str
|
||
) -> LLMResponse:
|
||
output_blocks = data.get("output", [])
|
||
choices: List[LLMChoice] = []
|
||
|
||
for idx, block in enumerate(output_blocks):
|
||
if block.get("type") != "message":
|
||
continue
|
||
content_items = block.get("content", [])
|
||
text_fragments: List[str] = []
|
||
for item in content_items:
|
||
if item.get("type") == "output_text":
|
||
text_fragments.append(item.get("text", ""))
|
||
|
||
if not text_fragments and data.get("output_text"):
|
||
text_fragments.append(data.get("output_text", ""))
|
||
|
||
message = LLMMessage(role=LLMRole.ASSISTANT, content="\n".join(text_fragments))
|
||
choices.append(LLMChoice(index=idx, message=message))
|
||
|
||
if not choices and data.get("output_text"):
|
||
message = LLMMessage(role=LLMRole.ASSISTANT, content=data.get("output_text", ""))
|
||
choices.append(LLMChoice(index=0, message=message))
|
||
|
||
return LLMResponse(
|
||
provider=LLMProvider.OPENAI,
|
||
model=data.get("model", fallback_model),
|
||
choices=choices,
|
||
raw=data,
|
||
)
|
||
|
||
|
||
async def call_openai_import_analysis(
|
||
request: DataImportAnalysisJobRequest,
|
||
client: httpx.AsyncClient,
|
||
*,
|
||
api_key: str | None = None,
|
||
) -> LLMResponse:
|
||
openai_api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||
if not openai_api_key:
|
||
raise ProviderAPICallError("OPENAI_API_KEY must be set to process import analysis.")
|
||
|
||
payload = build_openai_input_payload(request)
|
||
headers = {
|
||
"Authorization": f"Bearer {openai_api_key}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
|
||
response = await client.post(OPENAI_RESPONSES_URL, json=payload, headers=headers)
|
||
try:
|
||
response.raise_for_status()
|
||
except httpx.HTTPError as exc:
|
||
raise ProviderAPICallError(f"OpenAI response API call failed: {exc}") from exc
|
||
|
||
data: Dict[str, Any] = response.json()
|
||
return parse_openai_responses_payload(data, request.llm_model)
|
||
|
||
|
||
async def dispatch_import_analysis_job(
|
||
request: DataImportAnalysisJobRequest,
|
||
client: httpx.AsyncClient,
|
||
*,
|
||
api_key: str | None = None,
|
||
) -> Dict[str, Any]:
|
||
logger.info("Starting import analysis job %s", request.import_record_id)
|
||
llm_response = await call_openai_import_analysis(request, client, api_key=api_key)
|
||
|
||
result = {
|
||
"import_record_id": request.import_record_id,
|
||
"status": "succeeded",
|
||
"llm_response": llm_response.model_dump(),
|
||
}
|
||
|
||
logger.info("Completed import analysis job %s", request.import_record_id)
|
||
return result
|
||
|
||
|
||
async def notify_import_analysis_callback(
|
||
callback_url: str,
|
||
payload: Dict[str, Any],
|
||
client: httpx.AsyncClient,
|
||
) -> None:
|
||
try:
|
||
response = await client.post(callback_url, json=payload)
|
||
response.raise_for_status()
|
||
except httpx.HTTPError as exc:
|
||
logger.error(
|
||
"Failed to deliver import analysis callback to %s: %s",
|
||
callback_url,
|
||
exc,
|
||
)
|
||
|
||
|
||
async def process_import_analysis_job(
|
||
request: DataImportAnalysisJobRequest,
|
||
client: httpx.AsyncClient,
|
||
*,
|
||
api_key: str | None = None,
|
||
) -> None:
|
||
try:
|
||
payload = await dispatch_import_analysis_job(
|
||
request,
|
||
client,
|
||
api_key=api_key,
|
||
)
|
||
except ProviderAPICallError as exc:
|
||
payload = {
|
||
"import_record_id": request.import_record_id,
|
||
"status": "failed",
|
||
"error": str(exc),
|
||
}
|
||
except Exception as exc: # pragma: no cover - defensive logging
|
||
logger.exception(
|
||
"Unexpected failure while processing import analysis job %s",
|
||
request.import_record_id,
|
||
)
|
||
payload = {
|
||
"import_record_id": request.import_record_id,
|
||
"status": "failed",
|
||
"error": str(exc),
|
||
}
|
||
|
||
await notify_import_analysis_callback(request.callback_url, payload, client)
|