rag snippet生成入库和写rag
This commit is contained in:
83
app/services/rag_client.py
Normal file
83
app/services/rag_client.py
Normal file
@ -0,0 +1,83 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Sequence
|
||||
|
||||
import httpx
|
||||
|
||||
from app.exceptions import ProviderAPICallError
|
||||
from app.schemas.rag import RagDeleteRequest, RagItemPayload, RagRetrieveRequest
|
||||
from app.settings import RAG_API_AUTH_TOKEN, RAG_API_BASE_URL
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RagAPIClient:
|
||||
"""Thin async client wrapper around the RAG endpoints described in doc/rag-api.md."""
|
||||
|
||||
def __init__(self, *, base_url: str | None = None, auth_token: str | None = None) -> None:
|
||||
resolved_base = base_url or RAG_API_BASE_URL
|
||||
self._base_url = resolved_base.rstrip("/")
|
||||
self._auth_token = auth_token or RAG_API_AUTH_TOKEN
|
||||
|
||||
def _headers(self) -> dict[str, str]:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self._auth_token:
|
||||
headers["Authorization"] = f"Bearer {self._auth_token}"
|
||||
return headers
|
||||
|
||||
async def _post(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
path: str,
|
||||
payload: Any,
|
||||
) -> Any:
|
||||
url = f"{self._base_url}{path}"
|
||||
try:
|
||||
response = await client.post(url, json=payload, headers=self._headers())
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
status_code = exc.response.status_code if exc.response else None
|
||||
response_text = exc.response.text if exc.response else ""
|
||||
logger.error(
|
||||
"RAG API responded with an error (%s) for %s: %s",
|
||||
status_code,
|
||||
url,
|
||||
response_text,
|
||||
exc_info=True,
|
||||
)
|
||||
raise ProviderAPICallError(
|
||||
"RAG API call failed.",
|
||||
status_code=status_code,
|
||||
response_text=response_text,
|
||||
) from exc
|
||||
except httpx.HTTPError as exc:
|
||||
logger.error("Transport error calling RAG API %s: %s", url, exc, exc_info=True)
|
||||
raise ProviderAPICallError(f"RAG API call failed: {exc}") from exc
|
||||
|
||||
try:
|
||||
return response.json()
|
||||
except ValueError:
|
||||
logger.warning("RAG API returned non-JSON response for %s; returning raw text.", url)
|
||||
return {"raw": response.text}
|
||||
|
||||
async def add(self, client: httpx.AsyncClient, payload: RagItemPayload) -> Any:
|
||||
body = payload.model_dump(by_alias=True, exclude_none=True)
|
||||
return await self._post(client, "/rag/add", body)
|
||||
|
||||
async def add_batch(self, client: httpx.AsyncClient, items: Sequence[RagItemPayload]) -> Any:
|
||||
body = [item.model_dump(by_alias=True, exclude_none=True) for item in items]
|
||||
return await self._post(client, "/rag/addBatch", body)
|
||||
|
||||
async def update(self, client: httpx.AsyncClient, payload: RagItemPayload) -> Any:
|
||||
body = payload.model_dump(by_alias=True, exclude_none=True)
|
||||
return await self._post(client, "/rag/update", body)
|
||||
|
||||
async def delete(self, client: httpx.AsyncClient, payload: RagDeleteRequest) -> Any:
|
||||
body = payload.model_dump(by_alias=True, exclude_none=True)
|
||||
return await self._post(client, "/rag/delete", body)
|
||||
|
||||
async def retrieve(self, client: httpx.AsyncClient, payload: RagRetrieveRequest) -> Any:
|
||||
body = payload.model_dump(by_alias=True, exclude_none=True)
|
||||
return await self._post(client, "/rag/retrieve", body)
|
||||
@ -1,19 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, Tuple
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from app.db import get_engine
|
||||
from app.models import (
|
||||
ActionType,
|
||||
TableSnippetUpsertRequest,
|
||||
TableSnippetUpsertResponse,
|
||||
)
|
||||
from app.models import ActionType, TableSnippetUpsertRequest, TableSnippetUpsertResponse
|
||||
from app.schemas.rag import RagItemPayload
|
||||
from app.services.rag_client import RagAPIClient
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -46,6 +46,7 @@ def _prepare_model_params(params: Dict[str, Any] | None) -> str | None:
|
||||
|
||||
|
||||
def _collect_common_columns(request: TableSnippetUpsertRequest) -> Dict[str, Any]:
|
||||
# Build the base column set shared by all action types; action-specific fields are populated later.
|
||||
logger.debug(
|
||||
"Collecting common columns for table_id=%s version_ts=%s action_type=%s",
|
||||
request.table_id,
|
||||
@ -215,3 +216,405 @@ def upsert_action_result(request: TableSnippetUpsertRequest) -> TableSnippetUpse
|
||||
status=request.status,
|
||||
updated=updated,
|
||||
)
|
||||
|
||||
|
||||
def _decode_json_field(value: Any) -> Any:
|
||||
"""Decode JSON columns that may be returned as str/bytes/dicts/lists."""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, (dict, list)):
|
||||
return value
|
||||
if isinstance(value, (bytes, bytearray)):
|
||||
try:
|
||||
value = value.decode("utf-8")
|
||||
except Exception: # pragma: no cover - defensive
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Failed to decode JSON field: %s", value)
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _coerce_json_array(value: Any) -> List[Any]:
|
||||
decoded = _decode_json_field(value)
|
||||
return decoded if isinstance(decoded, list) else []
|
||||
|
||||
|
||||
def _fetch_action_payload(
|
||||
engine: Engine, table_id: int, version_ts: int, action_type: ActionType
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
sql = text(
|
||||
"""
|
||||
SELECT id AS action_result_id, snippet_json, snippet_alias_json, updated_at, status
|
||||
FROM action_results
|
||||
WHERE table_id = :table_id
|
||||
AND version_ts = :version_ts
|
||||
AND action_type = :action_type
|
||||
AND status IN ('success', 'partial')
|
||||
ORDER BY CASE status WHEN 'success' THEN 0 ELSE 1 END, updated_at DESC
|
||||
LIMIT 1
|
||||
"""
|
||||
)
|
||||
with engine.connect() as conn:
|
||||
row = conn.execute(
|
||||
sql,
|
||||
{
|
||||
"table_id": table_id,
|
||||
"version_ts": version_ts,
|
||||
"action_type": action_type.value,
|
||||
},
|
||||
).mappings().first()
|
||||
return dict(row) if row else None
|
||||
|
||||
|
||||
def _load_snippet_sources(
|
||||
engine: Engine, table_id: int, version_ts: int
|
||||
) -> Tuple[List[Any], List[Any], Optional[datetime], Optional[int], Optional[int]]:
|
||||
alias_row = _fetch_action_payload(engine, table_id, version_ts, ActionType.SNIPPET_ALIAS)
|
||||
snippet_row = _fetch_action_payload(engine, table_id, version_ts, ActionType.SNIPPET)
|
||||
|
||||
snippet_json = _coerce_json_array(alias_row.get("snippet_json") if alias_row else None)
|
||||
alias_json = _coerce_json_array(alias_row.get("snippet_alias_json") if alias_row else None)
|
||||
updated_at: Optional[datetime] = alias_row.get("updated_at") if alias_row else None
|
||||
alias_action_id: Optional[int] = alias_row.get("action_result_id") if alias_row else None
|
||||
snippet_action_id: Optional[int] = snippet_row.get("action_result_id") if snippet_row else None
|
||||
|
||||
if not snippet_json and snippet_row:
|
||||
snippet_json = _coerce_json_array(snippet_row.get("snippet_json"))
|
||||
if updated_at is None:
|
||||
updated_at = snippet_row.get("updated_at")
|
||||
if alias_action_id is None:
|
||||
alias_action_id = snippet_action_id
|
||||
|
||||
if not updated_at and alias_row:
|
||||
updated_at = alias_row.get("updated_at")
|
||||
|
||||
return snippet_json, alias_json, updated_at, alias_action_id, snippet_action_id
|
||||
|
||||
|
||||
def _normalize_aliases(raw_aliases: Any) -> List[Dict[str, Any]]:
|
||||
aliases: List[Dict[str, Any]] = []
|
||||
seen: set[str] = set()
|
||||
if not raw_aliases:
|
||||
return aliases
|
||||
if not isinstance(raw_aliases, list):
|
||||
return aliases
|
||||
for item in raw_aliases:
|
||||
if isinstance(item, dict):
|
||||
text_val = item.get("text")
|
||||
if not text_val or text_val in seen:
|
||||
continue
|
||||
seen.add(text_val)
|
||||
aliases.append({"text": text_val, "tone": item.get("tone")})
|
||||
elif isinstance(item, str):
|
||||
if item in seen:
|
||||
continue
|
||||
seen.add(item)
|
||||
aliases.append({"text": item})
|
||||
return aliases
|
||||
|
||||
|
||||
def _normalize_str_list(values: Any) -> List[str]:
|
||||
if not values:
|
||||
return []
|
||||
if not isinstance(values, list):
|
||||
return []
|
||||
seen: set[str] = set()
|
||||
normalised: List[str] = []
|
||||
for val in values:
|
||||
if not isinstance(val, str):
|
||||
continue
|
||||
if val in seen:
|
||||
continue
|
||||
seen.add(val)
|
||||
normalised.append(val)
|
||||
return normalised
|
||||
|
||||
|
||||
def _merge_alias_lists(primary: List[Dict[str, Any]], secondary: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
merged: List[Dict[str, Any]] = []
|
||||
seen: set[str] = set()
|
||||
for source in (primary, secondary):
|
||||
for item in source:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
text_val = item.get("text")
|
||||
if not text_val or text_val in seen:
|
||||
continue
|
||||
seen.add(text_val)
|
||||
merged.append({"text": text_val, "tone": item.get("tone")})
|
||||
return merged
|
||||
|
||||
|
||||
def _merge_str_lists(primary: List[str], secondary: List[str]) -> List[str]:
|
||||
merged: List[str] = []
|
||||
seen: set[str] = set()
|
||||
for source in (primary, secondary):
|
||||
for item in source:
|
||||
if item in seen:
|
||||
continue
|
||||
seen.add(item)
|
||||
merged.append(item)
|
||||
return merged
|
||||
|
||||
|
||||
def _build_alias_map(alias_payload: List[Any]) -> Dict[str, Dict[str, Any]]:
|
||||
alias_map: Dict[str, Dict[str, Any]] = {}
|
||||
for item in alias_payload:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
alias_id = item.get("id")
|
||||
if not alias_id:
|
||||
continue
|
||||
existing = alias_map.setdefault(
|
||||
alias_id,
|
||||
{"aliases": [], "keywords": [], "intent_tags": []},
|
||||
)
|
||||
existing["aliases"] = _merge_alias_lists(
|
||||
existing["aliases"], _normalize_aliases(item.get("aliases"))
|
||||
)
|
||||
existing["keywords"] = _merge_str_lists(
|
||||
existing["keywords"], _normalize_str_list(item.get("keywords"))
|
||||
)
|
||||
existing["intent_tags"] = _merge_str_lists(
|
||||
existing["intent_tags"], _normalize_str_list(item.get("intent_tags"))
|
||||
)
|
||||
return alias_map
|
||||
|
||||
|
||||
def merge_snippet_records_from_db(
|
||||
table_id: int,
|
||||
version_ts: int,
|
||||
*,
|
||||
engine: Optional[Engine] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Load snippet + snippet_alias JSON from action_results after snippet_alias is stored,
|
||||
then merge into a unified snippet object list ready for downstream RAG.
|
||||
"""
|
||||
engine = engine or get_engine()
|
||||
snippets, aliases, updated_at, alias_action_id, snippet_action_id = _load_snippet_sources(
|
||||
engine, table_id, version_ts
|
||||
)
|
||||
alias_map = _build_alias_map(aliases)
|
||||
|
||||
merged: List[Dict[str, Any]] = []
|
||||
seen_ids: set[str] = set()
|
||||
|
||||
for snippet in snippets:
|
||||
if not isinstance(snippet, dict):
|
||||
continue
|
||||
snippet_id = snippet.get("id")
|
||||
if not snippet_id:
|
||||
continue
|
||||
alias_info = alias_map.get(snippet_id)
|
||||
record = dict(snippet)
|
||||
record_aliases = _normalize_aliases(record.get("aliases"))
|
||||
record_keywords = _normalize_str_list(record.get("keywords"))
|
||||
record_intents = _normalize_str_list(record.get("intent_tags"))
|
||||
|
||||
if alias_info:
|
||||
record_aliases = _merge_alias_lists(record_aliases, alias_info["aliases"])
|
||||
record_keywords = _merge_str_lists(record_keywords, alias_info["keywords"])
|
||||
record_intents = _merge_str_lists(record_intents, alias_info["intent_tags"])
|
||||
|
||||
record["aliases"] = record_aliases
|
||||
record["keywords"] = record_keywords
|
||||
record["intent_tags"] = record_intents
|
||||
record["table_id"] = table_id
|
||||
record["version_ts"] = version_ts
|
||||
record["updated_at_from_action"] = updated_at
|
||||
record["source"] = "snippet"
|
||||
record["action_result_id"] = alias_action_id or snippet_action_id
|
||||
merged.append(record)
|
||||
seen_ids.add(snippet_id)
|
||||
|
||||
for alias_id, alias_info in alias_map.items():
|
||||
if alias_id in seen_ids:
|
||||
continue
|
||||
merged.append(
|
||||
{
|
||||
"id": alias_id,
|
||||
"aliases": alias_info["aliases"],
|
||||
"keywords": alias_info["keywords"],
|
||||
"intent_tags": alias_info["intent_tags"],
|
||||
"table_id": table_id,
|
||||
"version_ts": version_ts,
|
||||
"updated_at_from_action": updated_at,
|
||||
"source": "alias_only",
|
||||
"action_result_id": alias_action_id or snippet_action_id,
|
||||
}
|
||||
)
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
def _stable_rag_item_id(table_id: int, version_ts: int, snippet_id: str) -> int:
|
||||
digest = hashlib.md5(f"{table_id}:{version_ts}:{snippet_id}".encode("utf-8")).hexdigest()
|
||||
return int(digest[:16], 16) % 9_000_000_000_000_000_000
|
||||
|
||||
|
||||
def _build_rag_text(snippet: Dict[str, Any]) -> str:
|
||||
# Deterministic text concatenation for embedding input.
|
||||
parts: List[str] = []
|
||||
|
||||
def _add(label: str, value: Any) -> None:
|
||||
if value is None:
|
||||
return
|
||||
if isinstance(value, list):
|
||||
value = ", ".join([str(v) for v in value if v])
|
||||
elif isinstance(value, dict):
|
||||
value = json.dumps(value, ensure_ascii=False)
|
||||
if value:
|
||||
parts.append(f"{label}: {value}")
|
||||
|
||||
_add("Title", snippet.get("title") or snippet.get("id"))
|
||||
_add("Description", snippet.get("desc"))
|
||||
_add("Business", snippet.get("business_caliber"))
|
||||
_add("Type", snippet.get("type"))
|
||||
_add("Examples", snippet.get("examples") or [])
|
||||
_add("Aliases", [a.get("text") for a in snippet.get("aliases") or [] if isinstance(a, dict)])
|
||||
_add("Keywords", snippet.get("keywords") or [])
|
||||
_add("IntentTags", snippet.get("intent_tags") or [])
|
||||
_add("Applicability", snippet.get("applicability"))
|
||||
_add("DialectSQL", snippet.get("dialect_sql"))
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _prepare_rag_payloads(
|
||||
snippets: List[Dict[str, Any]],
|
||||
table_id: int,
|
||||
version_ts: int,
|
||||
workspace_id: int,
|
||||
rag_item_type: str = "SNIPPET",
|
||||
) -> Tuple[List[Dict[str, Any]], List[RagItemPayload]]:
|
||||
rows: List[Dict[str, Any]] = []
|
||||
payloads: List[RagItemPayload] = []
|
||||
now = datetime.utcnow()
|
||||
|
||||
for snippet in snippets:
|
||||
snippet_id = snippet.get("id")
|
||||
if not snippet_id:
|
||||
continue
|
||||
action_result_id = snippet.get("action_result_id")
|
||||
if action_result_id is None:
|
||||
logger.warning(
|
||||
"Skipping snippet without action_result_id for RAG ingestion (table_id=%s version_ts=%s snippet_id=%s)",
|
||||
table_id,
|
||||
version_ts,
|
||||
snippet_id,
|
||||
)
|
||||
continue
|
||||
rag_item_id = _stable_rag_item_id(table_id, version_ts, snippet_id)
|
||||
rag_text = _build_rag_text(snippet)
|
||||
merged_json = json.dumps(snippet, ensure_ascii=False)
|
||||
updated_at_raw = snippet.get("updated_at_from_action") or now
|
||||
if isinstance(updated_at_raw, str):
|
||||
try:
|
||||
updated_at = datetime.fromisoformat(updated_at_raw)
|
||||
except ValueError:
|
||||
updated_at = now
|
||||
else:
|
||||
updated_at = updated_at_raw if isinstance(updated_at_raw, datetime) else now
|
||||
|
||||
row = {
|
||||
"rag_item_id": rag_item_id,
|
||||
"workspace_id": workspace_id,
|
||||
"table_id": table_id,
|
||||
"version_ts": version_ts,
|
||||
"action_result_id": action_result_id,
|
||||
"snippet_id": snippet_id,
|
||||
"rag_text": rag_text,
|
||||
"merged_json": merged_json,
|
||||
"updated_at": updated_at,
|
||||
}
|
||||
rows.append(row)
|
||||
|
||||
payloads.append(
|
||||
RagItemPayload(
|
||||
id=rag_item_id,
|
||||
workspaceId=workspace_id,
|
||||
name=snippet.get("title") or snippet_id,
|
||||
embeddingData=rag_text,
|
||||
type=rag_item_type or "SNIPPET",
|
||||
)
|
||||
)
|
||||
|
||||
return rows, payloads
|
||||
|
||||
|
||||
def _upsert_rag_snippet_rows(engine: Engine, rows: Sequence[Dict[str, Any]]) -> None:
|
||||
if not rows:
|
||||
return
|
||||
delete_sql = text("DELETE FROM rag_snippet WHERE rag_item_id=:rag_item_id")
|
||||
insert_sql = text(
|
||||
"""
|
||||
INSERT INTO rag_snippet (
|
||||
rag_item_id,
|
||||
workspace_id,
|
||||
table_id,
|
||||
version_ts,
|
||||
action_result_id,
|
||||
snippet_id,
|
||||
rag_text,
|
||||
merged_json,
|
||||
updated_at
|
||||
) VALUES (
|
||||
:rag_item_id,
|
||||
:workspace_id,
|
||||
:table_id,
|
||||
:version_ts,
|
||||
:action_result_id,
|
||||
:snippet_id,
|
||||
:rag_text,
|
||||
:merged_json,
|
||||
:updated_at
|
||||
)
|
||||
"""
|
||||
)
|
||||
with engine.begin() as conn:
|
||||
for row in rows:
|
||||
conn.execute(delete_sql, row)
|
||||
conn.execute(insert_sql, row)
|
||||
|
||||
|
||||
async def ingest_snippet_rag_from_db(
|
||||
table_id: int,
|
||||
version_ts: int,
|
||||
*,
|
||||
workspace_id: int,
|
||||
rag_item_type: str = "SNIPPET",
|
||||
client,
|
||||
engine: Optional[Engine] = None,
|
||||
rag_client: Optional[RagAPIClient] = None,
|
||||
) -> List[int]:
|
||||
"""
|
||||
Merge snippet + alias JSON from action_results, persist to rag_snippet, then push to RAG via addBatch.
|
||||
Returns list of rag_item_id ingested.
|
||||
"""
|
||||
engine = engine or get_engine()
|
||||
snippets = merge_snippet_records_from_db(table_id, version_ts, engine=engine)
|
||||
if not snippets:
|
||||
logger.info(
|
||||
"No snippets available for RAG ingestion (table_id=%s version_ts=%s)",
|
||||
table_id,
|
||||
version_ts,
|
||||
)
|
||||
return []
|
||||
|
||||
rows, payloads = _prepare_rag_payloads(
|
||||
snippets,
|
||||
table_id=table_id,
|
||||
version_ts=version_ts,
|
||||
workspace_id=workspace_id,
|
||||
rag_item_type=rag_item_type,
|
||||
)
|
||||
|
||||
_upsert_rag_snippet_rows(engine, rows)
|
||||
|
||||
rag_client = rag_client or RagAPIClient()
|
||||
await rag_client.add_batch(client, payloads)
|
||||
return [row["rag_item_id"] for row in rows]
|
||||
|
||||
Reference in New Issue
Block a user