会话及轮次等相关api

This commit is contained in:
zhaoawd
2025-12-08 23:15:04 +08:00
parent f261121845
commit 509dae3270
6 changed files with 532 additions and 0 deletions

142
test/test_chat_api_mysql.py Normal file
View File

@ -0,0 +1,142 @@
from __future__ import annotations
import os
import random
from pathlib import Path
from typing import Generator, List
import sys
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import text
from sqlalchemy.exc import SQLAlchemyError
# Ensure the project root is importable when running directly via python.
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from app import db
from app.main import create_app
TEST_USER_ID = 872341
SCHEMA_PATH = Path("file/tableschema/chat.sql")
DEFAULT_MYSQL_URL = "mysql+pymysql://root:12345678@127.0.0.1:3306/data-ge?charset=utf8mb4"
@pytest.fixture(scope="module")
def client() -> Generator[TestClient, None, None]:
mysql_url = os.getenv("TEST_DATABASE_URL", DEFAULT_MYSQL_URL)
os.environ["DATABASE_URL"] = mysql_url
db.get_engine.cache_clear()
engine = db.get_engine()
try:
# Quick connectivity check
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
except SQLAlchemyError:
pytest.skip(f"Cannot connect to MySQL at {mysql_url}")
#_ensure_chat_schema(engine)
app = create_app()
with TestClient(app) as test_client:
yield test_client
# cleanup test artifacts
with engine.begin() as conn:
# remove retrievals and turns tied to test sessions
conn.execute(
text(
"""
DELETE FROM chat_turn_retrieval
WHERE turn_id IN (
SELECT id FROM chat_turn WHERE session_id IN (SELECT id FROM chat_session WHERE user_id=:uid)
)
"""
),
{"uid": TEST_USER_ID},
)
conn.execute(
text("DELETE FROM chat_turn WHERE session_id IN (SELECT id FROM chat_session WHERE user_id=:uid)"),
{"uid": TEST_USER_ID},
)
conn.execute(text("DELETE FROM chat_session WHERE user_id=:uid"), {"uid": TEST_USER_ID})
db.get_engine.cache_clear()
def test_session_lifecycle_mysql(client: TestClient) -> None:
# Create a session
resp = client.post("/api/v1/chat/sessions", json={"user_id": TEST_USER_ID})
assert resp.status_code == 200, resp.text
session = resp.json()
session_id = session["id"]
assert session["status"] == "OPEN"
# Get session
assert client.get(f"/api/v1/chat/sessions/{session_id}").status_code == 200
# List sessions (filter by user)
resp = client.get(f"/api/v1/chat/sessions", params={"user_id": TEST_USER_ID})
assert resp.status_code == 200
assert any(item["id"] == session_id for item in resp.json())
# Update status
resp = client.post(f"/api/v1/chat/sessions/{session_id}/update", json={"status": "PAUSED"})
assert resp.status_code == 200
assert resp.json()["status"] == "PAUSED"
# Close session
resp = client.post(f"/api/v1/chat/sessions/{session_id}/close")
assert resp.status_code == 200
assert resp.json()["status"] == "CLOSED"
def test_turns_and_retrievals_mysql(client: TestClient) -> None:
session_id = client.post("/api/v1/chat/sessions", json={"user_id": TEST_USER_ID}).json()["id"]
turn_payload = {
"user_id": TEST_USER_ID,
"user_query": "展示昨天订单GMV",
"intent": "METRIC_QUERY",
"ast_json": {"select": ["gmv"], "where": {"dt": "yesterday"}},
"main_metric_ids": [random.randint(1000, 9999)],
"created_metric_ids": [],
}
resp = client.post(f"/api/v1/chat/sessions/{session_id}/turns", json=turn_payload)
assert resp.status_code == 200, resp.text
turn = resp.json()
turn_id = turn["id"]
assert turn["turn_no"] == 1
# Fetch turn
assert client.get(f"/api/v1/chat/turns/{turn_id}").status_code == 200
# List turns under session
resp = client.get(f"/api/v1/chat/sessions/{session_id}/turns")
assert resp.status_code == 200
assert any(t["id"] == turn_id for t in resp.json())
# Insert retrievals
retrievals_payload = {
"retrievals": [
{"item_type": "METRIC", "item_id": "metric_foo", "used_in_sql": True, "rank_no": 1},
{"item_type": "SNIPPET", "item_id": "snpt_bar", "similarity_score": 0.77, "rank_no": 2},
]
}
resp = client.post(f"/api/v1/chat/turns/{turn_id}/retrievals", json=retrievals_payload)
assert resp.status_code == 200
assert resp.json()["inserted"] == 2
# List retrievals
resp = client.get(f"/api/v1/chat/turns/{turn_id}/retrievals")
assert resp.status_code == 200
items = resp.json()
assert len(items) == 2
assert {item["item_type"] for item in items} == {"METRIC", "SNIPPET"}
if __name__ == "__main__":
import pytest as _pytest
raise SystemExit(_pytest.main([__file__]))