diff --git a/app/main.py b/app/main.py index 03cbcb5..0d8fed0 100644 --- a/app/main.py +++ b/app/main.py @@ -57,7 +57,7 @@ def create_app() -> FastAPI: "/v1/import/analyze", response_model=DataImportAnalysisJobAck, summary="Schedule async import analysis and notify via callback", - status_code=200, + status_code=202, ) async def analyze_import_data( payload: DataImportAnalysisJobRequest, diff --git a/app/services/import_analysis.py b/app/services/import_analysis.py index e249066..aa86c1f 100644 --- a/app/services/import_analysis.py +++ b/app/services/import_analysis.py @@ -1,6 +1,5 @@ from __future__ import annotations -import base64 import csv import json import logging @@ -11,21 +10,26 @@ from pathlib import Path from typing import Any, Dict, List, Sequence, Tuple import httpx +from pydantic import ValidationError from app.exceptions import ProviderAPICallError from app.models import ( DataImportAnalysisJobRequest, DataImportAnalysisRequest, - LLMChoice, LLMMessage, LLMProvider, LLMResponse, LLMRole, ) +from app.settings import DEFAULT_IMPORT_MODEL, get_supported_import_models logger = logging.getLogger(__name__) -OPENAI_RESPONSES_URL = "https://api.openai.com/v1/responses" +IMPORT_GATEWAY_BASE_URL = os.getenv( + "IMPORT_GATEWAY_BASE_URL", "http://localhost:8000" +) + +SUPPORTED_IMPORT_MODELS = get_supported_import_models() def resolve_provider_from_model(llm_model: str) -> Tuple[LLMProvider, str]: @@ -113,17 +117,32 @@ def load_import_template() -> str: return template_path.read_text(encoding="utf-8").strip() -def derive_headers(rows: Sequence[Any], provided_headers: Sequence[str] | None) -> List[str]: +def derive_headers( + rows: Sequence[Any], provided_headers: Sequence[str] | None +) -> List[str]: if provided_headers: - return list(provided_headers) + return [str(header) for header in provided_headers] + + collected: List[str] = [] + list_lengths: List[int] = [] - seen: List[str] = [] for row in rows: if isinstance(row, dict): for key in row.keys(): - if key not in seen: - seen.append(str(key)) - return seen + key_str = str(key) + if key_str not in collected: + collected.append(key_str) + elif isinstance(row, (list, tuple)): + list_lengths.append(len(row)) + + if collected: + return collected + + if list_lengths: + max_len = max(list_lengths) + return [f"column_{idx + 1}" for idx in range(max_len)] + + return ["column_1"] def _stringify_cell(value: Any) -> str: @@ -137,14 +156,18 @@ def _stringify_cell(value: Any) -> str: return str(value) -def rows_to_csv_bytes(rows: Sequence[Any], headers: List[str]) -> bytes: +def rows_to_csv_text( + rows: Sequence[Any], headers: Sequence[str], *, max_rows: int = 50 +) -> str: buffer = StringIO() writer = csv.writer(buffer) if headers: writer.writerow(headers) - for row in rows: + for idx, row in enumerate(rows): + if max_rows and idx >= max_rows: + break if isinstance(row, dict): writer.writerow([_stringify_cell(row.get(header)) for header in headers]) elif isinstance(row, (list, tuple)): @@ -152,192 +175,153 @@ def rows_to_csv_bytes(rows: Sequence[Any], headers: List[str]) -> bytes: else: writer.writerow([_stringify_cell(row)]) - return buffer.getvalue().encode("utf-8") + return buffer.getvalue().strip() -def build_file_part(filename: str, payload: bytes, mime_type: str) -> Dict[str, Any]: - encoded = base64.b64encode(payload).decode("ascii") - return { - "type": "input_file", - "input_file": { - "file_data": { - "filename": filename, - "file_name": filename, - "mime_type": mime_type, - "b64_json": encoded, - } - }, - } +def format_table_schema(schema: Any) -> str: + if schema is None: + return "" + if isinstance(schema, str): + return schema.strip() + try: + return json.dumps(schema, ensure_ascii=False, indent=2) + except (TypeError, ValueError): + return str(schema) -def build_schema_part( - request: DataImportAnalysisJobRequest, headers: List[str] -) -> Dict[str, Any] | None: - if request.table_schema is not None: - if isinstance(request.table_schema, str): - schema_bytes = request.table_schema.encode("utf-8") - return build_file_part("table_schema.txt", schema_bytes, "text/plain") - - try: - schema_serialised = json.dumps( - request.table_schema, ensure_ascii=False, indent=2 - ) - except (TypeError, ValueError) as exc: - logger.warning("Failed to serialise table_schema for %s: %s", request.import_record_id, exc) - schema_serialised = str(request.table_schema) - - return build_file_part( - "table_schema.json", - schema_serialised.encode("utf-8"), - "application/json", - ) - - if headers: - headers_payload = json.dumps({"headers": headers}, ensure_ascii=False, indent=2) - return build_file_part( - "table_headers.json", - headers_payload.encode("utf-8"), - "application/json", - ) - - return None - - -def build_openai_input_payload( +def build_analysis_request( request: DataImportAnalysisJobRequest, -) -> Dict[str, Any]: +) -> DataImportAnalysisRequest: headers = derive_headers(request.rows, request.headers) - csv_bytes = rows_to_csv_bytes(request.rows, headers) - csv_part = build_file_part("sample_rows.csv", csv_bytes, "text/csv") - schema_part = build_schema_part(request, headers) - prompt = load_import_template() - - context_lines = [ - f"导入记录ID: {request.import_record_id}", - f"样本数据行数: {len(request.rows)}", - "请参考附件 `sample_rows.csv` 获取原始样本数据。", - ] - - if schema_part: - context_lines.append("附加结构信息来自第二个附件,请结合使用。") + if request.raw_csv: + csv_text = request.raw_csv.strip() else: - context_lines.append("未提供表头或Schema,请依据数据自行推断字段信息。") + csv_text = rows_to_csv_text(request.rows, headers) - user_content = [ - {"type": "input_text", "text": "\n".join(context_lines)}, - csv_part, - ] + sections: List[str] = [] + if csv_text: + sections.append("CSV样本预览:\n" + csv_text) + schema_text = format_table_schema(request.table_schema) + if schema_text: + sections.append("附加结构信息:\n" + schema_text) - if schema_part: - user_content.append(schema_part) + example_data = "\n\n".join(sections) if sections else "未提供样本数据。" - payload: Dict[str, Any] = { - "model": request.llm_model, - "input": [ - {"role": "system", "content": [{"type": "input_text", "text": prompt}]}, - {"role": "user", "content": user_content}, - ], - } + max_length = 30_000 + if len(example_data) > max_length: + example_data = example_data[: max_length - 3] + "..." - if request.temperature is not None: - payload["temperature"] = request.temperature - if request.max_output_tokens is not None: - payload["max_output_tokens"] = request.max_output_tokens - - return payload - - -def parse_openai_responses_payload( - data: Dict[str, Any], fallback_model: str -) -> LLMResponse: - output_blocks = data.get("output", []) - choices: List[LLMChoice] = [] - - for idx, block in enumerate(output_blocks): - if block.get("type") != "message": - continue - content_items = block.get("content", []) - text_fragments: List[str] = [] - for item in content_items: - if item.get("type") == "output_text": - text_fragments.append(item.get("text", "")) - - if not text_fragments and data.get("output_text"): - text_fragments.append(data.get("output_text", "")) - - message = LLMMessage(role=LLMRole.ASSISTANT, content="\n".join(text_fragments)) - choices.append(LLMChoice(index=idx, message=message)) - - if not choices and data.get("output_text"): - message = LLMMessage(role=LLMRole.ASSISTANT, content=data.get("output_text", "")) - choices.append(LLMChoice(index=0, message=message)) - - return LLMResponse( - provider=LLMProvider.OPENAI, - model=data.get("model", fallback_model), - choices=choices, - raw=data, + return DataImportAnalysisRequest( + import_record_id=request.import_record_id, + example_data=example_data, + table_headers=headers, + llm_model=request.llm_model or DEFAULT_IMPORT_MODEL, ) -async def call_openai_import_analysis( - request: DataImportAnalysisJobRequest, - client: httpx.AsyncClient, - *, - api_key: str | None = None, -) -> LLMResponse: - openai_api_key = api_key or os.getenv("OPENAI_API_KEY") - if not openai_api_key: - raise ProviderAPICallError("OPENAI_API_KEY must be set to process import analysis.") +def build_chat_payload(request: DataImportAnalysisJobRequest) -> Dict[str, Any]: + llm_input = request.llm_model or DEFAULT_IMPORT_MODEL + provider, model_name = resolve_provider_from_model(llm_input) + normalized_model = f"{provider.value}:{model_name}" - payload = build_openai_input_payload(request) - headers = { - "Authorization": f"Bearer {openai_api_key}", - "Content-Type": "application/json", + if SUPPORTED_IMPORT_MODELS and normalized_model not in SUPPORTED_IMPORT_MODELS: + raise ProviderAPICallError( + "Model '{model}' is not allowed. Allowed models: {allowed}".format( + model=normalized_model, + allowed=", ".join(sorted(SUPPORTED_IMPORT_MODELS)), + ) + ) + + analysis_request = build_analysis_request(request) + + messages = build_import_messages(analysis_request) + + payload: Dict[str, Any] = { + "provider": provider.value, + "model": model_name, + "messages": [message.model_dump() for message in messages], + "temperature": request.temperature if request.temperature is not None else 0.2, } - response = await client.post(OPENAI_RESPONSES_URL, json=payload, headers=headers) - try: - response.raise_for_status() - except httpx.HTTPError as exc: - raise ProviderAPICallError(f"OpenAI response API call failed: {exc}") from exc + if request.max_output_tokens is not None: + payload["max_tokens"] = request.max_output_tokens - data: Dict[str, Any] = response.json() - return parse_openai_responses_payload(data, request.llm_model) + return payload async def dispatch_import_analysis_job( request: DataImportAnalysisJobRequest, client: httpx.AsyncClient, - *, - api_key: str | None = None, ) -> Dict[str, Any]: logger.info("Starting import analysis job %s", request.import_record_id) - llm_response = await call_openai_import_analysis(request, client, api_key=api_key) - result = { + payload = build_chat_payload(request) + url = f"{IMPORT_GATEWAY_BASE_URL.rstrip('/')}/v1/chat/completions" + + print( + f"[ImportAnalysis] Dispatching import {request.import_record_id} to {url}: " + f"{json.dumps(payload, ensure_ascii=False)}" + ) + + try: + response = await client.post(url, json=payload) + response.raise_for_status() + except httpx.HTTPStatusError as exc: + body_preview = "" + if exc.response is not None: + body_preview = exc.response.text[:400] + raise ProviderAPICallError( + f"Failed to invoke chat completions endpoint: {exc}. Response body: {body_preview}" + ) from exc + except httpx.HTTPError as exc: + raise ProviderAPICallError( + f"HTTP call to chat completions endpoint failed: {exc}" + ) from exc + + try: + response_data = response.json() + except ValueError as exc: + raise ProviderAPICallError("Chat completions endpoint returned invalid JSON") from exc + + print( + f"[ImportAnalysis] LLM HTTP status for {request.import_record_id}: " + f"{response.status_code}" + ) + print( + f"[ImportAnalysis] LLM response for {request.import_record_id}: " + f"{json.dumps(response_data, ensure_ascii=False)}" + ) + + try: + llm_response = LLMResponse.model_validate(response_data) + except ValidationError as exc: + raise ProviderAPICallError( + "Chat completions endpoint returned unexpected schema" + ) from exc + + logger.info("Completed import analysis job %s", request.import_record_id) + return { "import_record_id": request.import_record_id, "status": "succeeded", "llm_response": llm_response.model_dump(), } - logger.info("Completed import analysis job %s", request.import_record_id) - return result - async def notify_import_analysis_callback( callback_url: str, payload: Dict[str, Any], client: httpx.AsyncClient, ) -> None: + callback_target = str(callback_url) + try: - response = await client.post(callback_url, json=payload) + response = await client.post(callback_target, json=payload) response.raise_for_status() except httpx.HTTPError as exc: logger.error( "Failed to deliver import analysis callback to %s: %s", - callback_url, + callback_target, exc, ) @@ -345,16 +329,13 @@ async def notify_import_analysis_callback( async def process_import_analysis_job( request: DataImportAnalysisJobRequest, client: httpx.AsyncClient, - *, - api_key: str | None = None, ) -> None: try: - payload = await dispatch_import_analysis_job( - request, - client, - api_key=api_key, - ) + payload = await dispatch_import_analysis_job(request, client) except ProviderAPICallError as exc: + print( + f"[ImportAnalysis] LLM call failed for {request.import_record_id}: {exc}" + ) payload = { "import_record_id": request.import_record_id, "status": "failed", @@ -365,6 +346,9 @@ async def process_import_analysis_job( "Unexpected failure while processing import analysis job %s", request.import_record_id, ) + print( + f"[ImportAnalysis] Unexpected error for {request.import_record_id}: {exc}" + ) payload = { "import_record_id": request.import_record_id, "status": "failed",