数据导入schema分析功能接口和测试用例

This commit is contained in:
zhaoawd
2025-10-29 22:35:29 +08:00
parent 76b8c9d79b
commit f43590585b
5 changed files with 380 additions and 37 deletions

View File

@ -1,16 +1,32 @@
from __future__ import annotations
import base64
import csv
import json
import logging
import os
from functools import lru_cache
from io import StringIO
from pathlib import Path
from typing import List, Tuple
from typing import Any, Dict, List, Sequence, Tuple
import httpx
from app.exceptions import ProviderAPICallError
from app.models import (
DataImportAnalysisJobRequest,
DataImportAnalysisRequest,
LLMChoice,
LLMMessage,
LLMProvider,
LLMResponse,
LLMRole,
)
logger = logging.getLogger(__name__)
OPENAI_RESPONSES_URL = "https://api.openai.com/v1/responses"
def resolve_provider_from_model(llm_model: str) -> Tuple[LLMProvider, str]:
"""Resolve provider based on the llm_model string.
@ -95,3 +111,264 @@ def load_import_template() -> str:
if not template_path.exists():
raise FileNotFoundError(f"Prompt template not found at {template_path}")
return template_path.read_text(encoding="utf-8").strip()
def derive_headers(rows: Sequence[Any], provided_headers: Sequence[str] | None) -> List[str]:
if provided_headers:
return list(provided_headers)
seen: List[str] = []
for row in rows:
if isinstance(row, dict):
for key in row.keys():
if key not in seen:
seen.append(str(key))
return seen
def _stringify_cell(value: Any) -> str:
if value is None:
return ""
if isinstance(value, (str, int, float, bool)):
return str(value)
try:
return json.dumps(value, ensure_ascii=False)
except (TypeError, ValueError):
return str(value)
def rows_to_csv_bytes(rows: Sequence[Any], headers: List[str]) -> bytes:
buffer = StringIO()
writer = csv.writer(buffer)
if headers:
writer.writerow(headers)
for row in rows:
if isinstance(row, dict):
writer.writerow([_stringify_cell(row.get(header)) for header in headers])
elif isinstance(row, (list, tuple)):
writer.writerow([_stringify_cell(item) for item in row])
else:
writer.writerow([_stringify_cell(row)])
return buffer.getvalue().encode("utf-8")
def build_file_part(filename: str, payload: bytes, mime_type: str) -> Dict[str, Any]:
encoded = base64.b64encode(payload).decode("ascii")
return {
"type": "input_file",
"input_file": {
"file_data": {
"filename": filename,
"file_name": filename,
"mime_type": mime_type,
"b64_json": encoded,
}
},
}
def build_schema_part(
request: DataImportAnalysisJobRequest, headers: List[str]
) -> Dict[str, Any] | None:
if request.table_schema is not None:
if isinstance(request.table_schema, str):
schema_bytes = request.table_schema.encode("utf-8")
return build_file_part("table_schema.txt", schema_bytes, "text/plain")
try:
schema_serialised = json.dumps(
request.table_schema, ensure_ascii=False, indent=2
)
except (TypeError, ValueError) as exc:
logger.warning("Failed to serialise table_schema for %s: %s", request.import_record_id, exc)
schema_serialised = str(request.table_schema)
return build_file_part(
"table_schema.json",
schema_serialised.encode("utf-8"),
"application/json",
)
if headers:
headers_payload = json.dumps({"headers": headers}, ensure_ascii=False, indent=2)
return build_file_part(
"table_headers.json",
headers_payload.encode("utf-8"),
"application/json",
)
return None
def build_openai_input_payload(
request: DataImportAnalysisJobRequest,
) -> Dict[str, Any]:
headers = derive_headers(request.rows, request.headers)
csv_bytes = rows_to_csv_bytes(request.rows, headers)
csv_part = build_file_part("sample_rows.csv", csv_bytes, "text/csv")
schema_part = build_schema_part(request, headers)
prompt = load_import_template()
context_lines = [
f"导入记录ID: {request.import_record_id}",
f"样本数据行数: {len(request.rows)}",
"请参考附件 `sample_rows.csv` 获取原始样本数据。",
]
if schema_part:
context_lines.append("附加结构信息来自第二个附件,请结合使用。")
else:
context_lines.append("未提供表头或Schema请依据数据自行推断字段信息。")
user_content = [
{"type": "input_text", "text": "\n".join(context_lines)},
csv_part,
]
if schema_part:
user_content.append(schema_part)
payload: Dict[str, Any] = {
"model": request.llm_model,
"input": [
{"role": "system", "content": [{"type": "input_text", "text": prompt}]},
{"role": "user", "content": user_content},
],
}
if request.temperature is not None:
payload["temperature"] = request.temperature
if request.max_output_tokens is not None:
payload["max_output_tokens"] = request.max_output_tokens
return payload
def parse_openai_responses_payload(
data: Dict[str, Any], fallback_model: str
) -> LLMResponse:
output_blocks = data.get("output", [])
choices: List[LLMChoice] = []
for idx, block in enumerate(output_blocks):
if block.get("type") != "message":
continue
content_items = block.get("content", [])
text_fragments: List[str] = []
for item in content_items:
if item.get("type") == "output_text":
text_fragments.append(item.get("text", ""))
if not text_fragments and data.get("output_text"):
text_fragments.append(data.get("output_text", ""))
message = LLMMessage(role=LLMRole.ASSISTANT, content="\n".join(text_fragments))
choices.append(LLMChoice(index=idx, message=message))
if not choices and data.get("output_text"):
message = LLMMessage(role=LLMRole.ASSISTANT, content=data.get("output_text", ""))
choices.append(LLMChoice(index=0, message=message))
return LLMResponse(
provider=LLMProvider.OPENAI,
model=data.get("model", fallback_model),
choices=choices,
raw=data,
)
async def call_openai_import_analysis(
request: DataImportAnalysisJobRequest,
client: httpx.AsyncClient,
*,
api_key: str | None = None,
) -> LLMResponse:
openai_api_key = api_key or os.getenv("OPENAI_API_KEY")
if not openai_api_key:
raise ProviderAPICallError("OPENAI_API_KEY must be set to process import analysis.")
payload = build_openai_input_payload(request)
headers = {
"Authorization": f"Bearer {openai_api_key}",
"Content-Type": "application/json",
}
response = await client.post(OPENAI_RESPONSES_URL, json=payload, headers=headers)
try:
response.raise_for_status()
except httpx.HTTPError as exc:
raise ProviderAPICallError(f"OpenAI response API call failed: {exc}") from exc
data: Dict[str, Any] = response.json()
return parse_openai_responses_payload(data, request.llm_model)
async def dispatch_import_analysis_job(
request: DataImportAnalysisJobRequest,
client: httpx.AsyncClient,
*,
api_key: str | None = None,
) -> Dict[str, Any]:
logger.info("Starting import analysis job %s", request.import_record_id)
llm_response = await call_openai_import_analysis(request, client, api_key=api_key)
result = {
"import_record_id": request.import_record_id,
"status": "succeeded",
"llm_response": llm_response.model_dump(),
}
logger.info("Completed import analysis job %s", request.import_record_id)
return result
async def notify_import_analysis_callback(
callback_url: str,
payload: Dict[str, Any],
client: httpx.AsyncClient,
) -> None:
try:
response = await client.post(callback_url, json=payload)
response.raise_for_status()
except httpx.HTTPError as exc:
logger.error(
"Failed to deliver import analysis callback to %s: %s",
callback_url,
exc,
)
async def process_import_analysis_job(
request: DataImportAnalysisJobRequest,
client: httpx.AsyncClient,
*,
api_key: str | None = None,
) -> None:
try:
payload = await dispatch_import_analysis_job(
request,
client,
api_key=api_key,
)
except ProviderAPICallError as exc:
payload = {
"import_record_id": request.import_record_id,
"status": "failed",
"error": str(exc),
}
except Exception as exc: # pragma: no cover - defensive logging
logger.exception(
"Unexpected failure while processing import analysis job %s",
request.import_record_id,
)
payload = {
"import_record_id": request.import_record_id,
"status": "failed",
"error": str(exc),
}
await notify_import_analysis_callback(request.callback_url, payload, client)