from __future__ import annotations import json from datetime import datetime import httpx import pytest from sqlalchemy import create_engine, text from app.services.table_snippet import ingest_snippet_rag_from_db def _setup_sqlite_engine(): engine = create_engine("sqlite://") with engine.begin() as conn: conn.execute( text( """ CREATE TABLE action_results ( id INTEGER PRIMARY KEY AUTOINCREMENT, table_id INTEGER, version_ts INTEGER, action_type TEXT, status TEXT, snippet_json TEXT, snippet_alias_json TEXT, updated_at TEXT ) """ ) ) conn.execute( text( """ CREATE TABLE rag_snippet ( rag_item_id INTEGER PRIMARY KEY, action_result_id INTEGER NOT NULL, workspace_id INTEGER, table_id INTEGER, version_ts INTEGER, created_at TEXT, snippet_id TEXT, rag_text TEXT, merged_json TEXT, updated_at TEXT ) """ ) ) return engine def _insert_action_row(engine, payload: dict) -> None: with engine.begin() as conn: conn.execute( text( """ INSERT INTO action_results (table_id, version_ts, action_type, status, snippet_json, snippet_alias_json, updated_at) VALUES (:table_id, :version_ts, :action_type, :status, :snippet_json, :snippet_alias_json, :updated_at) """ ), { "table_id": payload["table_id"], "version_ts": payload["version_ts"], "action_type": payload["action_type"], "status": payload.get("status", "success"), "snippet_json": json.dumps(payload.get("snippet_json"), ensure_ascii=False) if payload.get("snippet_json") is not None else None, "snippet_alias_json": json.dumps(payload.get("snippet_alias_json"), ensure_ascii=False) if payload.get("snippet_alias_json") is not None else None, "updated_at": payload.get("updated_at") or datetime.utcnow().isoformat(), }, ) class _StubRagClient: def __init__(self) -> None: self.received = None async def add_batch(self, _client, items): self.received = items return {"count": len(items)} @pytest.mark.asyncio async def test_ingest_snippet_rag_from_db_persists_and_calls_rag_client() -> None: engine = _setup_sqlite_engine() table_id = 321 version_ts = 20240102000000 snippet_payload = [ { "id": "snpt_topn", "title": "TopN", "aliases": [{"text": "站点水表排行前N", "tone": "中性"}], "keywords": ["TopN", "站点"], } ] alias_payload = [ { "id": "snpt_topn", "aliases": [ {"text": "站点水表排行前N", "tone": "中性"}, {"text": "按站点水表TopN", "tone": "专业"}, ], "keywords": ["TopN", "排行"], "intent_tags": ["topn", "aggregate"], }, { "id": "snpt_extra", "aliases": [{"text": "额外别名"}], "keywords": ["extra"], }, ] _insert_action_row( engine, { "table_id": table_id, "version_ts": version_ts, "action_type": "snippet_alias", "snippet_json": snippet_payload, "snippet_alias_json": alias_payload, "updated_at": "2024-01-02T00:00:00", }, ) rag_stub = _StubRagClient() async with httpx.AsyncClient() as client: rag_ids = await ingest_snippet_rag_from_db( table_id=table_id, version_ts=version_ts, workspace_id=99, rag_item_type="SNIPPET", client=client, engine=engine, rag_client=rag_stub, ) assert rag_stub.received is not None assert len(rag_stub.received) == 2 # includes alias-only row assert len(rag_ids) == 2 with engine.connect() as conn: rows = list( conn.execute( text("SELECT snippet_id, action_result_id, rag_text, merged_json FROM rag_snippet ORDER BY snippet_id") ) ) assert {row[0] for row in rows} == {"snpt_extra", "snpt_topn"} assert all(row[1] is not None for row in rows) topn_row = next(row for row in rows if row[0] == "snpt_topn") assert "TopN" in topn_row[2] assert "按站点水表TopN" in topn_row[2] assert "排行" in topn_row[2]