rag snippet生成入库和写rag
This commit is contained in:
83
app/services/rag_client.py
Normal file
83
app/services/rag_client.py
Normal file
@ -0,0 +1,83 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Sequence
|
||||
|
||||
import httpx
|
||||
|
||||
from app.exceptions import ProviderAPICallError
|
||||
from app.schemas.rag import RagDeleteRequest, RagItemPayload, RagRetrieveRequest
|
||||
from app.settings import RAG_API_AUTH_TOKEN, RAG_API_BASE_URL
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RagAPIClient:
|
||||
"""Thin async client wrapper around the RAG endpoints described in doc/rag-api.md."""
|
||||
|
||||
def __init__(self, *, base_url: str | None = None, auth_token: str | None = None) -> None:
|
||||
resolved_base = base_url or RAG_API_BASE_URL
|
||||
self._base_url = resolved_base.rstrip("/")
|
||||
self._auth_token = auth_token or RAG_API_AUTH_TOKEN
|
||||
|
||||
def _headers(self) -> dict[str, str]:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self._auth_token:
|
||||
headers["Authorization"] = f"Bearer {self._auth_token}"
|
||||
return headers
|
||||
|
||||
async def _post(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
path: str,
|
||||
payload: Any,
|
||||
) -> Any:
|
||||
url = f"{self._base_url}{path}"
|
||||
try:
|
||||
response = await client.post(url, json=payload, headers=self._headers())
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
status_code = exc.response.status_code if exc.response else None
|
||||
response_text = exc.response.text if exc.response else ""
|
||||
logger.error(
|
||||
"RAG API responded with an error (%s) for %s: %s",
|
||||
status_code,
|
||||
url,
|
||||
response_text,
|
||||
exc_info=True,
|
||||
)
|
||||
raise ProviderAPICallError(
|
||||
"RAG API call failed.",
|
||||
status_code=status_code,
|
||||
response_text=response_text,
|
||||
) from exc
|
||||
except httpx.HTTPError as exc:
|
||||
logger.error("Transport error calling RAG API %s: %s", url, exc, exc_info=True)
|
||||
raise ProviderAPICallError(f"RAG API call failed: {exc}") from exc
|
||||
|
||||
try:
|
||||
return response.json()
|
||||
except ValueError:
|
||||
logger.warning("RAG API returned non-JSON response for %s; returning raw text.", url)
|
||||
return {"raw": response.text}
|
||||
|
||||
async def add(self, client: httpx.AsyncClient, payload: RagItemPayload) -> Any:
|
||||
body = payload.model_dump(by_alias=True, exclude_none=True)
|
||||
return await self._post(client, "/rag/add", body)
|
||||
|
||||
async def add_batch(self, client: httpx.AsyncClient, items: Sequence[RagItemPayload]) -> Any:
|
||||
body = [item.model_dump(by_alias=True, exclude_none=True) for item in items]
|
||||
return await self._post(client, "/rag/addBatch", body)
|
||||
|
||||
async def update(self, client: httpx.AsyncClient, payload: RagItemPayload) -> Any:
|
||||
body = payload.model_dump(by_alias=True, exclude_none=True)
|
||||
return await self._post(client, "/rag/update", body)
|
||||
|
||||
async def delete(self, client: httpx.AsyncClient, payload: RagDeleteRequest) -> Any:
|
||||
body = payload.model_dump(by_alias=True, exclude_none=True)
|
||||
return await self._post(client, "/rag/delete", body)
|
||||
|
||||
async def retrieve(self, client: httpx.AsyncClient, payload: RagRetrieveRequest) -> Any:
|
||||
body = payload.model_dump(by_alias=True, exclude_none=True)
|
||||
return await self._post(client, "/rag/retrieve", body)
|
||||
Reference in New Issue
Block a user