103 lines
3.1 KiB
Python
103 lines
3.1 KiB
Python
from __future__ import annotations
|
|
|
|
from datetime import datetime
|
|
from typing import Any, List, Optional
|
|
|
|
from fastapi import APIRouter, HTTPException, Query
|
|
|
|
from app.schemas.chat import (
|
|
ChatSessionCreate,
|
|
ChatSessionUpdate,
|
|
ChatTurnCreate,
|
|
ChatTurnRetrievalBatch,
|
|
)
|
|
from app.services import metric_store
|
|
|
|
|
|
router = APIRouter(prefix="/api/v1/chat", tags=["chat"])
|
|
|
|
|
|
@router.post("/sessions")
|
|
def create_session(payload: ChatSessionCreate) -> Any:
|
|
"""Create a chat session."""
|
|
return metric_store.create_chat_session(payload)
|
|
|
|
|
|
@router.post("/sessions/{session_id}/update")
|
|
def update_session(session_id: int, payload: ChatSessionUpdate) -> Any:
|
|
try:
|
|
return metric_store.update_chat_session(session_id, payload)
|
|
except KeyError:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
|
|
@router.post("/sessions/{session_id}/close")
|
|
def close_session(session_id: int) -> Any:
|
|
"""Close a chat session and stamp end_time."""
|
|
try:
|
|
return metric_store.close_chat_session(session_id)
|
|
except KeyError:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
|
|
@router.get("/sessions/{session_id}")
|
|
def get_session(session_id: int) -> Any:
|
|
"""Fetch one session."""
|
|
session = metric_store.get_chat_session(session_id)
|
|
if not session:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
return session
|
|
|
|
|
|
@router.get("/sessions")
|
|
def list_sessions(
|
|
user_id: Optional[int] = None,
|
|
status: Optional[str] = None,
|
|
start_from: Optional[datetime] = Query(None, description="Filter by start time lower bound."),
|
|
start_to: Optional[datetime] = Query(None, description="Filter by start time upper bound."),
|
|
limit: int = Query(50, ge=1, le=500),
|
|
offset: int = Query(0, ge=0),
|
|
) -> List[Any]:
|
|
return metric_store.list_chat_sessions(
|
|
user_id=user_id,
|
|
status=status,
|
|
start_from=start_from,
|
|
start_to=start_to,
|
|
limit=limit,
|
|
offset=offset,
|
|
)
|
|
|
|
|
|
@router.post("/sessions/{session_id}/turns")
|
|
def create_turn(session_id: int, payload: ChatTurnCreate) -> Any:
|
|
"""Create a turn under a session."""
|
|
try:
|
|
return metric_store.create_chat_turn(session_id, payload)
|
|
except Exception as exc:
|
|
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
|
|
|
|
|
@router.get("/sessions/{session_id}/turns")
|
|
def list_turns(session_id: int) -> List[Any]:
|
|
return metric_store.list_chat_turns(session_id)
|
|
|
|
|
|
@router.get("/turns/{turn_id}")
|
|
def get_turn(turn_id: int) -> Any:
|
|
turn = metric_store.get_chat_turn(turn_id)
|
|
if not turn:
|
|
raise HTTPException(status_code=404, detail="Turn not found")
|
|
return turn
|
|
|
|
|
|
@router.post("/turns/{turn_id}/retrievals")
|
|
def write_retrievals(turn_id: int, payload: ChatTurnRetrievalBatch) -> Any:
|
|
"""Batch write retrieval records for a turn."""
|
|
count = metric_store.create_retrievals(turn_id, payload.retrievals)
|
|
return {"turn_id": turn_id, "inserted": count}
|
|
|
|
|
|
@router.get("/turns/{turn_id}/retrievals")
|
|
def list_retrievals(turn_id: int) -> List[Any]:
|
|
return metric_store.list_retrievals(turn_id)
|