rag snippet生成入库和写rag

This commit is contained in:
zhaoawd
2025-12-09 00:15:22 +08:00
parent ebd79b75bd
commit 3218e51bad
11 changed files with 1231 additions and 8 deletions

View File

@ -0,0 +1,83 @@
from __future__ import annotations
import logging
from typing import Any, Sequence
import httpx
from app.exceptions import ProviderAPICallError
from app.schemas.rag import RagDeleteRequest, RagItemPayload, RagRetrieveRequest
from app.settings import RAG_API_AUTH_TOKEN, RAG_API_BASE_URL
logger = logging.getLogger(__name__)
class RagAPIClient:
"""Thin async client wrapper around the RAG endpoints described in doc/rag-api.md."""
def __init__(self, *, base_url: str | None = None, auth_token: str | None = None) -> None:
resolved_base = base_url or RAG_API_BASE_URL
self._base_url = resolved_base.rstrip("/")
self._auth_token = auth_token or RAG_API_AUTH_TOKEN
def _headers(self) -> dict[str, str]:
headers = {"Content-Type": "application/json"}
if self._auth_token:
headers["Authorization"] = f"Bearer {self._auth_token}"
return headers
async def _post(
self,
client: httpx.AsyncClient,
path: str,
payload: Any,
) -> Any:
url = f"{self._base_url}{path}"
try:
response = await client.post(url, json=payload, headers=self._headers())
response.raise_for_status()
except httpx.HTTPStatusError as exc:
status_code = exc.response.status_code if exc.response else None
response_text = exc.response.text if exc.response else ""
logger.error(
"RAG API responded with an error (%s) for %s: %s",
status_code,
url,
response_text,
exc_info=True,
)
raise ProviderAPICallError(
"RAG API call failed.",
status_code=status_code,
response_text=response_text,
) from exc
except httpx.HTTPError as exc:
logger.error("Transport error calling RAG API %s: %s", url, exc, exc_info=True)
raise ProviderAPICallError(f"RAG API call failed: {exc}") from exc
try:
return response.json()
except ValueError:
logger.warning("RAG API returned non-JSON response for %s; returning raw text.", url)
return {"raw": response.text}
async def add(self, client: httpx.AsyncClient, payload: RagItemPayload) -> Any:
body = payload.model_dump(by_alias=True, exclude_none=True)
return await self._post(client, "/rag/add", body)
async def add_batch(self, client: httpx.AsyncClient, items: Sequence[RagItemPayload]) -> Any:
body = [item.model_dump(by_alias=True, exclude_none=True) for item in items]
return await self._post(client, "/rag/addBatch", body)
async def update(self, client: httpx.AsyncClient, payload: RagItemPayload) -> Any:
body = payload.model_dump(by_alias=True, exclude_none=True)
return await self._post(client, "/rag/update", body)
async def delete(self, client: httpx.AsyncClient, payload: RagDeleteRequest) -> Any:
body = payload.model_dump(by_alias=True, exclude_none=True)
return await self._post(client, "/rag/delete", body)
async def retrieve(self, client: httpx.AsyncClient, payload: RagRetrieveRequest) -> Any:
body = payload.model_dump(by_alias=True, exclude_none=True)
return await self._post(client, "/rag/retrieve", body)

View File

@ -1,19 +1,19 @@
from __future__ import annotations
import hashlib
import json
import logging
from typing import Any, Dict, Tuple
from datetime import datetime
from typing import Any, Dict, List, Optional, Sequence, Tuple
from sqlalchemy import text
from sqlalchemy.engine import Engine
from sqlalchemy.exc import SQLAlchemyError
from app.db import get_engine
from app.models import (
ActionType,
TableSnippetUpsertRequest,
TableSnippetUpsertResponse,
)
from app.models import ActionType, TableSnippetUpsertRequest, TableSnippetUpsertResponse
from app.schemas.rag import RagItemPayload
from app.services.rag_client import RagAPIClient
logger = logging.getLogger(__name__)
@ -46,6 +46,7 @@ def _prepare_model_params(params: Dict[str, Any] | None) -> str | None:
def _collect_common_columns(request: TableSnippetUpsertRequest) -> Dict[str, Any]:
# Build the base column set shared by all action types; action-specific fields are populated later.
logger.debug(
"Collecting common columns for table_id=%s version_ts=%s action_type=%s",
request.table_id,
@ -215,3 +216,405 @@ def upsert_action_result(request: TableSnippetUpsertRequest) -> TableSnippetUpse
status=request.status,
updated=updated,
)
def _decode_json_field(value: Any) -> Any:
"""Decode JSON columns that may be returned as str/bytes/dicts/lists."""
if value is None:
return None
if isinstance(value, (dict, list)):
return value
if isinstance(value, (bytes, bytearray)):
try:
value = value.decode("utf-8")
except Exception: # pragma: no cover - defensive
return None
if isinstance(value, str):
try:
return json.loads(value)
except json.JSONDecodeError:
logger.warning("Failed to decode JSON field: %s", value)
return None
return None
def _coerce_json_array(value: Any) -> List[Any]:
decoded = _decode_json_field(value)
return decoded if isinstance(decoded, list) else []
def _fetch_action_payload(
engine: Engine, table_id: int, version_ts: int, action_type: ActionType
) -> Optional[Dict[str, Any]]:
sql = text(
"""
SELECT id AS action_result_id, snippet_json, snippet_alias_json, updated_at, status
FROM action_results
WHERE table_id = :table_id
AND version_ts = :version_ts
AND action_type = :action_type
AND status IN ('success', 'partial')
ORDER BY CASE status WHEN 'success' THEN 0 ELSE 1 END, updated_at DESC
LIMIT 1
"""
)
with engine.connect() as conn:
row = conn.execute(
sql,
{
"table_id": table_id,
"version_ts": version_ts,
"action_type": action_type.value,
},
).mappings().first()
return dict(row) if row else None
def _load_snippet_sources(
engine: Engine, table_id: int, version_ts: int
) -> Tuple[List[Any], List[Any], Optional[datetime], Optional[int], Optional[int]]:
alias_row = _fetch_action_payload(engine, table_id, version_ts, ActionType.SNIPPET_ALIAS)
snippet_row = _fetch_action_payload(engine, table_id, version_ts, ActionType.SNIPPET)
snippet_json = _coerce_json_array(alias_row.get("snippet_json") if alias_row else None)
alias_json = _coerce_json_array(alias_row.get("snippet_alias_json") if alias_row else None)
updated_at: Optional[datetime] = alias_row.get("updated_at") if alias_row else None
alias_action_id: Optional[int] = alias_row.get("action_result_id") if alias_row else None
snippet_action_id: Optional[int] = snippet_row.get("action_result_id") if snippet_row else None
if not snippet_json and snippet_row:
snippet_json = _coerce_json_array(snippet_row.get("snippet_json"))
if updated_at is None:
updated_at = snippet_row.get("updated_at")
if alias_action_id is None:
alias_action_id = snippet_action_id
if not updated_at and alias_row:
updated_at = alias_row.get("updated_at")
return snippet_json, alias_json, updated_at, alias_action_id, snippet_action_id
def _normalize_aliases(raw_aliases: Any) -> List[Dict[str, Any]]:
aliases: List[Dict[str, Any]] = []
seen: set[str] = set()
if not raw_aliases:
return aliases
if not isinstance(raw_aliases, list):
return aliases
for item in raw_aliases:
if isinstance(item, dict):
text_val = item.get("text")
if not text_val or text_val in seen:
continue
seen.add(text_val)
aliases.append({"text": text_val, "tone": item.get("tone")})
elif isinstance(item, str):
if item in seen:
continue
seen.add(item)
aliases.append({"text": item})
return aliases
def _normalize_str_list(values: Any) -> List[str]:
if not values:
return []
if not isinstance(values, list):
return []
seen: set[str] = set()
normalised: List[str] = []
for val in values:
if not isinstance(val, str):
continue
if val in seen:
continue
seen.add(val)
normalised.append(val)
return normalised
def _merge_alias_lists(primary: List[Dict[str, Any]], secondary: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
merged: List[Dict[str, Any]] = []
seen: set[str] = set()
for source in (primary, secondary):
for item in source:
if not isinstance(item, dict):
continue
text_val = item.get("text")
if not text_val or text_val in seen:
continue
seen.add(text_val)
merged.append({"text": text_val, "tone": item.get("tone")})
return merged
def _merge_str_lists(primary: List[str], secondary: List[str]) -> List[str]:
merged: List[str] = []
seen: set[str] = set()
for source in (primary, secondary):
for item in source:
if item in seen:
continue
seen.add(item)
merged.append(item)
return merged
def _build_alias_map(alias_payload: List[Any]) -> Dict[str, Dict[str, Any]]:
alias_map: Dict[str, Dict[str, Any]] = {}
for item in alias_payload:
if not isinstance(item, dict):
continue
alias_id = item.get("id")
if not alias_id:
continue
existing = alias_map.setdefault(
alias_id,
{"aliases": [], "keywords": [], "intent_tags": []},
)
existing["aliases"] = _merge_alias_lists(
existing["aliases"], _normalize_aliases(item.get("aliases"))
)
existing["keywords"] = _merge_str_lists(
existing["keywords"], _normalize_str_list(item.get("keywords"))
)
existing["intent_tags"] = _merge_str_lists(
existing["intent_tags"], _normalize_str_list(item.get("intent_tags"))
)
return alias_map
def merge_snippet_records_from_db(
table_id: int,
version_ts: int,
*,
engine: Optional[Engine] = None,
) -> List[Dict[str, Any]]:
"""
Load snippet + snippet_alias JSON from action_results after snippet_alias is stored,
then merge into a unified snippet object list ready for downstream RAG.
"""
engine = engine or get_engine()
snippets, aliases, updated_at, alias_action_id, snippet_action_id = _load_snippet_sources(
engine, table_id, version_ts
)
alias_map = _build_alias_map(aliases)
merged: List[Dict[str, Any]] = []
seen_ids: set[str] = set()
for snippet in snippets:
if not isinstance(snippet, dict):
continue
snippet_id = snippet.get("id")
if not snippet_id:
continue
alias_info = alias_map.get(snippet_id)
record = dict(snippet)
record_aliases = _normalize_aliases(record.get("aliases"))
record_keywords = _normalize_str_list(record.get("keywords"))
record_intents = _normalize_str_list(record.get("intent_tags"))
if alias_info:
record_aliases = _merge_alias_lists(record_aliases, alias_info["aliases"])
record_keywords = _merge_str_lists(record_keywords, alias_info["keywords"])
record_intents = _merge_str_lists(record_intents, alias_info["intent_tags"])
record["aliases"] = record_aliases
record["keywords"] = record_keywords
record["intent_tags"] = record_intents
record["table_id"] = table_id
record["version_ts"] = version_ts
record["updated_at_from_action"] = updated_at
record["source"] = "snippet"
record["action_result_id"] = alias_action_id or snippet_action_id
merged.append(record)
seen_ids.add(snippet_id)
for alias_id, alias_info in alias_map.items():
if alias_id in seen_ids:
continue
merged.append(
{
"id": alias_id,
"aliases": alias_info["aliases"],
"keywords": alias_info["keywords"],
"intent_tags": alias_info["intent_tags"],
"table_id": table_id,
"version_ts": version_ts,
"updated_at_from_action": updated_at,
"source": "alias_only",
"action_result_id": alias_action_id or snippet_action_id,
}
)
return merged
def _stable_rag_item_id(table_id: int, version_ts: int, snippet_id: str) -> int:
digest = hashlib.md5(f"{table_id}:{version_ts}:{snippet_id}".encode("utf-8")).hexdigest()
return int(digest[:16], 16) % 9_000_000_000_000_000_000
def _build_rag_text(snippet: Dict[str, Any]) -> str:
# Deterministic text concatenation for embedding input.
parts: List[str] = []
def _add(label: str, value: Any) -> None:
if value is None:
return
if isinstance(value, list):
value = ", ".join([str(v) for v in value if v])
elif isinstance(value, dict):
value = json.dumps(value, ensure_ascii=False)
if value:
parts.append(f"{label}: {value}")
_add("Title", snippet.get("title") or snippet.get("id"))
_add("Description", snippet.get("desc"))
_add("Business", snippet.get("business_caliber"))
_add("Type", snippet.get("type"))
_add("Examples", snippet.get("examples") or [])
_add("Aliases", [a.get("text") for a in snippet.get("aliases") or [] if isinstance(a, dict)])
_add("Keywords", snippet.get("keywords") or [])
_add("IntentTags", snippet.get("intent_tags") or [])
_add("Applicability", snippet.get("applicability"))
_add("DialectSQL", snippet.get("dialect_sql"))
return "\n".join(parts)
def _prepare_rag_payloads(
snippets: List[Dict[str, Any]],
table_id: int,
version_ts: int,
workspace_id: int,
rag_item_type: str = "SNIPPET",
) -> Tuple[List[Dict[str, Any]], List[RagItemPayload]]:
rows: List[Dict[str, Any]] = []
payloads: List[RagItemPayload] = []
now = datetime.utcnow()
for snippet in snippets:
snippet_id = snippet.get("id")
if not snippet_id:
continue
action_result_id = snippet.get("action_result_id")
if action_result_id is None:
logger.warning(
"Skipping snippet without action_result_id for RAG ingestion (table_id=%s version_ts=%s snippet_id=%s)",
table_id,
version_ts,
snippet_id,
)
continue
rag_item_id = _stable_rag_item_id(table_id, version_ts, snippet_id)
rag_text = _build_rag_text(snippet)
merged_json = json.dumps(snippet, ensure_ascii=False)
updated_at_raw = snippet.get("updated_at_from_action") or now
if isinstance(updated_at_raw, str):
try:
updated_at = datetime.fromisoformat(updated_at_raw)
except ValueError:
updated_at = now
else:
updated_at = updated_at_raw if isinstance(updated_at_raw, datetime) else now
row = {
"rag_item_id": rag_item_id,
"workspace_id": workspace_id,
"table_id": table_id,
"version_ts": version_ts,
"action_result_id": action_result_id,
"snippet_id": snippet_id,
"rag_text": rag_text,
"merged_json": merged_json,
"updated_at": updated_at,
}
rows.append(row)
payloads.append(
RagItemPayload(
id=rag_item_id,
workspaceId=workspace_id,
name=snippet.get("title") or snippet_id,
embeddingData=rag_text,
type=rag_item_type or "SNIPPET",
)
)
return rows, payloads
def _upsert_rag_snippet_rows(engine: Engine, rows: Sequence[Dict[str, Any]]) -> None:
if not rows:
return
delete_sql = text("DELETE FROM rag_snippet WHERE rag_item_id=:rag_item_id")
insert_sql = text(
"""
INSERT INTO rag_snippet (
rag_item_id,
workspace_id,
table_id,
version_ts,
action_result_id,
snippet_id,
rag_text,
merged_json,
updated_at
) VALUES (
:rag_item_id,
:workspace_id,
:table_id,
:version_ts,
:action_result_id,
:snippet_id,
:rag_text,
:merged_json,
:updated_at
)
"""
)
with engine.begin() as conn:
for row in rows:
conn.execute(delete_sql, row)
conn.execute(insert_sql, row)
async def ingest_snippet_rag_from_db(
table_id: int,
version_ts: int,
*,
workspace_id: int,
rag_item_type: str = "SNIPPET",
client,
engine: Optional[Engine] = None,
rag_client: Optional[RagAPIClient] = None,
) -> List[int]:
"""
Merge snippet + alias JSON from action_results, persist to rag_snippet, then push to RAG via addBatch.
Returns list of rag_item_id ingested.
"""
engine = engine or get_engine()
snippets = merge_snippet_records_from_db(table_id, version_ts, engine=engine)
if not snippets:
logger.info(
"No snippets available for RAG ingestion (table_id=%s version_ts=%s)",
table_id,
version_ts,
)
return []
rows, payloads = _prepare_rag_payloads(
snippets,
table_id=table_id,
version_ts=version_ts,
workspace_id=workspace_id,
rag_item_type=rag_item_type,
)
_upsert_rag_snippet_rows(engine, rows)
rag_client = rag_client or RagAPIClient()
await rag_client.add_batch(client, payloads)
return [row["rag_item_id"] for row in rows]