92 lines
3.3 KiB
Python
92 lines
3.3 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
|
|
import httpx
|
|
import pytest
|
|
|
|
from app.exceptions import ProviderAPICallError
|
|
from app.schemas.rag import RagDeleteRequest, RagItemPayload, RagRetrieveRequest
|
|
from app.services.rag_client import RagAPIClient
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_add_sends_payload_and_headers() -> None:
|
|
rag_client = RagAPIClient(base_url="http://rag.test", auth_token="secret-token")
|
|
|
|
def handler(request: httpx.Request) -> httpx.Response:
|
|
assert request.method == "POST"
|
|
assert str(request.url) == "http://rag.test/rag/add"
|
|
assert request.headers["Authorization"] == "Bearer secret-token"
|
|
payload = json.loads(request.content.decode())
|
|
assert payload == {
|
|
"id": 1,
|
|
"workspaceId": 2,
|
|
"name": "demo",
|
|
"embeddingData": "vector",
|
|
"type": "METRIC",
|
|
}
|
|
return httpx.Response(200, json={"ok": True, "echo": payload})
|
|
|
|
transport = httpx.MockTransport(handler)
|
|
async with httpx.AsyncClient(transport=transport) as client:
|
|
result = await rag_client.add(
|
|
client,
|
|
RagItemPayload(id=1, workspaceId=2, name="demo", embeddingData="vector", type="METRIC"),
|
|
)
|
|
assert result["ok"] is True
|
|
assert result["echo"]["name"] == "demo"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_add_batch_serializes_list() -> None:
|
|
rag_client = RagAPIClient(base_url="http://rag.test", auth_token=None)
|
|
|
|
def handler(request: httpx.Request) -> httpx.Response:
|
|
payload = json.loads(request.content.decode())
|
|
assert request.url.path == "/rag/addBatch"
|
|
assert isinstance(payload, list) and len(payload) == 2
|
|
return httpx.Response(200, json={"received": len(payload)})
|
|
|
|
items = [
|
|
RagItemPayload(id=1, workspaceId=2, name="a", embeddingData="vec-a", type="METRIC"),
|
|
RagItemPayload(id=2, workspaceId=2, name="b", embeddingData="vec-b", type="METRIC"),
|
|
]
|
|
transport = httpx.MockTransport(handler)
|
|
async with httpx.AsyncClient(transport=transport) as client:
|
|
result = await rag_client.add_batch(client, items)
|
|
assert result == {"received": 2}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_http_error_raises_provider_error() -> None:
|
|
rag_client = RagAPIClient(base_url="http://rag.test")
|
|
|
|
def handler(request: httpx.Request) -> httpx.Response:
|
|
return httpx.Response(500, text="boom")
|
|
|
|
transport = httpx.MockTransport(handler)
|
|
async with httpx.AsyncClient(transport=transport) as client:
|
|
with pytest.raises(ProviderAPICallError) as excinfo:
|
|
await rag_client.delete(client, RagDeleteRequest(id=1, type="METRIC"))
|
|
|
|
err = excinfo.value
|
|
assert err.status_code == 500
|
|
assert "boom" in (err.response_text or "")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_non_json_response_returns_raw_text() -> None:
|
|
rag_client = RagAPIClient(base_url="http://rag.test")
|
|
|
|
def handler(request: httpx.Request) -> httpx.Response:
|
|
return httpx.Response(200, text="plain-text-body")
|
|
|
|
transport = httpx.MockTransport(handler)
|
|
async with httpx.AsyncClient(transport=transport) as client:
|
|
result = await rag_client.retrieve(
|
|
client, RagRetrieveRequest(query="foo", num=1, workspaceId=1, type="METRIC")
|
|
)
|
|
assert result == {"raw": "plain-text-body"}
|
|
|