98 lines
3.1 KiB
Python
98 lines
3.1 KiB
Python
from __future__ import annotations
|
|
|
|
from functools import lru_cache
|
|
from pathlib import Path
|
|
from typing import List, Tuple
|
|
|
|
from app.models import (
|
|
DataImportAnalysisRequest,
|
|
LLMMessage,
|
|
LLMProvider,
|
|
LLMRole,
|
|
)
|
|
|
|
|
|
def resolve_provider_from_model(llm_model: str) -> Tuple[LLMProvider, str]:
|
|
"""Resolve provider based on the llm_model string.
|
|
|
|
The llm_model may be provided as 'provider:model' or 'provider/model'.
|
|
If no provider prefix is present, try an educated guess from common model name patterns.
|
|
"""
|
|
|
|
normalized = llm_model.strip()
|
|
provider_hint: str | None = None
|
|
model_name = normalized
|
|
|
|
for delimiter in (":", "/", "|"):
|
|
if delimiter in normalized:
|
|
provider_hint, model_name = normalized.split(delimiter, 1)
|
|
provider_hint = provider_hint.strip().lower()
|
|
model_name = model_name.strip()
|
|
break
|
|
|
|
provider_map = {provider.value: provider for provider in LLMProvider}
|
|
|
|
if provider_hint:
|
|
if provider_hint not in provider_map:
|
|
raise ValueError(
|
|
f"Unsupported provider '{provider_hint}'. Expected one of: {', '.join(provider_map.keys())}."
|
|
)
|
|
return provider_map[provider_hint], model_name
|
|
|
|
return _guess_provider_from_model(model_name), model_name
|
|
|
|
|
|
def _guess_provider_from_model(model_name: str) -> LLMProvider:
|
|
lowered = model_name.lower()
|
|
|
|
if lowered.startswith(("gpt", "o1", "text-", "dall-e", "whisper")):
|
|
return LLMProvider.OPENAI
|
|
if lowered.startswith(("claude", "anthropic")):
|
|
return LLMProvider.ANTHROPIC
|
|
if lowered.startswith(("gemini", "models/gemini")):
|
|
return LLMProvider.GEMINI
|
|
if lowered.startswith("qwen"):
|
|
return LLMProvider.QWEN
|
|
if lowered.startswith("deepseek"):
|
|
return LLMProvider.DEEPSEEK
|
|
if lowered.startswith(("openrouter", "router-")):
|
|
return LLMProvider.OPENROUTER
|
|
|
|
supported = ", ".join(provider.value for provider in LLMProvider)
|
|
raise ValueError(
|
|
f"Unable to infer provider from model '{model_name}'. "
|
|
f"Please prefix with 'provider:model'. Supported providers: {supported}."
|
|
)
|
|
|
|
|
|
def build_import_messages(
|
|
request: DataImportAnalysisRequest,
|
|
) -> List[LLMMessage]:
|
|
"""Create system and user messages for the import analysis prompt."""
|
|
headers_formatted = "\n".join(f"- {header}" for header in request.table_headers)
|
|
|
|
system_prompt = load_import_template()
|
|
|
|
data_block = (
|
|
f"导入记录ID: {request.import_record_id}\n\n"
|
|
"表头信息:\n"
|
|
f"{headers_formatted}\n\n"
|
|
"示例数据:\n"
|
|
f"{request.example_data}"
|
|
)
|
|
|
|
return [
|
|
LLMMessage(role=LLMRole.SYSTEM, content=system_prompt),
|
|
LLMMessage(role=LLMRole.USER, content=data_block),
|
|
]
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def load_import_template() -> str:
|
|
template_path = (
|
|
Path(__file__).resolve().parents[2] / "prompt" / "data_import_analysis.md"
|
|
)
|
|
if not template_path.exists():
|
|
raise FileNotFoundError(f"Prompt template not found at {template_path}")
|
|
return template_path.read_text(encoding="utf-8").strip()
|