导入分析接口使用项目chat接口

This commit is contained in:
zhaoawd
2025-10-29 23:42:42 +08:00
parent f43590585b
commit 59c9efa5d8
2 changed files with 146 additions and 162 deletions

View File

@ -57,7 +57,7 @@ def create_app() -> FastAPI:
"/v1/import/analyze", "/v1/import/analyze",
response_model=DataImportAnalysisJobAck, response_model=DataImportAnalysisJobAck,
summary="Schedule async import analysis and notify via callback", summary="Schedule async import analysis and notify via callback",
status_code=200, status_code=202,
) )
async def analyze_import_data( async def analyze_import_data(
payload: DataImportAnalysisJobRequest, payload: DataImportAnalysisJobRequest,

View File

@ -1,6 +1,5 @@
from __future__ import annotations from __future__ import annotations
import base64
import csv import csv
import json import json
import logging import logging
@ -11,21 +10,26 @@ from pathlib import Path
from typing import Any, Dict, List, Sequence, Tuple from typing import Any, Dict, List, Sequence, Tuple
import httpx import httpx
from pydantic import ValidationError
from app.exceptions import ProviderAPICallError from app.exceptions import ProviderAPICallError
from app.models import ( from app.models import (
DataImportAnalysisJobRequest, DataImportAnalysisJobRequest,
DataImportAnalysisRequest, DataImportAnalysisRequest,
LLMChoice,
LLMMessage, LLMMessage,
LLMProvider, LLMProvider,
LLMResponse, LLMResponse,
LLMRole, LLMRole,
) )
from app.settings import DEFAULT_IMPORT_MODEL, get_supported_import_models
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
OPENAI_RESPONSES_URL = "https://api.openai.com/v1/responses" IMPORT_GATEWAY_BASE_URL = os.getenv(
"IMPORT_GATEWAY_BASE_URL", "http://localhost:8000"
)
SUPPORTED_IMPORT_MODELS = get_supported_import_models()
def resolve_provider_from_model(llm_model: str) -> Tuple[LLMProvider, str]: def resolve_provider_from_model(llm_model: str) -> Tuple[LLMProvider, str]:
@ -113,17 +117,32 @@ def load_import_template() -> str:
return template_path.read_text(encoding="utf-8").strip() return template_path.read_text(encoding="utf-8").strip()
def derive_headers(rows: Sequence[Any], provided_headers: Sequence[str] | None) -> List[str]: def derive_headers(
rows: Sequence[Any], provided_headers: Sequence[str] | None
) -> List[str]:
if provided_headers: if provided_headers:
return list(provided_headers) return [str(header) for header in provided_headers]
collected: List[str] = []
list_lengths: List[int] = []
seen: List[str] = []
for row in rows: for row in rows:
if isinstance(row, dict): if isinstance(row, dict):
for key in row.keys(): for key in row.keys():
if key not in seen: key_str = str(key)
seen.append(str(key)) if key_str not in collected:
return seen collected.append(key_str)
elif isinstance(row, (list, tuple)):
list_lengths.append(len(row))
if collected:
return collected
if list_lengths:
max_len = max(list_lengths)
return [f"column_{idx + 1}" for idx in range(max_len)]
return ["column_1"]
def _stringify_cell(value: Any) -> str: def _stringify_cell(value: Any) -> str:
@ -137,14 +156,18 @@ def _stringify_cell(value: Any) -> str:
return str(value) return str(value)
def rows_to_csv_bytes(rows: Sequence[Any], headers: List[str]) -> bytes: def rows_to_csv_text(
rows: Sequence[Any], headers: Sequence[str], *, max_rows: int = 50
) -> str:
buffer = StringIO() buffer = StringIO()
writer = csv.writer(buffer) writer = csv.writer(buffer)
if headers: if headers:
writer.writerow(headers) writer.writerow(headers)
for row in rows: for idx, row in enumerate(rows):
if max_rows and idx >= max_rows:
break
if isinstance(row, dict): if isinstance(row, dict):
writer.writerow([_stringify_cell(row.get(header)) for header in headers]) writer.writerow([_stringify_cell(row.get(header)) for header in headers])
elif isinstance(row, (list, tuple)): elif isinstance(row, (list, tuple)):
@ -152,192 +175,153 @@ def rows_to_csv_bytes(rows: Sequence[Any], headers: List[str]) -> bytes:
else: else:
writer.writerow([_stringify_cell(row)]) writer.writerow([_stringify_cell(row)])
return buffer.getvalue().encode("utf-8") return buffer.getvalue().strip()
def build_file_part(filename: str, payload: bytes, mime_type: str) -> Dict[str, Any]: def format_table_schema(schema: Any) -> str:
encoded = base64.b64encode(payload).decode("ascii") if schema is None:
return { return ""
"type": "input_file", if isinstance(schema, str):
"input_file": { return schema.strip()
"file_data": {
"filename": filename,
"file_name": filename,
"mime_type": mime_type,
"b64_json": encoded,
}
},
}
def build_schema_part(
request: DataImportAnalysisJobRequest, headers: List[str]
) -> Dict[str, Any] | None:
if request.table_schema is not None:
if isinstance(request.table_schema, str):
schema_bytes = request.table_schema.encode("utf-8")
return build_file_part("table_schema.txt", schema_bytes, "text/plain")
try: try:
schema_serialised = json.dumps( return json.dumps(schema, ensure_ascii=False, indent=2)
request.table_schema, ensure_ascii=False, indent=2 except (TypeError, ValueError):
) return str(schema)
except (TypeError, ValueError) as exc:
logger.warning("Failed to serialise table_schema for %s: %s", request.import_record_id, exc)
schema_serialised = str(request.table_schema)
return build_file_part(
"table_schema.json",
schema_serialised.encode("utf-8"),
"application/json",
)
if headers:
headers_payload = json.dumps({"headers": headers}, ensure_ascii=False, indent=2)
return build_file_part(
"table_headers.json",
headers_payload.encode("utf-8"),
"application/json",
)
return None
def build_openai_input_payload( def build_analysis_request(
request: DataImportAnalysisJobRequest, request: DataImportAnalysisJobRequest,
) -> Dict[str, Any]: ) -> DataImportAnalysisRequest:
headers = derive_headers(request.rows, request.headers) headers = derive_headers(request.rows, request.headers)
csv_bytes = rows_to_csv_bytes(request.rows, headers)
csv_part = build_file_part("sample_rows.csv", csv_bytes, "text/csv")
schema_part = build_schema_part(request, headers)
prompt = load_import_template() if request.raw_csv:
csv_text = request.raw_csv.strip()
context_lines = [
f"导入记录ID: {request.import_record_id}",
f"样本数据行数: {len(request.rows)}",
"请参考附件 `sample_rows.csv` 获取原始样本数据。",
]
if schema_part:
context_lines.append("附加结构信息来自第二个附件,请结合使用。")
else: else:
context_lines.append("未提供表头或Schema请依据数据自行推断字段信息。") csv_text = rows_to_csv_text(request.rows, headers)
user_content = [ sections: List[str] = []
{"type": "input_text", "text": "\n".join(context_lines)}, if csv_text:
csv_part, sections.append("CSV样本预览:\n" + csv_text)
] schema_text = format_table_schema(request.table_schema)
if schema_text:
sections.append("附加结构信息:\n" + schema_text)
if schema_part: example_data = "\n\n".join(sections) if sections else "未提供样本数据。"
user_content.append(schema_part)
max_length = 30_000
if len(example_data) > max_length:
example_data = example_data[: max_length - 3] + "..."
return DataImportAnalysisRequest(
import_record_id=request.import_record_id,
example_data=example_data,
table_headers=headers,
llm_model=request.llm_model or DEFAULT_IMPORT_MODEL,
)
def build_chat_payload(request: DataImportAnalysisJobRequest) -> Dict[str, Any]:
llm_input = request.llm_model or DEFAULT_IMPORT_MODEL
provider, model_name = resolve_provider_from_model(llm_input)
normalized_model = f"{provider.value}:{model_name}"
if SUPPORTED_IMPORT_MODELS and normalized_model not in SUPPORTED_IMPORT_MODELS:
raise ProviderAPICallError(
"Model '{model}' is not allowed. Allowed models: {allowed}".format(
model=normalized_model,
allowed=", ".join(sorted(SUPPORTED_IMPORT_MODELS)),
)
)
analysis_request = build_analysis_request(request)
messages = build_import_messages(analysis_request)
payload: Dict[str, Any] = { payload: Dict[str, Any] = {
"model": request.llm_model, "provider": provider.value,
"input": [ "model": model_name,
{"role": "system", "content": [{"type": "input_text", "text": prompt}]}, "messages": [message.model_dump() for message in messages],
{"role": "user", "content": user_content}, "temperature": request.temperature if request.temperature is not None else 0.2,
],
} }
if request.temperature is not None:
payload["temperature"] = request.temperature
if request.max_output_tokens is not None: if request.max_output_tokens is not None:
payload["max_output_tokens"] = request.max_output_tokens payload["max_tokens"] = request.max_output_tokens
return payload return payload
def parse_openai_responses_payload(
data: Dict[str, Any], fallback_model: str
) -> LLMResponse:
output_blocks = data.get("output", [])
choices: List[LLMChoice] = []
for idx, block in enumerate(output_blocks):
if block.get("type") != "message":
continue
content_items = block.get("content", [])
text_fragments: List[str] = []
for item in content_items:
if item.get("type") == "output_text":
text_fragments.append(item.get("text", ""))
if not text_fragments and data.get("output_text"):
text_fragments.append(data.get("output_text", ""))
message = LLMMessage(role=LLMRole.ASSISTANT, content="\n".join(text_fragments))
choices.append(LLMChoice(index=idx, message=message))
if not choices and data.get("output_text"):
message = LLMMessage(role=LLMRole.ASSISTANT, content=data.get("output_text", ""))
choices.append(LLMChoice(index=0, message=message))
return LLMResponse(
provider=LLMProvider.OPENAI,
model=data.get("model", fallback_model),
choices=choices,
raw=data,
)
async def call_openai_import_analysis(
request: DataImportAnalysisJobRequest,
client: httpx.AsyncClient,
*,
api_key: str | None = None,
) -> LLMResponse:
openai_api_key = api_key or os.getenv("OPENAI_API_KEY")
if not openai_api_key:
raise ProviderAPICallError("OPENAI_API_KEY must be set to process import analysis.")
payload = build_openai_input_payload(request)
headers = {
"Authorization": f"Bearer {openai_api_key}",
"Content-Type": "application/json",
}
response = await client.post(OPENAI_RESPONSES_URL, json=payload, headers=headers)
try:
response.raise_for_status()
except httpx.HTTPError as exc:
raise ProviderAPICallError(f"OpenAI response API call failed: {exc}") from exc
data: Dict[str, Any] = response.json()
return parse_openai_responses_payload(data, request.llm_model)
async def dispatch_import_analysis_job( async def dispatch_import_analysis_job(
request: DataImportAnalysisJobRequest, request: DataImportAnalysisJobRequest,
client: httpx.AsyncClient, client: httpx.AsyncClient,
*,
api_key: str | None = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
logger.info("Starting import analysis job %s", request.import_record_id) logger.info("Starting import analysis job %s", request.import_record_id)
llm_response = await call_openai_import_analysis(request, client, api_key=api_key)
result = { payload = build_chat_payload(request)
url = f"{IMPORT_GATEWAY_BASE_URL.rstrip('/')}/v1/chat/completions"
print(
f"[ImportAnalysis] Dispatching import {request.import_record_id} to {url}: "
f"{json.dumps(payload, ensure_ascii=False)}"
)
try:
response = await client.post(url, json=payload)
response.raise_for_status()
except httpx.HTTPStatusError as exc:
body_preview = ""
if exc.response is not None:
body_preview = exc.response.text[:400]
raise ProviderAPICallError(
f"Failed to invoke chat completions endpoint: {exc}. Response body: {body_preview}"
) from exc
except httpx.HTTPError as exc:
raise ProviderAPICallError(
f"HTTP call to chat completions endpoint failed: {exc}"
) from exc
try:
response_data = response.json()
except ValueError as exc:
raise ProviderAPICallError("Chat completions endpoint returned invalid JSON") from exc
print(
f"[ImportAnalysis] LLM HTTP status for {request.import_record_id}: "
f"{response.status_code}"
)
print(
f"[ImportAnalysis] LLM response for {request.import_record_id}: "
f"{json.dumps(response_data, ensure_ascii=False)}"
)
try:
llm_response = LLMResponse.model_validate(response_data)
except ValidationError as exc:
raise ProviderAPICallError(
"Chat completions endpoint returned unexpected schema"
) from exc
logger.info("Completed import analysis job %s", request.import_record_id)
return {
"import_record_id": request.import_record_id, "import_record_id": request.import_record_id,
"status": "succeeded", "status": "succeeded",
"llm_response": llm_response.model_dump(), "llm_response": llm_response.model_dump(),
} }
logger.info("Completed import analysis job %s", request.import_record_id)
return result
async def notify_import_analysis_callback( async def notify_import_analysis_callback(
callback_url: str, callback_url: str,
payload: Dict[str, Any], payload: Dict[str, Any],
client: httpx.AsyncClient, client: httpx.AsyncClient,
) -> None: ) -> None:
callback_target = str(callback_url)
try: try:
response = await client.post(callback_url, json=payload) response = await client.post(callback_target, json=payload)
response.raise_for_status() response.raise_for_status()
except httpx.HTTPError as exc: except httpx.HTTPError as exc:
logger.error( logger.error(
"Failed to deliver import analysis callback to %s: %s", "Failed to deliver import analysis callback to %s: %s",
callback_url, callback_target,
exc, exc,
) )
@ -345,16 +329,13 @@ async def notify_import_analysis_callback(
async def process_import_analysis_job( async def process_import_analysis_job(
request: DataImportAnalysisJobRequest, request: DataImportAnalysisJobRequest,
client: httpx.AsyncClient, client: httpx.AsyncClient,
*,
api_key: str | None = None,
) -> None: ) -> None:
try: try:
payload = await dispatch_import_analysis_job( payload = await dispatch_import_analysis_job(request, client)
request,
client,
api_key=api_key,
)
except ProviderAPICallError as exc: except ProviderAPICallError as exc:
print(
f"[ImportAnalysis] LLM call failed for {request.import_record_id}: {exc}"
)
payload = { payload = {
"import_record_id": request.import_record_id, "import_record_id": request.import_record_id,
"status": "failed", "status": "failed",
@ -365,6 +346,9 @@ async def process_import_analysis_job(
"Unexpected failure while processing import analysis job %s", "Unexpected failure while processing import analysis job %s",
request.import_record_id, request.import_record_id,
) )
print(
f"[ImportAnalysis] Unexpected error for {request.import_record_id}: {exc}"
)
payload = { payload = {
"import_record_id": request.import_record_id, "import_record_id": request.import_record_id,
"status": "failed", "status": "failed",