指标生成和查询相关功能api
This commit is contained in:
91
test/test_rag_client.py
Normal file
91
test/test_rag_client.py
Normal file
@ -0,0 +1,91 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from app.exceptions import ProviderAPICallError
|
||||
from app.schemas.rag import RagDeleteRequest, RagItemPayload, RagRetrieveRequest
|
||||
from app.services.rag_client import RagAPIClient
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_sends_payload_and_headers() -> None:
|
||||
rag_client = RagAPIClient(base_url="http://rag.test", auth_token="secret-token")
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
assert request.method == "POST"
|
||||
assert str(request.url) == "http://rag.test/rag/add"
|
||||
assert request.headers["Authorization"] == "Bearer secret-token"
|
||||
payload = json.loads(request.content.decode())
|
||||
assert payload == {
|
||||
"id": 1,
|
||||
"workspaceId": 2,
|
||||
"name": "demo",
|
||||
"embeddingData": "vector",
|
||||
"type": "METRIC",
|
||||
}
|
||||
return httpx.Response(200, json={"ok": True, "echo": payload})
|
||||
|
||||
transport = httpx.MockTransport(handler)
|
||||
async with httpx.AsyncClient(transport=transport) as client:
|
||||
result = await rag_client.add(
|
||||
client,
|
||||
RagItemPayload(id=1, workspaceId=2, name="demo", embeddingData="vector", type="METRIC"),
|
||||
)
|
||||
assert result["ok"] is True
|
||||
assert result["echo"]["name"] == "demo"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_batch_serializes_list() -> None:
|
||||
rag_client = RagAPIClient(base_url="http://rag.test", auth_token=None)
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
payload = json.loads(request.content.decode())
|
||||
assert request.url.path == "/rag/addBatch"
|
||||
assert isinstance(payload, list) and len(payload) == 2
|
||||
return httpx.Response(200, json={"received": len(payload)})
|
||||
|
||||
items = [
|
||||
RagItemPayload(id=1, workspaceId=2, name="a", embeddingData="vec-a", type="METRIC"),
|
||||
RagItemPayload(id=2, workspaceId=2, name="b", embeddingData="vec-b", type="METRIC"),
|
||||
]
|
||||
transport = httpx.MockTransport(handler)
|
||||
async with httpx.AsyncClient(transport=transport) as client:
|
||||
result = await rag_client.add_batch(client, items)
|
||||
assert result == {"received": 2}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_error_raises_provider_error() -> None:
|
||||
rag_client = RagAPIClient(base_url="http://rag.test")
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(500, text="boom")
|
||||
|
||||
transport = httpx.MockTransport(handler)
|
||||
async with httpx.AsyncClient(transport=transport) as client:
|
||||
with pytest.raises(ProviderAPICallError) as excinfo:
|
||||
await rag_client.delete(client, RagDeleteRequest(id=1, type="METRIC"))
|
||||
|
||||
err = excinfo.value
|
||||
assert err.status_code == 500
|
||||
assert "boom" in (err.response_text or "")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_json_response_returns_raw_text() -> None:
|
||||
rag_client = RagAPIClient(base_url="http://rag.test")
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(200, text="plain-text-body")
|
||||
|
||||
transport = httpx.MockTransport(handler)
|
||||
async with httpx.AsyncClient(transport=transport) as client:
|
||||
result = await rag_client.retrieve(
|
||||
client, RagRetrieveRequest(query="foo", num=1, workspaceId=1, type="METRIC")
|
||||
)
|
||||
assert result == {"raw": "plain-text-body"}
|
||||
|
||||
Reference in New Issue
Block a user