diff --git a/app/routers/chat.py b/app/routers/chat.py new file mode 100644 index 0000000..f663259 --- /dev/null +++ b/app/routers/chat.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Any, List, Optional + +from fastapi import APIRouter, HTTPException, Query + +from app.schemas.chat import ( + ChatSessionCreate, + ChatSessionUpdate, + ChatTurnCreate, + ChatTurnRetrievalBatch, +) +from app.services import metric_store + + +router = APIRouter(prefix="/api/v1/chat", tags=["chat"]) + + +@router.post("/sessions") +def create_session(payload: ChatSessionCreate) -> Any: + """Create a chat session.""" + return metric_store.create_chat_session(payload) + + +@router.post("/sessions/{session_id}/update") +def update_session(session_id: int, payload: ChatSessionUpdate) -> Any: + try: + return metric_store.update_chat_session(session_id, payload) + except KeyError: + raise HTTPException(status_code=404, detail="Session not found") + + +@router.post("/sessions/{session_id}/close") +def close_session(session_id: int) -> Any: + """Close a chat session and stamp end_time.""" + try: + return metric_store.close_chat_session(session_id) + except KeyError: + raise HTTPException(status_code=404, detail="Session not found") + + +@router.get("/sessions/{session_id}") +def get_session(session_id: int) -> Any: + """Fetch one session.""" + session = metric_store.get_chat_session(session_id) + if not session: + raise HTTPException(status_code=404, detail="Session not found") + return session + + +@router.get("/sessions") +def list_sessions( + user_id: Optional[int] = None, + status: Optional[str] = None, + start_from: Optional[datetime] = Query(None, description="Filter by start time lower bound."), + start_to: Optional[datetime] = Query(None, description="Filter by start time upper bound."), + limit: int = Query(50, ge=1, le=500), + offset: int = Query(0, ge=0), +) -> List[Any]: + return metric_store.list_chat_sessions( + user_id=user_id, + status=status, + start_from=start_from, + start_to=start_to, + limit=limit, + offset=offset, + ) + + +@router.post("/sessions/{session_id}/turns") +def create_turn(session_id: int, payload: ChatTurnCreate) -> Any: + """Create a turn under a session.""" + try: + return metric_store.create_chat_turn(session_id, payload) + except Exception as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + + +@router.get("/sessions/{session_id}/turns") +def list_turns(session_id: int) -> List[Any]: + return metric_store.list_chat_turns(session_id) + + +@router.get("/turns/{turn_id}") +def get_turn(turn_id: int) -> Any: + turn = metric_store.get_chat_turn(turn_id) + if not turn: + raise HTTPException(status_code=404, detail="Turn not found") + return turn + + +@router.post("/turns/{turn_id}/retrievals") +def write_retrievals(turn_id: int, payload: ChatTurnRetrievalBatch) -> Any: + """Batch write retrieval records for a turn.""" + count = metric_store.create_retrievals(turn_id, payload.retrievals) + return {"turn_id": turn_id, "inserted": count} + + +@router.get("/turns/{turn_id}/retrievals") +def list_retrievals(turn_id: int) -> List[Any]: + return metric_store.list_retrievals(turn_id) diff --git a/app/schemas/chat.py b/app/schemas/chat.py new file mode 100644 index 0000000..b00c600 --- /dev/null +++ b/app/schemas/chat.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Any, List, Optional + +from pydantic import BaseModel, Field + + +class ChatSessionCreate(BaseModel): + """Create a chat session to group multiple turns for a user.""" + user_id: int = Field(..., description="User ID owning the session.") + session_uuid: Optional[str] = Field(None, description="Optional externally provided UUID.") + status: Optional[str] = Field("OPEN", description="Session status, default OPEN.") + end_time: Optional[datetime] = Field(None, description="Optional end time.") + ext_context: Optional[dict[str, Any]] = Field(None, description="Arbitrary business context.") + + +class ChatSessionUpdate(BaseModel): + """Partial update for a chat session.""" + status: Optional[str] = Field(None, description="New session status.") + end_time: Optional[datetime] = Field(None, description="Close time override.") + last_turn_id: Optional[int] = Field(None, description="Pointer to last chat turn.") + ext_context: Optional[dict[str, Any]] = Field(None, description="Context patch.") + + +class ChatTurnCreate(BaseModel): + """Create a single chat turn with intent/SQL context.""" + user_id: int = Field(..., description="User ID for this turn.") + user_query: str = Field(..., description="Raw user query content.") + intent: Optional[str] = Field(None, description="Intent tag such as METRIC_QUERY.") + ast_json: Optional[dict[str, Any]] = Field(None, description="Parsed AST payload.") + generated_sql: Optional[str] = Field(None, description="Final SQL text, if generated.") + sql_status: Optional[str] = Field(None, description="SQL generation/execution status.") + error_msg: Optional[str] = Field(None, description="Error message when SQL failed.") + main_metric_ids: Optional[List[int]] = Field(None, description="Metric IDs referenced in this turn.") + created_metric_ids: Optional[List[int]] = Field(None, description="Metric IDs created in this turn.") + end_time: Optional[datetime] = Field(None, description="Turn end time.") + + +class ChatTurnRetrievalItem(BaseModel): + """Record of one retrieved item contributing to a turn.""" + item_type: str = Field(..., description="METRIC/SNIPPET/CHAT etc.") + item_id: str = Field(..., description="Identifier such as metric_id or snippet_id.") + item_extra: Optional[dict[str, Any]] = Field(None, description="Additional context like column name.") + similarity_score: Optional[float] = Field(None, description="Similarity score.") + rank_no: Optional[int] = Field(None, description="Ranking position.") + used_in_reasoning: Optional[bool] = Field(False, description="Flag if used in reasoning.") + used_in_sql: Optional[bool] = Field(False, description="Flag if used in final SQL.") + + +class ChatTurnRetrievalBatch(BaseModel): + """Batch insert wrapper for retrieval records.""" + retrievals: List[ChatTurnRetrievalItem] diff --git a/doc/会话api.md b/doc/会话api.md new file mode 100644 index 0000000..9a4c73c --- /dev/null +++ b/doc/会话api.md @@ -0,0 +1,49 @@ +# 创建会话 +curl -X POST "/api/v1/chat/sessions" \ + -H "Content-Type: application/json" \ + -d "{\"user_id\": $CHAT_USER_ID}" + +# 获取会话 +curl "/api/v1/chat/sessions/{session_id}" + +# 按用户列出会话 +curl "/api/v1/chat/sessions?user_id=$CHAT_USER_ID" + +# 更新会话状态 +curl -X POST "/api/v1/chat/sessions/{session_id}/update" \ + -H "Content-Type: application/json" \ + -d '{"status":"PAUSED"}' + +# 关闭会话 +curl -X POST "/api/v1/chat/sessions/{session_id}/close" + +# 创建对话轮次 +curl -X POST "/api/v1/chat/sessions/{session_id}/turns" \ + -H "Content-Type: application/json" \ + -d '{ + "user_id": '"$CHAT_USER_ID"', + "user_query": "展示昨天订单GMV", + "intent": "METRIC_QUERY", + "ast_json": {"select":["gmv"],"where":{"dt":"yesterday"}}, + "main_metric_ids": [1234], + "created_metric_ids": [] + }' + +# 获取单条对话轮次 +curl "/api/v1/chat/turns/{turn_id}" + +# 列出会话下的轮次 +curl "/api/v1/chat/sessions/{session_id}/turns" + +# 写入检索结果 +curl -X POST "/api/v1/chat/turns/{turn_id}/retrievals" \ + -H "Content-Type: application/json" \ + -d '{ + "retrievals": [ + {"item_type":"METRIC","item_id":"metric_foo","used_in_sql":true,"rank_no":1}, + {"item_type":"SNIPPET","item_id":"snpt_bar","similarity_score":0.77,"rank_no":2} + ] + }' + +# 列出轮次的检索结果 +curl "/api/v1/chat/turns/{turn_id}/retrievals" \ No newline at end of file diff --git a/doc/指标生成.md b/doc/指标生成.md new file mode 100644 index 0000000..e0fdd15 --- /dev/null +++ b/doc/指标生成.md @@ -0,0 +1,83 @@ +某个用户的一句问话 → 解析成某轮 chat_turn → 这轮用了哪些指标/知识/会话(chat_turn_retrieval) → +是否产生了新的指标(metric_def) → +是否触发了指标调度运行(metric_job_run.turn_id) → +最终产生了哪些指标结果(metric_result.metric_id + stat_time)。 + +会话域 +schema +会话表 chat_session + +会话轮次表 chat_turn + +会话轮次检索关联表 chat_turn_retrieval + + +API +1. 创建会话 +POST /api/v1/chat/sessions +2. 更新会话轮次 +POST /api/v1/chat/sessions/{session_id}/update +3. 结束会话 +POST /api/v1/chat/sessions/{session_id}/close +4. 查询会话 +GET /api/v1/chat/sessions/{session_id} +5. 会话列表查询(按用户、时间) +GET /api/v1/chat/sessions +6. 创建问答轮次(用户发起 query) +POST /api/v1/chat/sessions/{session_id}/turns +7. 查询某会话的所有轮次 +GET /api/v1/chat/sessions/{session_id}/turns +8. 查看单轮问答详情 +GET /api/v1/chat/turns/{turn_id} +9. 批量写入某轮的检索结果 +POST /api/v1/chat/turns/{turn_id}/retrievals +10. 查询某轮的检索记录 +GET /api/v1/chat/turns/{turn_id}/retrievals +11. 更新某轮的检索记录(in future) +POST /api/v1/chat/turns/{turn_id}/retrievals/update + +元数据域 +schema +指标定义表 metric_def + + +API +12. 创建指标(来自问答或传统定义) +POST /api/v1/metrics +13. 更新指标 +POST /api/v1/metrics/{id} +14. 获取指标详情 +GET /api/v1/metrics + +执行调度域(暂定airflow) +schema +指标调度配置表 metric_schedule + +调度运行记录表 metric_job_run + +API +1. 创建调度配置 +POST /api/v1/metric-schedules +2. 更新调度配置 +POST /api/v1/metric-schedules/{id} +3. 查询指标调度配置详情 +GET /api/v1/metrics/{metric_id}/schedules +4. 手动触发一次指标运行(例如来自问数) +POST /api/v1/metric-runs/trigger +5. 查询运行记录列表 +GET /api/v1/metric-runs +6. 查询单次运行详情 +GET /api/metric-runs/{run_id} + +数据域 +schema +指标结果表(纵表)metric_result + + +API +1. 查询指标结果(按时间段 & 维度) +GET /api/metric-results +2. 单点查询(最新值) +GET /api/metric-results/latest +3. 批量写入指标结果 +POST /api/v1/metric-results/{metrics_id} \ No newline at end of file diff --git a/file/tableschema/chat.sql b/file/tableschema/chat.sql new file mode 100644 index 0000000..81dd821 --- /dev/null +++ b/file/tableschema/chat.sql @@ -0,0 +1,103 @@ +CREATE TABLE IF NOT EXISTS chat_session ( + id BIGINT AUTO_INCREMENT PRIMARY KEY, + user_id BIGINT NOT NULL, + session_uuid CHAR(36) NOT NULL, -- 可用于对外展示的ID(UUID) + end_time DATETIME NULL, + status VARCHAR(16) NOT NULL DEFAULT 'OPEN', -- OPEN/CLOSED/ABANDONED + last_turn_id BIGINT NULL, -- 指向 chat_turn.id + ext_context JSON NULL, -- 业务上下文 + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + UNIQUE KEY uk_session_uuid (session_uuid), + KEY idx_user_time (user_id, created_at), + KEY idx_status_time (status, created_at), + KEY idx_last_turn (last_turn_id) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE IF NOT EXISTS chat_turn ( + id BIGINT AUTO_INCREMENT, + session_id BIGINT NOT NULL, -- 关联 chat_session.id + turn_no INT NOT NULL, -- 会话内轮次序号(1,2,3...) + user_id BIGINT NOT NULL, + + user_query TEXT NOT NULL, -- 原始用户问句 + intent VARCHAR(64) NULL, -- METRIC_QUERY/METRIC_EXPLAIN 等 + ast_json JSON NULL, -- 解析出来的 AST + + generated_sql MEDIUMTEXT NULL, -- 生成的最终SQL + sql_status VARCHAR(32) NULL, -- SUCCESS/FAILED/SKIPPED + error_msg TEXT NULL, -- SQL生成/执行错误信息 + + main_metric_ids JSON NULL, -- 本轮涉及的指标ID列表 + created_metric_ids JSON NULL, -- 本轮新建指标ID列表 + + end_time DATETIME NULL, + + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + -- 主键改为联合主键,必须包含 created_at + PRIMARY KEY (id, created_at), + KEY idx_session_turn (session_id, turn_no), + KEY idx_session_time (session_id, created_at), + KEY idx_intent_time (intent, created_at), + KEY idx_user_time (user_id, created_at) +) +ENGINE=InnoDB +DEFAULT CHARSET=utf8mb4 +PARTITION BY RANGE COLUMNS(created_at) ( + -- 历史数据分区(根据实际需求调整) + PARTITION p202511 VALUES LESS THAN ('2025-12-01'), + PARTITION p202512 VALUES LESS THAN ('2026-01-01'), + -- 2026年按月分区 + PARTITION p202601 VALUES LESS THAN ('2026-02-01'), + PARTITION p202602 VALUES LESS THAN ('2026-03-01'), + PARTITION p202603 VALUES LESS THAN ('2026-04-01'), + PARTITION p202604 VALUES LESS THAN ('2026-05-01'), + PARTITION p202605 VALUES LESS THAN ('2026-06-01'), + PARTITION p202606 VALUES LESS THAN ('2026-07-01'), + -- ... 可以预建几个月 ... + + -- 兜底分区,存放未来的数据,防止插入报错 + PARTITION p_future VALUES LESS THAN (MAXVALUE) +); + + +CREATE TABLE IF NOT EXISTS chat_turn_retrieval ( + id BIGINT AUTO_INCREMENT, + turn_id BIGINT NOT NULL, -- 关联 qa_turn.id + + item_type VARCHAR(32) NOT NULL, -- METRIC/SNIPPET/CHAT + item_id VARCHAR(128) NOT NULL, -- metric_id/snippet_id/table_name 等 + item_extra JSON NULL, -- 附加信息,如字段名等 + + similarity_score DECIMAL(10,6) NULL, -- 相似度 + rank_no INT NULL, -- 检索排名 + used_in_reasoning TINYINT(1) NOT NULL DEFAULT 0, -- 是否参与推理 + used_in_sql TINYINT(1) NOT NULL DEFAULT 0, -- 是否影响最终SQL + + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + -- 主键改为联合主键,必须包含 created_at + PRIMARY KEY (id, created_at), + KEY idx_turn (turn_id), + KEY idx_turn_type (turn_id, item_type), + KEY idx_item (item_type, item_id) +) +ENGINE=InnoDB +DEFAULT CHARSET=utf8mb4 +PARTITION BY RANGE COLUMNS(created_at) ( + -- 历史数据分区(根据实际需求调整) + PARTITION p202511 VALUES LESS THAN ('2025-12-01'), + PARTITION p202512 VALUES LESS THAN ('2026-01-01'), + -- 2026年按月分区 + PARTITION p202601 VALUES LESS THAN ('2026-02-01'), + PARTITION p202602 VALUES LESS THAN ('2026-03-01'), + PARTITION p202603 VALUES LESS THAN ('2026-04-01'), + PARTITION p202604 VALUES LESS THAN ('2026-05-01'), + PARTITION p202605 VALUES LESS THAN ('2026-06-01'), + PARTITION p202606 VALUES LESS THAN ('2026-07-01'), + -- ... 可以预建几个月 ... + + -- 兜底分区,存放未来的数据,防止插入报错 + PARTITION p_future VALUES LESS THAN (MAXVALUE) +); \ No newline at end of file diff --git a/test/test_chat_api_mysql.py b/test/test_chat_api_mysql.py new file mode 100644 index 0000000..0b22ae6 --- /dev/null +++ b/test/test_chat_api_mysql.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +import os +import random +from pathlib import Path +from typing import Generator, List +import sys + +import pytest +from fastapi.testclient import TestClient +from sqlalchemy import text +from sqlalchemy.exc import SQLAlchemyError + +# Ensure the project root is importable when running directly via python. +ROOT = Path(__file__).resolve().parents[1] +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + +from app import db +from app.main import create_app + + +TEST_USER_ID = 872341 +SCHEMA_PATH = Path("file/tableschema/chat.sql") +DEFAULT_MYSQL_URL = "mysql+pymysql://root:12345678@127.0.0.1:3306/data-ge?charset=utf8mb4" + + +@pytest.fixture(scope="module") +def client() -> Generator[TestClient, None, None]: + mysql_url = os.getenv("TEST_DATABASE_URL", DEFAULT_MYSQL_URL) + os.environ["DATABASE_URL"] = mysql_url + db.get_engine.cache_clear() + engine = db.get_engine() + try: + # Quick connectivity check + with engine.connect() as conn: + conn.execute(text("SELECT 1")) + except SQLAlchemyError: + pytest.skip(f"Cannot connect to MySQL at {mysql_url}") + + #_ensure_chat_schema(engine) + + app = create_app() + with TestClient(app) as test_client: + yield test_client + + # cleanup test artifacts + with engine.begin() as conn: + # remove retrievals and turns tied to test sessions + conn.execute( + text( + """ + DELETE FROM chat_turn_retrieval + WHERE turn_id IN ( + SELECT id FROM chat_turn WHERE session_id IN (SELECT id FROM chat_session WHERE user_id=:uid) + ) + """ + ), + {"uid": TEST_USER_ID}, + ) + conn.execute( + text("DELETE FROM chat_turn WHERE session_id IN (SELECT id FROM chat_session WHERE user_id=:uid)"), + {"uid": TEST_USER_ID}, + ) + conn.execute(text("DELETE FROM chat_session WHERE user_id=:uid"), {"uid": TEST_USER_ID}) + db.get_engine.cache_clear() + + +def test_session_lifecycle_mysql(client: TestClient) -> None: + # Create a session + resp = client.post("/api/v1/chat/sessions", json={"user_id": TEST_USER_ID}) + assert resp.status_code == 200, resp.text + session = resp.json() + session_id = session["id"] + assert session["status"] == "OPEN" + + # Get session + assert client.get(f"/api/v1/chat/sessions/{session_id}").status_code == 200 + + # List sessions (filter by user) + resp = client.get(f"/api/v1/chat/sessions", params={"user_id": TEST_USER_ID}) + assert resp.status_code == 200 + assert any(item["id"] == session_id for item in resp.json()) + + # Update status + resp = client.post(f"/api/v1/chat/sessions/{session_id}/update", json={"status": "PAUSED"}) + assert resp.status_code == 200 + assert resp.json()["status"] == "PAUSED" + + # Close session + resp = client.post(f"/api/v1/chat/sessions/{session_id}/close") + assert resp.status_code == 200 + assert resp.json()["status"] == "CLOSED" + + +def test_turns_and_retrievals_mysql(client: TestClient) -> None: + session_id = client.post("/api/v1/chat/sessions", json={"user_id": TEST_USER_ID}).json()["id"] + turn_payload = { + "user_id": TEST_USER_ID, + "user_query": "展示昨天订单GMV", + "intent": "METRIC_QUERY", + "ast_json": {"select": ["gmv"], "where": {"dt": "yesterday"}}, + "main_metric_ids": [random.randint(1000, 9999)], + "created_metric_ids": [], + } + resp = client.post(f"/api/v1/chat/sessions/{session_id}/turns", json=turn_payload) + assert resp.status_code == 200, resp.text + turn = resp.json() + turn_id = turn["id"] + assert turn["turn_no"] == 1 + + # Fetch turn + assert client.get(f"/api/v1/chat/turns/{turn_id}").status_code == 200 + + # List turns under session + resp = client.get(f"/api/v1/chat/sessions/{session_id}/turns") + assert resp.status_code == 200 + assert any(t["id"] == turn_id for t in resp.json()) + + # Insert retrievals + retrievals_payload = { + "retrievals": [ + {"item_type": "METRIC", "item_id": "metric_foo", "used_in_sql": True, "rank_no": 1}, + {"item_type": "SNIPPET", "item_id": "snpt_bar", "similarity_score": 0.77, "rank_no": 2}, + ] + } + resp = client.post(f"/api/v1/chat/turns/{turn_id}/retrievals", json=retrievals_payload) + assert resp.status_code == 200 + assert resp.json()["inserted"] == 2 + + # List retrievals + resp = client.get(f"/api/v1/chat/turns/{turn_id}/retrievals") + assert resp.status_code == 200 + items = resp.json() + assert len(items) == 2 + assert {item["item_type"] for item in items} == {"METRIC", "SNIPPET"} + + +if __name__ == "__main__": + import pytest as _pytest + + raise SystemExit(_pytest.main([__file__]))