143 lines
4.7 KiB
Python
143 lines
4.7 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
import random
|
|
from pathlib import Path
|
|
from typing import Generator, List
|
|
import sys
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
from sqlalchemy import text
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
|
|
# Ensure the project root is importable when running directly via python.
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
if str(ROOT) not in sys.path:
|
|
sys.path.insert(0, str(ROOT))
|
|
|
|
from app import db
|
|
from app.main import create_app
|
|
|
|
|
|
TEST_USER_ID = 872341
|
|
SCHEMA_PATH = Path("file/tableschema/chat.sql")
|
|
DEFAULT_MYSQL_URL = "mysql+pymysql://root:12345678@127.0.0.1:3306/data-ge?charset=utf8mb4"
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def client() -> Generator[TestClient, None, None]:
|
|
mysql_url = os.getenv("TEST_DATABASE_URL", DEFAULT_MYSQL_URL)
|
|
os.environ["DATABASE_URL"] = mysql_url
|
|
db.get_engine.cache_clear()
|
|
engine = db.get_engine()
|
|
try:
|
|
# Quick connectivity check
|
|
with engine.connect() as conn:
|
|
conn.execute(text("SELECT 1"))
|
|
except SQLAlchemyError:
|
|
pytest.skip(f"Cannot connect to MySQL at {mysql_url}")
|
|
|
|
#_ensure_chat_schema(engine)
|
|
|
|
app = create_app()
|
|
with TestClient(app) as test_client:
|
|
yield test_client
|
|
|
|
# cleanup test artifacts
|
|
with engine.begin() as conn:
|
|
# remove retrievals and turns tied to test sessions
|
|
conn.execute(
|
|
text(
|
|
"""
|
|
DELETE FROM chat_turn_retrieval
|
|
WHERE turn_id IN (
|
|
SELECT id FROM chat_turn WHERE session_id IN (SELECT id FROM chat_session WHERE user_id=:uid)
|
|
)
|
|
"""
|
|
),
|
|
{"uid": TEST_USER_ID},
|
|
)
|
|
conn.execute(
|
|
text("DELETE FROM chat_turn WHERE session_id IN (SELECT id FROM chat_session WHERE user_id=:uid)"),
|
|
{"uid": TEST_USER_ID},
|
|
)
|
|
conn.execute(text("DELETE FROM chat_session WHERE user_id=:uid"), {"uid": TEST_USER_ID})
|
|
db.get_engine.cache_clear()
|
|
|
|
|
|
def test_session_lifecycle_mysql(client: TestClient) -> None:
|
|
# Create a session
|
|
resp = client.post("/api/v1/chat/sessions", json={"user_id": TEST_USER_ID})
|
|
assert resp.status_code == 200, resp.text
|
|
session = resp.json()
|
|
session_id = session["id"]
|
|
assert session["status"] == "OPEN"
|
|
|
|
# Get session
|
|
assert client.get(f"/api/v1/chat/sessions/{session_id}").status_code == 200
|
|
|
|
# List sessions (filter by user)
|
|
resp = client.get(f"/api/v1/chat/sessions", params={"user_id": TEST_USER_ID})
|
|
assert resp.status_code == 200
|
|
assert any(item["id"] == session_id for item in resp.json())
|
|
|
|
# Update status
|
|
resp = client.post(f"/api/v1/chat/sessions/{session_id}/update", json={"status": "PAUSED"})
|
|
assert resp.status_code == 200
|
|
assert resp.json()["status"] == "PAUSED"
|
|
|
|
# Close session
|
|
resp = client.post(f"/api/v1/chat/sessions/{session_id}/close")
|
|
assert resp.status_code == 200
|
|
assert resp.json()["status"] == "CLOSED"
|
|
|
|
|
|
def test_turns_and_retrievals_mysql(client: TestClient) -> None:
|
|
session_id = client.post("/api/v1/chat/sessions", json={"user_id": TEST_USER_ID}).json()["id"]
|
|
turn_payload = {
|
|
"user_id": TEST_USER_ID,
|
|
"user_query": "展示昨天订单GMV",
|
|
"intent": "METRIC_QUERY",
|
|
"ast_json": {"select": ["gmv"], "where": {"dt": "yesterday"}},
|
|
"main_metric_ids": [random.randint(1000, 9999)],
|
|
"created_metric_ids": [],
|
|
}
|
|
resp = client.post(f"/api/v1/chat/sessions/{session_id}/turns", json=turn_payload)
|
|
assert resp.status_code == 200, resp.text
|
|
turn = resp.json()
|
|
turn_id = turn["id"]
|
|
assert turn["turn_no"] == 1
|
|
|
|
# Fetch turn
|
|
assert client.get(f"/api/v1/chat/turns/{turn_id}").status_code == 200
|
|
|
|
# List turns under session
|
|
resp = client.get(f"/api/v1/chat/sessions/{session_id}/turns")
|
|
assert resp.status_code == 200
|
|
assert any(t["id"] == turn_id for t in resp.json())
|
|
|
|
# Insert retrievals
|
|
retrievals_payload = {
|
|
"retrievals": [
|
|
{"item_type": "METRIC", "item_id": "metric_foo", "used_in_sql": True, "rank_no": 1},
|
|
{"item_type": "SNIPPET", "item_id": "snpt_bar", "similarity_score": 0.77, "rank_no": 2},
|
|
]
|
|
}
|
|
resp = client.post(f"/api/v1/chat/turns/{turn_id}/retrievals", json=retrievals_payload)
|
|
assert resp.status_code == 200
|
|
assert resp.json()["inserted"] == 2
|
|
|
|
# List retrievals
|
|
resp = client.get(f"/api/v1/chat/turns/{turn_id}/retrievals")
|
|
assert resp.status_code == 200
|
|
items = resp.json()
|
|
assert len(items) == 2
|
|
assert {item["item_type"] for item in items} == {"METRIC", "SNIPPET"}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import pytest as _pytest
|
|
|
|
raise SystemExit(_pytest.main([__file__]))
|