会话及轮次等相关api

This commit is contained in:
zhaoawd
2025-12-08 23:15:04 +08:00
parent f261121845
commit 509dae3270
6 changed files with 532 additions and 0 deletions

102
app/routers/chat.py Normal file
View File

@ -0,0 +1,102 @@
from __future__ import annotations
from datetime import datetime
from typing import Any, List, Optional
from fastapi import APIRouter, HTTPException, Query
from app.schemas.chat import (
ChatSessionCreate,
ChatSessionUpdate,
ChatTurnCreate,
ChatTurnRetrievalBatch,
)
from app.services import metric_store
router = APIRouter(prefix="/api/v1/chat", tags=["chat"])
@router.post("/sessions")
def create_session(payload: ChatSessionCreate) -> Any:
"""Create a chat session."""
return metric_store.create_chat_session(payload)
@router.post("/sessions/{session_id}/update")
def update_session(session_id: int, payload: ChatSessionUpdate) -> Any:
try:
return metric_store.update_chat_session(session_id, payload)
except KeyError:
raise HTTPException(status_code=404, detail="Session not found")
@router.post("/sessions/{session_id}/close")
def close_session(session_id: int) -> Any:
"""Close a chat session and stamp end_time."""
try:
return metric_store.close_chat_session(session_id)
except KeyError:
raise HTTPException(status_code=404, detail="Session not found")
@router.get("/sessions/{session_id}")
def get_session(session_id: int) -> Any:
"""Fetch one session."""
session = metric_store.get_chat_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
return session
@router.get("/sessions")
def list_sessions(
user_id: Optional[int] = None,
status: Optional[str] = None,
start_from: Optional[datetime] = Query(None, description="Filter by start time lower bound."),
start_to: Optional[datetime] = Query(None, description="Filter by start time upper bound."),
limit: int = Query(50, ge=1, le=500),
offset: int = Query(0, ge=0),
) -> List[Any]:
return metric_store.list_chat_sessions(
user_id=user_id,
status=status,
start_from=start_from,
start_to=start_to,
limit=limit,
offset=offset,
)
@router.post("/sessions/{session_id}/turns")
def create_turn(session_id: int, payload: ChatTurnCreate) -> Any:
"""Create a turn under a session."""
try:
return metric_store.create_chat_turn(session_id, payload)
except Exception as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@router.get("/sessions/{session_id}/turns")
def list_turns(session_id: int) -> List[Any]:
return metric_store.list_chat_turns(session_id)
@router.get("/turns/{turn_id}")
def get_turn(turn_id: int) -> Any:
turn = metric_store.get_chat_turn(turn_id)
if not turn:
raise HTTPException(status_code=404, detail="Turn not found")
return turn
@router.post("/turns/{turn_id}/retrievals")
def write_retrievals(turn_id: int, payload: ChatTurnRetrievalBatch) -> Any:
"""Batch write retrieval records for a turn."""
count = metric_store.create_retrievals(turn_id, payload.retrievals)
return {"turn_id": turn_id, "inserted": count}
@router.get("/turns/{turn_id}/retrievals")
def list_retrievals(turn_id: int) -> List[Any]:
return metric_store.list_retrievals(turn_id)

53
app/schemas/chat.py Normal file
View File

@ -0,0 +1,53 @@
from __future__ import annotations
from datetime import datetime
from typing import Any, List, Optional
from pydantic import BaseModel, Field
class ChatSessionCreate(BaseModel):
"""Create a chat session to group multiple turns for a user."""
user_id: int = Field(..., description="User ID owning the session.")
session_uuid: Optional[str] = Field(None, description="Optional externally provided UUID.")
status: Optional[str] = Field("OPEN", description="Session status, default OPEN.")
end_time: Optional[datetime] = Field(None, description="Optional end time.")
ext_context: Optional[dict[str, Any]] = Field(None, description="Arbitrary business context.")
class ChatSessionUpdate(BaseModel):
"""Partial update for a chat session."""
status: Optional[str] = Field(None, description="New session status.")
end_time: Optional[datetime] = Field(None, description="Close time override.")
last_turn_id: Optional[int] = Field(None, description="Pointer to last chat turn.")
ext_context: Optional[dict[str, Any]] = Field(None, description="Context patch.")
class ChatTurnCreate(BaseModel):
"""Create a single chat turn with intent/SQL context."""
user_id: int = Field(..., description="User ID for this turn.")
user_query: str = Field(..., description="Raw user query content.")
intent: Optional[str] = Field(None, description="Intent tag such as METRIC_QUERY.")
ast_json: Optional[dict[str, Any]] = Field(None, description="Parsed AST payload.")
generated_sql: Optional[str] = Field(None, description="Final SQL text, if generated.")
sql_status: Optional[str] = Field(None, description="SQL generation/execution status.")
error_msg: Optional[str] = Field(None, description="Error message when SQL failed.")
main_metric_ids: Optional[List[int]] = Field(None, description="Metric IDs referenced in this turn.")
created_metric_ids: Optional[List[int]] = Field(None, description="Metric IDs created in this turn.")
end_time: Optional[datetime] = Field(None, description="Turn end time.")
class ChatTurnRetrievalItem(BaseModel):
"""Record of one retrieved item contributing to a turn."""
item_type: str = Field(..., description="METRIC/SNIPPET/CHAT etc.")
item_id: str = Field(..., description="Identifier such as metric_id or snippet_id.")
item_extra: Optional[dict[str, Any]] = Field(None, description="Additional context like column name.")
similarity_score: Optional[float] = Field(None, description="Similarity score.")
rank_no: Optional[int] = Field(None, description="Ranking position.")
used_in_reasoning: Optional[bool] = Field(False, description="Flag if used in reasoning.")
used_in_sql: Optional[bool] = Field(False, description="Flag if used in final SQL.")
class ChatTurnRetrievalBatch(BaseModel):
"""Batch insert wrapper for retrieval records."""
retrievals: List[ChatTurnRetrievalItem]

49
doc/会话api.md Normal file
View File

@ -0,0 +1,49 @@
# 创建会话
curl -X POST "/api/v1/chat/sessions" \
-H "Content-Type: application/json" \
-d "{\"user_id\": $CHAT_USER_ID}"
# 获取会话
curl "/api/v1/chat/sessions/{session_id}"
# 按用户列出会话
curl "/api/v1/chat/sessions?user_id=$CHAT_USER_ID"
# 更新会话状态
curl -X POST "/api/v1/chat/sessions/{session_id}/update" \
-H "Content-Type: application/json" \
-d '{"status":"PAUSED"}'
# 关闭会话
curl -X POST "/api/v1/chat/sessions/{session_id}/close"
# 创建对话轮次
curl -X POST "/api/v1/chat/sessions/{session_id}/turns" \
-H "Content-Type: application/json" \
-d '{
"user_id": '"$CHAT_USER_ID"',
"user_query": "展示昨天订单GMV",
"intent": "METRIC_QUERY",
"ast_json": {"select":["gmv"],"where":{"dt":"yesterday"}},
"main_metric_ids": [1234],
"created_metric_ids": []
}'
# 获取单条对话轮次
curl "/api/v1/chat/turns/{turn_id}"
# 列出会话下的轮次
curl "/api/v1/chat/sessions/{session_id}/turns"
# 写入检索结果
curl -X POST "/api/v1/chat/turns/{turn_id}/retrievals" \
-H "Content-Type: application/json" \
-d '{
"retrievals": [
{"item_type":"METRIC","item_id":"metric_foo","used_in_sql":true,"rank_no":1},
{"item_type":"SNIPPET","item_id":"snpt_bar","similarity_score":0.77,"rank_no":2}
]
}'
# 列出轮次的检索结果
curl "/api/v1/chat/turns/{turn_id}/retrievals"

83
doc/指标生成.md Normal file
View File

@ -0,0 +1,83 @@
某个用户的一句问话 → 解析成某轮 chat_turn → 这轮用了哪些指标/知识/会话chat_turn_retrieval
是否产生了新的指标metric_def
是否触发了指标调度运行metric_job_run.turn_id
最终产生了哪些指标结果metric_result.metric_id + stat_time
会话域
schema
会话表 chat_session
会话轮次表 chat_turn
会话轮次检索关联表 chat_turn_retrieval
API
1. 创建会话
POST /api/v1/chat/sessions
2. 更新会话轮次
POST /api/v1/chat/sessions/{session_id}/update
3. 结束会话
POST /api/v1/chat/sessions/{session_id}/close
4. 查询会话
GET /api/v1/chat/sessions/{session_id}
5. 会话列表查询(按用户、时间)
GET /api/v1/chat/sessions
6. 创建问答轮次(用户发起 query
POST /api/v1/chat/sessions/{session_id}/turns
7. 查询某会话的所有轮次
GET /api/v1/chat/sessions/{session_id}/turns
8. 查看单轮问答详情
GET /api/v1/chat/turns/{turn_id}
9. 批量写入某轮的检索结果
POST /api/v1/chat/turns/{turn_id}/retrievals
10. 查询某轮的检索记录
GET /api/v1/chat/turns/{turn_id}/retrievals
11. 更新某轮的检索记录in future
POST /api/v1/chat/turns/{turn_id}/retrievals/update
元数据域
schema
指标定义表 metric_def
API
12. 创建指标(来自问答或传统定义)
POST /api/v1/metrics
13. 更新指标
POST /api/v1/metrics/{id}
14. 获取指标详情
GET /api/v1/metrics
执行调度域暂定airflow
schema
指标调度配置表 metric_schedule
调度运行记录表 metric_job_run
API
1. 创建调度配置
POST /api/v1/metric-schedules
2. 更新调度配置
POST /api/v1/metric-schedules/{id}
3. 查询指标调度配置详情
GET /api/v1/metrics/{metric_id}/schedules
4. 手动触发一次指标运行(例如来自问数)
POST /api/v1/metric-runs/trigger
5. 查询运行记录列表
GET /api/v1/metric-runs
6. 查询单次运行详情
GET /api/metric-runs/{run_id}
数据域
schema
指标结果表纵表metric_result
API
1. 查询指标结果(按时间段 & 维度)
GET /api/metric-results
2. 单点查询(最新值)
GET /api/metric-results/latest
3. 批量写入指标结果
POST /api/v1/metric-results/{metrics_id}

103
file/tableschema/chat.sql Normal file
View File

@ -0,0 +1,103 @@
CREATE TABLE IF NOT EXISTS chat_session (
id BIGINT AUTO_INCREMENT PRIMARY KEY,
user_id BIGINT NOT NULL,
session_uuid CHAR(36) NOT NULL, -- 可用于对外展示的IDUUID
end_time DATETIME NULL,
status VARCHAR(16) NOT NULL DEFAULT 'OPEN', -- OPEN/CLOSED/ABANDONED
last_turn_id BIGINT NULL, -- 指向 chat_turn.id
ext_context JSON NULL, -- 业务上下文
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
UNIQUE KEY uk_session_uuid (session_uuid),
KEY idx_user_time (user_id, created_at),
KEY idx_status_time (status, created_at),
KEY idx_last_turn (last_turn_id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
CREATE TABLE IF NOT EXISTS chat_turn (
id BIGINT AUTO_INCREMENT,
session_id BIGINT NOT NULL, -- 关联 chat_session.id
turn_no INT NOT NULL, -- 会话内轮次序号1,2,3...
user_id BIGINT NOT NULL,
user_query TEXT NOT NULL, -- 原始用户问句
intent VARCHAR(64) NULL, -- METRIC_QUERY/METRIC_EXPLAIN 等
ast_json JSON NULL, -- 解析出来的 AST
generated_sql MEDIUMTEXT NULL, -- 生成的最终SQL
sql_status VARCHAR(32) NULL, -- SUCCESS/FAILED/SKIPPED
error_msg TEXT NULL, -- SQL生成/执行错误信息
main_metric_ids JSON NULL, -- 本轮涉及的指标ID列表
created_metric_ids JSON NULL, -- 本轮新建指标ID列表
end_time DATETIME NULL,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
-- 主键改为联合主键,必须包含 created_at
PRIMARY KEY (id, created_at),
KEY idx_session_turn (session_id, turn_no),
KEY idx_session_time (session_id, created_at),
KEY idx_intent_time (intent, created_at),
KEY idx_user_time (user_id, created_at)
)
ENGINE=InnoDB
DEFAULT CHARSET=utf8mb4
PARTITION BY RANGE COLUMNS(created_at) (
-- 历史数据分区(根据实际需求调整)
PARTITION p202511 VALUES LESS THAN ('2025-12-01'),
PARTITION p202512 VALUES LESS THAN ('2026-01-01'),
-- 2026年按月分区
PARTITION p202601 VALUES LESS THAN ('2026-02-01'),
PARTITION p202602 VALUES LESS THAN ('2026-03-01'),
PARTITION p202603 VALUES LESS THAN ('2026-04-01'),
PARTITION p202604 VALUES LESS THAN ('2026-05-01'),
PARTITION p202605 VALUES LESS THAN ('2026-06-01'),
PARTITION p202606 VALUES LESS THAN ('2026-07-01'),
-- ... 可以预建几个月 ...
-- 兜底分区,存放未来的数据,防止插入报错
PARTITION p_future VALUES LESS THAN (MAXVALUE)
);
CREATE TABLE IF NOT EXISTS chat_turn_retrieval (
id BIGINT AUTO_INCREMENT,
turn_id BIGINT NOT NULL, -- 关联 qa_turn.id
item_type VARCHAR(32) NOT NULL, -- METRIC/SNIPPET/CHAT
item_id VARCHAR(128) NOT NULL, -- metric_id/snippet_id/table_name 等
item_extra JSON NULL, -- 附加信息,如字段名等
similarity_score DECIMAL(10,6) NULL, -- 相似度
rank_no INT NULL, -- 检索排名
used_in_reasoning TINYINT(1) NOT NULL DEFAULT 0, -- 是否参与推理
used_in_sql TINYINT(1) NOT NULL DEFAULT 0, -- 是否影响最终SQL
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
-- 主键改为联合主键,必须包含 created_at
PRIMARY KEY (id, created_at),
KEY idx_turn (turn_id),
KEY idx_turn_type (turn_id, item_type),
KEY idx_item (item_type, item_id)
)
ENGINE=InnoDB
DEFAULT CHARSET=utf8mb4
PARTITION BY RANGE COLUMNS(created_at) (
-- 历史数据分区(根据实际需求调整)
PARTITION p202511 VALUES LESS THAN ('2025-12-01'),
PARTITION p202512 VALUES LESS THAN ('2026-01-01'),
-- 2026年按月分区
PARTITION p202601 VALUES LESS THAN ('2026-02-01'),
PARTITION p202602 VALUES LESS THAN ('2026-03-01'),
PARTITION p202603 VALUES LESS THAN ('2026-04-01'),
PARTITION p202604 VALUES LESS THAN ('2026-05-01'),
PARTITION p202605 VALUES LESS THAN ('2026-06-01'),
PARTITION p202606 VALUES LESS THAN ('2026-07-01'),
-- ... 可以预建几个月 ...
-- 兜底分区,存放未来的数据,防止插入报错
PARTITION p_future VALUES LESS THAN (MAXVALUE)
);

142
test/test_chat_api_mysql.py Normal file
View File

@ -0,0 +1,142 @@
from __future__ import annotations
import os
import random
from pathlib import Path
from typing import Generator, List
import sys
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import text
from sqlalchemy.exc import SQLAlchemyError
# Ensure the project root is importable when running directly via python.
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from app import db
from app.main import create_app
TEST_USER_ID = 872341
SCHEMA_PATH = Path("file/tableschema/chat.sql")
DEFAULT_MYSQL_URL = "mysql+pymysql://root:12345678@127.0.0.1:3306/data-ge?charset=utf8mb4"
@pytest.fixture(scope="module")
def client() -> Generator[TestClient, None, None]:
mysql_url = os.getenv("TEST_DATABASE_URL", DEFAULT_MYSQL_URL)
os.environ["DATABASE_URL"] = mysql_url
db.get_engine.cache_clear()
engine = db.get_engine()
try:
# Quick connectivity check
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
except SQLAlchemyError:
pytest.skip(f"Cannot connect to MySQL at {mysql_url}")
#_ensure_chat_schema(engine)
app = create_app()
with TestClient(app) as test_client:
yield test_client
# cleanup test artifacts
with engine.begin() as conn:
# remove retrievals and turns tied to test sessions
conn.execute(
text(
"""
DELETE FROM chat_turn_retrieval
WHERE turn_id IN (
SELECT id FROM chat_turn WHERE session_id IN (SELECT id FROM chat_session WHERE user_id=:uid)
)
"""
),
{"uid": TEST_USER_ID},
)
conn.execute(
text("DELETE FROM chat_turn WHERE session_id IN (SELECT id FROM chat_session WHERE user_id=:uid)"),
{"uid": TEST_USER_ID},
)
conn.execute(text("DELETE FROM chat_session WHERE user_id=:uid"), {"uid": TEST_USER_ID})
db.get_engine.cache_clear()
def test_session_lifecycle_mysql(client: TestClient) -> None:
# Create a session
resp = client.post("/api/v1/chat/sessions", json={"user_id": TEST_USER_ID})
assert resp.status_code == 200, resp.text
session = resp.json()
session_id = session["id"]
assert session["status"] == "OPEN"
# Get session
assert client.get(f"/api/v1/chat/sessions/{session_id}").status_code == 200
# List sessions (filter by user)
resp = client.get(f"/api/v1/chat/sessions", params={"user_id": TEST_USER_ID})
assert resp.status_code == 200
assert any(item["id"] == session_id for item in resp.json())
# Update status
resp = client.post(f"/api/v1/chat/sessions/{session_id}/update", json={"status": "PAUSED"})
assert resp.status_code == 200
assert resp.json()["status"] == "PAUSED"
# Close session
resp = client.post(f"/api/v1/chat/sessions/{session_id}/close")
assert resp.status_code == 200
assert resp.json()["status"] == "CLOSED"
def test_turns_and_retrievals_mysql(client: TestClient) -> None:
session_id = client.post("/api/v1/chat/sessions", json={"user_id": TEST_USER_ID}).json()["id"]
turn_payload = {
"user_id": TEST_USER_ID,
"user_query": "展示昨天订单GMV",
"intent": "METRIC_QUERY",
"ast_json": {"select": ["gmv"], "where": {"dt": "yesterday"}},
"main_metric_ids": [random.randint(1000, 9999)],
"created_metric_ids": [],
}
resp = client.post(f"/api/v1/chat/sessions/{session_id}/turns", json=turn_payload)
assert resp.status_code == 200, resp.text
turn = resp.json()
turn_id = turn["id"]
assert turn["turn_no"] == 1
# Fetch turn
assert client.get(f"/api/v1/chat/turns/{turn_id}").status_code == 200
# List turns under session
resp = client.get(f"/api/v1/chat/sessions/{session_id}/turns")
assert resp.status_code == 200
assert any(t["id"] == turn_id for t in resp.json())
# Insert retrievals
retrievals_payload = {
"retrievals": [
{"item_type": "METRIC", "item_id": "metric_foo", "used_in_sql": True, "rank_no": 1},
{"item_type": "SNIPPET", "item_id": "snpt_bar", "similarity_score": 0.77, "rank_no": 2},
]
}
resp = client.post(f"/api/v1/chat/turns/{turn_id}/retrievals", json=retrievals_payload)
assert resp.status_code == 200
assert resp.json()["inserted"] == 2
# List retrievals
resp = client.get(f"/api/v1/chat/turns/{turn_id}/retrievals")
assert resp.status_code == 200
items = resp.json()
assert len(items) == 2
assert {item["item_type"] for item in items} == {"METRIC", "SNIPPET"}
if __name__ == "__main__":
import pytest as _pytest
raise SystemExit(_pytest.main([__file__]))