Files
data-ge/app/services/table_snippet.py
2025-12-09 00:36:02 +08:00

628 lines
22 KiB
Python

from __future__ import annotations
import hashlib
import json
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional, Sequence, Tuple
from sqlalchemy import text
from sqlalchemy.engine import Engine
from sqlalchemy.exc import SQLAlchemyError
from app.db import get_engine
from app.models import ActionType, TableSnippetUpsertRequest, TableSnippetUpsertResponse
from app.schemas.rag import RagItemPayload
from app.services.rag_client import RagAPIClient
logger = logging.getLogger(__name__)
def _serialize_json(value: Any) -> Tuple[str | None, int | None]:
logger.debug("Serializing JSON payload: %s", value)
if value is None:
return None, None
if isinstance(value, str):
encoded = value.encode("utf-8")
return value, len(encoded)
serialized = json.dumps(value, ensure_ascii=False)
encoded = serialized.encode("utf-8")
return serialized, len(encoded)
def _prepare_table_schema(value: Any) -> str:
logger.debug("Preparing table_schema payload.")
if isinstance(value, str):
return value
return json.dumps(value, ensure_ascii=False)
def _prepare_model_params(params: Dict[str, Any] | None) -> str | None:
if not params:
return None
serialized, _ = _serialize_json(params)
return serialized
def _collect_common_columns(request: TableSnippetUpsertRequest) -> Dict[str, Any]:
# Build the base column set shared by all action types; action-specific fields are populated later.
logger.debug(
"Collecting common columns for table_id=%s version_ts=%s action_type=%s",
request.table_id,
request.version_ts,
request.action_type,
)
payload: Dict[str, Any] = {
"table_id": request.table_id,
"version_ts": request.version_ts,
"action_type": request.action_type.value,
"status": request.status.value,
"callback_url": str(request.callback_url),
"table_schema_version_id": request.table_schema_version_id,
"table_schema": _prepare_table_schema(request.table_schema),
"model": request.model,
"model_provider": request.model_provider,
}
payload.update(
{
"ge_profiling_json": None,
"ge_profiling_json_size_bytes": None,
"ge_profiling_summary": None,
"ge_profiling_summary_size_bytes": None,
"ge_profiling_total_size_bytes": None,
"ge_profiling_html_report_url": None,
"ge_result_desc_json": None,
"ge_result_desc_json_size_bytes": None,
"snippet_json": None,
"snippet_json_size_bytes": None,
"snippet_alias_json": None,
"snippet_alias_json_size_bytes": None,
}
)
payload["model_params"] = _prepare_model_params(request.model_params)
if request.llm_usage is not None:
llm_usage_json, _ = _serialize_json(request.llm_usage)
if llm_usage_json is not None:
payload["llm_usage"] = llm_usage_json
if request.error_code is not None:
logger.debug("Adding error_code: %s", request.error_code)
payload["error_code"] = request.error_code
if request.error_message is not None:
logger.debug("Adding error_message: %s", request.error_message)
payload["error_message"] = request.error_message
if request.started_at is not None:
payload["started_at"] = request.started_at
if request.finished_at is not None:
payload["finished_at"] = request.finished_at
if request.duration_ms is not None:
payload["duration_ms"] = request.duration_ms
if request.result_checksum is not None:
payload["result_checksum"] = request.result_checksum
logger.debug("Collected common payload: %s", payload)
return payload
def _apply_action_payload(
request: TableSnippetUpsertRequest,
payload: Dict[str, Any],
) -> None:
logger.debug("Applying action-specific payload for action_type=%s", request.action_type)
if request.action_type == ActionType.GE_PROFILING:
full_json, full_size = _serialize_json(request.ge_profiling_json)
summary_json, summary_size = _serialize_json(request.ge_profiling_summary)
if full_json is not None:
payload["ge_profiling_json"] = full_json
payload["ge_profiling_json_size_bytes"] = full_size
if summary_json is not None:
payload["ge_profiling_summary"] = summary_json
payload["ge_profiling_summary_size_bytes"] = summary_size
if request.ge_profiling_total_size_bytes is not None:
payload["ge_profiling_total_size_bytes"] = request.ge_profiling_total_size_bytes
elif full_size is not None or summary_size is not None:
payload["ge_profiling_total_size_bytes"] = (full_size or 0) + (summary_size or 0)
if request.ge_profiling_html_report_url:
payload["ge_profiling_html_report_url"] = request.ge_profiling_html_report_url
elif request.action_type == ActionType.GE_RESULT_DESC:
full_json, full_size = _serialize_json(request.ge_result_desc_json)
if full_json is not None:
payload["ge_result_desc_json"] = full_json
payload["ge_result_desc_json_size_bytes"] = full_size
elif request.action_type == ActionType.SNIPPET:
full_json, full_size = _serialize_json(request.snippet_json)
if full_json is not None:
payload["snippet_json"] = full_json
payload["snippet_json_size_bytes"] = full_size
elif request.action_type == ActionType.SNIPPET_ALIAS:
full_json, full_size = _serialize_json(request.snippet_alias_json)
if full_json is not None:
payload["snippet_alias_json"] = full_json
payload["snippet_alias_json_size_bytes"] = full_size
else:
logger.error("Unsupported action type encountered: %s", request.action_type)
raise ValueError(f"Unsupported action type '{request.action_type}'.")
logger.debug("Payload after applying action-specific data: %s", payload)
def _build_insert_statement(columns: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
logger.debug("Building insert statement for columns: %s", list(columns.keys()))
column_names = list(columns.keys())
placeholders = [f":{name}" for name in column_names]
update_assignments = [
f"{name}=VALUES({name})"
for name in column_names
if name not in {"table_id", "version_ts", "action_type"}
]
update_assignments.append("updated_at=CURRENT_TIMESTAMP")
sql = (
"INSERT INTO action_results ({cols}) VALUES ({vals}) "
"ON DUPLICATE KEY UPDATE {updates}"
).format(
cols=", ".join(column_names),
vals=", ".join(placeholders),
updates=", ".join(update_assignments),
)
logger.debug("Generated SQL: %s", sql)
return sql, columns
def _execute_upsert(engine: Engine, sql: str, params: Dict[str, Any]) -> int:
logger.info("Executing upsert for table_id=%s version_ts=%s action_type=%s", params.get("table_id"), params.get("version_ts"), params.get("action_type"))
with engine.begin() as conn:
result = conn.execute(text(sql), params)
logger.info("Rows affected: %s", result.rowcount)
return result.rowcount
def upsert_action_result(request: TableSnippetUpsertRequest) -> TableSnippetUpsertResponse:
logger.info(
"Received upsert request: table_id=%s version_ts=%s action_type=%s status=%s",
request.table_id,
request.version_ts,
request.action_type,
request.status,
)
logger.debug("Request payload: %s", request.model_dump())
columns = _collect_common_columns(request)
_apply_action_payload(request, columns)
sql, params = _build_insert_statement(columns)
logger.debug("Final SQL params: %s", params)
engine = get_engine()
try:
rowcount = _execute_upsert(engine, sql, params)
except SQLAlchemyError as exc:
logger.exception(
"Failed to upsert action result: table_id=%s version_ts=%s action_type=%s",
request.table_id,
request.version_ts,
request.action_type,
)
raise RuntimeError(f"Database operation failed: {exc}") from exc
updated = rowcount > 1
return TableSnippetUpsertResponse(
table_id=request.table_id,
version_ts=request.version_ts,
action_type=request.action_type,
status=request.status,
updated=updated,
)
def _decode_json_field(value: Any) -> Any:
"""Decode JSON columns that may be returned as str/bytes/dicts/lists."""
if value is None:
return None
if isinstance(value, (dict, list)):
return value
if isinstance(value, (bytes, bytearray)):
try:
value = value.decode("utf-8")
except Exception: # pragma: no cover - defensive
return None
if isinstance(value, str):
try:
return json.loads(value)
except json.JSONDecodeError:
logger.warning("Failed to decode JSON field: %s", value)
return None
return None
def _coerce_json_array(value: Any) -> List[Any]:
decoded = _decode_json_field(value)
return decoded if isinstance(decoded, list) else []
def _fetch_action_payload(
engine: Engine, table_id: int, version_ts: int, action_type: ActionType
) -> Optional[Dict[str, Any]]:
sql = text(
"""
SELECT id AS action_result_id, snippet_json, snippet_alias_json, updated_at, status
FROM action_results
WHERE table_id = :table_id
AND version_ts = :version_ts
AND action_type = :action_type
AND status IN ('success', 'partial')
ORDER BY CASE status WHEN 'success' THEN 0 ELSE 1 END, updated_at DESC
LIMIT 1
"""
)
with engine.connect() as conn:
row = conn.execute(
sql,
{
"table_id": table_id,
"version_ts": version_ts,
"action_type": action_type.value,
},
).mappings().first()
return dict(row) if row else None
def _load_snippet_sources(
engine: Engine, table_id: int, version_ts: int
) -> Tuple[List[Any], List[Any], Optional[datetime], Optional[int], Optional[int]]:
alias_row = _fetch_action_payload(engine, table_id, version_ts, ActionType.SNIPPET_ALIAS)
snippet_row = _fetch_action_payload(engine, table_id, version_ts, ActionType.SNIPPET)
snippet_json = _coerce_json_array(alias_row.get("snippet_json") if alias_row else None)
alias_json = _coerce_json_array(alias_row.get("snippet_alias_json") if alias_row else None)
updated_at: Optional[datetime] = alias_row.get("updated_at") if alias_row else None
alias_action_id: Optional[int] = alias_row.get("action_result_id") if alias_row else None
snippet_action_id: Optional[int] = snippet_row.get("action_result_id") if snippet_row else None
if not snippet_json and snippet_row:
snippet_json = _coerce_json_array(snippet_row.get("snippet_json"))
if updated_at is None:
updated_at = snippet_row.get("updated_at")
if alias_action_id is None:
alias_action_id = snippet_action_id
if not updated_at and alias_row:
updated_at = alias_row.get("updated_at")
return snippet_json, alias_json, updated_at, alias_action_id, snippet_action_id
def _normalize_aliases(raw_aliases: Any) -> List[Dict[str, Any]]:
aliases: List[Dict[str, Any]] = []
seen: set[str] = set()
if not raw_aliases:
return aliases
if not isinstance(raw_aliases, list):
return aliases
for item in raw_aliases:
if isinstance(item, dict):
text_val = item.get("text")
if not text_val or text_val in seen:
continue
seen.add(text_val)
aliases.append({"text": text_val, "tone": item.get("tone")})
elif isinstance(item, str):
if item in seen:
continue
seen.add(item)
aliases.append({"text": item})
return aliases
def _normalize_str_list(values: Any) -> List[str]:
if not values:
return []
if not isinstance(values, list):
return []
seen: set[str] = set()
normalised: List[str] = []
for val in values:
if not isinstance(val, str):
continue
if val in seen:
continue
seen.add(val)
normalised.append(val)
return normalised
def _merge_alias_lists(primary: List[Dict[str, Any]], secondary: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
merged: List[Dict[str, Any]] = []
seen: set[str] = set()
for source in (primary, secondary):
for item in source:
if not isinstance(item, dict):
continue
text_val = item.get("text")
if not text_val or text_val in seen:
continue
seen.add(text_val)
merged.append({"text": text_val, "tone": item.get("tone")})
return merged
def _merge_str_lists(primary: List[str], secondary: List[str]) -> List[str]:
merged: List[str] = []
seen: set[str] = set()
for source in (primary, secondary):
for item in source:
if item in seen:
continue
seen.add(item)
merged.append(item)
return merged
def _build_alias_map(alias_payload: List[Any]) -> Dict[str, Dict[str, Any]]:
alias_map: Dict[str, Dict[str, Any]] = {}
for item in alias_payload:
if not isinstance(item, dict):
continue
alias_id = item.get("id")
if not alias_id:
continue
existing = alias_map.setdefault(
alias_id,
{"aliases": [], "keywords": [], "intent_tags": []},
)
existing["aliases"] = _merge_alias_lists(
existing["aliases"], _normalize_aliases(item.get("aliases"))
)
existing["keywords"] = _merge_str_lists(
existing["keywords"], _normalize_str_list(item.get("keywords"))
)
existing["intent_tags"] = _merge_str_lists(
existing["intent_tags"], _normalize_str_list(item.get("intent_tags"))
)
return alias_map
def merge_snippet_records_from_db(
table_id: int,
version_ts: int,
*,
engine: Optional[Engine] = None,
) -> List[Dict[str, Any]]:
"""
Load snippet + snippet_alias JSON from action_results after snippet_alias is stored,
then merge into a unified snippet object list ready for downstream RAG.
"""
engine = engine or get_engine()
snippets, aliases, updated_at, alias_action_id, snippet_action_id = _load_snippet_sources(
engine, table_id, version_ts
)
alias_map = _build_alias_map(aliases)
merged: List[Dict[str, Any]] = []
seen_ids: set[str] = set()
for snippet in snippets:
if not isinstance(snippet, dict):
continue
snippet_id = snippet.get("id")
if not snippet_id:
continue
alias_info = alias_map.get(snippet_id)
record = dict(snippet)
record_aliases = _normalize_aliases(record.get("aliases"))
record_keywords = _normalize_str_list(record.get("keywords"))
record_intents = _normalize_str_list(record.get("intent_tags"))
if alias_info:
record_aliases = _merge_alias_lists(record_aliases, alias_info["aliases"])
record_keywords = _merge_str_lists(record_keywords, alias_info["keywords"])
record_intents = _merge_str_lists(record_intents, alias_info["intent_tags"])
record["aliases"] = record_aliases
record["keywords"] = record_keywords
record["intent_tags"] = record_intents
record["table_id"] = table_id
record["version_ts"] = version_ts
record["updated_at_from_action"] = updated_at
record["source"] = "snippet"
record["action_result_id"] = alias_action_id or snippet_action_id
merged.append(record)
seen_ids.add(snippet_id)
for alias_id, alias_info in alias_map.items():
if alias_id in seen_ids:
continue
if alias_action_id is None and snippet_action_id is None:
continue
merged.append(
{
"id": alias_id,
"aliases": alias_info["aliases"],
"keywords": alias_info["keywords"],
"intent_tags": alias_info["intent_tags"],
"table_id": table_id,
"version_ts": version_ts,
"updated_at_from_action": updated_at,
"source": "alias_only",
"action_result_id": alias_action_id or snippet_action_id,
}
)
return merged
def _stable_rag_item_id(table_id: int, version_ts: int, snippet_id: str) -> int:
digest = hashlib.md5(f"{table_id}:{version_ts}:{snippet_id}".encode("utf-8")).hexdigest()
return int(digest[:16], 16) % 9_000_000_000_000_000_000
def _build_rag_text(snippet: Dict[str, Any]) -> str:
# Deterministic text concatenation for embedding input.
parts: List[str] = []
def _add(label: str, value: Any) -> None:
if value is None:
return
if isinstance(value, list):
value = ", ".join([str(v) for v in value if v])
elif isinstance(value, dict):
value = json.dumps(value, ensure_ascii=False)
if value:
parts.append(f"{label}: {value}")
_add("Title", snippet.get("title") or snippet.get("id"))
_add("Description", snippet.get("desc"))
_add("Business", snippet.get("business_caliber"))
_add("Type", snippet.get("type"))
_add("Examples", snippet.get("examples") or [])
_add("Aliases", [a.get("text") for a in snippet.get("aliases") or [] if isinstance(a, dict)])
_add("Keywords", snippet.get("keywords") or [])
_add("IntentTags", snippet.get("intent_tags") or [])
_add("Applicability", snippet.get("applicability"))
_add("DialectSQL", snippet.get("dialect_sql"))
return "\n".join(parts)
def _prepare_rag_payloads(
snippets: List[Dict[str, Any]],
table_id: int,
version_ts: int,
workspace_id: int,
rag_item_type: str = "SNIPPET",
) -> Tuple[List[Dict[str, Any]], List[RagItemPayload]]:
rows: List[Dict[str, Any]] = []
payloads: List[RagItemPayload] = []
now = datetime.utcnow()
for snippet in snippets:
snippet_id = snippet.get("id")
if not snippet_id:
continue
action_result_id = snippet.get("action_result_id")
if action_result_id is None:
logger.warning(
"Skipping snippet without action_result_id for RAG ingestion (table_id=%s version_ts=%s snippet_id=%s)",
table_id,
version_ts,
snippet_id,
)
continue
rag_item_id = _stable_rag_item_id(table_id, version_ts, snippet_id)
rag_text = _build_rag_text(snippet)
merged_json = json.dumps(snippet, ensure_ascii=False)
updated_at_raw = snippet.get("updated_at_from_action") or now
if isinstance(updated_at_raw, str):
try:
updated_at = datetime.fromisoformat(updated_at_raw)
except ValueError:
updated_at = now
else:
updated_at = updated_at_raw if isinstance(updated_at_raw, datetime) else now
created_at = updated_at
row = {
"rag_item_id": rag_item_id,
"workspace_id": workspace_id,
"table_id": table_id,
"version_ts": version_ts,
"created_at": created_at,
"action_result_id": action_result_id,
"snippet_id": snippet_id,
"rag_text": rag_text,
"merged_json": merged_json,
"updated_at": updated_at,
}
rows.append(row)
payloads.append(
RagItemPayload(
id=rag_item_id,
workspaceId=workspace_id,
name=snippet.get("title") or snippet_id,
embeddingData=rag_text,
type=rag_item_type or "SNIPPET",
)
)
return rows, payloads
def _upsert_rag_snippet_rows(engine: Engine, rows: Sequence[Dict[str, Any]]) -> None:
if not rows:
return
delete_sql = text("DELETE FROM rag_snippet WHERE rag_item_id=:rag_item_id")
insert_sql = text(
"""
INSERT INTO rag_snippet (
rag_item_id,
workspace_id,
table_id,
version_ts,
created_at,
action_result_id,
snippet_id,
rag_text,
merged_json,
updated_at
) VALUES (
:rag_item_id,
:workspace_id,
:table_id,
:version_ts,
:created_at,
:action_result_id,
:snippet_id,
:rag_text,
:merged_json,
:updated_at
)
"""
)
with engine.begin() as conn:
for row in rows:
conn.execute(delete_sql, row)
conn.execute(insert_sql, row)
async def ingest_snippet_rag_from_db(
table_id: int,
version_ts: int,
*,
workspace_id: int,
rag_item_type: str = "SNIPPET",
client,
engine: Optional[Engine] = None,
rag_client: Optional[RagAPIClient] = None,
) -> List[int]:
"""
Merge snippet + alias JSON from action_results, persist to rag_snippet, then push to RAG via addBatch.
Returns list of rag_item_id ingested.
"""
engine = engine or get_engine()
snippets = merge_snippet_records_from_db(table_id, version_ts, engine=engine)
if not snippets:
logger.info(
"No snippets available for RAG ingestion (table_id=%s version_ts=%s)",
table_id,
version_ts,
)
return []
rows, payloads = _prepare_rag_payloads(
snippets,
table_id=table_id,
version_ts=version_ts,
workspace_id=workspace_id,
rag_item_type=rag_item_type,
)
_upsert_rag_snippet_rows(engine, rows)
rag_client = rag_client or RagAPIClient()
await rag_client.add_batch(client, payloads)
return [row["rag_item_id"] for row in rows]