会话及轮次等相关api
This commit is contained in:
142
test/test_chat_api_mysql.py
Normal file
142
test/test_chat_api_mysql.py
Normal file
@ -0,0 +1,142 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Generator, List
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
# Ensure the project root is importable when running directly via python.
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
from app import db
|
||||
from app.main import create_app
|
||||
|
||||
|
||||
TEST_USER_ID = 872341
|
||||
SCHEMA_PATH = Path("file/tableschema/chat.sql")
|
||||
DEFAULT_MYSQL_URL = "mysql+pymysql://root:12345678@127.0.0.1:3306/data-ge?charset=utf8mb4"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client() -> Generator[TestClient, None, None]:
|
||||
mysql_url = os.getenv("TEST_DATABASE_URL", DEFAULT_MYSQL_URL)
|
||||
os.environ["DATABASE_URL"] = mysql_url
|
||||
db.get_engine.cache_clear()
|
||||
engine = db.get_engine()
|
||||
try:
|
||||
# Quick connectivity check
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text("SELECT 1"))
|
||||
except SQLAlchemyError:
|
||||
pytest.skip(f"Cannot connect to MySQL at {mysql_url}")
|
||||
|
||||
#_ensure_chat_schema(engine)
|
||||
|
||||
app = create_app()
|
||||
with TestClient(app) as test_client:
|
||||
yield test_client
|
||||
|
||||
# cleanup test artifacts
|
||||
with engine.begin() as conn:
|
||||
# remove retrievals and turns tied to test sessions
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
DELETE FROM chat_turn_retrieval
|
||||
WHERE turn_id IN (
|
||||
SELECT id FROM chat_turn WHERE session_id IN (SELECT id FROM chat_session WHERE user_id=:uid)
|
||||
)
|
||||
"""
|
||||
),
|
||||
{"uid": TEST_USER_ID},
|
||||
)
|
||||
conn.execute(
|
||||
text("DELETE FROM chat_turn WHERE session_id IN (SELECT id FROM chat_session WHERE user_id=:uid)"),
|
||||
{"uid": TEST_USER_ID},
|
||||
)
|
||||
conn.execute(text("DELETE FROM chat_session WHERE user_id=:uid"), {"uid": TEST_USER_ID})
|
||||
db.get_engine.cache_clear()
|
||||
|
||||
|
||||
def test_session_lifecycle_mysql(client: TestClient) -> None:
|
||||
# Create a session
|
||||
resp = client.post("/api/v1/chat/sessions", json={"user_id": TEST_USER_ID})
|
||||
assert resp.status_code == 200, resp.text
|
||||
session = resp.json()
|
||||
session_id = session["id"]
|
||||
assert session["status"] == "OPEN"
|
||||
|
||||
# Get session
|
||||
assert client.get(f"/api/v1/chat/sessions/{session_id}").status_code == 200
|
||||
|
||||
# List sessions (filter by user)
|
||||
resp = client.get(f"/api/v1/chat/sessions", params={"user_id": TEST_USER_ID})
|
||||
assert resp.status_code == 200
|
||||
assert any(item["id"] == session_id for item in resp.json())
|
||||
|
||||
# Update status
|
||||
resp = client.post(f"/api/v1/chat/sessions/{session_id}/update", json={"status": "PAUSED"})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "PAUSED"
|
||||
|
||||
# Close session
|
||||
resp = client.post(f"/api/v1/chat/sessions/{session_id}/close")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "CLOSED"
|
||||
|
||||
|
||||
def test_turns_and_retrievals_mysql(client: TestClient) -> None:
|
||||
session_id = client.post("/api/v1/chat/sessions", json={"user_id": TEST_USER_ID}).json()["id"]
|
||||
turn_payload = {
|
||||
"user_id": TEST_USER_ID,
|
||||
"user_query": "展示昨天订单GMV",
|
||||
"intent": "METRIC_QUERY",
|
||||
"ast_json": {"select": ["gmv"], "where": {"dt": "yesterday"}},
|
||||
"main_metric_ids": [random.randint(1000, 9999)],
|
||||
"created_metric_ids": [],
|
||||
}
|
||||
resp = client.post(f"/api/v1/chat/sessions/{session_id}/turns", json=turn_payload)
|
||||
assert resp.status_code == 200, resp.text
|
||||
turn = resp.json()
|
||||
turn_id = turn["id"]
|
||||
assert turn["turn_no"] == 1
|
||||
|
||||
# Fetch turn
|
||||
assert client.get(f"/api/v1/chat/turns/{turn_id}").status_code == 200
|
||||
|
||||
# List turns under session
|
||||
resp = client.get(f"/api/v1/chat/sessions/{session_id}/turns")
|
||||
assert resp.status_code == 200
|
||||
assert any(t["id"] == turn_id for t in resp.json())
|
||||
|
||||
# Insert retrievals
|
||||
retrievals_payload = {
|
||||
"retrievals": [
|
||||
{"item_type": "METRIC", "item_id": "metric_foo", "used_in_sql": True, "rank_no": 1},
|
||||
{"item_type": "SNIPPET", "item_id": "snpt_bar", "similarity_score": 0.77, "rank_no": 2},
|
||||
]
|
||||
}
|
||||
resp = client.post(f"/api/v1/chat/turns/{turn_id}/retrievals", json=retrievals_payload)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["inserted"] == 2
|
||||
|
||||
# List retrievals
|
||||
resp = client.get(f"/api/v1/chat/turns/{turn_id}/retrievals")
|
||||
assert resp.status_code == 200
|
||||
items = resp.json()
|
||||
assert len(items) == 2
|
||||
assert {item["item_type"] for item in items} == {"METRIC", "SNIPPET"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest as _pytest
|
||||
|
||||
raise SystemExit(_pytest.main([__file__]))
|
||||
Reference in New Issue
Block a user