from __future__ import annotations import logging from typing import Any, Dict, List, Tuple import httpx from app.exceptions import ProviderAPICallError from app.models import ( LLMChoice, LLMMessage, LLMProvider, LLMRequest, LLMResponse, LLMRole, ) from app.providers.base import LLMProviderClient logger = logging.getLogger(__name__) class AnthropicProvider(LLMProviderClient): name = LLMProvider.ANTHROPIC.value api_key_env = "ANTHROPIC_API_KEY" base_url = "https://api.anthropic.com/v1/messages" anthropic_version = "2023-06-01" async def chat( self, request: LLMRequest, client: httpx.AsyncClient ) -> LLMResponse: self.ensure_stream_supported(request.stream) system_prompt, chat_messages = self._convert_messages(request.messages) payload = self.merge_payload( { "model": request.model, "messages": chat_messages, "max_tokens": request.max_tokens or 1024, "temperature": request.temperature, "top_p": request.top_p, }, request.extra_params, ) if system_prompt: payload["system"] = system_prompt headers = { "x-api-key": self.api_key, "anthropic-version": self.anthropic_version, "content-type": "application/json", } try: response = await client.post(self.base_url, json=payload, headers=headers) response.raise_for_status() except httpx.HTTPStatusError as exc: status_code = exc.response.status_code body = exc.response.text logger.error( "Anthropic upstream returned %s: %s", status_code, body, exc_info=True ) raise ProviderAPICallError( f"Anthropic request failed with status {status_code}", status_code=status_code, response_text=body, ) from exc except httpx.HTTPError as exc: logger.error("Anthropic transport error: %s", exc, exc_info=True) raise ProviderAPICallError(f"Anthropic request failed: {exc}") from exc data: Dict[str, Any] = response.json() message = self._build_message(data) return LLMResponse( provider=LLMProvider.ANTHROPIC, model=data.get("model", request.model), choices=[LLMChoice(index=0, message=message)], raw=data, ) @staticmethod def _convert_messages( messages: List[LLMMessage], ) -> Tuple[str | None, List[dict[str, Any]]]: system_parts: List[str] = [] chat_payload: List[dict[str, Any]] = [] for msg in messages: if msg.role == LLMRole.SYSTEM: system_parts.append(msg.content) continue role = "user" if msg.role == LLMRole.USER else "assistant" chat_payload.append( {"role": role, "content": [{"type": "text", "text": msg.content}]} ) system_prompt = "\n\n".join(system_parts) if system_parts else None return system_prompt, chat_payload @staticmethod def _build_message(data: Dict[str, Any]) -> LLMMessage: role = data.get("role", "assistant") content_blocks = data.get("content", []) text_parts = [ block.get("text", "") for block in content_blocks if isinstance(block, dict) and block.get("type") == "text" ] content = "\n\n".join(part for part in text_parts if part) return LLMMessage(role=role, content=content)