数据导入schema分析功能接口和测试用例

This commit is contained in:
zhaoawd
2025-10-29 22:35:29 +08:00
parent 76b8c9d79b
commit f43590585b
5 changed files with 380 additions and 37 deletions

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import asyncio
from contextlib import asynccontextmanager
import httpx
@ -7,13 +8,13 @@ from fastapi import Depends, FastAPI, HTTPException, Request
from app.exceptions import ProviderAPICallError, ProviderConfigurationError
from app.models import (
DataImportAnalysisRequest,
DataImportAnalysisResponse,
DataImportAnalysisJobAck,
DataImportAnalysisJobRequest,
LLMRequest,
LLMResponse,
)
from app.services import LLMGateway
from app.services.import_analysis import build_import_messages, resolve_provider_from_model
from app.services.import_analysis import process_import_analysis_job
@asynccontextmanager
@ -54,40 +55,22 @@ def create_app() -> FastAPI:
@application.post(
"/v1/import/analyze",
response_model=DataImportAnalysisResponse,
summary="Analyze import sample data via configured LLM",
response_model=DataImportAnalysisJobAck,
summary="Schedule async import analysis and notify via callback",
status_code=200,
)
async def analyze_import_data(
payload: DataImportAnalysisRequest,
gateway: LLMGateway = Depends(get_gateway),
payload: DataImportAnalysisJobRequest,
client: httpx.AsyncClient = Depends(get_http_client),
) -> DataImportAnalysisResponse:
try:
provider, model_name = resolve_provider_from_model(payload.llm_model)
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
) -> DataImportAnalysisJobAck:
request_copy = payload.model_copy(deep=True)
messages = build_import_messages(payload)
async def _runner() -> None:
await process_import_analysis_job(request_copy, client)
llm_request = LLMRequest(
provider=provider,
model=model_name,
messages=messages,
temperature=payload.temperature if payload.temperature is not None else 0.2,
max_tokens=payload.max_tokens,
)
asyncio.create_task(_runner())
try:
llm_response = await gateway.chat(llm_request, client)
except ProviderConfigurationError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
except ProviderAPICallError as exc:
raise HTTPException(status_code=502, detail=str(exc)) from exc
return DataImportAnalysisResponse(
import_record_id=payload.import_record_id,
llm_response=llm_response,
)
return DataImportAnalysisJobAck(import_record_id=payload.import_record_id, status="accepted")
return application

View File

@ -1,9 +1,9 @@
from __future__ import annotations
from enum import Enum
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, HttpUrl
class LLMRole(str, Enum):
@ -90,3 +90,46 @@ class DataImportAnalysisRequest(BaseModel):
class DataImportAnalysisResponse(BaseModel):
import_record_id: str
llm_response: LLMResponse
class DataImportAnalysisJobRequest(BaseModel):
import_record_id: str = Field(
..., description="Unique identifier for this import request run."
)
rows: List[Union[Dict[str, Any], List[Any]]] = Field(
...,
description="Sample rows from the import payload. Accepts list of dicts or list of lists.",
)
headers: Optional[List[str]] = Field(
None,
description="Ordered list of table headers associated with the data sample.",
)
raw_csv: Optional[str] = Field(
None,
description="Optional raw CSV representation of the sample rows, if already prepared.",
)
table_schema: Optional[Any] = Field(
None,
description="Optional schema description for the table. Can be a string or JSON-serialisable structure.",
)
callback_url: HttpUrl = Field(
...,
description="URL to notify when the analysis completes. Receives JSON payload with status/results.",
)
llm_model: str = Field(
"gpt-4.1-mini",
description="Target LLM model identifier. Defaults to gpt-4.1-mini.",
)
temperature: Optional[float] = Field(
None,
description="Optional override for model temperature when generating analysis output.",
)
max_output_tokens: Optional[int] = Field(
None,
description="Optional maximum number of tokens to generate in the analysis response.",
)
class DataImportAnalysisJobAck(BaseModel):
import_record_id: str = Field(..., description="Echo of the import record identifier")
status: str = Field("accepted", description="Processing status acknowledgement.")

View File

@ -1,16 +1,32 @@
from __future__ import annotations
import base64
import csv
import json
import logging
import os
from functools import lru_cache
from io import StringIO
from pathlib import Path
from typing import List, Tuple
from typing import Any, Dict, List, Sequence, Tuple
import httpx
from app.exceptions import ProviderAPICallError
from app.models import (
DataImportAnalysisJobRequest,
DataImportAnalysisRequest,
LLMChoice,
LLMMessage,
LLMProvider,
LLMResponse,
LLMRole,
)
logger = logging.getLogger(__name__)
OPENAI_RESPONSES_URL = "https://api.openai.com/v1/responses"
def resolve_provider_from_model(llm_model: str) -> Tuple[LLMProvider, str]:
"""Resolve provider based on the llm_model string.
@ -95,3 +111,264 @@ def load_import_template() -> str:
if not template_path.exists():
raise FileNotFoundError(f"Prompt template not found at {template_path}")
return template_path.read_text(encoding="utf-8").strip()
def derive_headers(rows: Sequence[Any], provided_headers: Sequence[str] | None) -> List[str]:
if provided_headers:
return list(provided_headers)
seen: List[str] = []
for row in rows:
if isinstance(row, dict):
for key in row.keys():
if key not in seen:
seen.append(str(key))
return seen
def _stringify_cell(value: Any) -> str:
if value is None:
return ""
if isinstance(value, (str, int, float, bool)):
return str(value)
try:
return json.dumps(value, ensure_ascii=False)
except (TypeError, ValueError):
return str(value)
def rows_to_csv_bytes(rows: Sequence[Any], headers: List[str]) -> bytes:
buffer = StringIO()
writer = csv.writer(buffer)
if headers:
writer.writerow(headers)
for row in rows:
if isinstance(row, dict):
writer.writerow([_stringify_cell(row.get(header)) for header in headers])
elif isinstance(row, (list, tuple)):
writer.writerow([_stringify_cell(item) for item in row])
else:
writer.writerow([_stringify_cell(row)])
return buffer.getvalue().encode("utf-8")
def build_file_part(filename: str, payload: bytes, mime_type: str) -> Dict[str, Any]:
encoded = base64.b64encode(payload).decode("ascii")
return {
"type": "input_file",
"input_file": {
"file_data": {
"filename": filename,
"file_name": filename,
"mime_type": mime_type,
"b64_json": encoded,
}
},
}
def build_schema_part(
request: DataImportAnalysisJobRequest, headers: List[str]
) -> Dict[str, Any] | None:
if request.table_schema is not None:
if isinstance(request.table_schema, str):
schema_bytes = request.table_schema.encode("utf-8")
return build_file_part("table_schema.txt", schema_bytes, "text/plain")
try:
schema_serialised = json.dumps(
request.table_schema, ensure_ascii=False, indent=2
)
except (TypeError, ValueError) as exc:
logger.warning("Failed to serialise table_schema for %s: %s", request.import_record_id, exc)
schema_serialised = str(request.table_schema)
return build_file_part(
"table_schema.json",
schema_serialised.encode("utf-8"),
"application/json",
)
if headers:
headers_payload = json.dumps({"headers": headers}, ensure_ascii=False, indent=2)
return build_file_part(
"table_headers.json",
headers_payload.encode("utf-8"),
"application/json",
)
return None
def build_openai_input_payload(
request: DataImportAnalysisJobRequest,
) -> Dict[str, Any]:
headers = derive_headers(request.rows, request.headers)
csv_bytes = rows_to_csv_bytes(request.rows, headers)
csv_part = build_file_part("sample_rows.csv", csv_bytes, "text/csv")
schema_part = build_schema_part(request, headers)
prompt = load_import_template()
context_lines = [
f"导入记录ID: {request.import_record_id}",
f"样本数据行数: {len(request.rows)}",
"请参考附件 `sample_rows.csv` 获取原始样本数据。",
]
if schema_part:
context_lines.append("附加结构信息来自第二个附件,请结合使用。")
else:
context_lines.append("未提供表头或Schema请依据数据自行推断字段信息。")
user_content = [
{"type": "input_text", "text": "\n".join(context_lines)},
csv_part,
]
if schema_part:
user_content.append(schema_part)
payload: Dict[str, Any] = {
"model": request.llm_model,
"input": [
{"role": "system", "content": [{"type": "input_text", "text": prompt}]},
{"role": "user", "content": user_content},
],
}
if request.temperature is not None:
payload["temperature"] = request.temperature
if request.max_output_tokens is not None:
payload["max_output_tokens"] = request.max_output_tokens
return payload
def parse_openai_responses_payload(
data: Dict[str, Any], fallback_model: str
) -> LLMResponse:
output_blocks = data.get("output", [])
choices: List[LLMChoice] = []
for idx, block in enumerate(output_blocks):
if block.get("type") != "message":
continue
content_items = block.get("content", [])
text_fragments: List[str] = []
for item in content_items:
if item.get("type") == "output_text":
text_fragments.append(item.get("text", ""))
if not text_fragments and data.get("output_text"):
text_fragments.append(data.get("output_text", ""))
message = LLMMessage(role=LLMRole.ASSISTANT, content="\n".join(text_fragments))
choices.append(LLMChoice(index=idx, message=message))
if not choices and data.get("output_text"):
message = LLMMessage(role=LLMRole.ASSISTANT, content=data.get("output_text", ""))
choices.append(LLMChoice(index=0, message=message))
return LLMResponse(
provider=LLMProvider.OPENAI,
model=data.get("model", fallback_model),
choices=choices,
raw=data,
)
async def call_openai_import_analysis(
request: DataImportAnalysisJobRequest,
client: httpx.AsyncClient,
*,
api_key: str | None = None,
) -> LLMResponse:
openai_api_key = api_key or os.getenv("OPENAI_API_KEY")
if not openai_api_key:
raise ProviderAPICallError("OPENAI_API_KEY must be set to process import analysis.")
payload = build_openai_input_payload(request)
headers = {
"Authorization": f"Bearer {openai_api_key}",
"Content-Type": "application/json",
}
response = await client.post(OPENAI_RESPONSES_URL, json=payload, headers=headers)
try:
response.raise_for_status()
except httpx.HTTPError as exc:
raise ProviderAPICallError(f"OpenAI response API call failed: {exc}") from exc
data: Dict[str, Any] = response.json()
return parse_openai_responses_payload(data, request.llm_model)
async def dispatch_import_analysis_job(
request: DataImportAnalysisJobRequest,
client: httpx.AsyncClient,
*,
api_key: str | None = None,
) -> Dict[str, Any]:
logger.info("Starting import analysis job %s", request.import_record_id)
llm_response = await call_openai_import_analysis(request, client, api_key=api_key)
result = {
"import_record_id": request.import_record_id,
"status": "succeeded",
"llm_response": llm_response.model_dump(),
}
logger.info("Completed import analysis job %s", request.import_record_id)
return result
async def notify_import_analysis_callback(
callback_url: str,
payload: Dict[str, Any],
client: httpx.AsyncClient,
) -> None:
try:
response = await client.post(callback_url, json=payload)
response.raise_for_status()
except httpx.HTTPError as exc:
logger.error(
"Failed to deliver import analysis callback to %s: %s",
callback_url,
exc,
)
async def process_import_analysis_job(
request: DataImportAnalysisJobRequest,
client: httpx.AsyncClient,
*,
api_key: str | None = None,
) -> None:
try:
payload = await dispatch_import_analysis_job(
request,
client,
api_key=api_key,
)
except ProviderAPICallError as exc:
payload = {
"import_record_id": request.import_record_id,
"status": "failed",
"error": str(exc),
}
except Exception as exc: # pragma: no cover - defensive logging
logger.exception(
"Unexpected failure while processing import analysis job %s",
request.import_record_id,
)
payload = {
"import_record_id": request.import_record_id,
"status": "failed",
"error": str(exc),
}
await notify_import_analysis_callback(request.callback_url, payload, client)