from __future__ import annotations import os import random from pathlib import Path from typing import Generator, List import sys import pytest from fastapi.testclient import TestClient from sqlalchemy import text from sqlalchemy.exc import SQLAlchemyError # Ensure the project root is importable when running directly via python. ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from app import db from app.main import create_app TEST_USER_ID = 872341 SCHEMA_PATH = Path("file/tableschema/chat.sql") DEFAULT_MYSQL_URL = "mysql+pymysql://root:12345678@127.0.0.1:3306/data-ge?charset=utf8mb4" @pytest.fixture(scope="module") def client() -> Generator[TestClient, None, None]: mysql_url = os.getenv("TEST_DATABASE_URL", DEFAULT_MYSQL_URL) os.environ["DATABASE_URL"] = mysql_url db.get_engine.cache_clear() engine = db.get_engine() try: # Quick connectivity check with engine.connect() as conn: conn.execute(text("SELECT 1")) except SQLAlchemyError: pytest.skip(f"Cannot connect to MySQL at {mysql_url}") #_ensure_chat_schema(engine) app = create_app() with TestClient(app) as test_client: yield test_client # cleanup test artifacts with engine.begin() as conn: # remove retrievals and turns tied to test sessions conn.execute( text( """ DELETE FROM chat_turn_retrieval WHERE turn_id IN ( SELECT id FROM chat_turn WHERE session_id IN (SELECT id FROM chat_session WHERE user_id=:uid) ) """ ), {"uid": TEST_USER_ID}, ) conn.execute( text("DELETE FROM chat_turn WHERE session_id IN (SELECT id FROM chat_session WHERE user_id=:uid)"), {"uid": TEST_USER_ID}, ) conn.execute(text("DELETE FROM chat_session WHERE user_id=:uid"), {"uid": TEST_USER_ID}) db.get_engine.cache_clear() def test_session_lifecycle_mysql(client: TestClient) -> None: # Create a session resp = client.post("/api/v1/chat/sessions", json={"user_id": TEST_USER_ID}) assert resp.status_code == 200, resp.text session = resp.json() session_id = session["id"] assert session["status"] == "OPEN" # Get session assert client.get(f"/api/v1/chat/sessions/{session_id}").status_code == 200 # List sessions (filter by user) resp = client.get(f"/api/v1/chat/sessions", params={"user_id": TEST_USER_ID}) assert resp.status_code == 200 assert any(item["id"] == session_id for item in resp.json()) # Update status resp = client.post(f"/api/v1/chat/sessions/{session_id}/update", json={"status": "PAUSED"}) assert resp.status_code == 200 assert resp.json()["status"] == "PAUSED" # Close session resp = client.post(f"/api/v1/chat/sessions/{session_id}/close") assert resp.status_code == 200 assert resp.json()["status"] == "CLOSED" def test_turns_and_retrievals_mysql(client: TestClient) -> None: session_id = client.post("/api/v1/chat/sessions", json={"user_id": TEST_USER_ID}).json()["id"] turn_payload = { "user_id": TEST_USER_ID, "user_query": "展示昨天订单GMV", "intent": "METRIC_QUERY", "ast_json": {"select": ["gmv"], "where": {"dt": "yesterday"}}, "main_metric_ids": [random.randint(1000, 9999)], "created_metric_ids": [], } resp = client.post(f"/api/v1/chat/sessions/{session_id}/turns", json=turn_payload) assert resp.status_code == 200, resp.text turn = resp.json() turn_id = turn["id"] assert turn["turn_no"] == 1 # Fetch turn assert client.get(f"/api/v1/chat/turns/{turn_id}").status_code == 200 # List turns under session resp = client.get(f"/api/v1/chat/sessions/{session_id}/turns") assert resp.status_code == 200 assert any(t["id"] == turn_id for t in resp.json()) # Insert retrievals retrievals_payload = { "retrievals": [ {"item_type": "METRIC", "item_id": "metric_foo", "used_in_sql": True, "rank_no": 1}, {"item_type": "SNIPPET", "item_id": "snpt_bar", "similarity_score": 0.77, "rank_no": 2}, ] } resp = client.post(f"/api/v1/chat/turns/{turn_id}/retrievals", json=retrievals_payload) assert resp.status_code == 200 assert resp.json()["inserted"] == 2 # List retrievals resp = client.get(f"/api/v1/chat/turns/{turn_id}/retrievals") assert resp.status_code == 200 items = resp.json() assert len(items) == 2 assert {item["item_type"] for item in items} == {"METRIC", "SNIPPET"} if __name__ == "__main__": import pytest as _pytest raise SystemExit(_pytest.main([__file__]))