from __future__ import annotations import hashlib import json import logging from datetime import datetime from typing import Any, Dict, List, Optional, Sequence, Tuple from sqlalchemy import text from sqlalchemy.engine import Engine from sqlalchemy.exc import SQLAlchemyError from app.db import get_engine from app.models import ActionType, TableSnippetUpsertRequest, TableSnippetUpsertResponse from app.schemas.rag import RagItemPayload from app.services.rag_client import RagAPIClient logger = logging.getLogger(__name__) def _serialize_json(value: Any) -> Tuple[str | None, int | None]: logger.debug("Serializing JSON payload: %s", value) if value is None: return None, None if isinstance(value, str): encoded = value.encode("utf-8") return value, len(encoded) serialized = json.dumps(value, ensure_ascii=False) encoded = serialized.encode("utf-8") return serialized, len(encoded) def _prepare_table_schema(value: Any) -> str: logger.debug("Preparing table_schema payload.") if isinstance(value, str): return value return json.dumps(value, ensure_ascii=False) def _prepare_model_params(params: Dict[str, Any] | None) -> str | None: if not params: return None serialized, _ = _serialize_json(params) return serialized def _collect_common_columns(request: TableSnippetUpsertRequest) -> Dict[str, Any]: # Build the base column set shared by all action types; action-specific fields are populated later. logger.debug( "Collecting common columns for table_id=%s version_ts=%s action_type=%s", request.table_id, request.version_ts, request.action_type, ) payload: Dict[str, Any] = { "table_id": request.table_id, "version_ts": request.version_ts, "action_type": request.action_type.value, "status": request.status.value, "callback_url": str(request.callback_url), "table_schema_version_id": request.table_schema_version_id, "table_schema": _prepare_table_schema(request.table_schema), "model": request.model, "model_provider": request.model_provider, } payload.update( { "ge_profiling_json": None, "ge_profiling_json_size_bytes": None, "ge_profiling_summary": None, "ge_profiling_summary_size_bytes": None, "ge_profiling_total_size_bytes": None, "ge_profiling_html_report_url": None, "ge_result_desc_json": None, "ge_result_desc_json_size_bytes": None, "snippet_json": None, "snippet_json_size_bytes": None, "snippet_alias_json": None, "snippet_alias_json_size_bytes": None, } ) payload["model_params"] = _prepare_model_params(request.model_params) if request.llm_usage is not None: llm_usage_json, _ = _serialize_json(request.llm_usage) if llm_usage_json is not None: payload["llm_usage"] = llm_usage_json if request.error_code is not None: logger.debug("Adding error_code: %s", request.error_code) payload["error_code"] = request.error_code if request.error_message is not None: logger.debug("Adding error_message: %s", request.error_message) payload["error_message"] = request.error_message if request.started_at is not None: payload["started_at"] = request.started_at if request.finished_at is not None: payload["finished_at"] = request.finished_at if request.duration_ms is not None: payload["duration_ms"] = request.duration_ms if request.result_checksum is not None: payload["result_checksum"] = request.result_checksum logger.debug("Collected common payload: %s", payload) return payload def _apply_action_payload( request: TableSnippetUpsertRequest, payload: Dict[str, Any], ) -> None: logger.debug("Applying action-specific payload for action_type=%s", request.action_type) if request.action_type == ActionType.GE_PROFILING: full_json, full_size = _serialize_json(request.ge_profiling_json) summary_json, summary_size = _serialize_json(request.ge_profiling_summary) if full_json is not None: payload["ge_profiling_json"] = full_json payload["ge_profiling_json_size_bytes"] = full_size if summary_json is not None: payload["ge_profiling_summary"] = summary_json payload["ge_profiling_summary_size_bytes"] = summary_size if request.ge_profiling_total_size_bytes is not None: payload["ge_profiling_total_size_bytes"] = request.ge_profiling_total_size_bytes elif full_size is not None or summary_size is not None: payload["ge_profiling_total_size_bytes"] = (full_size or 0) + (summary_size or 0) if request.ge_profiling_html_report_url: payload["ge_profiling_html_report_url"] = request.ge_profiling_html_report_url elif request.action_type == ActionType.GE_RESULT_DESC: full_json, full_size = _serialize_json(request.ge_result_desc_json) if full_json is not None: payload["ge_result_desc_json"] = full_json payload["ge_result_desc_json_size_bytes"] = full_size elif request.action_type == ActionType.SNIPPET: full_json, full_size = _serialize_json(request.snippet_json) if full_json is not None: payload["snippet_json"] = full_json payload["snippet_json_size_bytes"] = full_size elif request.action_type == ActionType.SNIPPET_ALIAS: full_json, full_size = _serialize_json(request.snippet_alias_json) if full_json is not None: payload["snippet_alias_json"] = full_json payload["snippet_alias_json_size_bytes"] = full_size else: logger.error("Unsupported action type encountered: %s", request.action_type) raise ValueError(f"Unsupported action type '{request.action_type}'.") logger.debug("Payload after applying action-specific data: %s", payload) def _build_insert_statement(columns: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]: logger.debug("Building insert statement for columns: %s", list(columns.keys())) column_names = list(columns.keys()) placeholders = [f":{name}" for name in column_names] update_assignments = [ f"{name}=VALUES({name})" for name in column_names if name not in {"table_id", "version_ts", "action_type"} ] update_assignments.append("updated_at=CURRENT_TIMESTAMP") sql = ( "INSERT INTO action_results ({cols}) VALUES ({vals}) " "ON DUPLICATE KEY UPDATE {updates}" ).format( cols=", ".join(column_names), vals=", ".join(placeholders), updates=", ".join(update_assignments), ) logger.debug("Generated SQL: %s", sql) return sql, columns def _execute_upsert(engine: Engine, sql: str, params: Dict[str, Any]) -> int: logger.info("Executing upsert for table_id=%s version_ts=%s action_type=%s", params.get("table_id"), params.get("version_ts"), params.get("action_type")) with engine.begin() as conn: result = conn.execute(text(sql), params) logger.info("Rows affected: %s", result.rowcount) return result.rowcount def upsert_action_result(request: TableSnippetUpsertRequest) -> TableSnippetUpsertResponse: logger.info( "Received upsert request: table_id=%s version_ts=%s action_type=%s status=%s", request.table_id, request.version_ts, request.action_type, request.status, ) logger.debug("Request payload: %s", request.model_dump()) columns = _collect_common_columns(request) _apply_action_payload(request, columns) sql, params = _build_insert_statement(columns) logger.debug("Final SQL params: %s", params) engine = get_engine() try: rowcount = _execute_upsert(engine, sql, params) except SQLAlchemyError as exc: logger.exception( "Failed to upsert action result: table_id=%s version_ts=%s action_type=%s", request.table_id, request.version_ts, request.action_type, ) raise RuntimeError(f"Database operation failed: {exc}") from exc updated = rowcount > 1 return TableSnippetUpsertResponse( table_id=request.table_id, version_ts=request.version_ts, action_type=request.action_type, status=request.status, updated=updated, ) def _decode_json_field(value: Any) -> Any: """Decode JSON columns that may be returned as str/bytes/dicts/lists.""" if value is None: return None if isinstance(value, (dict, list)): return value if isinstance(value, (bytes, bytearray)): try: value = value.decode("utf-8") except Exception: # pragma: no cover - defensive return None if isinstance(value, str): try: return json.loads(value) except json.JSONDecodeError: logger.warning("Failed to decode JSON field: %s", value) return None return None def _coerce_json_array(value: Any) -> List[Any]: decoded = _decode_json_field(value) return decoded if isinstance(decoded, list) else [] def _fetch_action_payload( engine: Engine, table_id: int, version_ts: int, action_type: ActionType ) -> Optional[Dict[str, Any]]: sql = text( """ SELECT id AS action_result_id, snippet_json, snippet_alias_json, updated_at, status FROM action_results WHERE table_id = :table_id AND version_ts = :version_ts AND action_type = :action_type AND status IN ('success', 'partial') ORDER BY CASE status WHEN 'success' THEN 0 ELSE 1 END, updated_at DESC LIMIT 1 """ ) with engine.connect() as conn: row = conn.execute( sql, { "table_id": table_id, "version_ts": version_ts, "action_type": action_type.value, }, ).mappings().first() return dict(row) if row else None def _load_snippet_sources( engine: Engine, table_id: int, version_ts: int ) -> Tuple[List[Any], List[Any], Optional[datetime], Optional[int], Optional[int]]: alias_row = _fetch_action_payload(engine, table_id, version_ts, ActionType.SNIPPET_ALIAS) snippet_row = _fetch_action_payload(engine, table_id, version_ts, ActionType.SNIPPET) snippet_json = _coerce_json_array(alias_row.get("snippet_json") if alias_row else None) alias_json = _coerce_json_array(alias_row.get("snippet_alias_json") if alias_row else None) updated_at: Optional[datetime] = alias_row.get("updated_at") if alias_row else None alias_action_id: Optional[int] = alias_row.get("action_result_id") if alias_row else None snippet_action_id: Optional[int] = snippet_row.get("action_result_id") if snippet_row else None if not snippet_json and snippet_row: snippet_json = _coerce_json_array(snippet_row.get("snippet_json")) if updated_at is None: updated_at = snippet_row.get("updated_at") if alias_action_id is None: alias_action_id = snippet_action_id if not updated_at and alias_row: updated_at = alias_row.get("updated_at") return snippet_json, alias_json, updated_at, alias_action_id, snippet_action_id def _normalize_aliases(raw_aliases: Any) -> List[Dict[str, Any]]: aliases: List[Dict[str, Any]] = [] seen: set[str] = set() if not raw_aliases: return aliases if not isinstance(raw_aliases, list): return aliases for item in raw_aliases: if isinstance(item, dict): text_val = item.get("text") if not text_val or text_val in seen: continue seen.add(text_val) aliases.append({"text": text_val, "tone": item.get("tone")}) elif isinstance(item, str): if item in seen: continue seen.add(item) aliases.append({"text": item}) return aliases def _normalize_str_list(values: Any) -> List[str]: if not values: return [] if not isinstance(values, list): return [] seen: set[str] = set() normalised: List[str] = [] for val in values: if not isinstance(val, str): continue if val in seen: continue seen.add(val) normalised.append(val) return normalised def _merge_alias_lists(primary: List[Dict[str, Any]], secondary: List[Dict[str, Any]]) -> List[Dict[str, Any]]: merged: List[Dict[str, Any]] = [] seen: set[str] = set() for source in (primary, secondary): for item in source: if not isinstance(item, dict): continue text_val = item.get("text") if not text_val or text_val in seen: continue seen.add(text_val) merged.append({"text": text_val, "tone": item.get("tone")}) return merged def _merge_str_lists(primary: List[str], secondary: List[str]) -> List[str]: merged: List[str] = [] seen: set[str] = set() for source in (primary, secondary): for item in source: if item in seen: continue seen.add(item) merged.append(item) return merged def _build_alias_map(alias_payload: List[Any]) -> Dict[str, Dict[str, Any]]: alias_map: Dict[str, Dict[str, Any]] = {} for item in alias_payload: if not isinstance(item, dict): continue alias_id = item.get("id") if not alias_id: continue existing = alias_map.setdefault( alias_id, {"aliases": [], "keywords": [], "intent_tags": []}, ) existing["aliases"] = _merge_alias_lists( existing["aliases"], _normalize_aliases(item.get("aliases")) ) existing["keywords"] = _merge_str_lists( existing["keywords"], _normalize_str_list(item.get("keywords")) ) existing["intent_tags"] = _merge_str_lists( existing["intent_tags"], _normalize_str_list(item.get("intent_tags")) ) return alias_map def merge_snippet_records_from_db( table_id: int, version_ts: int, *, engine: Optional[Engine] = None, ) -> List[Dict[str, Any]]: """ Load snippet + snippet_alias JSON from action_results after snippet_alias is stored, then merge into a unified snippet object list ready for downstream RAG. """ engine = engine or get_engine() snippets, aliases, updated_at, alias_action_id, snippet_action_id = _load_snippet_sources( engine, table_id, version_ts ) alias_map = _build_alias_map(aliases) merged: List[Dict[str, Any]] = [] seen_ids: set[str] = set() for snippet in snippets: if not isinstance(snippet, dict): continue snippet_id = snippet.get("id") if not snippet_id: continue alias_info = alias_map.get(snippet_id) record = dict(snippet) record_aliases = _normalize_aliases(record.get("aliases")) record_keywords = _normalize_str_list(record.get("keywords")) record_intents = _normalize_str_list(record.get("intent_tags")) if alias_info: record_aliases = _merge_alias_lists(record_aliases, alias_info["aliases"]) record_keywords = _merge_str_lists(record_keywords, alias_info["keywords"]) record_intents = _merge_str_lists(record_intents, alias_info["intent_tags"]) record["aliases"] = record_aliases record["keywords"] = record_keywords record["intent_tags"] = record_intents record["table_id"] = table_id record["version_ts"] = version_ts record["updated_at_from_action"] = updated_at record["source"] = "snippet" record["action_result_id"] = alias_action_id or snippet_action_id merged.append(record) seen_ids.add(snippet_id) for alias_id, alias_info in alias_map.items(): if alias_id in seen_ids: continue if alias_action_id is None and snippet_action_id is None: continue merged.append( { "id": alias_id, "aliases": alias_info["aliases"], "keywords": alias_info["keywords"], "intent_tags": alias_info["intent_tags"], "table_id": table_id, "version_ts": version_ts, "updated_at_from_action": updated_at, "source": "alias_only", "action_result_id": alias_action_id or snippet_action_id, } ) return merged def _stable_rag_item_id(table_id: int, version_ts: int, snippet_id: str) -> int: digest = hashlib.md5(f"{table_id}:{version_ts}:{snippet_id}".encode("utf-8")).hexdigest() return int(digest[:16], 16) % 9_000_000_000_000_000_000 def _to_serializable(value: Any) -> Any: if value is None or isinstance(value, (str, int, float, bool)): return value if isinstance(value, datetime): return value.isoformat() if isinstance(value, dict): return {k: _to_serializable(v) for k, v in value.items()} if isinstance(value, list): return [_to_serializable(v) for v in value] return str(value) def _build_rag_text(snippet: Dict[str, Any]) -> str: # Deterministic text concatenation for embedding input. parts: List[str] = [] def _add(label: str, value: Any) -> None: if value is None: return if isinstance(value, list): value = ", ".join([str(v) for v in value if v]) elif isinstance(value, dict): value = json.dumps(value, ensure_ascii=False) if value: parts.append(f"{label}: {value}") _add("Title", snippet.get("title") or snippet.get("id")) _add("Description", snippet.get("desc")) _add("Business", snippet.get("business_caliber")) _add("Type", snippet.get("type")) _add("Examples", snippet.get("examples") or []) _add("Aliases", [a.get("text") for a in snippet.get("aliases") or [] if isinstance(a, dict)]) _add("Keywords", snippet.get("keywords") or []) _add("IntentTags", snippet.get("intent_tags") or []) _add("Applicability", snippet.get("applicability")) _add("DialectSQL", snippet.get("dialect_sql")) return "\n".join(parts) def _prepare_rag_payloads( snippets: List[Dict[str, Any]], table_id: int, version_ts: int, workspace_id: int, rag_item_type: str = "SNIPPET", ) -> Tuple[List[Dict[str, Any]], List[RagItemPayload]]: rows: List[Dict[str, Any]] = [] payloads: List[RagItemPayload] = [] now = datetime.utcnow() for snippet in snippets: snippet_id = snippet.get("id") if not snippet_id: continue action_result_id = snippet.get("action_result_id") if action_result_id is None: logger.warning( "Skipping snippet without action_result_id for RAG ingestion (table_id=%s version_ts=%s snippet_id=%s)", table_id, version_ts, snippet_id, ) continue rag_item_id = _stable_rag_item_id(table_id, version_ts, snippet_id) rag_text = _build_rag_text(snippet) serializable_snippet = _to_serializable(snippet) merged_json = json.dumps(serializable_snippet, ensure_ascii=False) updated_at_raw = snippet.get("updated_at_from_action") or now if isinstance(updated_at_raw, str): try: updated_at = datetime.fromisoformat(updated_at_raw) except ValueError: updated_at = now else: updated_at = updated_at_raw if isinstance(updated_at_raw, datetime) else now created_at = updated_at row = { "rag_item_id": rag_item_id, "workspace_id": workspace_id, "table_id": table_id, "version_ts": version_ts, "created_at": created_at, "action_result_id": action_result_id, "snippet_id": snippet_id, "rag_text": rag_text, "merged_json": merged_json, "updated_at": updated_at, } rows.append(row) payloads.append( RagItemPayload( id=rag_item_id, workspaceId=workspace_id, name=snippet.get("title") or snippet_id, embeddingData=rag_text, type=rag_item_type or "SNIPPET", ) ) return rows, payloads def _upsert_rag_snippet_rows(engine: Engine, rows: Sequence[Dict[str, Any]]) -> None: if not rows: return delete_sql = text("DELETE FROM rag_snippet WHERE rag_item_id=:rag_item_id") insert_sql = text( """ INSERT INTO rag_snippet ( rag_item_id, workspace_id, table_id, version_ts, created_at, action_result_id, snippet_id, rag_text, merged_json, updated_at ) VALUES ( :rag_item_id, :workspace_id, :table_id, :version_ts, :created_at, :action_result_id, :snippet_id, :rag_text, :merged_json, :updated_at ) """ ) with engine.begin() as conn: for row in rows: conn.execute(delete_sql, row) conn.execute(insert_sql, row) async def ingest_snippet_rag_from_db( table_id: int, version_ts: int, *, workspace_id: int, rag_item_type: str = "SNIPPET", client, engine: Optional[Engine] = None, rag_client: Optional[RagAPIClient] = None, ) -> List[int]: """ Merge snippet + alias JSON from action_results, persist to rag_snippet, then push to RAG via addBatch. Returns list of rag_item_id ingested. """ engine = engine or get_engine() snippets = merge_snippet_records_from_db(table_id, version_ts, engine=engine) if not snippets: logger.info( "No snippets available for RAG ingestion (table_id=%s version_ts=%s)", table_id, version_ts, ) return [] rows, payloads = _prepare_rag_payloads( snippets, table_id=table_id, version_ts=version_ts, workspace_id=workspace_id, rag_item_type=rag_item_type, ) _upsert_rag_snippet_rows(engine, rows) rag_client = rag_client or RagAPIClient() await rag_client.add_batch(client, payloads) return [row["rag_item_id"] for row in rows]