Files
data-ge/app/providers/anthropic.py
2025-10-30 18:25:29 +08:00

114 lines
3.6 KiB
Python

from __future__ import annotations
import logging
from typing import Any, Dict, List, Tuple
import httpx
from app.exceptions import ProviderAPICallError
from app.models import (
LLMChoice,
LLMMessage,
LLMProvider,
LLMRequest,
LLMResponse,
LLMRole,
)
from app.providers.base import LLMProviderClient
logger = logging.getLogger(__name__)
class AnthropicProvider(LLMProviderClient):
name = LLMProvider.ANTHROPIC.value
api_key_env = "ANTHROPIC_API_KEY"
base_url = "https://api.anthropic.com/v1/messages"
anthropic_version = "2023-06-01"
async def chat(
self, request: LLMRequest, client: httpx.AsyncClient
) -> LLMResponse:
self.ensure_stream_supported(request.stream)
system_prompt, chat_messages = self._convert_messages(request.messages)
payload = self.merge_payload(
{
"model": request.model,
"messages": chat_messages,
"max_tokens": request.max_tokens or 1024,
"temperature": request.temperature,
"top_p": request.top_p,
},
request.extra_params,
)
if system_prompt:
payload["system"] = system_prompt
headers = {
"x-api-key": self.api_key,
"anthropic-version": self.anthropic_version,
"content-type": "application/json",
}
try:
response = await client.post(self.base_url, json=payload, headers=headers)
response.raise_for_status()
except httpx.HTTPStatusError as exc:
status_code = exc.response.status_code
body = exc.response.text
logger.error(
"Anthropic upstream returned %s: %s", status_code, body, exc_info=True
)
raise ProviderAPICallError(
f"Anthropic request failed with status {status_code}",
status_code=status_code,
response_text=body,
) from exc
except httpx.HTTPError as exc:
logger.error("Anthropic transport error: %s", exc, exc_info=True)
raise ProviderAPICallError(f"Anthropic request failed: {exc}") from exc
data: Dict[str, Any] = response.json()
message = self._build_message(data)
return LLMResponse(
provider=LLMProvider.ANTHROPIC,
model=data.get("model", request.model),
choices=[LLMChoice(index=0, message=message)],
raw=data,
)
@staticmethod
def _convert_messages(
messages: List[LLMMessage],
) -> Tuple[str | None, List[dict[str, Any]]]:
system_parts: List[str] = []
chat_payload: List[dict[str, Any]] = []
for msg in messages:
if msg.role == LLMRole.SYSTEM:
system_parts.append(msg.content)
continue
role = "user" if msg.role == LLMRole.USER else "assistant"
chat_payload.append(
{"role": role, "content": [{"type": "text", "text": msg.content}]}
)
system_prompt = "\n\n".join(system_parts) if system_parts else None
return system_prompt, chat_payload
@staticmethod
def _build_message(data: Dict[str, Any]) -> LLMMessage:
role = data.get("role", "assistant")
content_blocks = data.get("content", [])
text_parts = [
block.get("text", "")
for block in content_blocks
if isinstance(block, dict) and block.get("type") == "text"
]
content = "\n\n".join(part for part in text_parts if part)
return LLMMessage(role=role, content=content)