导入分析接口使用项目chat接口
This commit is contained in:
@ -57,7 +57,7 @@ def create_app() -> FastAPI:
|
|||||||
"/v1/import/analyze",
|
"/v1/import/analyze",
|
||||||
response_model=DataImportAnalysisJobAck,
|
response_model=DataImportAnalysisJobAck,
|
||||||
summary="Schedule async import analysis and notify via callback",
|
summary="Schedule async import analysis and notify via callback",
|
||||||
status_code=200,
|
status_code=202,
|
||||||
)
|
)
|
||||||
async def analyze_import_data(
|
async def analyze_import_data(
|
||||||
payload: DataImportAnalysisJobRequest,
|
payload: DataImportAnalysisJobRequest,
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import base64
|
|
||||||
import csv
|
import csv
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@ -11,21 +10,26 @@ from pathlib import Path
|
|||||||
from typing import Any, Dict, List, Sequence, Tuple
|
from typing import Any, Dict, List, Sequence, Tuple
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from app.exceptions import ProviderAPICallError
|
from app.exceptions import ProviderAPICallError
|
||||||
from app.models import (
|
from app.models import (
|
||||||
DataImportAnalysisJobRequest,
|
DataImportAnalysisJobRequest,
|
||||||
DataImportAnalysisRequest,
|
DataImportAnalysisRequest,
|
||||||
LLMChoice,
|
|
||||||
LLMMessage,
|
LLMMessage,
|
||||||
LLMProvider,
|
LLMProvider,
|
||||||
LLMResponse,
|
LLMResponse,
|
||||||
LLMRole,
|
LLMRole,
|
||||||
)
|
)
|
||||||
|
from app.settings import DEFAULT_IMPORT_MODEL, get_supported_import_models
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
OPENAI_RESPONSES_URL = "https://api.openai.com/v1/responses"
|
IMPORT_GATEWAY_BASE_URL = os.getenv(
|
||||||
|
"IMPORT_GATEWAY_BASE_URL", "http://localhost:8000"
|
||||||
|
)
|
||||||
|
|
||||||
|
SUPPORTED_IMPORT_MODELS = get_supported_import_models()
|
||||||
|
|
||||||
|
|
||||||
def resolve_provider_from_model(llm_model: str) -> Tuple[LLMProvider, str]:
|
def resolve_provider_from_model(llm_model: str) -> Tuple[LLMProvider, str]:
|
||||||
@ -113,17 +117,32 @@ def load_import_template() -> str:
|
|||||||
return template_path.read_text(encoding="utf-8").strip()
|
return template_path.read_text(encoding="utf-8").strip()
|
||||||
|
|
||||||
|
|
||||||
def derive_headers(rows: Sequence[Any], provided_headers: Sequence[str] | None) -> List[str]:
|
def derive_headers(
|
||||||
|
rows: Sequence[Any], provided_headers: Sequence[str] | None
|
||||||
|
) -> List[str]:
|
||||||
if provided_headers:
|
if provided_headers:
|
||||||
return list(provided_headers)
|
return [str(header) for header in provided_headers]
|
||||||
|
|
||||||
|
collected: List[str] = []
|
||||||
|
list_lengths: List[int] = []
|
||||||
|
|
||||||
seen: List[str] = []
|
|
||||||
for row in rows:
|
for row in rows:
|
||||||
if isinstance(row, dict):
|
if isinstance(row, dict):
|
||||||
for key in row.keys():
|
for key in row.keys():
|
||||||
if key not in seen:
|
key_str = str(key)
|
||||||
seen.append(str(key))
|
if key_str not in collected:
|
||||||
return seen
|
collected.append(key_str)
|
||||||
|
elif isinstance(row, (list, tuple)):
|
||||||
|
list_lengths.append(len(row))
|
||||||
|
|
||||||
|
if collected:
|
||||||
|
return collected
|
||||||
|
|
||||||
|
if list_lengths:
|
||||||
|
max_len = max(list_lengths)
|
||||||
|
return [f"column_{idx + 1}" for idx in range(max_len)]
|
||||||
|
|
||||||
|
return ["column_1"]
|
||||||
|
|
||||||
|
|
||||||
def _stringify_cell(value: Any) -> str:
|
def _stringify_cell(value: Any) -> str:
|
||||||
@ -137,14 +156,18 @@ def _stringify_cell(value: Any) -> str:
|
|||||||
return str(value)
|
return str(value)
|
||||||
|
|
||||||
|
|
||||||
def rows_to_csv_bytes(rows: Sequence[Any], headers: List[str]) -> bytes:
|
def rows_to_csv_text(
|
||||||
|
rows: Sequence[Any], headers: Sequence[str], *, max_rows: int = 50
|
||||||
|
) -> str:
|
||||||
buffer = StringIO()
|
buffer = StringIO()
|
||||||
writer = csv.writer(buffer)
|
writer = csv.writer(buffer)
|
||||||
|
|
||||||
if headers:
|
if headers:
|
||||||
writer.writerow(headers)
|
writer.writerow(headers)
|
||||||
|
|
||||||
for row in rows:
|
for idx, row in enumerate(rows):
|
||||||
|
if max_rows and idx >= max_rows:
|
||||||
|
break
|
||||||
if isinstance(row, dict):
|
if isinstance(row, dict):
|
||||||
writer.writerow([_stringify_cell(row.get(header)) for header in headers])
|
writer.writerow([_stringify_cell(row.get(header)) for header in headers])
|
||||||
elif isinstance(row, (list, tuple)):
|
elif isinstance(row, (list, tuple)):
|
||||||
@ -152,192 +175,153 @@ def rows_to_csv_bytes(rows: Sequence[Any], headers: List[str]) -> bytes:
|
|||||||
else:
|
else:
|
||||||
writer.writerow([_stringify_cell(row)])
|
writer.writerow([_stringify_cell(row)])
|
||||||
|
|
||||||
return buffer.getvalue().encode("utf-8")
|
return buffer.getvalue().strip()
|
||||||
|
|
||||||
|
|
||||||
def build_file_part(filename: str, payload: bytes, mime_type: str) -> Dict[str, Any]:
|
def format_table_schema(schema: Any) -> str:
|
||||||
encoded = base64.b64encode(payload).decode("ascii")
|
if schema is None:
|
||||||
return {
|
return ""
|
||||||
"type": "input_file",
|
if isinstance(schema, str):
|
||||||
"input_file": {
|
return schema.strip()
|
||||||
"file_data": {
|
try:
|
||||||
"filename": filename,
|
return json.dumps(schema, ensure_ascii=False, indent=2)
|
||||||
"file_name": filename,
|
except (TypeError, ValueError):
|
||||||
"mime_type": mime_type,
|
return str(schema)
|
||||||
"b64_json": encoded,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def build_schema_part(
|
def build_analysis_request(
|
||||||
request: DataImportAnalysisJobRequest, headers: List[str]
|
|
||||||
) -> Dict[str, Any] | None:
|
|
||||||
if request.table_schema is not None:
|
|
||||||
if isinstance(request.table_schema, str):
|
|
||||||
schema_bytes = request.table_schema.encode("utf-8")
|
|
||||||
return build_file_part("table_schema.txt", schema_bytes, "text/plain")
|
|
||||||
|
|
||||||
try:
|
|
||||||
schema_serialised = json.dumps(
|
|
||||||
request.table_schema, ensure_ascii=False, indent=2
|
|
||||||
)
|
|
||||||
except (TypeError, ValueError) as exc:
|
|
||||||
logger.warning("Failed to serialise table_schema for %s: %s", request.import_record_id, exc)
|
|
||||||
schema_serialised = str(request.table_schema)
|
|
||||||
|
|
||||||
return build_file_part(
|
|
||||||
"table_schema.json",
|
|
||||||
schema_serialised.encode("utf-8"),
|
|
||||||
"application/json",
|
|
||||||
)
|
|
||||||
|
|
||||||
if headers:
|
|
||||||
headers_payload = json.dumps({"headers": headers}, ensure_ascii=False, indent=2)
|
|
||||||
return build_file_part(
|
|
||||||
"table_headers.json",
|
|
||||||
headers_payload.encode("utf-8"),
|
|
||||||
"application/json",
|
|
||||||
)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def build_openai_input_payload(
|
|
||||||
request: DataImportAnalysisJobRequest,
|
request: DataImportAnalysisJobRequest,
|
||||||
) -> Dict[str, Any]:
|
) -> DataImportAnalysisRequest:
|
||||||
headers = derive_headers(request.rows, request.headers)
|
headers = derive_headers(request.rows, request.headers)
|
||||||
csv_bytes = rows_to_csv_bytes(request.rows, headers)
|
|
||||||
csv_part = build_file_part("sample_rows.csv", csv_bytes, "text/csv")
|
|
||||||
schema_part = build_schema_part(request, headers)
|
|
||||||
|
|
||||||
prompt = load_import_template()
|
if request.raw_csv:
|
||||||
|
csv_text = request.raw_csv.strip()
|
||||||
context_lines = [
|
|
||||||
f"导入记录ID: {request.import_record_id}",
|
|
||||||
f"样本数据行数: {len(request.rows)}",
|
|
||||||
"请参考附件 `sample_rows.csv` 获取原始样本数据。",
|
|
||||||
]
|
|
||||||
|
|
||||||
if schema_part:
|
|
||||||
context_lines.append("附加结构信息来自第二个附件,请结合使用。")
|
|
||||||
else:
|
else:
|
||||||
context_lines.append("未提供表头或Schema,请依据数据自行推断字段信息。")
|
csv_text = rows_to_csv_text(request.rows, headers)
|
||||||
|
|
||||||
user_content = [
|
sections: List[str] = []
|
||||||
{"type": "input_text", "text": "\n".join(context_lines)},
|
if csv_text:
|
||||||
csv_part,
|
sections.append("CSV样本预览:\n" + csv_text)
|
||||||
]
|
schema_text = format_table_schema(request.table_schema)
|
||||||
|
if schema_text:
|
||||||
|
sections.append("附加结构信息:\n" + schema_text)
|
||||||
|
|
||||||
if schema_part:
|
example_data = "\n\n".join(sections) if sections else "未提供样本数据。"
|
||||||
user_content.append(schema_part)
|
|
||||||
|
|
||||||
payload: Dict[str, Any] = {
|
max_length = 30_000
|
||||||
"model": request.llm_model,
|
if len(example_data) > max_length:
|
||||||
"input": [
|
example_data = example_data[: max_length - 3] + "..."
|
||||||
{"role": "system", "content": [{"type": "input_text", "text": prompt}]},
|
|
||||||
{"role": "user", "content": user_content},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
if request.temperature is not None:
|
return DataImportAnalysisRequest(
|
||||||
payload["temperature"] = request.temperature
|
import_record_id=request.import_record_id,
|
||||||
if request.max_output_tokens is not None:
|
example_data=example_data,
|
||||||
payload["max_output_tokens"] = request.max_output_tokens
|
table_headers=headers,
|
||||||
|
llm_model=request.llm_model or DEFAULT_IMPORT_MODEL,
|
||||||
return payload
|
|
||||||
|
|
||||||
|
|
||||||
def parse_openai_responses_payload(
|
|
||||||
data: Dict[str, Any], fallback_model: str
|
|
||||||
) -> LLMResponse:
|
|
||||||
output_blocks = data.get("output", [])
|
|
||||||
choices: List[LLMChoice] = []
|
|
||||||
|
|
||||||
for idx, block in enumerate(output_blocks):
|
|
||||||
if block.get("type") != "message":
|
|
||||||
continue
|
|
||||||
content_items = block.get("content", [])
|
|
||||||
text_fragments: List[str] = []
|
|
||||||
for item in content_items:
|
|
||||||
if item.get("type") == "output_text":
|
|
||||||
text_fragments.append(item.get("text", ""))
|
|
||||||
|
|
||||||
if not text_fragments and data.get("output_text"):
|
|
||||||
text_fragments.append(data.get("output_text", ""))
|
|
||||||
|
|
||||||
message = LLMMessage(role=LLMRole.ASSISTANT, content="\n".join(text_fragments))
|
|
||||||
choices.append(LLMChoice(index=idx, message=message))
|
|
||||||
|
|
||||||
if not choices and data.get("output_text"):
|
|
||||||
message = LLMMessage(role=LLMRole.ASSISTANT, content=data.get("output_text", ""))
|
|
||||||
choices.append(LLMChoice(index=0, message=message))
|
|
||||||
|
|
||||||
return LLMResponse(
|
|
||||||
provider=LLMProvider.OPENAI,
|
|
||||||
model=data.get("model", fallback_model),
|
|
||||||
choices=choices,
|
|
||||||
raw=data,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def call_openai_import_analysis(
|
def build_chat_payload(request: DataImportAnalysisJobRequest) -> Dict[str, Any]:
|
||||||
request: DataImportAnalysisJobRequest,
|
llm_input = request.llm_model or DEFAULT_IMPORT_MODEL
|
||||||
client: httpx.AsyncClient,
|
provider, model_name = resolve_provider_from_model(llm_input)
|
||||||
*,
|
normalized_model = f"{provider.value}:{model_name}"
|
||||||
api_key: str | None = None,
|
|
||||||
) -> LLMResponse:
|
|
||||||
openai_api_key = api_key or os.getenv("OPENAI_API_KEY")
|
|
||||||
if not openai_api_key:
|
|
||||||
raise ProviderAPICallError("OPENAI_API_KEY must be set to process import analysis.")
|
|
||||||
|
|
||||||
payload = build_openai_input_payload(request)
|
if SUPPORTED_IMPORT_MODELS and normalized_model not in SUPPORTED_IMPORT_MODELS:
|
||||||
headers = {
|
raise ProviderAPICallError(
|
||||||
"Authorization": f"Bearer {openai_api_key}",
|
"Model '{model}' is not allowed. Allowed models: {allowed}".format(
|
||||||
"Content-Type": "application/json",
|
model=normalized_model,
|
||||||
|
allowed=", ".join(sorted(SUPPORTED_IMPORT_MODELS)),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
analysis_request = build_analysis_request(request)
|
||||||
|
|
||||||
|
messages = build_import_messages(analysis_request)
|
||||||
|
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"provider": provider.value,
|
||||||
|
"model": model_name,
|
||||||
|
"messages": [message.model_dump() for message in messages],
|
||||||
|
"temperature": request.temperature if request.temperature is not None else 0.2,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = await client.post(OPENAI_RESPONSES_URL, json=payload, headers=headers)
|
if request.max_output_tokens is not None:
|
||||||
try:
|
payload["max_tokens"] = request.max_output_tokens
|
||||||
response.raise_for_status()
|
|
||||||
except httpx.HTTPError as exc:
|
|
||||||
raise ProviderAPICallError(f"OpenAI response API call failed: {exc}") from exc
|
|
||||||
|
|
||||||
data: Dict[str, Any] = response.json()
|
return payload
|
||||||
return parse_openai_responses_payload(data, request.llm_model)
|
|
||||||
|
|
||||||
|
|
||||||
async def dispatch_import_analysis_job(
|
async def dispatch_import_analysis_job(
|
||||||
request: DataImportAnalysisJobRequest,
|
request: DataImportAnalysisJobRequest,
|
||||||
client: httpx.AsyncClient,
|
client: httpx.AsyncClient,
|
||||||
*,
|
|
||||||
api_key: str | None = None,
|
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
logger.info("Starting import analysis job %s", request.import_record_id)
|
logger.info("Starting import analysis job %s", request.import_record_id)
|
||||||
llm_response = await call_openai_import_analysis(request, client, api_key=api_key)
|
|
||||||
|
|
||||||
result = {
|
payload = build_chat_payload(request)
|
||||||
|
url = f"{IMPORT_GATEWAY_BASE_URL.rstrip('/')}/v1/chat/completions"
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"[ImportAnalysis] Dispatching import {request.import_record_id} to {url}: "
|
||||||
|
f"{json.dumps(payload, ensure_ascii=False)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.post(url, json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as exc:
|
||||||
|
body_preview = ""
|
||||||
|
if exc.response is not None:
|
||||||
|
body_preview = exc.response.text[:400]
|
||||||
|
raise ProviderAPICallError(
|
||||||
|
f"Failed to invoke chat completions endpoint: {exc}. Response body: {body_preview}"
|
||||||
|
) from exc
|
||||||
|
except httpx.HTTPError as exc:
|
||||||
|
raise ProviderAPICallError(
|
||||||
|
f"HTTP call to chat completions endpoint failed: {exc}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
try:
|
||||||
|
response_data = response.json()
|
||||||
|
except ValueError as exc:
|
||||||
|
raise ProviderAPICallError("Chat completions endpoint returned invalid JSON") from exc
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"[ImportAnalysis] LLM HTTP status for {request.import_record_id}: "
|
||||||
|
f"{response.status_code}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"[ImportAnalysis] LLM response for {request.import_record_id}: "
|
||||||
|
f"{json.dumps(response_data, ensure_ascii=False)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
llm_response = LLMResponse.model_validate(response_data)
|
||||||
|
except ValidationError as exc:
|
||||||
|
raise ProviderAPICallError(
|
||||||
|
"Chat completions endpoint returned unexpected schema"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
logger.info("Completed import analysis job %s", request.import_record_id)
|
||||||
|
return {
|
||||||
"import_record_id": request.import_record_id,
|
"import_record_id": request.import_record_id,
|
||||||
"status": "succeeded",
|
"status": "succeeded",
|
||||||
"llm_response": llm_response.model_dump(),
|
"llm_response": llm_response.model_dump(),
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info("Completed import analysis job %s", request.import_record_id)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
async def notify_import_analysis_callback(
|
async def notify_import_analysis_callback(
|
||||||
callback_url: str,
|
callback_url: str,
|
||||||
payload: Dict[str, Any],
|
payload: Dict[str, Any],
|
||||||
client: httpx.AsyncClient,
|
client: httpx.AsyncClient,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
callback_target = str(callback_url)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await client.post(callback_url, json=payload)
|
response = await client.post(callback_target, json=payload)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
except httpx.HTTPError as exc:
|
except httpx.HTTPError as exc:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Failed to deliver import analysis callback to %s: %s",
|
"Failed to deliver import analysis callback to %s: %s",
|
||||||
callback_url,
|
callback_target,
|
||||||
exc,
|
exc,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -345,16 +329,13 @@ async def notify_import_analysis_callback(
|
|||||||
async def process_import_analysis_job(
|
async def process_import_analysis_job(
|
||||||
request: DataImportAnalysisJobRequest,
|
request: DataImportAnalysisJobRequest,
|
||||||
client: httpx.AsyncClient,
|
client: httpx.AsyncClient,
|
||||||
*,
|
|
||||||
api_key: str | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
payload = await dispatch_import_analysis_job(
|
payload = await dispatch_import_analysis_job(request, client)
|
||||||
request,
|
|
||||||
client,
|
|
||||||
api_key=api_key,
|
|
||||||
)
|
|
||||||
except ProviderAPICallError as exc:
|
except ProviderAPICallError as exc:
|
||||||
|
print(
|
||||||
|
f"[ImportAnalysis] LLM call failed for {request.import_record_id}: {exc}"
|
||||||
|
)
|
||||||
payload = {
|
payload = {
|
||||||
"import_record_id": request.import_record_id,
|
"import_record_id": request.import_record_id,
|
||||||
"status": "failed",
|
"status": "failed",
|
||||||
@ -365,6 +346,9 @@ async def process_import_analysis_job(
|
|||||||
"Unexpected failure while processing import analysis job %s",
|
"Unexpected failure while processing import analysis job %s",
|
||||||
request.import_record_id,
|
request.import_record_id,
|
||||||
)
|
)
|
||||||
|
print(
|
||||||
|
f"[ImportAnalysis] Unexpected error for {request.import_record_id}: {exc}"
|
||||||
|
)
|
||||||
payload = {
|
payload = {
|
||||||
"import_record_id": request.import_record_id,
|
"import_record_id": request.import_record_id,
|
||||||
"status": "failed",
|
"status": "failed",
|
||||||
|
|||||||
Reference in New Issue
Block a user