会话及轮次等相关api
This commit is contained in:
102
app/routers/chat.py
Normal file
102
app/routers/chat.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Query
|
||||||
|
|
||||||
|
from app.schemas.chat import (
|
||||||
|
ChatSessionCreate,
|
||||||
|
ChatSessionUpdate,
|
||||||
|
ChatTurnCreate,
|
||||||
|
ChatTurnRetrievalBatch,
|
||||||
|
)
|
||||||
|
from app.services import metric_store
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1/chat", tags=["chat"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions")
|
||||||
|
def create_session(payload: ChatSessionCreate) -> Any:
|
||||||
|
"""Create a chat session."""
|
||||||
|
return metric_store.create_chat_session(payload)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/update")
|
||||||
|
def update_session(session_id: int, payload: ChatSessionUpdate) -> Any:
|
||||||
|
try:
|
||||||
|
return metric_store.update_chat_session(session_id, payload)
|
||||||
|
except KeyError:
|
||||||
|
raise HTTPException(status_code=404, detail="Session not found")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/close")
|
||||||
|
def close_session(session_id: int) -> Any:
|
||||||
|
"""Close a chat session and stamp end_time."""
|
||||||
|
try:
|
||||||
|
return metric_store.close_chat_session(session_id)
|
||||||
|
except KeyError:
|
||||||
|
raise HTTPException(status_code=404, detail="Session not found")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sessions/{session_id}")
|
||||||
|
def get_session(session_id: int) -> Any:
|
||||||
|
"""Fetch one session."""
|
||||||
|
session = metric_store.get_chat_session(session_id)
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(status_code=404, detail="Session not found")
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sessions")
|
||||||
|
def list_sessions(
|
||||||
|
user_id: Optional[int] = None,
|
||||||
|
status: Optional[str] = None,
|
||||||
|
start_from: Optional[datetime] = Query(None, description="Filter by start time lower bound."),
|
||||||
|
start_to: Optional[datetime] = Query(None, description="Filter by start time upper bound."),
|
||||||
|
limit: int = Query(50, ge=1, le=500),
|
||||||
|
offset: int = Query(0, ge=0),
|
||||||
|
) -> List[Any]:
|
||||||
|
return metric_store.list_chat_sessions(
|
||||||
|
user_id=user_id,
|
||||||
|
status=status,
|
||||||
|
start_from=start_from,
|
||||||
|
start_to=start_to,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sessions/{session_id}/turns")
|
||||||
|
def create_turn(session_id: int, payload: ChatTurnCreate) -> Any:
|
||||||
|
"""Create a turn under a session."""
|
||||||
|
try:
|
||||||
|
return metric_store.create_chat_turn(session_id, payload)
|
||||||
|
except Exception as exc:
|
||||||
|
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sessions/{session_id}/turns")
|
||||||
|
def list_turns(session_id: int) -> List[Any]:
|
||||||
|
return metric_store.list_chat_turns(session_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/turns/{turn_id}")
|
||||||
|
def get_turn(turn_id: int) -> Any:
|
||||||
|
turn = metric_store.get_chat_turn(turn_id)
|
||||||
|
if not turn:
|
||||||
|
raise HTTPException(status_code=404, detail="Turn not found")
|
||||||
|
return turn
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/turns/{turn_id}/retrievals")
|
||||||
|
def write_retrievals(turn_id: int, payload: ChatTurnRetrievalBatch) -> Any:
|
||||||
|
"""Batch write retrieval records for a turn."""
|
||||||
|
count = metric_store.create_retrievals(turn_id, payload.retrievals)
|
||||||
|
return {"turn_id": turn_id, "inserted": count}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/turns/{turn_id}/retrievals")
|
||||||
|
def list_retrievals(turn_id: int) -> List[Any]:
|
||||||
|
return metric_store.list_retrievals(turn_id)
|
||||||
53
app/schemas/chat.py
Normal file
53
app/schemas/chat.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class ChatSessionCreate(BaseModel):
|
||||||
|
"""Create a chat session to group multiple turns for a user."""
|
||||||
|
user_id: int = Field(..., description="User ID owning the session.")
|
||||||
|
session_uuid: Optional[str] = Field(None, description="Optional externally provided UUID.")
|
||||||
|
status: Optional[str] = Field("OPEN", description="Session status, default OPEN.")
|
||||||
|
end_time: Optional[datetime] = Field(None, description="Optional end time.")
|
||||||
|
ext_context: Optional[dict[str, Any]] = Field(None, description="Arbitrary business context.")
|
||||||
|
|
||||||
|
|
||||||
|
class ChatSessionUpdate(BaseModel):
|
||||||
|
"""Partial update for a chat session."""
|
||||||
|
status: Optional[str] = Field(None, description="New session status.")
|
||||||
|
end_time: Optional[datetime] = Field(None, description="Close time override.")
|
||||||
|
last_turn_id: Optional[int] = Field(None, description="Pointer to last chat turn.")
|
||||||
|
ext_context: Optional[dict[str, Any]] = Field(None, description="Context patch.")
|
||||||
|
|
||||||
|
|
||||||
|
class ChatTurnCreate(BaseModel):
|
||||||
|
"""Create a single chat turn with intent/SQL context."""
|
||||||
|
user_id: int = Field(..., description="User ID for this turn.")
|
||||||
|
user_query: str = Field(..., description="Raw user query content.")
|
||||||
|
intent: Optional[str] = Field(None, description="Intent tag such as METRIC_QUERY.")
|
||||||
|
ast_json: Optional[dict[str, Any]] = Field(None, description="Parsed AST payload.")
|
||||||
|
generated_sql: Optional[str] = Field(None, description="Final SQL text, if generated.")
|
||||||
|
sql_status: Optional[str] = Field(None, description="SQL generation/execution status.")
|
||||||
|
error_msg: Optional[str] = Field(None, description="Error message when SQL failed.")
|
||||||
|
main_metric_ids: Optional[List[int]] = Field(None, description="Metric IDs referenced in this turn.")
|
||||||
|
created_metric_ids: Optional[List[int]] = Field(None, description="Metric IDs created in this turn.")
|
||||||
|
end_time: Optional[datetime] = Field(None, description="Turn end time.")
|
||||||
|
|
||||||
|
|
||||||
|
class ChatTurnRetrievalItem(BaseModel):
|
||||||
|
"""Record of one retrieved item contributing to a turn."""
|
||||||
|
item_type: str = Field(..., description="METRIC/SNIPPET/CHAT etc.")
|
||||||
|
item_id: str = Field(..., description="Identifier such as metric_id or snippet_id.")
|
||||||
|
item_extra: Optional[dict[str, Any]] = Field(None, description="Additional context like column name.")
|
||||||
|
similarity_score: Optional[float] = Field(None, description="Similarity score.")
|
||||||
|
rank_no: Optional[int] = Field(None, description="Ranking position.")
|
||||||
|
used_in_reasoning: Optional[bool] = Field(False, description="Flag if used in reasoning.")
|
||||||
|
used_in_sql: Optional[bool] = Field(False, description="Flag if used in final SQL.")
|
||||||
|
|
||||||
|
|
||||||
|
class ChatTurnRetrievalBatch(BaseModel):
|
||||||
|
"""Batch insert wrapper for retrieval records."""
|
||||||
|
retrievals: List[ChatTurnRetrievalItem]
|
||||||
49
doc/会话api.md
Normal file
49
doc/会话api.md
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
# 创建会话
|
||||||
|
curl -X POST "/api/v1/chat/sessions" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d "{\"user_id\": $CHAT_USER_ID}"
|
||||||
|
|
||||||
|
# 获取会话
|
||||||
|
curl "/api/v1/chat/sessions/{session_id}"
|
||||||
|
|
||||||
|
# 按用户列出会话
|
||||||
|
curl "/api/v1/chat/sessions?user_id=$CHAT_USER_ID"
|
||||||
|
|
||||||
|
# 更新会话状态
|
||||||
|
curl -X POST "/api/v1/chat/sessions/{session_id}/update" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"status":"PAUSED"}'
|
||||||
|
|
||||||
|
# 关闭会话
|
||||||
|
curl -X POST "/api/v1/chat/sessions/{session_id}/close"
|
||||||
|
|
||||||
|
# 创建对话轮次
|
||||||
|
curl -X POST "/api/v1/chat/sessions/{session_id}/turns" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"user_id": '"$CHAT_USER_ID"',
|
||||||
|
"user_query": "展示昨天订单GMV",
|
||||||
|
"intent": "METRIC_QUERY",
|
||||||
|
"ast_json": {"select":["gmv"],"where":{"dt":"yesterday"}},
|
||||||
|
"main_metric_ids": [1234],
|
||||||
|
"created_metric_ids": []
|
||||||
|
}'
|
||||||
|
|
||||||
|
# 获取单条对话轮次
|
||||||
|
curl "/api/v1/chat/turns/{turn_id}"
|
||||||
|
|
||||||
|
# 列出会话下的轮次
|
||||||
|
curl "/api/v1/chat/sessions/{session_id}/turns"
|
||||||
|
|
||||||
|
# 写入检索结果
|
||||||
|
curl -X POST "/api/v1/chat/turns/{turn_id}/retrievals" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"retrievals": [
|
||||||
|
{"item_type":"METRIC","item_id":"metric_foo","used_in_sql":true,"rank_no":1},
|
||||||
|
{"item_type":"SNIPPET","item_id":"snpt_bar","similarity_score":0.77,"rank_no":2}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
|
||||||
|
# 列出轮次的检索结果
|
||||||
|
curl "/api/v1/chat/turns/{turn_id}/retrievals"
|
||||||
83
doc/指标生成.md
Normal file
83
doc/指标生成.md
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
某个用户的一句问话 → 解析成某轮 chat_turn → 这轮用了哪些指标/知识/会话(chat_turn_retrieval) →
|
||||||
|
是否产生了新的指标(metric_def) →
|
||||||
|
是否触发了指标调度运行(metric_job_run.turn_id) →
|
||||||
|
最终产生了哪些指标结果(metric_result.metric_id + stat_time)。
|
||||||
|
|
||||||
|
会话域
|
||||||
|
schema
|
||||||
|
会话表 chat_session
|
||||||
|
|
||||||
|
会话轮次表 chat_turn
|
||||||
|
|
||||||
|
会话轮次检索关联表 chat_turn_retrieval
|
||||||
|
|
||||||
|
|
||||||
|
API
|
||||||
|
1. 创建会话
|
||||||
|
POST /api/v1/chat/sessions
|
||||||
|
2. 更新会话轮次
|
||||||
|
POST /api/v1/chat/sessions/{session_id}/update
|
||||||
|
3. 结束会话
|
||||||
|
POST /api/v1/chat/sessions/{session_id}/close
|
||||||
|
4. 查询会话
|
||||||
|
GET /api/v1/chat/sessions/{session_id}
|
||||||
|
5. 会话列表查询(按用户、时间)
|
||||||
|
GET /api/v1/chat/sessions
|
||||||
|
6. 创建问答轮次(用户发起 query)
|
||||||
|
POST /api/v1/chat/sessions/{session_id}/turns
|
||||||
|
7. 查询某会话的所有轮次
|
||||||
|
GET /api/v1/chat/sessions/{session_id}/turns
|
||||||
|
8. 查看单轮问答详情
|
||||||
|
GET /api/v1/chat/turns/{turn_id}
|
||||||
|
9. 批量写入某轮的检索结果
|
||||||
|
POST /api/v1/chat/turns/{turn_id}/retrievals
|
||||||
|
10. 查询某轮的检索记录
|
||||||
|
GET /api/v1/chat/turns/{turn_id}/retrievals
|
||||||
|
11. 更新某轮的检索记录(in future)
|
||||||
|
POST /api/v1/chat/turns/{turn_id}/retrievals/update
|
||||||
|
|
||||||
|
元数据域
|
||||||
|
schema
|
||||||
|
指标定义表 metric_def
|
||||||
|
|
||||||
|
|
||||||
|
API
|
||||||
|
12. 创建指标(来自问答或传统定义)
|
||||||
|
POST /api/v1/metrics
|
||||||
|
13. 更新指标
|
||||||
|
POST /api/v1/metrics/{id}
|
||||||
|
14. 获取指标详情
|
||||||
|
GET /api/v1/metrics
|
||||||
|
|
||||||
|
执行调度域(暂定airflow)
|
||||||
|
schema
|
||||||
|
指标调度配置表 metric_schedule
|
||||||
|
|
||||||
|
调度运行记录表 metric_job_run
|
||||||
|
|
||||||
|
API
|
||||||
|
1. 创建调度配置
|
||||||
|
POST /api/v1/metric-schedules
|
||||||
|
2. 更新调度配置
|
||||||
|
POST /api/v1/metric-schedules/{id}
|
||||||
|
3. 查询指标调度配置详情
|
||||||
|
GET /api/v1/metrics/{metric_id}/schedules
|
||||||
|
4. 手动触发一次指标运行(例如来自问数)
|
||||||
|
POST /api/v1/metric-runs/trigger
|
||||||
|
5. 查询运行记录列表
|
||||||
|
GET /api/v1/metric-runs
|
||||||
|
6. 查询单次运行详情
|
||||||
|
GET /api/metric-runs/{run_id}
|
||||||
|
|
||||||
|
数据域
|
||||||
|
schema
|
||||||
|
指标结果表(纵表)metric_result
|
||||||
|
|
||||||
|
|
||||||
|
API
|
||||||
|
1. 查询指标结果(按时间段 & 维度)
|
||||||
|
GET /api/metric-results
|
||||||
|
2. 单点查询(最新值)
|
||||||
|
GET /api/metric-results/latest
|
||||||
|
3. 批量写入指标结果
|
||||||
|
POST /api/v1/metric-results/{metrics_id}
|
||||||
103
file/tableschema/chat.sql
Normal file
103
file/tableschema/chat.sql
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
CREATE TABLE IF NOT EXISTS chat_session (
|
||||||
|
id BIGINT AUTO_INCREMENT PRIMARY KEY,
|
||||||
|
user_id BIGINT NOT NULL,
|
||||||
|
session_uuid CHAR(36) NOT NULL, -- 可用于对外展示的ID(UUID)
|
||||||
|
end_time DATETIME NULL,
|
||||||
|
status VARCHAR(16) NOT NULL DEFAULT 'OPEN', -- OPEN/CLOSED/ABANDONED
|
||||||
|
last_turn_id BIGINT NULL, -- 指向 chat_turn.id
|
||||||
|
ext_context JSON NULL, -- 业务上下文
|
||||||
|
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||||
|
UNIQUE KEY uk_session_uuid (session_uuid),
|
||||||
|
KEY idx_user_time (user_id, created_at),
|
||||||
|
KEY idx_status_time (status, created_at),
|
||||||
|
KEY idx_last_turn (last_turn_id)
|
||||||
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS chat_turn (
|
||||||
|
id BIGINT AUTO_INCREMENT,
|
||||||
|
session_id BIGINT NOT NULL, -- 关联 chat_session.id
|
||||||
|
turn_no INT NOT NULL, -- 会话内轮次序号(1,2,3...)
|
||||||
|
user_id BIGINT NOT NULL,
|
||||||
|
|
||||||
|
user_query TEXT NOT NULL, -- 原始用户问句
|
||||||
|
intent VARCHAR(64) NULL, -- METRIC_QUERY/METRIC_EXPLAIN 等
|
||||||
|
ast_json JSON NULL, -- 解析出来的 AST
|
||||||
|
|
||||||
|
generated_sql MEDIUMTEXT NULL, -- 生成的最终SQL
|
||||||
|
sql_status VARCHAR(32) NULL, -- SUCCESS/FAILED/SKIPPED
|
||||||
|
error_msg TEXT NULL, -- SQL生成/执行错误信息
|
||||||
|
|
||||||
|
main_metric_ids JSON NULL, -- 本轮涉及的指标ID列表
|
||||||
|
created_metric_ids JSON NULL, -- 本轮新建指标ID列表
|
||||||
|
|
||||||
|
end_time DATETIME NULL,
|
||||||
|
|
||||||
|
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||||
|
-- 主键改为联合主键,必须包含 created_at
|
||||||
|
PRIMARY KEY (id, created_at),
|
||||||
|
KEY idx_session_turn (session_id, turn_no),
|
||||||
|
KEY idx_session_time (session_id, created_at),
|
||||||
|
KEY idx_intent_time (intent, created_at),
|
||||||
|
KEY idx_user_time (user_id, created_at)
|
||||||
|
)
|
||||||
|
ENGINE=InnoDB
|
||||||
|
DEFAULT CHARSET=utf8mb4
|
||||||
|
PARTITION BY RANGE COLUMNS(created_at) (
|
||||||
|
-- 历史数据分区(根据实际需求调整)
|
||||||
|
PARTITION p202511 VALUES LESS THAN ('2025-12-01'),
|
||||||
|
PARTITION p202512 VALUES LESS THAN ('2026-01-01'),
|
||||||
|
-- 2026年按月分区
|
||||||
|
PARTITION p202601 VALUES LESS THAN ('2026-02-01'),
|
||||||
|
PARTITION p202602 VALUES LESS THAN ('2026-03-01'),
|
||||||
|
PARTITION p202603 VALUES LESS THAN ('2026-04-01'),
|
||||||
|
PARTITION p202604 VALUES LESS THAN ('2026-05-01'),
|
||||||
|
PARTITION p202605 VALUES LESS THAN ('2026-06-01'),
|
||||||
|
PARTITION p202606 VALUES LESS THAN ('2026-07-01'),
|
||||||
|
-- ... 可以预建几个月 ...
|
||||||
|
|
||||||
|
-- 兜底分区,存放未来的数据,防止插入报错
|
||||||
|
PARTITION p_future VALUES LESS THAN (MAXVALUE)
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS chat_turn_retrieval (
|
||||||
|
id BIGINT AUTO_INCREMENT,
|
||||||
|
turn_id BIGINT NOT NULL, -- 关联 qa_turn.id
|
||||||
|
|
||||||
|
item_type VARCHAR(32) NOT NULL, -- METRIC/SNIPPET/CHAT
|
||||||
|
item_id VARCHAR(128) NOT NULL, -- metric_id/snippet_id/table_name 等
|
||||||
|
item_extra JSON NULL, -- 附加信息,如字段名等
|
||||||
|
|
||||||
|
similarity_score DECIMAL(10,6) NULL, -- 相似度
|
||||||
|
rank_no INT NULL, -- 检索排名
|
||||||
|
used_in_reasoning TINYINT(1) NOT NULL DEFAULT 0, -- 是否参与推理
|
||||||
|
used_in_sql TINYINT(1) NOT NULL DEFAULT 0, -- 是否影响最终SQL
|
||||||
|
|
||||||
|
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||||
|
-- 主键改为联合主键,必须包含 created_at
|
||||||
|
PRIMARY KEY (id, created_at),
|
||||||
|
KEY idx_turn (turn_id),
|
||||||
|
KEY idx_turn_type (turn_id, item_type),
|
||||||
|
KEY idx_item (item_type, item_id)
|
||||||
|
)
|
||||||
|
ENGINE=InnoDB
|
||||||
|
DEFAULT CHARSET=utf8mb4
|
||||||
|
PARTITION BY RANGE COLUMNS(created_at) (
|
||||||
|
-- 历史数据分区(根据实际需求调整)
|
||||||
|
PARTITION p202511 VALUES LESS THAN ('2025-12-01'),
|
||||||
|
PARTITION p202512 VALUES LESS THAN ('2026-01-01'),
|
||||||
|
-- 2026年按月分区
|
||||||
|
PARTITION p202601 VALUES LESS THAN ('2026-02-01'),
|
||||||
|
PARTITION p202602 VALUES LESS THAN ('2026-03-01'),
|
||||||
|
PARTITION p202603 VALUES LESS THAN ('2026-04-01'),
|
||||||
|
PARTITION p202604 VALUES LESS THAN ('2026-05-01'),
|
||||||
|
PARTITION p202605 VALUES LESS THAN ('2026-06-01'),
|
||||||
|
PARTITION p202606 VALUES LESS THAN ('2026-07-01'),
|
||||||
|
-- ... 可以预建几个月 ...
|
||||||
|
|
||||||
|
-- 兜底分区,存放未来的数据,防止插入报错
|
||||||
|
PARTITION p_future VALUES LESS THAN (MAXVALUE)
|
||||||
|
);
|
||||||
142
test/test_chat_api_mysql.py
Normal file
142
test/test_chat_api_mysql.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Generator, List
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
||||||
|
# Ensure the project root is importable when running directly via python.
|
||||||
|
ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
if str(ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(ROOT))
|
||||||
|
|
||||||
|
from app import db
|
||||||
|
from app.main import create_app
|
||||||
|
|
||||||
|
|
||||||
|
TEST_USER_ID = 872341
|
||||||
|
SCHEMA_PATH = Path("file/tableschema/chat.sql")
|
||||||
|
DEFAULT_MYSQL_URL = "mysql+pymysql://root:12345678@127.0.0.1:3306/data-ge?charset=utf8mb4"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def client() -> Generator[TestClient, None, None]:
|
||||||
|
mysql_url = os.getenv("TEST_DATABASE_URL", DEFAULT_MYSQL_URL)
|
||||||
|
os.environ["DATABASE_URL"] = mysql_url
|
||||||
|
db.get_engine.cache_clear()
|
||||||
|
engine = db.get_engine()
|
||||||
|
try:
|
||||||
|
# Quick connectivity check
|
||||||
|
with engine.connect() as conn:
|
||||||
|
conn.execute(text("SELECT 1"))
|
||||||
|
except SQLAlchemyError:
|
||||||
|
pytest.skip(f"Cannot connect to MySQL at {mysql_url}")
|
||||||
|
|
||||||
|
#_ensure_chat_schema(engine)
|
||||||
|
|
||||||
|
app = create_app()
|
||||||
|
with TestClient(app) as test_client:
|
||||||
|
yield test_client
|
||||||
|
|
||||||
|
# cleanup test artifacts
|
||||||
|
with engine.begin() as conn:
|
||||||
|
# remove retrievals and turns tied to test sessions
|
||||||
|
conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
DELETE FROM chat_turn_retrieval
|
||||||
|
WHERE turn_id IN (
|
||||||
|
SELECT id FROM chat_turn WHERE session_id IN (SELECT id FROM chat_session WHERE user_id=:uid)
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"uid": TEST_USER_ID},
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
text("DELETE FROM chat_turn WHERE session_id IN (SELECT id FROM chat_session WHERE user_id=:uid)"),
|
||||||
|
{"uid": TEST_USER_ID},
|
||||||
|
)
|
||||||
|
conn.execute(text("DELETE FROM chat_session WHERE user_id=:uid"), {"uid": TEST_USER_ID})
|
||||||
|
db.get_engine.cache_clear()
|
||||||
|
|
||||||
|
|
||||||
|
def test_session_lifecycle_mysql(client: TestClient) -> None:
|
||||||
|
# Create a session
|
||||||
|
resp = client.post("/api/v1/chat/sessions", json={"user_id": TEST_USER_ID})
|
||||||
|
assert resp.status_code == 200, resp.text
|
||||||
|
session = resp.json()
|
||||||
|
session_id = session["id"]
|
||||||
|
assert session["status"] == "OPEN"
|
||||||
|
|
||||||
|
# Get session
|
||||||
|
assert client.get(f"/api/v1/chat/sessions/{session_id}").status_code == 200
|
||||||
|
|
||||||
|
# List sessions (filter by user)
|
||||||
|
resp = client.get(f"/api/v1/chat/sessions", params={"user_id": TEST_USER_ID})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert any(item["id"] == session_id for item in resp.json())
|
||||||
|
|
||||||
|
# Update status
|
||||||
|
resp = client.post(f"/api/v1/chat/sessions/{session_id}/update", json={"status": "PAUSED"})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["status"] == "PAUSED"
|
||||||
|
|
||||||
|
# Close session
|
||||||
|
resp = client.post(f"/api/v1/chat/sessions/{session_id}/close")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["status"] == "CLOSED"
|
||||||
|
|
||||||
|
|
||||||
|
def test_turns_and_retrievals_mysql(client: TestClient) -> None:
|
||||||
|
session_id = client.post("/api/v1/chat/sessions", json={"user_id": TEST_USER_ID}).json()["id"]
|
||||||
|
turn_payload = {
|
||||||
|
"user_id": TEST_USER_ID,
|
||||||
|
"user_query": "展示昨天订单GMV",
|
||||||
|
"intent": "METRIC_QUERY",
|
||||||
|
"ast_json": {"select": ["gmv"], "where": {"dt": "yesterday"}},
|
||||||
|
"main_metric_ids": [random.randint(1000, 9999)],
|
||||||
|
"created_metric_ids": [],
|
||||||
|
}
|
||||||
|
resp = client.post(f"/api/v1/chat/sessions/{session_id}/turns", json=turn_payload)
|
||||||
|
assert resp.status_code == 200, resp.text
|
||||||
|
turn = resp.json()
|
||||||
|
turn_id = turn["id"]
|
||||||
|
assert turn["turn_no"] == 1
|
||||||
|
|
||||||
|
# Fetch turn
|
||||||
|
assert client.get(f"/api/v1/chat/turns/{turn_id}").status_code == 200
|
||||||
|
|
||||||
|
# List turns under session
|
||||||
|
resp = client.get(f"/api/v1/chat/sessions/{session_id}/turns")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert any(t["id"] == turn_id for t in resp.json())
|
||||||
|
|
||||||
|
# Insert retrievals
|
||||||
|
retrievals_payload = {
|
||||||
|
"retrievals": [
|
||||||
|
{"item_type": "METRIC", "item_id": "metric_foo", "used_in_sql": True, "rank_no": 1},
|
||||||
|
{"item_type": "SNIPPET", "item_id": "snpt_bar", "similarity_score": 0.77, "rank_no": 2},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
resp = client.post(f"/api/v1/chat/turns/{turn_id}/retrievals", json=retrievals_payload)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["inserted"] == 2
|
||||||
|
|
||||||
|
# List retrievals
|
||||||
|
resp = client.get(f"/api/v1/chat/turns/{turn_id}/retrievals")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
items = resp.json()
|
||||||
|
assert len(items) == 2
|
||||||
|
assert {item["item_type"] for item in items} == {"METRIC", "SNIPPET"}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import pytest as _pytest
|
||||||
|
|
||||||
|
raise SystemExit(_pytest.main([__file__]))
|
||||||
Reference in New Issue
Block a user