导入分析接口使用项目chat接口
This commit is contained in:
@ -57,7 +57,7 @@ def create_app() -> FastAPI:
|
||||
"/v1/import/analyze",
|
||||
response_model=DataImportAnalysisJobAck,
|
||||
summary="Schedule async import analysis and notify via callback",
|
||||
status_code=200,
|
||||
status_code=202,
|
||||
)
|
||||
async def analyze_import_data(
|
||||
payload: DataImportAnalysisJobRequest,
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import csv
|
||||
import json
|
||||
import logging
|
||||
@ -11,21 +10,26 @@ from pathlib import Path
|
||||
from typing import Any, Dict, List, Sequence, Tuple
|
||||
|
||||
import httpx
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.exceptions import ProviderAPICallError
|
||||
from app.models import (
|
||||
DataImportAnalysisJobRequest,
|
||||
DataImportAnalysisRequest,
|
||||
LLMChoice,
|
||||
LLMMessage,
|
||||
LLMProvider,
|
||||
LLMResponse,
|
||||
LLMRole,
|
||||
)
|
||||
from app.settings import DEFAULT_IMPORT_MODEL, get_supported_import_models
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENAI_RESPONSES_URL = "https://api.openai.com/v1/responses"
|
||||
IMPORT_GATEWAY_BASE_URL = os.getenv(
|
||||
"IMPORT_GATEWAY_BASE_URL", "http://localhost:8000"
|
||||
)
|
||||
|
||||
SUPPORTED_IMPORT_MODELS = get_supported_import_models()
|
||||
|
||||
|
||||
def resolve_provider_from_model(llm_model: str) -> Tuple[LLMProvider, str]:
|
||||
@ -113,17 +117,32 @@ def load_import_template() -> str:
|
||||
return template_path.read_text(encoding="utf-8").strip()
|
||||
|
||||
|
||||
def derive_headers(rows: Sequence[Any], provided_headers: Sequence[str] | None) -> List[str]:
|
||||
def derive_headers(
|
||||
rows: Sequence[Any], provided_headers: Sequence[str] | None
|
||||
) -> List[str]:
|
||||
if provided_headers:
|
||||
return list(provided_headers)
|
||||
return [str(header) for header in provided_headers]
|
||||
|
||||
collected: List[str] = []
|
||||
list_lengths: List[int] = []
|
||||
|
||||
seen: List[str] = []
|
||||
for row in rows:
|
||||
if isinstance(row, dict):
|
||||
for key in row.keys():
|
||||
if key not in seen:
|
||||
seen.append(str(key))
|
||||
return seen
|
||||
key_str = str(key)
|
||||
if key_str not in collected:
|
||||
collected.append(key_str)
|
||||
elif isinstance(row, (list, tuple)):
|
||||
list_lengths.append(len(row))
|
||||
|
||||
if collected:
|
||||
return collected
|
||||
|
||||
if list_lengths:
|
||||
max_len = max(list_lengths)
|
||||
return [f"column_{idx + 1}" for idx in range(max_len)]
|
||||
|
||||
return ["column_1"]
|
||||
|
||||
|
||||
def _stringify_cell(value: Any) -> str:
|
||||
@ -137,14 +156,18 @@ def _stringify_cell(value: Any) -> str:
|
||||
return str(value)
|
||||
|
||||
|
||||
def rows_to_csv_bytes(rows: Sequence[Any], headers: List[str]) -> bytes:
|
||||
def rows_to_csv_text(
|
||||
rows: Sequence[Any], headers: Sequence[str], *, max_rows: int = 50
|
||||
) -> str:
|
||||
buffer = StringIO()
|
||||
writer = csv.writer(buffer)
|
||||
|
||||
if headers:
|
||||
writer.writerow(headers)
|
||||
|
||||
for row in rows:
|
||||
for idx, row in enumerate(rows):
|
||||
if max_rows and idx >= max_rows:
|
||||
break
|
||||
if isinstance(row, dict):
|
||||
writer.writerow([_stringify_cell(row.get(header)) for header in headers])
|
||||
elif isinstance(row, (list, tuple)):
|
||||
@ -152,192 +175,153 @@ def rows_to_csv_bytes(rows: Sequence[Any], headers: List[str]) -> bytes:
|
||||
else:
|
||||
writer.writerow([_stringify_cell(row)])
|
||||
|
||||
return buffer.getvalue().encode("utf-8")
|
||||
return buffer.getvalue().strip()
|
||||
|
||||
|
||||
def build_file_part(filename: str, payload: bytes, mime_type: str) -> Dict[str, Any]:
|
||||
encoded = base64.b64encode(payload).decode("ascii")
|
||||
return {
|
||||
"type": "input_file",
|
||||
"input_file": {
|
||||
"file_data": {
|
||||
"filename": filename,
|
||||
"file_name": filename,
|
||||
"mime_type": mime_type,
|
||||
"b64_json": encoded,
|
||||
}
|
||||
},
|
||||
}
|
||||
def format_table_schema(schema: Any) -> str:
|
||||
if schema is None:
|
||||
return ""
|
||||
if isinstance(schema, str):
|
||||
return schema.strip()
|
||||
try:
|
||||
return json.dumps(schema, ensure_ascii=False, indent=2)
|
||||
except (TypeError, ValueError):
|
||||
return str(schema)
|
||||
|
||||
|
||||
def build_schema_part(
|
||||
request: DataImportAnalysisJobRequest, headers: List[str]
|
||||
) -> Dict[str, Any] | None:
|
||||
if request.table_schema is not None:
|
||||
if isinstance(request.table_schema, str):
|
||||
schema_bytes = request.table_schema.encode("utf-8")
|
||||
return build_file_part("table_schema.txt", schema_bytes, "text/plain")
|
||||
|
||||
try:
|
||||
schema_serialised = json.dumps(
|
||||
request.table_schema, ensure_ascii=False, indent=2
|
||||
)
|
||||
except (TypeError, ValueError) as exc:
|
||||
logger.warning("Failed to serialise table_schema for %s: %s", request.import_record_id, exc)
|
||||
schema_serialised = str(request.table_schema)
|
||||
|
||||
return build_file_part(
|
||||
"table_schema.json",
|
||||
schema_serialised.encode("utf-8"),
|
||||
"application/json",
|
||||
)
|
||||
|
||||
if headers:
|
||||
headers_payload = json.dumps({"headers": headers}, ensure_ascii=False, indent=2)
|
||||
return build_file_part(
|
||||
"table_headers.json",
|
||||
headers_payload.encode("utf-8"),
|
||||
"application/json",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def build_openai_input_payload(
|
||||
def build_analysis_request(
|
||||
request: DataImportAnalysisJobRequest,
|
||||
) -> Dict[str, Any]:
|
||||
) -> DataImportAnalysisRequest:
|
||||
headers = derive_headers(request.rows, request.headers)
|
||||
csv_bytes = rows_to_csv_bytes(request.rows, headers)
|
||||
csv_part = build_file_part("sample_rows.csv", csv_bytes, "text/csv")
|
||||
schema_part = build_schema_part(request, headers)
|
||||
|
||||
prompt = load_import_template()
|
||||
|
||||
context_lines = [
|
||||
f"导入记录ID: {request.import_record_id}",
|
||||
f"样本数据行数: {len(request.rows)}",
|
||||
"请参考附件 `sample_rows.csv` 获取原始样本数据。",
|
||||
]
|
||||
|
||||
if schema_part:
|
||||
context_lines.append("附加结构信息来自第二个附件,请结合使用。")
|
||||
if request.raw_csv:
|
||||
csv_text = request.raw_csv.strip()
|
||||
else:
|
||||
context_lines.append("未提供表头或Schema,请依据数据自行推断字段信息。")
|
||||
csv_text = rows_to_csv_text(request.rows, headers)
|
||||
|
||||
user_content = [
|
||||
{"type": "input_text", "text": "\n".join(context_lines)},
|
||||
csv_part,
|
||||
]
|
||||
sections: List[str] = []
|
||||
if csv_text:
|
||||
sections.append("CSV样本预览:\n" + csv_text)
|
||||
schema_text = format_table_schema(request.table_schema)
|
||||
if schema_text:
|
||||
sections.append("附加结构信息:\n" + schema_text)
|
||||
|
||||
if schema_part:
|
||||
user_content.append(schema_part)
|
||||
example_data = "\n\n".join(sections) if sections else "未提供样本数据。"
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"model": request.llm_model,
|
||||
"input": [
|
||||
{"role": "system", "content": [{"type": "input_text", "text": prompt}]},
|
||||
{"role": "user", "content": user_content},
|
||||
],
|
||||
}
|
||||
max_length = 30_000
|
||||
if len(example_data) > max_length:
|
||||
example_data = example_data[: max_length - 3] + "..."
|
||||
|
||||
if request.temperature is not None:
|
||||
payload["temperature"] = request.temperature
|
||||
if request.max_output_tokens is not None:
|
||||
payload["max_output_tokens"] = request.max_output_tokens
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def parse_openai_responses_payload(
|
||||
data: Dict[str, Any], fallback_model: str
|
||||
) -> LLMResponse:
|
||||
output_blocks = data.get("output", [])
|
||||
choices: List[LLMChoice] = []
|
||||
|
||||
for idx, block in enumerate(output_blocks):
|
||||
if block.get("type") != "message":
|
||||
continue
|
||||
content_items = block.get("content", [])
|
||||
text_fragments: List[str] = []
|
||||
for item in content_items:
|
||||
if item.get("type") == "output_text":
|
||||
text_fragments.append(item.get("text", ""))
|
||||
|
||||
if not text_fragments and data.get("output_text"):
|
||||
text_fragments.append(data.get("output_text", ""))
|
||||
|
||||
message = LLMMessage(role=LLMRole.ASSISTANT, content="\n".join(text_fragments))
|
||||
choices.append(LLMChoice(index=idx, message=message))
|
||||
|
||||
if not choices and data.get("output_text"):
|
||||
message = LLMMessage(role=LLMRole.ASSISTANT, content=data.get("output_text", ""))
|
||||
choices.append(LLMChoice(index=0, message=message))
|
||||
|
||||
return LLMResponse(
|
||||
provider=LLMProvider.OPENAI,
|
||||
model=data.get("model", fallback_model),
|
||||
choices=choices,
|
||||
raw=data,
|
||||
return DataImportAnalysisRequest(
|
||||
import_record_id=request.import_record_id,
|
||||
example_data=example_data,
|
||||
table_headers=headers,
|
||||
llm_model=request.llm_model or DEFAULT_IMPORT_MODEL,
|
||||
)
|
||||
|
||||
|
||||
async def call_openai_import_analysis(
|
||||
request: DataImportAnalysisJobRequest,
|
||||
client: httpx.AsyncClient,
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
) -> LLMResponse:
|
||||
openai_api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
if not openai_api_key:
|
||||
raise ProviderAPICallError("OPENAI_API_KEY must be set to process import analysis.")
|
||||
def build_chat_payload(request: DataImportAnalysisJobRequest) -> Dict[str, Any]:
|
||||
llm_input = request.llm_model or DEFAULT_IMPORT_MODEL
|
||||
provider, model_name = resolve_provider_from_model(llm_input)
|
||||
normalized_model = f"{provider.value}:{model_name}"
|
||||
|
||||
payload = build_openai_input_payload(request)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {openai_api_key}",
|
||||
"Content-Type": "application/json",
|
||||
if SUPPORTED_IMPORT_MODELS and normalized_model not in SUPPORTED_IMPORT_MODELS:
|
||||
raise ProviderAPICallError(
|
||||
"Model '{model}' is not allowed. Allowed models: {allowed}".format(
|
||||
model=normalized_model,
|
||||
allowed=", ".join(sorted(SUPPORTED_IMPORT_MODELS)),
|
||||
)
|
||||
)
|
||||
|
||||
analysis_request = build_analysis_request(request)
|
||||
|
||||
messages = build_import_messages(analysis_request)
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"provider": provider.value,
|
||||
"model": model_name,
|
||||
"messages": [message.model_dump() for message in messages],
|
||||
"temperature": request.temperature if request.temperature is not None else 0.2,
|
||||
}
|
||||
|
||||
response = await client.post(OPENAI_RESPONSES_URL, json=payload, headers=headers)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPError as exc:
|
||||
raise ProviderAPICallError(f"OpenAI response API call failed: {exc}") from exc
|
||||
if request.max_output_tokens is not None:
|
||||
payload["max_tokens"] = request.max_output_tokens
|
||||
|
||||
data: Dict[str, Any] = response.json()
|
||||
return parse_openai_responses_payload(data, request.llm_model)
|
||||
return payload
|
||||
|
||||
|
||||
async def dispatch_import_analysis_job(
|
||||
request: DataImportAnalysisJobRequest,
|
||||
client: httpx.AsyncClient,
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
logger.info("Starting import analysis job %s", request.import_record_id)
|
||||
llm_response = await call_openai_import_analysis(request, client, api_key=api_key)
|
||||
|
||||
result = {
|
||||
payload = build_chat_payload(request)
|
||||
url = f"{IMPORT_GATEWAY_BASE_URL.rstrip('/')}/v1/chat/completions"
|
||||
|
||||
print(
|
||||
f"[ImportAnalysis] Dispatching import {request.import_record_id} to {url}: "
|
||||
f"{json.dumps(payload, ensure_ascii=False)}"
|
||||
)
|
||||
|
||||
try:
|
||||
response = await client.post(url, json=payload)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
body_preview = ""
|
||||
if exc.response is not None:
|
||||
body_preview = exc.response.text[:400]
|
||||
raise ProviderAPICallError(
|
||||
f"Failed to invoke chat completions endpoint: {exc}. Response body: {body_preview}"
|
||||
) from exc
|
||||
except httpx.HTTPError as exc:
|
||||
raise ProviderAPICallError(
|
||||
f"HTTP call to chat completions endpoint failed: {exc}"
|
||||
) from exc
|
||||
|
||||
try:
|
||||
response_data = response.json()
|
||||
except ValueError as exc:
|
||||
raise ProviderAPICallError("Chat completions endpoint returned invalid JSON") from exc
|
||||
|
||||
print(
|
||||
f"[ImportAnalysis] LLM HTTP status for {request.import_record_id}: "
|
||||
f"{response.status_code}"
|
||||
)
|
||||
print(
|
||||
f"[ImportAnalysis] LLM response for {request.import_record_id}: "
|
||||
f"{json.dumps(response_data, ensure_ascii=False)}"
|
||||
)
|
||||
|
||||
try:
|
||||
llm_response = LLMResponse.model_validate(response_data)
|
||||
except ValidationError as exc:
|
||||
raise ProviderAPICallError(
|
||||
"Chat completions endpoint returned unexpected schema"
|
||||
) from exc
|
||||
|
||||
logger.info("Completed import analysis job %s", request.import_record_id)
|
||||
return {
|
||||
"import_record_id": request.import_record_id,
|
||||
"status": "succeeded",
|
||||
"llm_response": llm_response.model_dump(),
|
||||
}
|
||||
|
||||
logger.info("Completed import analysis job %s", request.import_record_id)
|
||||
return result
|
||||
|
||||
|
||||
async def notify_import_analysis_callback(
|
||||
callback_url: str,
|
||||
payload: Dict[str, Any],
|
||||
client: httpx.AsyncClient,
|
||||
) -> None:
|
||||
callback_target = str(callback_url)
|
||||
|
||||
try:
|
||||
response = await client.post(callback_url, json=payload)
|
||||
response = await client.post(callback_target, json=payload)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPError as exc:
|
||||
logger.error(
|
||||
"Failed to deliver import analysis callback to %s: %s",
|
||||
callback_url,
|
||||
callback_target,
|
||||
exc,
|
||||
)
|
||||
|
||||
@ -345,16 +329,13 @@ async def notify_import_analysis_callback(
|
||||
async def process_import_analysis_job(
|
||||
request: DataImportAnalysisJobRequest,
|
||||
client: httpx.AsyncClient,
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
) -> None:
|
||||
try:
|
||||
payload = await dispatch_import_analysis_job(
|
||||
request,
|
||||
client,
|
||||
api_key=api_key,
|
||||
)
|
||||
payload = await dispatch_import_analysis_job(request, client)
|
||||
except ProviderAPICallError as exc:
|
||||
print(
|
||||
f"[ImportAnalysis] LLM call failed for {request.import_record_id}: {exc}"
|
||||
)
|
||||
payload = {
|
||||
"import_record_id": request.import_record_id,
|
||||
"status": "failed",
|
||||
@ -365,6 +346,9 @@ async def process_import_analysis_job(
|
||||
"Unexpected failure while processing import analysis job %s",
|
||||
request.import_record_id,
|
||||
)
|
||||
print(
|
||||
f"[ImportAnalysis] Unexpected error for {request.import_record_id}: {exc}"
|
||||
)
|
||||
payload = {
|
||||
"import_record_id": request.import_record_id,
|
||||
"status": "failed",
|
||||
|
||||
Reference in New Issue
Block a user