from __future__ import annotations import json import httpx import pytest from app.exceptions import ProviderAPICallError from app.schemas.rag import RagDeleteRequest, RagItemPayload, RagRetrieveRequest from app.services.rag_client import RagAPIClient @pytest.mark.asyncio async def test_add_sends_payload_and_headers() -> None: rag_client = RagAPIClient(base_url="http://rag.test", auth_token="secret-token") def handler(request: httpx.Request) -> httpx.Response: assert request.method == "POST" assert str(request.url) == "http://rag.test/rag/add" assert request.headers["Authorization"] == "Bearer secret-token" payload = json.loads(request.content.decode()) assert payload == { "id": 1, "workspaceId": 2, "name": "demo", "embeddingData": "vector", "type": "METRIC", } return httpx.Response(200, json={"ok": True, "echo": payload}) transport = httpx.MockTransport(handler) async with httpx.AsyncClient(transport=transport) as client: result = await rag_client.add( client, RagItemPayload(id=1, workspaceId=2, name="demo", embeddingData="vector", type="METRIC"), ) assert result["ok"] is True assert result["echo"]["name"] == "demo" @pytest.mark.asyncio async def test_add_batch_serializes_list() -> None: rag_client = RagAPIClient(base_url="http://rag.test", auth_token=None) def handler(request: httpx.Request) -> httpx.Response: payload = json.loads(request.content.decode()) assert request.url.path == "/rag/addBatch" assert isinstance(payload, list) and len(payload) == 2 return httpx.Response(200, json={"received": len(payload)}) items = [ RagItemPayload(id=1, workspaceId=2, name="a", embeddingData="vec-a", type="METRIC"), RagItemPayload(id=2, workspaceId=2, name="b", embeddingData="vec-b", type="METRIC"), ] transport = httpx.MockTransport(handler) async with httpx.AsyncClient(transport=transport) as client: result = await rag_client.add_batch(client, items) assert result == {"received": 2} @pytest.mark.asyncio async def test_http_error_raises_provider_error() -> None: rag_client = RagAPIClient(base_url="http://rag.test") def handler(request: httpx.Request) -> httpx.Response: return httpx.Response(500, text="boom") transport = httpx.MockTransport(handler) async with httpx.AsyncClient(transport=transport) as client: with pytest.raises(ProviderAPICallError) as excinfo: await rag_client.delete(client, RagDeleteRequest(id=1, type="METRIC")) err = excinfo.value assert err.status_code == 500 assert "boom" in (err.response_text or "") @pytest.mark.asyncio async def test_non_json_response_returns_raw_text() -> None: rag_client = RagAPIClient(base_url="http://rag.test") def handler(request: httpx.Request) -> httpx.Response: return httpx.Response(200, text="plain-text-body") transport = httpx.MockTransport(handler) async with httpx.AsyncClient(transport=transport) as client: result = await rag_client.retrieve( client, RagRetrieveRequest(query="foo", num=1, workspaceId=1, type="METRIC") ) assert result == {"raw": "plain-text-body"}