from __future__ import annotations import logging from typing import Any, Sequence import httpx from app.exceptions import ProviderAPICallError from app.schemas.rag import RagDeleteRequest, RagItemPayload, RagRetrieveRequest from app.settings import RAG_API_AUTH_TOKEN, RAG_API_BASE_URL logger = logging.getLogger(__name__) class RagAPIClient: """Thin async client wrapper around the RAG endpoints described in doc/rag-api.md.""" def __init__(self, *, base_url: str | None = None, auth_token: str | None = None) -> None: resolved_base = base_url or RAG_API_BASE_URL self._base_url = resolved_base.rstrip("/") self._auth_token = auth_token or RAG_API_AUTH_TOKEN def _headers(self) -> dict[str, str]: headers = {"Content-Type": "application/json"} if self._auth_token: headers["Authorization"] = f"Bearer {self._auth_token}" return headers async def _post( self, client: httpx.AsyncClient, path: str, payload: Any, ) -> Any: url = f"{self._base_url}{path}" try: response = await client.post(url, json=payload, headers=self._headers()) response.raise_for_status() except httpx.HTTPStatusError as exc: status_code = exc.response.status_code if exc.response else None response_text = exc.response.text if exc.response else "" logger.error( "RAG API responded with an error (%s) for %s: %s", status_code, url, response_text, exc_info=True, ) raise ProviderAPICallError( "RAG API call failed.", status_code=status_code, response_text=response_text, ) from exc except httpx.HTTPError as exc: logger.error("Transport error calling RAG API %s: %s", url, exc, exc_info=True) raise ProviderAPICallError(f"RAG API call failed: {exc}") from exc try: return response.json() except ValueError: logger.warning("RAG API returned non-JSON response for %s; returning raw text.", url) return {"raw": response.text} async def add(self, client: httpx.AsyncClient, payload: RagItemPayload) -> Any: body = payload.model_dump(by_alias=True, exclude_none=True) return await self._post(client, "/rag/add", body) async def add_batch(self, client: httpx.AsyncClient, items: Sequence[RagItemPayload]) -> Any: body = [item.model_dump(by_alias=True, exclude_none=True) for item in items] return await self._post(client, "/rag/addBatch", body) async def update(self, client: httpx.AsyncClient, payload: RagItemPayload) -> Any: body = payload.model_dump(by_alias=True, exclude_none=True) return await self._post(client, "/rag/update", body) async def delete(self, client: httpx.AsyncClient, payload: RagDeleteRequest) -> Any: body = payload.model_dump(by_alias=True, exclude_none=True) return await self._post(client, "/rag/delete", body) async def retrieve(self, client: httpx.AsyncClient, payload: RagRetrieveRequest) -> Any: body = payload.model_dump(by_alias=True, exclude_none=True) return await self._post(client, "/rag/retrieve", body)