From f43590585b8c91094b0f66bfd7506dcd5303d030 Mon Sep 17 00:00:00 2001 From: zhaoawd Date: Wed, 29 Oct 2025 22:35:29 +0800 Subject: [PATCH] =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=AF=BC=E5=85=A5schema?= =?UTF-8?q?=E5=88=86=E6=9E=90=E5=8A=9F=E8=83=BD=E6=8E=A5=E5=8F=A3=E5=92=8C?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/main.py | 45 ++--- app/models.py | 47 ++++- app/services/import_analysis.py | 279 ++++++++++++++++++++++++++- prompt/data_import_analysis.md | 4 +- test/data_import_analysis_example.py | 42 ++++ 5 files changed, 380 insertions(+), 37 deletions(-) create mode 100644 test/data_import_analysis_example.py diff --git a/app/main.py b/app/main.py index c261879..03cbcb5 100644 --- a/app/main.py +++ b/app/main.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from contextlib import asynccontextmanager import httpx @@ -7,13 +8,13 @@ from fastapi import Depends, FastAPI, HTTPException, Request from app.exceptions import ProviderAPICallError, ProviderConfigurationError from app.models import ( - DataImportAnalysisRequest, - DataImportAnalysisResponse, + DataImportAnalysisJobAck, + DataImportAnalysisJobRequest, LLMRequest, LLMResponse, ) from app.services import LLMGateway -from app.services.import_analysis import build_import_messages, resolve_provider_from_model +from app.services.import_analysis import process_import_analysis_job @asynccontextmanager @@ -54,40 +55,22 @@ def create_app() -> FastAPI: @application.post( "/v1/import/analyze", - response_model=DataImportAnalysisResponse, - summary="Analyze import sample data via configured LLM", + response_model=DataImportAnalysisJobAck, + summary="Schedule async import analysis and notify via callback", + status_code=200, ) async def analyze_import_data( - payload: DataImportAnalysisRequest, - gateway: LLMGateway = Depends(get_gateway), + payload: DataImportAnalysisJobRequest, client: httpx.AsyncClient = Depends(get_http_client), - ) -> DataImportAnalysisResponse: - try: - provider, model_name = resolve_provider_from_model(payload.llm_model) - except ValueError as exc: - raise HTTPException(status_code=422, detail=str(exc)) from exc + ) -> DataImportAnalysisJobAck: + request_copy = payload.model_copy(deep=True) - messages = build_import_messages(payload) + async def _runner() -> None: + await process_import_analysis_job(request_copy, client) - llm_request = LLMRequest( - provider=provider, - model=model_name, - messages=messages, - temperature=payload.temperature if payload.temperature is not None else 0.2, - max_tokens=payload.max_tokens, - ) + asyncio.create_task(_runner()) - try: - llm_response = await gateway.chat(llm_request, client) - except ProviderConfigurationError as exc: - raise HTTPException(status_code=422, detail=str(exc)) from exc - except ProviderAPICallError as exc: - raise HTTPException(status_code=502, detail=str(exc)) from exc - - return DataImportAnalysisResponse( - import_record_id=payload.import_record_id, - llm_response=llm_response, - ) + return DataImportAnalysisJobAck(import_record_id=payload.import_record_id, status="accepted") return application diff --git a/app/models.py b/app/models.py index 401df11..812df56 100644 --- a/app/models.py +++ b/app/models.py @@ -1,9 +1,9 @@ from __future__ import annotations from enum import Enum -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, HttpUrl class LLMRole(str, Enum): @@ -90,3 +90,46 @@ class DataImportAnalysisRequest(BaseModel): class DataImportAnalysisResponse(BaseModel): import_record_id: str llm_response: LLMResponse + + +class DataImportAnalysisJobRequest(BaseModel): + import_record_id: str = Field( + ..., description="Unique identifier for this import request run." + ) + rows: List[Union[Dict[str, Any], List[Any]]] = Field( + ..., + description="Sample rows from the import payload. Accepts list of dicts or list of lists.", + ) + headers: Optional[List[str]] = Field( + None, + description="Ordered list of table headers associated with the data sample.", + ) + raw_csv: Optional[str] = Field( + None, + description="Optional raw CSV representation of the sample rows, if already prepared.", + ) + table_schema: Optional[Any] = Field( + None, + description="Optional schema description for the table. Can be a string or JSON-serialisable structure.", + ) + callback_url: HttpUrl = Field( + ..., + description="URL to notify when the analysis completes. Receives JSON payload with status/results.", + ) + llm_model: str = Field( + "gpt-4.1-mini", + description="Target LLM model identifier. Defaults to gpt-4.1-mini.", + ) + temperature: Optional[float] = Field( + None, + description="Optional override for model temperature when generating analysis output.", + ) + max_output_tokens: Optional[int] = Field( + None, + description="Optional maximum number of tokens to generate in the analysis response.", + ) + + +class DataImportAnalysisJobAck(BaseModel): + import_record_id: str = Field(..., description="Echo of the import record identifier") + status: str = Field("accepted", description="Processing status acknowledgement.") diff --git a/app/services/import_analysis.py b/app/services/import_analysis.py index 5c92c01..e249066 100644 --- a/app/services/import_analysis.py +++ b/app/services/import_analysis.py @@ -1,16 +1,32 @@ from __future__ import annotations +import base64 +import csv +import json +import logging +import os from functools import lru_cache +from io import StringIO from pathlib import Path -from typing import List, Tuple +from typing import Any, Dict, List, Sequence, Tuple +import httpx + +from app.exceptions import ProviderAPICallError from app.models import ( + DataImportAnalysisJobRequest, DataImportAnalysisRequest, + LLMChoice, LLMMessage, LLMProvider, + LLMResponse, LLMRole, ) +logger = logging.getLogger(__name__) + +OPENAI_RESPONSES_URL = "https://api.openai.com/v1/responses" + def resolve_provider_from_model(llm_model: str) -> Tuple[LLMProvider, str]: """Resolve provider based on the llm_model string. @@ -95,3 +111,264 @@ def load_import_template() -> str: if not template_path.exists(): raise FileNotFoundError(f"Prompt template not found at {template_path}") return template_path.read_text(encoding="utf-8").strip() + + +def derive_headers(rows: Sequence[Any], provided_headers: Sequence[str] | None) -> List[str]: + if provided_headers: + return list(provided_headers) + + seen: List[str] = [] + for row in rows: + if isinstance(row, dict): + for key in row.keys(): + if key not in seen: + seen.append(str(key)) + return seen + + +def _stringify_cell(value: Any) -> str: + if value is None: + return "" + if isinstance(value, (str, int, float, bool)): + return str(value) + try: + return json.dumps(value, ensure_ascii=False) + except (TypeError, ValueError): + return str(value) + + +def rows_to_csv_bytes(rows: Sequence[Any], headers: List[str]) -> bytes: + buffer = StringIO() + writer = csv.writer(buffer) + + if headers: + writer.writerow(headers) + + for row in rows: + if isinstance(row, dict): + writer.writerow([_stringify_cell(row.get(header)) for header in headers]) + elif isinstance(row, (list, tuple)): + writer.writerow([_stringify_cell(item) for item in row]) + else: + writer.writerow([_stringify_cell(row)]) + + return buffer.getvalue().encode("utf-8") + + +def build_file_part(filename: str, payload: bytes, mime_type: str) -> Dict[str, Any]: + encoded = base64.b64encode(payload).decode("ascii") + return { + "type": "input_file", + "input_file": { + "file_data": { + "filename": filename, + "file_name": filename, + "mime_type": mime_type, + "b64_json": encoded, + } + }, + } + + +def build_schema_part( + request: DataImportAnalysisJobRequest, headers: List[str] +) -> Dict[str, Any] | None: + if request.table_schema is not None: + if isinstance(request.table_schema, str): + schema_bytes = request.table_schema.encode("utf-8") + return build_file_part("table_schema.txt", schema_bytes, "text/plain") + + try: + schema_serialised = json.dumps( + request.table_schema, ensure_ascii=False, indent=2 + ) + except (TypeError, ValueError) as exc: + logger.warning("Failed to serialise table_schema for %s: %s", request.import_record_id, exc) + schema_serialised = str(request.table_schema) + + return build_file_part( + "table_schema.json", + schema_serialised.encode("utf-8"), + "application/json", + ) + + if headers: + headers_payload = json.dumps({"headers": headers}, ensure_ascii=False, indent=2) + return build_file_part( + "table_headers.json", + headers_payload.encode("utf-8"), + "application/json", + ) + + return None + + +def build_openai_input_payload( + request: DataImportAnalysisJobRequest, +) -> Dict[str, Any]: + headers = derive_headers(request.rows, request.headers) + csv_bytes = rows_to_csv_bytes(request.rows, headers) + csv_part = build_file_part("sample_rows.csv", csv_bytes, "text/csv") + schema_part = build_schema_part(request, headers) + + prompt = load_import_template() + + context_lines = [ + f"导入记录ID: {request.import_record_id}", + f"样本数据行数: {len(request.rows)}", + "请参考附件 `sample_rows.csv` 获取原始样本数据。", + ] + + if schema_part: + context_lines.append("附加结构信息来自第二个附件,请结合使用。") + else: + context_lines.append("未提供表头或Schema,请依据数据自行推断字段信息。") + + user_content = [ + {"type": "input_text", "text": "\n".join(context_lines)}, + csv_part, + ] + + if schema_part: + user_content.append(schema_part) + + payload: Dict[str, Any] = { + "model": request.llm_model, + "input": [ + {"role": "system", "content": [{"type": "input_text", "text": prompt}]}, + {"role": "user", "content": user_content}, + ], + } + + if request.temperature is not None: + payload["temperature"] = request.temperature + if request.max_output_tokens is not None: + payload["max_output_tokens"] = request.max_output_tokens + + return payload + + +def parse_openai_responses_payload( + data: Dict[str, Any], fallback_model: str +) -> LLMResponse: + output_blocks = data.get("output", []) + choices: List[LLMChoice] = [] + + for idx, block in enumerate(output_blocks): + if block.get("type") != "message": + continue + content_items = block.get("content", []) + text_fragments: List[str] = [] + for item in content_items: + if item.get("type") == "output_text": + text_fragments.append(item.get("text", "")) + + if not text_fragments and data.get("output_text"): + text_fragments.append(data.get("output_text", "")) + + message = LLMMessage(role=LLMRole.ASSISTANT, content="\n".join(text_fragments)) + choices.append(LLMChoice(index=idx, message=message)) + + if not choices and data.get("output_text"): + message = LLMMessage(role=LLMRole.ASSISTANT, content=data.get("output_text", "")) + choices.append(LLMChoice(index=0, message=message)) + + return LLMResponse( + provider=LLMProvider.OPENAI, + model=data.get("model", fallback_model), + choices=choices, + raw=data, + ) + + +async def call_openai_import_analysis( + request: DataImportAnalysisJobRequest, + client: httpx.AsyncClient, + *, + api_key: str | None = None, +) -> LLMResponse: + openai_api_key = api_key or os.getenv("OPENAI_API_KEY") + if not openai_api_key: + raise ProviderAPICallError("OPENAI_API_KEY must be set to process import analysis.") + + payload = build_openai_input_payload(request) + headers = { + "Authorization": f"Bearer {openai_api_key}", + "Content-Type": "application/json", + } + + response = await client.post(OPENAI_RESPONSES_URL, json=payload, headers=headers) + try: + response.raise_for_status() + except httpx.HTTPError as exc: + raise ProviderAPICallError(f"OpenAI response API call failed: {exc}") from exc + + data: Dict[str, Any] = response.json() + return parse_openai_responses_payload(data, request.llm_model) + + +async def dispatch_import_analysis_job( + request: DataImportAnalysisJobRequest, + client: httpx.AsyncClient, + *, + api_key: str | None = None, +) -> Dict[str, Any]: + logger.info("Starting import analysis job %s", request.import_record_id) + llm_response = await call_openai_import_analysis(request, client, api_key=api_key) + + result = { + "import_record_id": request.import_record_id, + "status": "succeeded", + "llm_response": llm_response.model_dump(), + } + + logger.info("Completed import analysis job %s", request.import_record_id) + return result + + +async def notify_import_analysis_callback( + callback_url: str, + payload: Dict[str, Any], + client: httpx.AsyncClient, +) -> None: + try: + response = await client.post(callback_url, json=payload) + response.raise_for_status() + except httpx.HTTPError as exc: + logger.error( + "Failed to deliver import analysis callback to %s: %s", + callback_url, + exc, + ) + + +async def process_import_analysis_job( + request: DataImportAnalysisJobRequest, + client: httpx.AsyncClient, + *, + api_key: str | None = None, +) -> None: + try: + payload = await dispatch_import_analysis_job( + request, + client, + api_key=api_key, + ) + except ProviderAPICallError as exc: + payload = { + "import_record_id": request.import_record_id, + "status": "failed", + "error": str(exc), + } + except Exception as exc: # pragma: no cover - defensive logging + logger.exception( + "Unexpected failure while processing import analysis job %s", + request.import_record_id, + ) + payload = { + "import_record_id": request.import_record_id, + "status": "failed", + "error": str(exc), + } + + await notify_import_analysis_callback(request.callback_url, payload, client) diff --git a/prompt/data_import_analysis.md b/prompt/data_import_analysis.md index 73d8bff..a16aafc 100644 --- a/prompt/data_import_analysis.md +++ b/prompt/data_import_analysis.md @@ -2,7 +2,7 @@ 任务目标:对提供的数据(含表头或table schema与若干行样本数据)进行解析,生成一份导入分析与处理报告,指导如何将其导入为标准化表结构及 JSON 元数据定义,不要省略任何字段信息,全量输出。 -请从以下四个方向进行思考: +请从以下两个方向进行思考: 方向 1:元数据识别与整理 解析表明:根据表头、Origin Table Name、Orign File Name生成表名,表名需要有意义 @@ -40,5 +40,3 @@ 若信息不足,请显式指出“信息不足”并给出补充数据需求清单。 避免武断结论,用“可能 / 候选 / 建议”字样。 不要捏造样本未出现的值。 - -数据块 \ No newline at end of file diff --git a/test/data_import_analysis_example.py b/test/data_import_analysis_example.py new file mode 100644 index 0000000..dd894f2 --- /dev/null +++ b/test/data_import_analysis_example.py @@ -0,0 +1,42 @@ +"""Minimal example for hitting the /v1/import/analyze endpoint with Excel data.""" + +from __future__ import annotations + +import asyncio +from pathlib import Path + +import httpx +import pandas as pd + + +API_URL = "http://localhost:8000/v1/import/analyze" +CALLBACK_URL = "http://localhost:8000/__mock__/import-callback" +EXCEL_PATH = Path(__file__).resolve().parents[1] / "file" / "全国品牌.xlsx" + + +async def main() -> None: + excel = pd.ExcelFile(EXCEL_PATH) + sheet_name = excel.sheet_names[0] + df = excel.parse(sheet_name) + sampled = df.head(10) + + rows = sampled.to_dict(orient="records") + headers = [str(column) for column in sampled.columns] + + payload = { + "import_record_id": "demo-import-001", + "rows": rows, + "struce": headers, + "llm_model": "deepseek:deepseek-chat", + "temperature": 0.2, + "callback_url": CALLBACK_URL, + } + + async with httpx.AsyncClient(timeout=httpx.Timeout(15.0)) as client: + response = await client.post(API_URL, json=payload) + print("Status:", response.status_code) + print("Body:", response.json()) + + +if __name__ == "__main__": + asyncio.run(main())