54 lines
2.9 KiB
Python
54 lines
2.9 KiB
Python
from __future__ import annotations
|
|
|
|
from datetime import datetime
|
|
from typing import Any, List, Optional
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
|
|
class ChatSessionCreate(BaseModel):
|
|
"""Create a chat session to group multiple turns for a user."""
|
|
user_id: int = Field(..., description="User ID owning the session.")
|
|
session_uuid: Optional[str] = Field(None, description="Optional externally provided UUID.")
|
|
status: Optional[str] = Field("OPEN", description="Session status, default OPEN.")
|
|
end_time: Optional[datetime] = Field(None, description="Optional end time.")
|
|
ext_context: Optional[dict[str, Any]] = Field(None, description="Arbitrary business context.")
|
|
|
|
|
|
class ChatSessionUpdate(BaseModel):
|
|
"""Partial update for a chat session."""
|
|
status: Optional[str] = Field(None, description="New session status.")
|
|
end_time: Optional[datetime] = Field(None, description="Close time override.")
|
|
last_turn_id: Optional[int] = Field(None, description="Pointer to last chat turn.")
|
|
ext_context: Optional[dict[str, Any]] = Field(None, description="Context patch.")
|
|
|
|
|
|
class ChatTurnCreate(BaseModel):
|
|
"""Create a single chat turn with intent/SQL context."""
|
|
user_id: int = Field(..., description="User ID for this turn.")
|
|
user_query: str = Field(..., description="Raw user query content.")
|
|
intent: Optional[str] = Field(None, description="Intent tag such as METRIC_QUERY.")
|
|
ast_json: Optional[dict[str, Any]] = Field(None, description="Parsed AST payload.")
|
|
generated_sql: Optional[str] = Field(None, description="Final SQL text, if generated.")
|
|
sql_status: Optional[str] = Field(None, description="SQL generation/execution status.")
|
|
error_msg: Optional[str] = Field(None, description="Error message when SQL failed.")
|
|
main_metric_ids: Optional[List[int]] = Field(None, description="Metric IDs referenced in this turn.")
|
|
created_metric_ids: Optional[List[int]] = Field(None, description="Metric IDs created in this turn.")
|
|
end_time: Optional[datetime] = Field(None, description="Turn end time.")
|
|
|
|
|
|
class ChatTurnRetrievalItem(BaseModel):
|
|
"""Record of one retrieved item contributing to a turn."""
|
|
item_type: str = Field(..., description="METRIC/SNIPPET/CHAT etc.")
|
|
item_id: str = Field(..., description="Identifier such as metric_id or snippet_id.")
|
|
item_extra: Optional[dict[str, Any]] = Field(None, description="Additional context like column name.")
|
|
similarity_score: Optional[float] = Field(None, description="Similarity score.")
|
|
rank_no: Optional[int] = Field(None, description="Ranking position.")
|
|
used_in_reasoning: Optional[bool] = Field(False, description="Flag if used in reasoning.")
|
|
used_in_sql: Optional[bool] = Field(False, description="Flag if used in final SQL.")
|
|
|
|
|
|
class ChatTurnRetrievalBatch(BaseModel):
|
|
"""Batch insert wrapper for retrieval records."""
|
|
retrievals: List[ChatTurnRetrievalItem]
|