158 lines
4.9 KiB
Python
158 lines
4.9 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from datetime import datetime
|
|
|
|
import httpx
|
|
import pytest
|
|
from sqlalchemy import create_engine, text
|
|
|
|
from app.services.table_snippet import ingest_snippet_rag_from_db
|
|
|
|
|
|
def _setup_sqlite_engine():
|
|
engine = create_engine("sqlite://")
|
|
with engine.begin() as conn:
|
|
conn.execute(
|
|
text(
|
|
"""
|
|
CREATE TABLE action_results (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
table_id INTEGER,
|
|
version_ts INTEGER,
|
|
action_type TEXT,
|
|
status TEXT,
|
|
snippet_json TEXT,
|
|
snippet_alias_json TEXT,
|
|
updated_at TEXT
|
|
)
|
|
"""
|
|
)
|
|
)
|
|
conn.execute(
|
|
text(
|
|
"""
|
|
CREATE TABLE rag_snippet (
|
|
rag_item_id INTEGER PRIMARY KEY,
|
|
action_result_id INTEGER NOT NULL,
|
|
workspace_id INTEGER,
|
|
table_id INTEGER,
|
|
version_ts INTEGER,
|
|
created_at TEXT,
|
|
snippet_id TEXT,
|
|
rag_text TEXT,
|
|
merged_json TEXT,
|
|
updated_at TEXT
|
|
)
|
|
"""
|
|
)
|
|
)
|
|
return engine
|
|
|
|
|
|
def _insert_action_row(engine, payload: dict) -> None:
|
|
with engine.begin() as conn:
|
|
conn.execute(
|
|
text(
|
|
"""
|
|
INSERT INTO action_results (table_id, version_ts, action_type, status, snippet_json, snippet_alias_json, updated_at)
|
|
VALUES (:table_id, :version_ts, :action_type, :status, :snippet_json, :snippet_alias_json, :updated_at)
|
|
"""
|
|
),
|
|
{
|
|
"table_id": payload["table_id"],
|
|
"version_ts": payload["version_ts"],
|
|
"action_type": payload["action_type"],
|
|
"status": payload.get("status", "success"),
|
|
"snippet_json": json.dumps(payload.get("snippet_json"), ensure_ascii=False)
|
|
if payload.get("snippet_json") is not None
|
|
else None,
|
|
"snippet_alias_json": json.dumps(payload.get("snippet_alias_json"), ensure_ascii=False)
|
|
if payload.get("snippet_alias_json") is not None
|
|
else None,
|
|
"updated_at": payload.get("updated_at") or datetime.utcnow().isoformat(),
|
|
},
|
|
)
|
|
|
|
|
|
class _StubRagClient:
|
|
def __init__(self) -> None:
|
|
self.received = None
|
|
|
|
async def add_batch(self, _client, items):
|
|
self.received = items
|
|
return {"count": len(items)}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ingest_snippet_rag_from_db_persists_and_calls_rag_client() -> None:
|
|
engine = _setup_sqlite_engine()
|
|
table_id = 321
|
|
version_ts = 20240102000000
|
|
|
|
snippet_payload = [
|
|
{
|
|
"id": "snpt_topn",
|
|
"title": "TopN",
|
|
"aliases": [{"text": "站点水表排行前N", "tone": "中性"}],
|
|
"keywords": ["TopN", "站点"],
|
|
}
|
|
]
|
|
alias_payload = [
|
|
{
|
|
"id": "snpt_topn",
|
|
"aliases": [
|
|
{"text": "站点水表排行前N", "tone": "中性"},
|
|
{"text": "按站点水表TopN", "tone": "专业"},
|
|
],
|
|
"keywords": ["TopN", "排行"],
|
|
"intent_tags": ["topn", "aggregate"],
|
|
},
|
|
{
|
|
"id": "snpt_extra",
|
|
"aliases": [{"text": "额外别名"}],
|
|
"keywords": ["extra"],
|
|
},
|
|
]
|
|
|
|
_insert_action_row(
|
|
engine,
|
|
{
|
|
"table_id": table_id,
|
|
"version_ts": version_ts,
|
|
"action_type": "snippet_alias",
|
|
"snippet_json": snippet_payload,
|
|
"snippet_alias_json": alias_payload,
|
|
"updated_at": "2024-01-02T00:00:00",
|
|
},
|
|
)
|
|
|
|
rag_stub = _StubRagClient()
|
|
async with httpx.AsyncClient() as client:
|
|
rag_ids = await ingest_snippet_rag_from_db(
|
|
table_id=table_id,
|
|
version_ts=version_ts,
|
|
workspace_id=99,
|
|
rag_item_type="SNIPPET",
|
|
client=client,
|
|
engine=engine,
|
|
rag_client=rag_stub,
|
|
)
|
|
|
|
assert rag_stub.received is not None
|
|
assert len(rag_stub.received) == 2 # includes alias-only row
|
|
assert len(rag_ids) == 2
|
|
|
|
with engine.connect() as conn:
|
|
rows = list(
|
|
conn.execute(
|
|
text("SELECT snippet_id, action_result_id, rag_text, merged_json FROM rag_snippet ORDER BY snippet_id")
|
|
)
|
|
)
|
|
assert {row[0] for row in rows} == {"snpt_extra", "snpt_topn"}
|
|
assert all(row[1] is not None for row in rows)
|
|
topn_row = next(row for row in rows if row[0] == "snpt_topn")
|
|
assert "TopN" in topn_row[2]
|
|
assert "按站点水表TopN" in topn_row[2]
|
|
assert "排行" in topn_row[2]
|