导入分析接口使用项目chat接口

This commit is contained in:
zhaoawd
2025-10-29 23:42:42 +08:00
parent f43590585b
commit 59c9efa5d8
2 changed files with 146 additions and 162 deletions

View File

@ -57,7 +57,7 @@ def create_app() -> FastAPI:
"/v1/import/analyze",
response_model=DataImportAnalysisJobAck,
summary="Schedule async import analysis and notify via callback",
status_code=200,
status_code=202,
)
async def analyze_import_data(
payload: DataImportAnalysisJobRequest,

View File

@ -1,6 +1,5 @@
from __future__ import annotations
import base64
import csv
import json
import logging
@ -11,21 +10,26 @@ from pathlib import Path
from typing import Any, Dict, List, Sequence, Tuple
import httpx
from pydantic import ValidationError
from app.exceptions import ProviderAPICallError
from app.models import (
DataImportAnalysisJobRequest,
DataImportAnalysisRequest,
LLMChoice,
LLMMessage,
LLMProvider,
LLMResponse,
LLMRole,
)
from app.settings import DEFAULT_IMPORT_MODEL, get_supported_import_models
logger = logging.getLogger(__name__)
OPENAI_RESPONSES_URL = "https://api.openai.com/v1/responses"
IMPORT_GATEWAY_BASE_URL = os.getenv(
"IMPORT_GATEWAY_BASE_URL", "http://localhost:8000"
)
SUPPORTED_IMPORT_MODELS = get_supported_import_models()
def resolve_provider_from_model(llm_model: str) -> Tuple[LLMProvider, str]:
@ -113,17 +117,32 @@ def load_import_template() -> str:
return template_path.read_text(encoding="utf-8").strip()
def derive_headers(rows: Sequence[Any], provided_headers: Sequence[str] | None) -> List[str]:
def derive_headers(
rows: Sequence[Any], provided_headers: Sequence[str] | None
) -> List[str]:
if provided_headers:
return list(provided_headers)
return [str(header) for header in provided_headers]
collected: List[str] = []
list_lengths: List[int] = []
seen: List[str] = []
for row in rows:
if isinstance(row, dict):
for key in row.keys():
if key not in seen:
seen.append(str(key))
return seen
key_str = str(key)
if key_str not in collected:
collected.append(key_str)
elif isinstance(row, (list, tuple)):
list_lengths.append(len(row))
if collected:
return collected
if list_lengths:
max_len = max(list_lengths)
return [f"column_{idx + 1}" for idx in range(max_len)]
return ["column_1"]
def _stringify_cell(value: Any) -> str:
@ -137,14 +156,18 @@ def _stringify_cell(value: Any) -> str:
return str(value)
def rows_to_csv_bytes(rows: Sequence[Any], headers: List[str]) -> bytes:
def rows_to_csv_text(
rows: Sequence[Any], headers: Sequence[str], *, max_rows: int = 50
) -> str:
buffer = StringIO()
writer = csv.writer(buffer)
if headers:
writer.writerow(headers)
for row in rows:
for idx, row in enumerate(rows):
if max_rows and idx >= max_rows:
break
if isinstance(row, dict):
writer.writerow([_stringify_cell(row.get(header)) for header in headers])
elif isinstance(row, (list, tuple)):
@ -152,192 +175,153 @@ def rows_to_csv_bytes(rows: Sequence[Any], headers: List[str]) -> bytes:
else:
writer.writerow([_stringify_cell(row)])
return buffer.getvalue().encode("utf-8")
return buffer.getvalue().strip()
def build_file_part(filename: str, payload: bytes, mime_type: str) -> Dict[str, Any]:
encoded = base64.b64encode(payload).decode("ascii")
return {
"type": "input_file",
"input_file": {
"file_data": {
"filename": filename,
"file_name": filename,
"mime_type": mime_type,
"b64_json": encoded,
}
},
}
def format_table_schema(schema: Any) -> str:
if schema is None:
return ""
if isinstance(schema, str):
return schema.strip()
try:
return json.dumps(schema, ensure_ascii=False, indent=2)
except (TypeError, ValueError):
return str(schema)
def build_schema_part(
request: DataImportAnalysisJobRequest, headers: List[str]
) -> Dict[str, Any] | None:
if request.table_schema is not None:
if isinstance(request.table_schema, str):
schema_bytes = request.table_schema.encode("utf-8")
return build_file_part("table_schema.txt", schema_bytes, "text/plain")
try:
schema_serialised = json.dumps(
request.table_schema, ensure_ascii=False, indent=2
)
except (TypeError, ValueError) as exc:
logger.warning("Failed to serialise table_schema for %s: %s", request.import_record_id, exc)
schema_serialised = str(request.table_schema)
return build_file_part(
"table_schema.json",
schema_serialised.encode("utf-8"),
"application/json",
)
if headers:
headers_payload = json.dumps({"headers": headers}, ensure_ascii=False, indent=2)
return build_file_part(
"table_headers.json",
headers_payload.encode("utf-8"),
"application/json",
)
return None
def build_openai_input_payload(
def build_analysis_request(
request: DataImportAnalysisJobRequest,
) -> Dict[str, Any]:
) -> DataImportAnalysisRequest:
headers = derive_headers(request.rows, request.headers)
csv_bytes = rows_to_csv_bytes(request.rows, headers)
csv_part = build_file_part("sample_rows.csv", csv_bytes, "text/csv")
schema_part = build_schema_part(request, headers)
prompt = load_import_template()
context_lines = [
f"导入记录ID: {request.import_record_id}",
f"样本数据行数: {len(request.rows)}",
"请参考附件 `sample_rows.csv` 获取原始样本数据。",
]
if schema_part:
context_lines.append("附加结构信息来自第二个附件,请结合使用。")
if request.raw_csv:
csv_text = request.raw_csv.strip()
else:
context_lines.append("未提供表头或Schema请依据数据自行推断字段信息。")
csv_text = rows_to_csv_text(request.rows, headers)
user_content = [
{"type": "input_text", "text": "\n".join(context_lines)},
csv_part,
]
sections: List[str] = []
if csv_text:
sections.append("CSV样本预览:\n" + csv_text)
schema_text = format_table_schema(request.table_schema)
if schema_text:
sections.append("附加结构信息:\n" + schema_text)
if schema_part:
user_content.append(schema_part)
example_data = "\n\n".join(sections) if sections else "未提供样本数据。"
payload: Dict[str, Any] = {
"model": request.llm_model,
"input": [
{"role": "system", "content": [{"type": "input_text", "text": prompt}]},
{"role": "user", "content": user_content},
],
}
max_length = 30_000
if len(example_data) > max_length:
example_data = example_data[: max_length - 3] + "..."
if request.temperature is not None:
payload["temperature"] = request.temperature
if request.max_output_tokens is not None:
payload["max_output_tokens"] = request.max_output_tokens
return payload
def parse_openai_responses_payload(
data: Dict[str, Any], fallback_model: str
) -> LLMResponse:
output_blocks = data.get("output", [])
choices: List[LLMChoice] = []
for idx, block in enumerate(output_blocks):
if block.get("type") != "message":
continue
content_items = block.get("content", [])
text_fragments: List[str] = []
for item in content_items:
if item.get("type") == "output_text":
text_fragments.append(item.get("text", ""))
if not text_fragments and data.get("output_text"):
text_fragments.append(data.get("output_text", ""))
message = LLMMessage(role=LLMRole.ASSISTANT, content="\n".join(text_fragments))
choices.append(LLMChoice(index=idx, message=message))
if not choices and data.get("output_text"):
message = LLMMessage(role=LLMRole.ASSISTANT, content=data.get("output_text", ""))
choices.append(LLMChoice(index=0, message=message))
return LLMResponse(
provider=LLMProvider.OPENAI,
model=data.get("model", fallback_model),
choices=choices,
raw=data,
return DataImportAnalysisRequest(
import_record_id=request.import_record_id,
example_data=example_data,
table_headers=headers,
llm_model=request.llm_model or DEFAULT_IMPORT_MODEL,
)
async def call_openai_import_analysis(
request: DataImportAnalysisJobRequest,
client: httpx.AsyncClient,
*,
api_key: str | None = None,
) -> LLMResponse:
openai_api_key = api_key or os.getenv("OPENAI_API_KEY")
if not openai_api_key:
raise ProviderAPICallError("OPENAI_API_KEY must be set to process import analysis.")
def build_chat_payload(request: DataImportAnalysisJobRequest) -> Dict[str, Any]:
llm_input = request.llm_model or DEFAULT_IMPORT_MODEL
provider, model_name = resolve_provider_from_model(llm_input)
normalized_model = f"{provider.value}:{model_name}"
payload = build_openai_input_payload(request)
headers = {
"Authorization": f"Bearer {openai_api_key}",
"Content-Type": "application/json",
if SUPPORTED_IMPORT_MODELS and normalized_model not in SUPPORTED_IMPORT_MODELS:
raise ProviderAPICallError(
"Model '{model}' is not allowed. Allowed models: {allowed}".format(
model=normalized_model,
allowed=", ".join(sorted(SUPPORTED_IMPORT_MODELS)),
)
)
analysis_request = build_analysis_request(request)
messages = build_import_messages(analysis_request)
payload: Dict[str, Any] = {
"provider": provider.value,
"model": model_name,
"messages": [message.model_dump() for message in messages],
"temperature": request.temperature if request.temperature is not None else 0.2,
}
response = await client.post(OPENAI_RESPONSES_URL, json=payload, headers=headers)
try:
response.raise_for_status()
except httpx.HTTPError as exc:
raise ProviderAPICallError(f"OpenAI response API call failed: {exc}") from exc
if request.max_output_tokens is not None:
payload["max_tokens"] = request.max_output_tokens
data: Dict[str, Any] = response.json()
return parse_openai_responses_payload(data, request.llm_model)
return payload
async def dispatch_import_analysis_job(
request: DataImportAnalysisJobRequest,
client: httpx.AsyncClient,
*,
api_key: str | None = None,
) -> Dict[str, Any]:
logger.info("Starting import analysis job %s", request.import_record_id)
llm_response = await call_openai_import_analysis(request, client, api_key=api_key)
result = {
payload = build_chat_payload(request)
url = f"{IMPORT_GATEWAY_BASE_URL.rstrip('/')}/v1/chat/completions"
print(
f"[ImportAnalysis] Dispatching import {request.import_record_id} to {url}: "
f"{json.dumps(payload, ensure_ascii=False)}"
)
try:
response = await client.post(url, json=payload)
response.raise_for_status()
except httpx.HTTPStatusError as exc:
body_preview = ""
if exc.response is not None:
body_preview = exc.response.text[:400]
raise ProviderAPICallError(
f"Failed to invoke chat completions endpoint: {exc}. Response body: {body_preview}"
) from exc
except httpx.HTTPError as exc:
raise ProviderAPICallError(
f"HTTP call to chat completions endpoint failed: {exc}"
) from exc
try:
response_data = response.json()
except ValueError as exc:
raise ProviderAPICallError("Chat completions endpoint returned invalid JSON") from exc
print(
f"[ImportAnalysis] LLM HTTP status for {request.import_record_id}: "
f"{response.status_code}"
)
print(
f"[ImportAnalysis] LLM response for {request.import_record_id}: "
f"{json.dumps(response_data, ensure_ascii=False)}"
)
try:
llm_response = LLMResponse.model_validate(response_data)
except ValidationError as exc:
raise ProviderAPICallError(
"Chat completions endpoint returned unexpected schema"
) from exc
logger.info("Completed import analysis job %s", request.import_record_id)
return {
"import_record_id": request.import_record_id,
"status": "succeeded",
"llm_response": llm_response.model_dump(),
}
logger.info("Completed import analysis job %s", request.import_record_id)
return result
async def notify_import_analysis_callback(
callback_url: str,
payload: Dict[str, Any],
client: httpx.AsyncClient,
) -> None:
callback_target = str(callback_url)
try:
response = await client.post(callback_url, json=payload)
response = await client.post(callback_target, json=payload)
response.raise_for_status()
except httpx.HTTPError as exc:
logger.error(
"Failed to deliver import analysis callback to %s: %s",
callback_url,
callback_target,
exc,
)
@ -345,16 +329,13 @@ async def notify_import_analysis_callback(
async def process_import_analysis_job(
request: DataImportAnalysisJobRequest,
client: httpx.AsyncClient,
*,
api_key: str | None = None,
) -> None:
try:
payload = await dispatch_import_analysis_job(
request,
client,
api_key=api_key,
)
payload = await dispatch_import_analysis_job(request, client)
except ProviderAPICallError as exc:
print(
f"[ImportAnalysis] LLM call failed for {request.import_record_id}: {exc}"
)
payload = {
"import_record_id": request.import_record_id,
"status": "failed",
@ -365,6 +346,9 @@ async def process_import_analysis_job(
"Unexpected failure while processing import analysis job %s",
request.import_record_id,
)
print(
f"[ImportAnalysis] Unexpected error for {request.import_record_id}: {exc}"
)
payload = {
"import_record_id": request.import_record_id,
"status": "failed",