数据导入schema分析功能接口和测试用例
This commit is contained in:
@ -1,16 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import csv
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
from typing import Any, Dict, List, Sequence, Tuple
|
||||
|
||||
import httpx
|
||||
|
||||
from app.exceptions import ProviderAPICallError
|
||||
from app.models import (
|
||||
DataImportAnalysisJobRequest,
|
||||
DataImportAnalysisRequest,
|
||||
LLMChoice,
|
||||
LLMMessage,
|
||||
LLMProvider,
|
||||
LLMResponse,
|
||||
LLMRole,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENAI_RESPONSES_URL = "https://api.openai.com/v1/responses"
|
||||
|
||||
|
||||
def resolve_provider_from_model(llm_model: str) -> Tuple[LLMProvider, str]:
|
||||
"""Resolve provider based on the llm_model string.
|
||||
@ -95,3 +111,264 @@ def load_import_template() -> str:
|
||||
if not template_path.exists():
|
||||
raise FileNotFoundError(f"Prompt template not found at {template_path}")
|
||||
return template_path.read_text(encoding="utf-8").strip()
|
||||
|
||||
|
||||
def derive_headers(rows: Sequence[Any], provided_headers: Sequence[str] | None) -> List[str]:
|
||||
if provided_headers:
|
||||
return list(provided_headers)
|
||||
|
||||
seen: List[str] = []
|
||||
for row in rows:
|
||||
if isinstance(row, dict):
|
||||
for key in row.keys():
|
||||
if key not in seen:
|
||||
seen.append(str(key))
|
||||
return seen
|
||||
|
||||
|
||||
def _stringify_cell(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, (str, int, float, bool)):
|
||||
return str(value)
|
||||
try:
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
except (TypeError, ValueError):
|
||||
return str(value)
|
||||
|
||||
|
||||
def rows_to_csv_bytes(rows: Sequence[Any], headers: List[str]) -> bytes:
|
||||
buffer = StringIO()
|
||||
writer = csv.writer(buffer)
|
||||
|
||||
if headers:
|
||||
writer.writerow(headers)
|
||||
|
||||
for row in rows:
|
||||
if isinstance(row, dict):
|
||||
writer.writerow([_stringify_cell(row.get(header)) for header in headers])
|
||||
elif isinstance(row, (list, tuple)):
|
||||
writer.writerow([_stringify_cell(item) for item in row])
|
||||
else:
|
||||
writer.writerow([_stringify_cell(row)])
|
||||
|
||||
return buffer.getvalue().encode("utf-8")
|
||||
|
||||
|
||||
def build_file_part(filename: str, payload: bytes, mime_type: str) -> Dict[str, Any]:
|
||||
encoded = base64.b64encode(payload).decode("ascii")
|
||||
return {
|
||||
"type": "input_file",
|
||||
"input_file": {
|
||||
"file_data": {
|
||||
"filename": filename,
|
||||
"file_name": filename,
|
||||
"mime_type": mime_type,
|
||||
"b64_json": encoded,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def build_schema_part(
|
||||
request: DataImportAnalysisJobRequest, headers: List[str]
|
||||
) -> Dict[str, Any] | None:
|
||||
if request.table_schema is not None:
|
||||
if isinstance(request.table_schema, str):
|
||||
schema_bytes = request.table_schema.encode("utf-8")
|
||||
return build_file_part("table_schema.txt", schema_bytes, "text/plain")
|
||||
|
||||
try:
|
||||
schema_serialised = json.dumps(
|
||||
request.table_schema, ensure_ascii=False, indent=2
|
||||
)
|
||||
except (TypeError, ValueError) as exc:
|
||||
logger.warning("Failed to serialise table_schema for %s: %s", request.import_record_id, exc)
|
||||
schema_serialised = str(request.table_schema)
|
||||
|
||||
return build_file_part(
|
||||
"table_schema.json",
|
||||
schema_serialised.encode("utf-8"),
|
||||
"application/json",
|
||||
)
|
||||
|
||||
if headers:
|
||||
headers_payload = json.dumps({"headers": headers}, ensure_ascii=False, indent=2)
|
||||
return build_file_part(
|
||||
"table_headers.json",
|
||||
headers_payload.encode("utf-8"),
|
||||
"application/json",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def build_openai_input_payload(
|
||||
request: DataImportAnalysisJobRequest,
|
||||
) -> Dict[str, Any]:
|
||||
headers = derive_headers(request.rows, request.headers)
|
||||
csv_bytes = rows_to_csv_bytes(request.rows, headers)
|
||||
csv_part = build_file_part("sample_rows.csv", csv_bytes, "text/csv")
|
||||
schema_part = build_schema_part(request, headers)
|
||||
|
||||
prompt = load_import_template()
|
||||
|
||||
context_lines = [
|
||||
f"导入记录ID: {request.import_record_id}",
|
||||
f"样本数据行数: {len(request.rows)}",
|
||||
"请参考附件 `sample_rows.csv` 获取原始样本数据。",
|
||||
]
|
||||
|
||||
if schema_part:
|
||||
context_lines.append("附加结构信息来自第二个附件,请结合使用。")
|
||||
else:
|
||||
context_lines.append("未提供表头或Schema,请依据数据自行推断字段信息。")
|
||||
|
||||
user_content = [
|
||||
{"type": "input_text", "text": "\n".join(context_lines)},
|
||||
csv_part,
|
||||
]
|
||||
|
||||
if schema_part:
|
||||
user_content.append(schema_part)
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"model": request.llm_model,
|
||||
"input": [
|
||||
{"role": "system", "content": [{"type": "input_text", "text": prompt}]},
|
||||
{"role": "user", "content": user_content},
|
||||
],
|
||||
}
|
||||
|
||||
if request.temperature is not None:
|
||||
payload["temperature"] = request.temperature
|
||||
if request.max_output_tokens is not None:
|
||||
payload["max_output_tokens"] = request.max_output_tokens
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def parse_openai_responses_payload(
|
||||
data: Dict[str, Any], fallback_model: str
|
||||
) -> LLMResponse:
|
||||
output_blocks = data.get("output", [])
|
||||
choices: List[LLMChoice] = []
|
||||
|
||||
for idx, block in enumerate(output_blocks):
|
||||
if block.get("type") != "message":
|
||||
continue
|
||||
content_items = block.get("content", [])
|
||||
text_fragments: List[str] = []
|
||||
for item in content_items:
|
||||
if item.get("type") == "output_text":
|
||||
text_fragments.append(item.get("text", ""))
|
||||
|
||||
if not text_fragments and data.get("output_text"):
|
||||
text_fragments.append(data.get("output_text", ""))
|
||||
|
||||
message = LLMMessage(role=LLMRole.ASSISTANT, content="\n".join(text_fragments))
|
||||
choices.append(LLMChoice(index=idx, message=message))
|
||||
|
||||
if not choices and data.get("output_text"):
|
||||
message = LLMMessage(role=LLMRole.ASSISTANT, content=data.get("output_text", ""))
|
||||
choices.append(LLMChoice(index=0, message=message))
|
||||
|
||||
return LLMResponse(
|
||||
provider=LLMProvider.OPENAI,
|
||||
model=data.get("model", fallback_model),
|
||||
choices=choices,
|
||||
raw=data,
|
||||
)
|
||||
|
||||
|
||||
async def call_openai_import_analysis(
|
||||
request: DataImportAnalysisJobRequest,
|
||||
client: httpx.AsyncClient,
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
) -> LLMResponse:
|
||||
openai_api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
if not openai_api_key:
|
||||
raise ProviderAPICallError("OPENAI_API_KEY must be set to process import analysis.")
|
||||
|
||||
payload = build_openai_input_payload(request)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {openai_api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = await client.post(OPENAI_RESPONSES_URL, json=payload, headers=headers)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPError as exc:
|
||||
raise ProviderAPICallError(f"OpenAI response API call failed: {exc}") from exc
|
||||
|
||||
data: Dict[str, Any] = response.json()
|
||||
return parse_openai_responses_payload(data, request.llm_model)
|
||||
|
||||
|
||||
async def dispatch_import_analysis_job(
|
||||
request: DataImportAnalysisJobRequest,
|
||||
client: httpx.AsyncClient,
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
logger.info("Starting import analysis job %s", request.import_record_id)
|
||||
llm_response = await call_openai_import_analysis(request, client, api_key=api_key)
|
||||
|
||||
result = {
|
||||
"import_record_id": request.import_record_id,
|
||||
"status": "succeeded",
|
||||
"llm_response": llm_response.model_dump(),
|
||||
}
|
||||
|
||||
logger.info("Completed import analysis job %s", request.import_record_id)
|
||||
return result
|
||||
|
||||
|
||||
async def notify_import_analysis_callback(
|
||||
callback_url: str,
|
||||
payload: Dict[str, Any],
|
||||
client: httpx.AsyncClient,
|
||||
) -> None:
|
||||
try:
|
||||
response = await client.post(callback_url, json=payload)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPError as exc:
|
||||
logger.error(
|
||||
"Failed to deliver import analysis callback to %s: %s",
|
||||
callback_url,
|
||||
exc,
|
||||
)
|
||||
|
||||
|
||||
async def process_import_analysis_job(
|
||||
request: DataImportAnalysisJobRequest,
|
||||
client: httpx.AsyncClient,
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
) -> None:
|
||||
try:
|
||||
payload = await dispatch_import_analysis_job(
|
||||
request,
|
||||
client,
|
||||
api_key=api_key,
|
||||
)
|
||||
except ProviderAPICallError as exc:
|
||||
payload = {
|
||||
"import_record_id": request.import_record_id,
|
||||
"status": "failed",
|
||||
"error": str(exc),
|
||||
}
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.exception(
|
||||
"Unexpected failure while processing import analysis job %s",
|
||||
request.import_record_id,
|
||||
)
|
||||
payload = {
|
||||
"import_record_id": request.import_record_id,
|
||||
"status": "failed",
|
||||
"error": str(exc),
|
||||
}
|
||||
|
||||
await notify_import_analysis_callback(request.callback_url, payload, client)
|
||||
|
||||
Reference in New Issue
Block a user