rag snippet生成入库和写rag
This commit is contained in:
156
test/test_snippet_rag_ingest.py
Normal file
156
test/test_snippet_rag_ingest.py
Normal file
@ -0,0 +1,156 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
from app.services.table_snippet import ingest_snippet_rag_from_db
|
||||
|
||||
|
||||
def _setup_sqlite_engine():
|
||||
engine = create_engine("sqlite://")
|
||||
with engine.begin() as conn:
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE action_results (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
table_id INTEGER,
|
||||
version_ts INTEGER,
|
||||
action_type TEXT,
|
||||
status TEXT,
|
||||
snippet_json TEXT,
|
||||
snippet_alias_json TEXT,
|
||||
updated_at TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE rag_snippet (
|
||||
rag_item_id INTEGER PRIMARY KEY,
|
||||
action_result_id INTEGER NOT NULL,
|
||||
workspace_id INTEGER,
|
||||
table_id INTEGER,
|
||||
version_ts INTEGER,
|
||||
snippet_id TEXT,
|
||||
rag_text TEXT,
|
||||
merged_json TEXT,
|
||||
updated_at TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
return engine
|
||||
|
||||
|
||||
def _insert_action_row(engine, payload: dict) -> None:
|
||||
with engine.begin() as conn:
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO action_results (table_id, version_ts, action_type, status, snippet_json, snippet_alias_json, updated_at)
|
||||
VALUES (:table_id, :version_ts, :action_type, :status, :snippet_json, :snippet_alias_json, :updated_at)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"table_id": payload["table_id"],
|
||||
"version_ts": payload["version_ts"],
|
||||
"action_type": payload["action_type"],
|
||||
"status": payload.get("status", "success"),
|
||||
"snippet_json": json.dumps(payload.get("snippet_json"), ensure_ascii=False)
|
||||
if payload.get("snippet_json") is not None
|
||||
else None,
|
||||
"snippet_alias_json": json.dumps(payload.get("snippet_alias_json"), ensure_ascii=False)
|
||||
if payload.get("snippet_alias_json") is not None
|
||||
else None,
|
||||
"updated_at": payload.get("updated_at") or datetime.utcnow().isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class _StubRagClient:
|
||||
def __init__(self) -> None:
|
||||
self.received = None
|
||||
|
||||
async def add_batch(self, _client, items):
|
||||
self.received = items
|
||||
return {"count": len(items)}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_snippet_rag_from_db_persists_and_calls_rag_client() -> None:
|
||||
engine = _setup_sqlite_engine()
|
||||
table_id = 321
|
||||
version_ts = 20240102000000
|
||||
|
||||
snippet_payload = [
|
||||
{
|
||||
"id": "snpt_topn",
|
||||
"title": "TopN",
|
||||
"aliases": [{"text": "站点水表排行前N", "tone": "中性"}],
|
||||
"keywords": ["TopN", "站点"],
|
||||
}
|
||||
]
|
||||
alias_payload = [
|
||||
{
|
||||
"id": "snpt_topn",
|
||||
"aliases": [
|
||||
{"text": "站点水表排行前N", "tone": "中性"},
|
||||
{"text": "按站点水表TopN", "tone": "专业"},
|
||||
],
|
||||
"keywords": ["TopN", "排行"],
|
||||
"intent_tags": ["topn", "aggregate"],
|
||||
},
|
||||
{
|
||||
"id": "snpt_extra",
|
||||
"aliases": [{"text": "额外别名"}],
|
||||
"keywords": ["extra"],
|
||||
},
|
||||
]
|
||||
|
||||
_insert_action_row(
|
||||
engine,
|
||||
{
|
||||
"table_id": table_id,
|
||||
"version_ts": version_ts,
|
||||
"action_type": "snippet_alias",
|
||||
"snippet_json": snippet_payload,
|
||||
"snippet_alias_json": alias_payload,
|
||||
"updated_at": "2024-01-02T00:00:00",
|
||||
},
|
||||
)
|
||||
|
||||
rag_stub = _StubRagClient()
|
||||
async with httpx.AsyncClient() as client:
|
||||
rag_ids = await ingest_snippet_rag_from_db(
|
||||
table_id=table_id,
|
||||
version_ts=version_ts,
|
||||
workspace_id=99,
|
||||
rag_item_type="SNIPPET",
|
||||
client=client,
|
||||
engine=engine,
|
||||
rag_client=rag_stub,
|
||||
)
|
||||
|
||||
assert rag_stub.received is not None
|
||||
assert len(rag_stub.received) == 2 # includes alias-only row
|
||||
assert len(rag_ids) == 2
|
||||
|
||||
with engine.connect() as conn:
|
||||
rows = list(
|
||||
conn.execute(
|
||||
text("SELECT snippet_id, action_result_id, rag_text, merged_json FROM rag_snippet ORDER BY snippet_id")
|
||||
)
|
||||
)
|
||||
assert {row[0] for row in rows} == {"snpt_extra", "snpt_topn"}
|
||||
assert all(row[1] is not None for row in rows)
|
||||
topn_row = next(row for row in rows if row[0] == "snpt_topn")
|
||||
assert "TopN" in topn_row[2]
|
||||
assert "按站点水表TopN" in topn_row[2]
|
||||
assert "排行" in topn_row[2]
|
||||
Reference in New Issue
Block a user