Files
data-ge/app/services/rag_client.py
2025-12-09 00:15:22 +08:00

84 lines
3.3 KiB
Python

from __future__ import annotations
import logging
from typing import Any, Sequence
import httpx
from app.exceptions import ProviderAPICallError
from app.schemas.rag import RagDeleteRequest, RagItemPayload, RagRetrieveRequest
from app.settings import RAG_API_AUTH_TOKEN, RAG_API_BASE_URL
logger = logging.getLogger(__name__)
class RagAPIClient:
"""Thin async client wrapper around the RAG endpoints described in doc/rag-api.md."""
def __init__(self, *, base_url: str | None = None, auth_token: str | None = None) -> None:
resolved_base = base_url or RAG_API_BASE_URL
self._base_url = resolved_base.rstrip("/")
self._auth_token = auth_token or RAG_API_AUTH_TOKEN
def _headers(self) -> dict[str, str]:
headers = {"Content-Type": "application/json"}
if self._auth_token:
headers["Authorization"] = f"Bearer {self._auth_token}"
return headers
async def _post(
self,
client: httpx.AsyncClient,
path: str,
payload: Any,
) -> Any:
url = f"{self._base_url}{path}"
try:
response = await client.post(url, json=payload, headers=self._headers())
response.raise_for_status()
except httpx.HTTPStatusError as exc:
status_code = exc.response.status_code if exc.response else None
response_text = exc.response.text if exc.response else ""
logger.error(
"RAG API responded with an error (%s) for %s: %s",
status_code,
url,
response_text,
exc_info=True,
)
raise ProviderAPICallError(
"RAG API call failed.",
status_code=status_code,
response_text=response_text,
) from exc
except httpx.HTTPError as exc:
logger.error("Transport error calling RAG API %s: %s", url, exc, exc_info=True)
raise ProviderAPICallError(f"RAG API call failed: {exc}") from exc
try:
return response.json()
except ValueError:
logger.warning("RAG API returned non-JSON response for %s; returning raw text.", url)
return {"raw": response.text}
async def add(self, client: httpx.AsyncClient, payload: RagItemPayload) -> Any:
body = payload.model_dump(by_alias=True, exclude_none=True)
return await self._post(client, "/rag/add", body)
async def add_batch(self, client: httpx.AsyncClient, items: Sequence[RagItemPayload]) -> Any:
body = [item.model_dump(by_alias=True, exclude_none=True) for item in items]
return await self._post(client, "/rag/addBatch", body)
async def update(self, client: httpx.AsyncClient, payload: RagItemPayload) -> Any:
body = payload.model_dump(by_alias=True, exclude_none=True)
return await self._post(client, "/rag/update", body)
async def delete(self, client: httpx.AsyncClient, payload: RagDeleteRequest) -> Any:
body = payload.model_dump(by_alias=True, exclude_none=True)
return await self._post(client, "/rag/delete", body)
async def retrieve(self, client: httpx.AsyncClient, payload: RagRetrieveRequest) -> Any:
body = payload.model_dump(by_alias=True, exclude_none=True)
return await self._post(client, "/rag/retrieve", body)