Files
data-ge/test/test_rag_client.py
2025-12-08 23:16:13 +08:00

92 lines
3.3 KiB
Python

from __future__ import annotations
import json
import httpx
import pytest
from app.exceptions import ProviderAPICallError
from app.schemas.rag import RagDeleteRequest, RagItemPayload, RagRetrieveRequest
from app.services.rag_client import RagAPIClient
@pytest.mark.asyncio
async def test_add_sends_payload_and_headers() -> None:
rag_client = RagAPIClient(base_url="http://rag.test", auth_token="secret-token")
def handler(request: httpx.Request) -> httpx.Response:
assert request.method == "POST"
assert str(request.url) == "http://rag.test/rag/add"
assert request.headers["Authorization"] == "Bearer secret-token"
payload = json.loads(request.content.decode())
assert payload == {
"id": 1,
"workspaceId": 2,
"name": "demo",
"embeddingData": "vector",
"type": "METRIC",
}
return httpx.Response(200, json={"ok": True, "echo": payload})
transport = httpx.MockTransport(handler)
async with httpx.AsyncClient(transport=transport) as client:
result = await rag_client.add(
client,
RagItemPayload(id=1, workspaceId=2, name="demo", embeddingData="vector", type="METRIC"),
)
assert result["ok"] is True
assert result["echo"]["name"] == "demo"
@pytest.mark.asyncio
async def test_add_batch_serializes_list() -> None:
rag_client = RagAPIClient(base_url="http://rag.test", auth_token=None)
def handler(request: httpx.Request) -> httpx.Response:
payload = json.loads(request.content.decode())
assert request.url.path == "/rag/addBatch"
assert isinstance(payload, list) and len(payload) == 2
return httpx.Response(200, json={"received": len(payload)})
items = [
RagItemPayload(id=1, workspaceId=2, name="a", embeddingData="vec-a", type="METRIC"),
RagItemPayload(id=2, workspaceId=2, name="b", embeddingData="vec-b", type="METRIC"),
]
transport = httpx.MockTransport(handler)
async with httpx.AsyncClient(transport=transport) as client:
result = await rag_client.add_batch(client, items)
assert result == {"received": 2}
@pytest.mark.asyncio
async def test_http_error_raises_provider_error() -> None:
rag_client = RagAPIClient(base_url="http://rag.test")
def handler(request: httpx.Request) -> httpx.Response:
return httpx.Response(500, text="boom")
transport = httpx.MockTransport(handler)
async with httpx.AsyncClient(transport=transport) as client:
with pytest.raises(ProviderAPICallError) as excinfo:
await rag_client.delete(client, RagDeleteRequest(id=1, type="METRIC"))
err = excinfo.value
assert err.status_code == 500
assert "boom" in (err.response_text or "")
@pytest.mark.asyncio
async def test_non_json_response_returns_raw_text() -> None:
rag_client = RagAPIClient(base_url="http://rag.test")
def handler(request: httpx.Request) -> httpx.Response:
return httpx.Response(200, text="plain-text-body")
transport = httpx.MockTransport(handler)
async with httpx.AsyncClient(transport=transport) as client:
result = await rag_client.retrieve(
client, RagRetrieveRequest(query="foo", num=1, workspaceId=1, type="METRIC")
)
assert result == {"raw": "plain-text-body"}