from __future__ import annotations from datetime import datetime from typing import Any, List, Optional from pydantic import BaseModel, Field class ChatSessionCreate(BaseModel): """Create a chat session to group multiple turns for a user.""" user_id: int = Field(..., description="User ID owning the session.") session_uuid: Optional[str] = Field(None, description="Optional externally provided UUID.") status: Optional[str] = Field("OPEN", description="Session status, default OPEN.") end_time: Optional[datetime] = Field(None, description="Optional end time.") ext_context: Optional[dict[str, Any]] = Field(None, description="Arbitrary business context.") class ChatSessionUpdate(BaseModel): """Partial update for a chat session.""" status: Optional[str] = Field(None, description="New session status.") end_time: Optional[datetime] = Field(None, description="Close time override.") last_turn_id: Optional[int] = Field(None, description="Pointer to last chat turn.") ext_context: Optional[dict[str, Any]] = Field(None, description="Context patch.") class ChatTurnCreate(BaseModel): """Create a single chat turn with intent/SQL context.""" user_id: int = Field(..., description="User ID for this turn.") user_query: str = Field(..., description="Raw user query content.") intent: Optional[str] = Field(None, description="Intent tag such as METRIC_QUERY.") ast_json: Optional[dict[str, Any]] = Field(None, description="Parsed AST payload.") generated_sql: Optional[str] = Field(None, description="Final SQL text, if generated.") sql_status: Optional[str] = Field(None, description="SQL generation/execution status.") error_msg: Optional[str] = Field(None, description="Error message when SQL failed.") main_metric_ids: Optional[List[int]] = Field(None, description="Metric IDs referenced in this turn.") created_metric_ids: Optional[List[int]] = Field(None, description="Metric IDs created in this turn.") end_time: Optional[datetime] = Field(None, description="Turn end time.") class ChatTurnRetrievalItem(BaseModel): """Record of one retrieved item contributing to a turn.""" item_type: str = Field(..., description="METRIC/SNIPPET/CHAT etc.") item_id: str = Field(..., description="Identifier such as metric_id or snippet_id.") item_extra: Optional[dict[str, Any]] = Field(None, description="Additional context like column name.") similarity_score: Optional[float] = Field(None, description="Similarity score.") rank_no: Optional[int] = Field(None, description="Ranking position.") used_in_reasoning: Optional[bool] = Field(False, description="Flag if used in reasoning.") used_in_sql: Optional[bool] = Field(False, description="Flag if used in final SQL.") class ChatTurnRetrievalBatch(BaseModel): """Batch insert wrapper for retrieval records.""" retrievals: List[ChatTurnRetrievalItem]