Files
data-ge/test/test_snippet_rag_ingest.py
2025-12-09 00:36:02 +08:00

158 lines
4.9 KiB
Python

from __future__ import annotations
import json
from datetime import datetime
import httpx
import pytest
from sqlalchemy import create_engine, text
from app.services.table_snippet import ingest_snippet_rag_from_db
def _setup_sqlite_engine():
engine = create_engine("sqlite://")
with engine.begin() as conn:
conn.execute(
text(
"""
CREATE TABLE action_results (
id INTEGER PRIMARY KEY AUTOINCREMENT,
table_id INTEGER,
version_ts INTEGER,
action_type TEXT,
status TEXT,
snippet_json TEXT,
snippet_alias_json TEXT,
updated_at TEXT
)
"""
)
)
conn.execute(
text(
"""
CREATE TABLE rag_snippet (
rag_item_id INTEGER PRIMARY KEY,
action_result_id INTEGER NOT NULL,
workspace_id INTEGER,
table_id INTEGER,
version_ts INTEGER,
created_at TEXT,
snippet_id TEXT,
rag_text TEXT,
merged_json TEXT,
updated_at TEXT
)
"""
)
)
return engine
def _insert_action_row(engine, payload: dict) -> None:
with engine.begin() as conn:
conn.execute(
text(
"""
INSERT INTO action_results (table_id, version_ts, action_type, status, snippet_json, snippet_alias_json, updated_at)
VALUES (:table_id, :version_ts, :action_type, :status, :snippet_json, :snippet_alias_json, :updated_at)
"""
),
{
"table_id": payload["table_id"],
"version_ts": payload["version_ts"],
"action_type": payload["action_type"],
"status": payload.get("status", "success"),
"snippet_json": json.dumps(payload.get("snippet_json"), ensure_ascii=False)
if payload.get("snippet_json") is not None
else None,
"snippet_alias_json": json.dumps(payload.get("snippet_alias_json"), ensure_ascii=False)
if payload.get("snippet_alias_json") is not None
else None,
"updated_at": payload.get("updated_at") or datetime.utcnow().isoformat(),
},
)
class _StubRagClient:
def __init__(self) -> None:
self.received = None
async def add_batch(self, _client, items):
self.received = items
return {"count": len(items)}
@pytest.mark.asyncio
async def test_ingest_snippet_rag_from_db_persists_and_calls_rag_client() -> None:
engine = _setup_sqlite_engine()
table_id = 321
version_ts = 20240102000000
snippet_payload = [
{
"id": "snpt_topn",
"title": "TopN",
"aliases": [{"text": "站点水表排行前N", "tone": "中性"}],
"keywords": ["TopN", "站点"],
}
]
alias_payload = [
{
"id": "snpt_topn",
"aliases": [
{"text": "站点水表排行前N", "tone": "中性"},
{"text": "按站点水表TopN", "tone": "专业"},
],
"keywords": ["TopN", "排行"],
"intent_tags": ["topn", "aggregate"],
},
{
"id": "snpt_extra",
"aliases": [{"text": "额外别名"}],
"keywords": ["extra"],
},
]
_insert_action_row(
engine,
{
"table_id": table_id,
"version_ts": version_ts,
"action_type": "snippet_alias",
"snippet_json": snippet_payload,
"snippet_alias_json": alias_payload,
"updated_at": "2024-01-02T00:00:00",
},
)
rag_stub = _StubRagClient()
async with httpx.AsyncClient() as client:
rag_ids = await ingest_snippet_rag_from_db(
table_id=table_id,
version_ts=version_ts,
workspace_id=99,
rag_item_type="SNIPPET",
client=client,
engine=engine,
rag_client=rag_stub,
)
assert rag_stub.received is not None
assert len(rag_stub.received) == 2 # includes alias-only row
assert len(rag_ids) == 2
with engine.connect() as conn:
rows = list(
conn.execute(
text("SELECT snippet_id, action_result_id, rag_text, merged_json FROM rag_snippet ORDER BY snippet_id")
)
)
assert {row[0] for row in rows} == {"snpt_extra", "snpt_topn"}
assert all(row[1] is not None for row in rows)
topn_row = next(row for row in rows if row[0] == "snpt_topn")
assert "TopN" in topn_row[2]
assert "按站点水表TopN" in topn_row[2]
assert "排行" in topn_row[2]