from __future__ import annotations import hashlib import json import logging from datetime import datetime from typing import Any, Dict, Iterable, List, Optional from uuid import uuid4 from sqlalchemy import text from sqlalchemy.engine import Row from app.db import get_engine from app.schemas.chat import ( ChatSessionCreate, ChatSessionUpdate, ChatTurnCreate, ChatTurnRetrievalItem, ) from app.schemas.metrics import ( MetricCreate, MetricResultItem, MetricResultsWriteRequest, MetricRunTrigger, MetricScheduleCreate, MetricScheduleUpdate, MetricUpdate, ) logger = logging.getLogger(__name__) # Common helpers def _json_dump(value: Any) -> Optional[str]: """Safe JSON dumper; returns None on failure to keep DB writes simple.""" if value is None: return None if isinstance(value, str): return value try: return json.dumps(value, ensure_ascii=False) except (TypeError, ValueError): return None def _parse_json_fields(payload: Dict[str, Any], fields: Iterable[str]) -> Dict[str, Any]: """Parse select fields from JSON strings into dict/list for responses.""" for field in fields: raw = payload.get(field) if raw is None or isinstance(raw, (dict, list)): continue if isinstance(raw, (bytes, bytearray)): raw = raw.decode("utf-8", errors="ignore") if isinstance(raw, str): try: payload[field] = json.loads(raw) except ValueError: pass return payload def _row_to_dict(row: Row[Any]) -> Dict[str, Any]: return dict(row._mapping) # Chat sessions & turns def create_chat_session(payload: ChatSessionCreate) -> Dict[str, Any]: """Create a chat session row with optional external UUID.""" engine = get_engine() session_uuid = payload.session_uuid or str(uuid4()) now = datetime.utcnow() params = { "user_id": payload.user_id, "session_uuid": session_uuid, "end_time": payload.end_time, "status": payload.status or "OPEN", "ext_context": _json_dump(payload.ext_context), } with engine.begin() as conn: result = conn.execute( text( """ INSERT INTO chat_session (user_id, session_uuid, end_time, status, ext_context) VALUES (:user_id, :session_uuid, :end_time, :status, :ext_context) """ ), params, ) session_id = result.lastrowid row = conn.execute( text("SELECT * FROM chat_session WHERE id=:id"), {"id": session_id} ).first() if not row: raise RuntimeError("Failed to create chat session.") data = _row_to_dict(row) _parse_json_fields(data, ["ext_context"]) return data def update_chat_session(session_id: int, payload: ChatSessionUpdate) -> Dict[str, Any]: """Patch selected chat session fields.""" updates = {} if payload.status is not None: updates["status"] = payload.status if payload.end_time is not None: updates["end_time"] = payload.end_time if payload.last_turn_id is not None: updates["last_turn_id"] = payload.last_turn_id if payload.ext_context is not None: updates["ext_context"] = _json_dump(payload.ext_context) if not updates: current = get_chat_session(session_id) if not current: raise KeyError(f"Session {session_id} not found.") return current set_clause = ", ".join(f"{key}=:{key}" for key in updates.keys()) params = dict(updates) params["id"] = session_id engine = get_engine() with engine.begin() as conn: conn.execute( text(f"UPDATE chat_session SET {set_clause} WHERE id=:id"), params, ) row = conn.execute( text("SELECT * FROM chat_session WHERE id=:id"), {"id": session_id} ).first() if not row: raise KeyError(f"Session {session_id} not found.") data = _row_to_dict(row) _parse_json_fields(data, ["ext_context"]) return data def close_chat_session(session_id: int) -> Dict[str, Any]: """Mark a chat session as CLOSED with end_time.""" now = datetime.utcnow() return update_chat_session( session_id, ChatSessionUpdate(status="CLOSED", end_time=now), ) def get_chat_session(session_id: int) -> Optional[Dict[str, Any]]: engine = get_engine() with engine.begin() as conn: row = conn.execute( text("SELECT * FROM chat_session WHERE id=:id"), {"id": session_id} ).first() if not row: return None data = _row_to_dict(row) _parse_json_fields(data, ["ext_context"]) return data def list_chat_sessions( *, user_id: Optional[int] = None, status: Optional[str] = None, start_from: Optional[datetime] = None, start_to: Optional[datetime] = None, limit: int = 50, offset: int = 0, ) -> List[Dict[str, Any]]: """List chat sessions with optional filters and pagination.""" conditions = [] params: Dict[str, Any] = {"limit": limit, "offset": offset} if user_id is not None: conditions.append("user_id=:user_id") params["user_id"] = user_id if status is not None: conditions.append("status=:status") params["status"] = status if start_from is not None: conditions.append("created_at>=:start_from") params["start_from"] = start_from if start_to is not None: conditions.append("created_at<=:start_to") params["start_to"] = start_to where_clause = "WHERE " + " AND ".join(conditions) if conditions else "" engine = get_engine() with engine.begin() as conn: rows = conn.execute( text( f"SELECT * FROM chat_session {where_clause} " "ORDER BY created_at DESC LIMIT :limit OFFSET :offset" ), params, ).fetchall() results: List[Dict[str, Any]] = [] for row in rows: data = _row_to_dict(row) _parse_json_fields(data, ["ext_context"]) results.append(data) return results def _next_turn_no(conn, session_id: int) -> int: row = conn.execute( text("SELECT COALESCE(MAX(turn_no), 0) + 1 AS next_no FROM chat_turn WHERE session_id=:sid"), {"sid": session_id}, ).first() if not row: return 1 return int(row._mapping["next_no"]) def create_chat_turn(session_id: int, payload: ChatTurnCreate) -> Dict[str, Any]: """Insert a chat turn and auto-increment turn number within the session.""" engine = get_engine() now = datetime.utcnow() params = { "session_id": session_id, "user_id": payload.user_id, "user_query": payload.user_query, "intent": payload.intent, "ast_json": _json_dump(payload.ast_json), "generated_sql": payload.generated_sql, "sql_status": payload.sql_status, "error_msg": payload.error_msg, "main_metric_ids": _json_dump(payload.main_metric_ids), "created_metric_ids": _json_dump(payload.created_metric_ids), "end_time": payload.end_time, } with engine.begin() as conn: turn_no = _next_turn_no(conn, session_id) params["turn_no"] = turn_no result = conn.execute( text( """ INSERT INTO chat_turn ( session_id, turn_no, user_id, user_query, intent, ast_json, generated_sql, sql_status, error_msg, main_metric_ids, created_metric_ids, end_time ) VALUES ( :session_id, :turn_no, :user_id, :user_query, :intent, :ast_json, :generated_sql, :sql_status, :error_msg, :main_metric_ids, :created_metric_ids, :end_time ) """ ), params, ) turn_id = result.lastrowid row = conn.execute( text("SELECT * FROM chat_turn WHERE id=:id"), {"id": turn_id} ).first() if not row: raise RuntimeError("Failed to create chat turn.") data = _row_to_dict(row) _parse_json_fields(data, ["ast_json", "main_metric_ids", "created_metric_ids"]) return data def get_chat_turn(turn_id: int) -> Optional[Dict[str, Any]]: engine = get_engine() with engine.begin() as conn: row = conn.execute( text("SELECT * FROM chat_turn WHERE id=:id"), {"id": turn_id} ).first() if not row: return None data = _row_to_dict(row) _parse_json_fields(data, ["ast_json", "main_metric_ids", "created_metric_ids"]) return data def list_chat_turns(session_id: int) -> List[Dict[str, Any]]: engine = get_engine() with engine.begin() as conn: rows = conn.execute( text( "SELECT * FROM chat_turn WHERE session_id=:session_id ORDER BY turn_no ASC" ), {"session_id": session_id}, ).fetchall() results: List[Dict[str, Any]] = [] for row in rows: data = _row_to_dict(row) _parse_json_fields(data, ["ast_json", "main_metric_ids", "created_metric_ids"]) results.append(data) return results def create_retrievals(turn_id: int, retrievals: List[ChatTurnRetrievalItem]) -> int: """Batch insert retrieval records for a turn.""" if not retrievals: return 0 engine = get_engine() params_list = [] for item in retrievals: params_list.append( { "turn_id": turn_id, "item_type": item.item_type, "item_id": item.item_id, "item_extra": _json_dump(item.item_extra), "similarity_score": item.similarity_score, "rank_no": item.rank_no, "used_in_reasoning": 1 if item.used_in_reasoning else 0, "used_in_sql": 1 if item.used_in_sql else 0, } ) with engine.begin() as conn: conn.execute( text( """ INSERT INTO chat_turn_retrieval ( turn_id, item_type, item_id, item_extra, similarity_score, rank_no, used_in_reasoning, used_in_sql ) VALUES ( :turn_id, :item_type, :item_id, :item_extra, :similarity_score, :rank_no, :used_in_reasoning, :used_in_sql ) """ ), params_list, ) return len(retrievals) def list_retrievals(turn_id: int) -> List[Dict[str, Any]]: engine = get_engine() with engine.begin() as conn: rows = conn.execute( text( "SELECT * FROM chat_turn_retrieval WHERE turn_id=:turn_id ORDER BY created_at ASC, rank_no ASC" ), {"turn_id": turn_id}, ).fetchall() results: List[Dict[str, Any]] = [] for row in rows: data = _row_to_dict(row) _parse_json_fields(data, ["item_extra"]) data["used_in_reasoning"] = bool(data.get("used_in_reasoning")) data["used_in_sql"] = bool(data.get("used_in_sql")) results.append(data) return results # Metric registry def _metric_sql_hash(sql_text: str) -> str: """Compute a stable hash to detect SQL definition changes.""" return hashlib.md5(sql_text.encode("utf-8")).hexdigest() def create_metric(payload: MetricCreate) -> Dict[str, Any]: """Insert a new metric definition; version starts at 1.""" engine = get_engine() now = datetime.utcnow() sql_hash = _metric_sql_hash(payload.base_sql) params = { "metric_code": payload.metric_code, "metric_name": payload.metric_name, "metric_aliases": _json_dump(payload.metric_aliases), "biz_domain": payload.biz_domain, "biz_desc": payload.biz_desc, "chat_turn_id": payload.chat_turn_id, "tech_desc": payload.tech_desc, "formula_expr": payload.formula_expr, "base_sql": payload.base_sql, "time_grain": payload.time_grain, "dim_binding": _json_dump(payload.dim_binding), "update_strategy": payload.update_strategy, "schedule_id": payload.schedule_id, "schedule_type": payload.schedule_type, "version": 1, "is_active": 1 if payload.is_active else 0, "sql_hash": sql_hash, "created_by": payload.created_by, "updated_by": payload.updated_by, "created_at": now, "updated_at": now, } with engine.begin() as conn: result = conn.execute( text( """ INSERT INTO metric_def ( metric_code, metric_name, metric_aliases, biz_domain, biz_desc, chat_turn_id, tech_desc, formula_expr, base_sql, time_grain, dim_binding, update_strategy, schedule_id, schedule_type, version, is_active, sql_hash, created_by, updated_by, created_at, updated_at ) VALUES ( :metric_code, :metric_name, :metric_aliases, :biz_domain, :biz_desc, :chat_turn_id, :tech_desc, :formula_expr, :base_sql, :time_grain, :dim_binding, :update_strategy, :schedule_id, :schedule_type, :version, :is_active, :sql_hash, :created_by, :updated_by, :created_at, :updated_at ) """ ), params, ) metric_id = result.lastrowid row = conn.execute( text("SELECT * FROM metric_def WHERE id=:id"), {"id": metric_id} ).first() if not row: raise RuntimeError("Failed to create metric definition.") data = _row_to_dict(row) _parse_json_fields(data, ["metric_aliases", "dim_binding"]) data["is_active"] = bool(data.get("is_active")) return data def update_metric(metric_id: int, payload: MetricUpdate) -> Dict[str, Any]: """Update mutable fields of a metric definition and refresh sql_hash when needed.""" updates: Dict[str, Any] = {} for field in ( "metric_name", "biz_domain", "biz_desc", "tech_desc", "formula_expr", "base_sql", "time_grain", "update_strategy", "schedule_id", "schedule_type", "updated_by", ): value = getattr(payload, field) if value is not None: updates[field] = value if payload.metric_aliases is not None: updates["metric_aliases"] = _json_dump(payload.metric_aliases) if payload.dim_binding is not None: updates["dim_binding"] = _json_dump(payload.dim_binding) if payload.is_active is not None: updates["is_active"] = 1 if payload.is_active else 0 if payload.base_sql is not None: updates["sql_hash"] = _metric_sql_hash(payload.base_sql) if not updates: current = get_metric(metric_id) if not current: raise KeyError(f"Metric {metric_id} not found.") return current updates["updated_at"] = datetime.utcnow() set_clause = ", ".join(f"{key}=:{key}" for key in updates.keys()) params = dict(updates) params["id"] = metric_id engine = get_engine() with engine.begin() as conn: conn.execute( text(f"UPDATE metric_def SET {set_clause} WHERE id=:id"), params, ) row = conn.execute( text("SELECT * FROM metric_def WHERE id=:id"), {"id": metric_id} ).first() if not row: raise KeyError(f"Metric {metric_id} not found.") data = _row_to_dict(row) _parse_json_fields(data, ["metric_aliases", "dim_binding"]) data["is_active"] = bool(data.get("is_active")) return data def get_metric(metric_id: int) -> Optional[Dict[str, Any]]: engine = get_engine() with engine.begin() as conn: row = conn.execute( text("SELECT * FROM metric_def WHERE id=:id"), {"id": metric_id} ).first() if not row: return None data = _row_to_dict(row) _parse_json_fields(data, ["metric_aliases", "dim_binding"]) data["is_active"] = bool(data.get("is_active")) return data def list_metrics( *, biz_domain: Optional[str] = None, is_active: Optional[bool] = None, keyword: Optional[str] = None, limit: int = 100, offset: int = 0, ) -> List[Dict[str, Any]]: """List metric definitions with simple filters and pagination.""" conditions = [] params: Dict[str, Any] = {"limit": limit, "offset": offset} if biz_domain: conditions.append("biz_domain=:biz_domain") params["biz_domain"] = biz_domain if is_active is not None: conditions.append("is_active=:is_active") params["is_active"] = 1 if is_active else 0 if keyword: conditions.append("(metric_code LIKE :kw OR metric_name LIKE :kw)") params["kw"] = f"%{keyword}%" where_clause = "WHERE " + " AND ".join(conditions) if conditions else "" engine = get_engine() with engine.begin() as conn: rows = conn.execute( text( f"SELECT * FROM metric_def {where_clause} " "ORDER BY updated_at DESC LIMIT :limit OFFSET :offset" ), params, ).fetchall() results: List[Dict[str, Any]] = [] for row in rows: data = _row_to_dict(row) _parse_json_fields(data, ["metric_aliases", "dim_binding"]) data["is_active"] = bool(data.get("is_active")) results.append(data) return results # Metric schedules def create_metric_schedule(payload: MetricScheduleCreate) -> Dict[str, Any]: """Create a schedule record for a metric.""" engine = get_engine() params = { "metric_id": payload.metric_id, "cron_expr": payload.cron_expr, "enabled": 1 if payload.enabled else 0, "priority": payload.priority, "backfill_allowed": 1 if payload.backfill_allowed else 0, "max_runtime_sec": payload.max_runtime_sec, "retry_times": payload.retry_times, "owner_team": payload.owner_team, "owner_user_id": payload.owner_user_id, } with engine.begin() as conn: result = conn.execute( text( """ INSERT INTO metric_schedule ( metric_id, cron_expr, enabled, priority, backfill_allowed, max_runtime_sec, retry_times, owner_team, owner_user_id ) VALUES ( :metric_id, :cron_expr, :enabled, :priority, :backfill_allowed, :max_runtime_sec, :retry_times, :owner_team, :owner_user_id ) """ ), params, ) schedule_id = result.lastrowid row = conn.execute( text("SELECT * FROM metric_schedule WHERE id=:id"), {"id": schedule_id} ).first() if not row: raise RuntimeError("Failed to create metric schedule.") data = _row_to_dict(row) data["enabled"] = bool(data.get("enabled")) data["backfill_allowed"] = bool(data.get("backfill_allowed")) return data def update_metric_schedule(schedule_id: int, payload: MetricScheduleUpdate) -> Dict[str, Any]: updates: Dict[str, Any] = {} for field in ( "cron_expr", "priority", "max_runtime_sec", "retry_times", "owner_team", "owner_user_id", ): value = getattr(payload, field) if value is not None: updates[field] = value if payload.enabled is not None: updates["enabled"] = 1 if payload.enabled else 0 if payload.backfill_allowed is not None: updates["backfill_allowed"] = 1 if payload.backfill_allowed else 0 if not updates: current = list_schedules_for_metric(schedule_id=schedule_id) if current: return current[0] raise KeyError(f"Schedule {schedule_id} not found.") set_clause = ", ".join(f"{key}=:{key}" for key in updates.keys()) params = dict(updates) params["id"] = schedule_id engine = get_engine() with engine.begin() as conn: conn.execute( text(f"UPDATE metric_schedule SET {set_clause} WHERE id=:id"), params, ) row = conn.execute( text("SELECT * FROM metric_schedule WHERE id=:id"), {"id": schedule_id} ).first() if not row: raise KeyError(f"Schedule {schedule_id} not found.") data = _row_to_dict(row) data["enabled"] = bool(data.get("enabled")) data["backfill_allowed"] = bool(data.get("backfill_allowed")) return data def list_schedules_for_metric(metric_id: Optional[int] = None, schedule_id: Optional[int] = None) -> List[Dict[str, Any]]: conditions = [] params: Dict[str, Any] = {} if metric_id is not None: conditions.append("metric_id=:metric_id") params["metric_id"] = metric_id if schedule_id is not None: conditions.append("id=:id") params["id"] = schedule_id where_clause = "WHERE " + " AND ".join(conditions) if conditions else "" engine = get_engine() with engine.begin() as conn: rows = conn.execute( text(f"SELECT * FROM metric_schedule {where_clause} ORDER BY id DESC"), params, ).fetchall() results: List[Dict[str, Any]] = [] for row in rows: data = _row_to_dict(row) data["enabled"] = bool(data.get("enabled")) data["backfill_allowed"] = bool(data.get("backfill_allowed")) results.append(data) return results # Metric runs def trigger_metric_run(payload: MetricRunTrigger) -> Dict[str, Any]: """Create a metric_job_run entry; execution is orchestrated elsewhere.""" metric = get_metric(payload.metric_id) if not metric: raise KeyError(f"Metric {payload.metric_id} not found.") metric_version = payload.metric_version or metric.get("version", 1) base_sql_snapshot = payload.base_sql_snapshot or metric.get("base_sql") triggered_at = payload.triggered_at or datetime.utcnow() params = { "metric_id": payload.metric_id, "schedule_id": payload.schedule_id, "source_turn_id": payload.source_turn_id, "data_time_from": payload.data_time_from, "data_time_to": payload.data_time_to, "metric_version": metric_version, "base_sql_snapshot": base_sql_snapshot, "status": "RUNNING", "error_msg": None, "affected_rows": None, "runtime_ms": None, "triggered_by": payload.triggered_by, "triggered_at": triggered_at, "started_at": None, "finished_at": None, } engine = get_engine() with engine.begin() as conn: result = conn.execute( text( """ INSERT INTO metric_job_run ( metric_id, schedule_id, source_turn_id, data_time_from, data_time_to, metric_version, base_sql_snapshot, status, error_msg, affected_rows, runtime_ms, triggered_by, triggered_at, started_at, finished_at ) VALUES ( :metric_id, :schedule_id, :source_turn_id, :data_time_from, :data_time_to, :metric_version, :base_sql_snapshot, :status, :error_msg, :affected_rows, :runtime_ms, :triggered_by, :triggered_at, :started_at, :finished_at ) """ ), params, ) run_id = result.lastrowid row = conn.execute( text("SELECT * FROM metric_job_run WHERE id=:id"), {"id": run_id} ).first() if not row: raise RuntimeError("Failed to create metric job run.") return _row_to_dict(row) def get_metric_run(run_id: int) -> Optional[Dict[str, Any]]: engine = get_engine() with engine.begin() as conn: row = conn.execute( text("SELECT * FROM metric_job_run WHERE id=:id"), {"id": run_id} ).first() if not row: return None return _row_to_dict(row) def list_metric_runs( *, metric_id: Optional[int] = None, status: Optional[str] = None, limit: int = 100, offset: int = 0, ) -> List[Dict[str, Any]]: conditions = [] params: Dict[str, Any] = {"limit": limit, "offset": offset} if metric_id is not None: conditions.append("metric_id=:metric_id") params["metric_id"] = metric_id if status is not None: conditions.append("status=:status") params["status"] = status where_clause = "WHERE " + " AND ".join(conditions) if conditions else "" engine = get_engine() with engine.begin() as conn: rows = conn.execute( text( f"SELECT * FROM metric_job_run {where_clause} " "ORDER BY triggered_at DESC LIMIT :limit OFFSET :offset" ), params, ).fetchall() return [_row_to_dict(row) for row in rows] # Metric results def write_metric_results(payload: MetricResultsWriteRequest) -> int: """Bulk insert metric_result rows for a metric/version.""" metric = get_metric(payload.metric_id) if not metric: raise KeyError(f"Metric {payload.metric_id} not found.") default_version = metric.get("version", 1) now = datetime.utcnow() rows: List[Dict[str, Any]] = [] for item in payload.results: rows.append( { "metric_id": payload.metric_id, "metric_version": item.metric_version or default_version, "stat_time": item.stat_time, "extra_dims": _json_dump(item.extra_dims), "metric_value": item.metric_value, "load_time": item.load_time or now, "data_version": item.data_version, } ) if not rows: return 0 engine = get_engine() with engine.begin() as conn: conn.execute( text( """ INSERT INTO metric_result ( metric_id, metric_version, stat_time, extra_dims, metric_value, load_time, data_version ) VALUES ( :metric_id, :metric_version, :stat_time, :extra_dims, :metric_value, :load_time, :data_version ) """ ), rows, ) return len(rows) def query_metric_results( *, metric_id: int, stat_from: Optional[datetime] = None, stat_to: Optional[datetime] = None, limit: int = 200, offset: int = 0, ) -> List[Dict[str, Any]]: conditions = ["metric_id=:metric_id"] params: Dict[str, Any] = { "metric_id": metric_id, "limit": limit, "offset": offset, } if stat_from is not None: conditions.append("stat_time>=:stat_from") params["stat_from"] = stat_from if stat_to is not None: conditions.append("stat_time<=:stat_to") params["stat_to"] = stat_to where_clause = "WHERE " + " AND ".join(conditions) engine = get_engine() with engine.begin() as conn: rows = conn.execute( text( f"SELECT * FROM metric_result {where_clause} " "ORDER BY stat_time DESC LIMIT :limit OFFSET :offset" ), params, ).fetchall() results: List[Dict[str, Any]] = [] for row in rows: data = _row_to_dict(row) _parse_json_fields(data, ["extra_dims"]) results.append(data) return results def latest_metric_result(metric_id: int) -> Optional[Dict[str, Any]]: engine = get_engine() with engine.begin() as conn: row = conn.execute( text( """ SELECT * FROM metric_result WHERE metric_id=:metric_id ORDER BY stat_time DESC LIMIT 1 """ ), {"metric_id": metric_id}, ).first() if not row: return None data = _row_to_dict(row) _parse_json_fields(data, ["extra_dims"]) return data