会话及轮次等相关api
This commit is contained in:
102
app/routers/chat.py
Normal file
102
app/routers/chat.py
Normal file
@ -0,0 +1,102 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
from app.schemas.chat import (
|
||||
ChatSessionCreate,
|
||||
ChatSessionUpdate,
|
||||
ChatTurnCreate,
|
||||
ChatTurnRetrievalBatch,
|
||||
)
|
||||
from app.services import metric_store
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/v1/chat", tags=["chat"])
|
||||
|
||||
|
||||
@router.post("/sessions")
|
||||
def create_session(payload: ChatSessionCreate) -> Any:
|
||||
"""Create a chat session."""
|
||||
return metric_store.create_chat_session(payload)
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/update")
|
||||
def update_session(session_id: int, payload: ChatSessionUpdate) -> Any:
|
||||
try:
|
||||
return metric_store.update_chat_session(session_id, payload)
|
||||
except KeyError:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/close")
|
||||
def close_session(session_id: int) -> Any:
|
||||
"""Close a chat session and stamp end_time."""
|
||||
try:
|
||||
return metric_store.close_chat_session(session_id)
|
||||
except KeyError:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}")
|
||||
def get_session(session_id: int) -> Any:
|
||||
"""Fetch one session."""
|
||||
session = metric_store.get_chat_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
return session
|
||||
|
||||
|
||||
@router.get("/sessions")
|
||||
def list_sessions(
|
||||
user_id: Optional[int] = None,
|
||||
status: Optional[str] = None,
|
||||
start_from: Optional[datetime] = Query(None, description="Filter by start time lower bound."),
|
||||
start_to: Optional[datetime] = Query(None, description="Filter by start time upper bound."),
|
||||
limit: int = Query(50, ge=1, le=500),
|
||||
offset: int = Query(0, ge=0),
|
||||
) -> List[Any]:
|
||||
return metric_store.list_chat_sessions(
|
||||
user_id=user_id,
|
||||
status=status,
|
||||
start_from=start_from,
|
||||
start_to=start_to,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/turns")
|
||||
def create_turn(session_id: int, payload: ChatTurnCreate) -> Any:
|
||||
"""Create a turn under a session."""
|
||||
try:
|
||||
return metric_store.create_chat_turn(session_id, payload)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/turns")
|
||||
def list_turns(session_id: int) -> List[Any]:
|
||||
return metric_store.list_chat_turns(session_id)
|
||||
|
||||
|
||||
@router.get("/turns/{turn_id}")
|
||||
def get_turn(turn_id: int) -> Any:
|
||||
turn = metric_store.get_chat_turn(turn_id)
|
||||
if not turn:
|
||||
raise HTTPException(status_code=404, detail="Turn not found")
|
||||
return turn
|
||||
|
||||
|
||||
@router.post("/turns/{turn_id}/retrievals")
|
||||
def write_retrievals(turn_id: int, payload: ChatTurnRetrievalBatch) -> Any:
|
||||
"""Batch write retrieval records for a turn."""
|
||||
count = metric_store.create_retrievals(turn_id, payload.retrievals)
|
||||
return {"turn_id": turn_id, "inserted": count}
|
||||
|
||||
|
||||
@router.get("/turns/{turn_id}/retrievals")
|
||||
def list_retrievals(turn_id: int) -> List[Any]:
|
||||
return metric_store.list_retrievals(turn_id)
|
||||
53
app/schemas/chat.py
Normal file
53
app/schemas/chat.py
Normal file
@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ChatSessionCreate(BaseModel):
|
||||
"""Create a chat session to group multiple turns for a user."""
|
||||
user_id: int = Field(..., description="User ID owning the session.")
|
||||
session_uuid: Optional[str] = Field(None, description="Optional externally provided UUID.")
|
||||
status: Optional[str] = Field("OPEN", description="Session status, default OPEN.")
|
||||
end_time: Optional[datetime] = Field(None, description="Optional end time.")
|
||||
ext_context: Optional[dict[str, Any]] = Field(None, description="Arbitrary business context.")
|
||||
|
||||
|
||||
class ChatSessionUpdate(BaseModel):
|
||||
"""Partial update for a chat session."""
|
||||
status: Optional[str] = Field(None, description="New session status.")
|
||||
end_time: Optional[datetime] = Field(None, description="Close time override.")
|
||||
last_turn_id: Optional[int] = Field(None, description="Pointer to last chat turn.")
|
||||
ext_context: Optional[dict[str, Any]] = Field(None, description="Context patch.")
|
||||
|
||||
|
||||
class ChatTurnCreate(BaseModel):
|
||||
"""Create a single chat turn with intent/SQL context."""
|
||||
user_id: int = Field(..., description="User ID for this turn.")
|
||||
user_query: str = Field(..., description="Raw user query content.")
|
||||
intent: Optional[str] = Field(None, description="Intent tag such as METRIC_QUERY.")
|
||||
ast_json: Optional[dict[str, Any]] = Field(None, description="Parsed AST payload.")
|
||||
generated_sql: Optional[str] = Field(None, description="Final SQL text, if generated.")
|
||||
sql_status: Optional[str] = Field(None, description="SQL generation/execution status.")
|
||||
error_msg: Optional[str] = Field(None, description="Error message when SQL failed.")
|
||||
main_metric_ids: Optional[List[int]] = Field(None, description="Metric IDs referenced in this turn.")
|
||||
created_metric_ids: Optional[List[int]] = Field(None, description="Metric IDs created in this turn.")
|
||||
end_time: Optional[datetime] = Field(None, description="Turn end time.")
|
||||
|
||||
|
||||
class ChatTurnRetrievalItem(BaseModel):
|
||||
"""Record of one retrieved item contributing to a turn."""
|
||||
item_type: str = Field(..., description="METRIC/SNIPPET/CHAT etc.")
|
||||
item_id: str = Field(..., description="Identifier such as metric_id or snippet_id.")
|
||||
item_extra: Optional[dict[str, Any]] = Field(None, description="Additional context like column name.")
|
||||
similarity_score: Optional[float] = Field(None, description="Similarity score.")
|
||||
rank_no: Optional[int] = Field(None, description="Ranking position.")
|
||||
used_in_reasoning: Optional[bool] = Field(False, description="Flag if used in reasoning.")
|
||||
used_in_sql: Optional[bool] = Field(False, description="Flag if used in final SQL.")
|
||||
|
||||
|
||||
class ChatTurnRetrievalBatch(BaseModel):
|
||||
"""Batch insert wrapper for retrieval records."""
|
||||
retrievals: List[ChatTurnRetrievalItem]
|
||||
Reference in New Issue
Block a user