84 lines
3.3 KiB
Python
84 lines
3.3 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import Any, Sequence
|
|
|
|
import httpx
|
|
|
|
from app.exceptions import ProviderAPICallError
|
|
from app.schemas.rag import RagDeleteRequest, RagItemPayload, RagRetrieveRequest
|
|
from app.settings import RAG_API_AUTH_TOKEN, RAG_API_BASE_URL
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RagAPIClient:
|
|
"""Thin async client wrapper around the RAG endpoints described in doc/rag-api.md."""
|
|
|
|
def __init__(self, *, base_url: str | None = None, auth_token: str | None = None) -> None:
|
|
resolved_base = base_url or RAG_API_BASE_URL
|
|
self._base_url = resolved_base.rstrip("/")
|
|
self._auth_token = auth_token or RAG_API_AUTH_TOKEN
|
|
|
|
def _headers(self) -> dict[str, str]:
|
|
headers = {"Content-Type": "application/json"}
|
|
if self._auth_token:
|
|
headers["Authorization"] = f"Bearer {self._auth_token}"
|
|
return headers
|
|
|
|
async def _post(
|
|
self,
|
|
client: httpx.AsyncClient,
|
|
path: str,
|
|
payload: Any,
|
|
) -> Any:
|
|
url = f"{self._base_url}{path}"
|
|
try:
|
|
response = await client.post(url, json=payload, headers=self._headers())
|
|
response.raise_for_status()
|
|
except httpx.HTTPStatusError as exc:
|
|
status_code = exc.response.status_code if exc.response else None
|
|
response_text = exc.response.text if exc.response else ""
|
|
logger.error(
|
|
"RAG API responded with an error (%s) for %s: %s",
|
|
status_code,
|
|
url,
|
|
response_text,
|
|
exc_info=True,
|
|
)
|
|
raise ProviderAPICallError(
|
|
"RAG API call failed.",
|
|
status_code=status_code,
|
|
response_text=response_text,
|
|
) from exc
|
|
except httpx.HTTPError as exc:
|
|
logger.error("Transport error calling RAG API %s: %s", url, exc, exc_info=True)
|
|
raise ProviderAPICallError(f"RAG API call failed: {exc}") from exc
|
|
|
|
try:
|
|
return response.json()
|
|
except ValueError:
|
|
logger.warning("RAG API returned non-JSON response for %s; returning raw text.", url)
|
|
return {"raw": response.text}
|
|
|
|
async def add(self, client: httpx.AsyncClient, payload: RagItemPayload) -> Any:
|
|
body = payload.model_dump(by_alias=True, exclude_none=True)
|
|
return await self._post(client, "/rag/add", body)
|
|
|
|
async def add_batch(self, client: httpx.AsyncClient, items: Sequence[RagItemPayload]) -> Any:
|
|
body = [item.model_dump(by_alias=True, exclude_none=True) for item in items]
|
|
return await self._post(client, "/rag/addBatch", body)
|
|
|
|
async def update(self, client: httpx.AsyncClient, payload: RagItemPayload) -> Any:
|
|
body = payload.model_dump(by_alias=True, exclude_none=True)
|
|
return await self._post(client, "/rag/update", body)
|
|
|
|
async def delete(self, client: httpx.AsyncClient, payload: RagDeleteRequest) -> Any:
|
|
body = payload.model_dump(by_alias=True, exclude_none=True)
|
|
return await self._post(client, "/rag/delete", body)
|
|
|
|
async def retrieve(self, client: httpx.AsyncClient, payload: RagRetrieveRequest) -> Any:
|
|
body = payload.model_dump(by_alias=True, exclude_none=True)
|
|
return await self._post(client, "/rag/retrieve", body)
|