rag snippet生成入库和写rag

This commit is contained in:
zhaoawd
2025-12-09 00:15:22 +08:00
parent ebd79b75bd
commit 3218e51bad
11 changed files with 1231 additions and 8 deletions

View File

@ -0,0 +1,207 @@
from __future__ import annotations
import os
import random
from datetime import datetime, timedelta
from pathlib import Path
from typing import Generator, List
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import text
from sqlalchemy.exc import SQLAlchemyError
# Ensure project root on path for direct execution
ROOT = Path(__file__).resolve().parents[1]
import sys
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from app import db
from app.main import create_app
TEST_USER_ID = 98765
#SCHEMA_PATH = Path("file/tableschema/metrics.sql")
DEFAULT_MYSQL_URL = "mysql+pymysql://root:12345678@127.0.0.1:3306/data-ge?charset=utf8mb4"
# def _run_sql_script(engine, sql_text: str) -> None:
# """Execute semicolon-terminated SQL statements sequentially."""
# statements: List[str] = []
# buffer: List[str] = []
# for line in sql_text.splitlines():
# stripped = line.strip()
# if not stripped or stripped.startswith("--"):
# continue
# buffer.append(line)
# if stripped.endswith(";"):
# statements.append("\n".join(buffer).rstrip(";"))
# buffer = []
# if buffer:
# statements.append("\n".join(buffer))
# with engine.begin() as conn:
# for stmt in statements:
# conn.execute(text(stmt))
# def _ensure_metric_schema(engine) -> None:
# if not SCHEMA_PATH.exists():
# pytest.skip("metrics.sql schema file not found.")
# raw_sql = SCHEMA_PATH.read_text(encoding="utf-8")
# raw_sql = raw_sql.replace("CREATE TABLE metric_def", "CREATE TABLE IF NOT EXISTS metric_def")
# raw_sql = raw_sql.replace("CREATE TABLE metric_schedule", "CREATE TABLE IF NOT EXISTS metric_schedule")
# raw_sql = raw_sql.replace("CREATE TABLE metric_job_run", "CREATE TABLE IF NOT EXISTS metric_job_run")
# raw_sql = raw_sql.replace("CREATE TABLE metric_result", "CREATE TABLE IF NOT EXISTS metric_result")
# _run_sql_script(engine, raw_sql)
@pytest.fixture(scope="module")
def client() -> Generator[TestClient, None, None]:
mysql_url = os.getenv("TEST_DATABASE_URL", DEFAULT_MYSQL_URL)
os.environ["DATABASE_URL"] = mysql_url
db.get_engine.cache_clear()
engine = db.get_engine()
try:
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
except SQLAlchemyError:
pytest.skip(f"Cannot connect to MySQL at {mysql_url}")
#_ensure_metric_schema(engine)
app = create_app()
with TestClient(app) as test_client:
yield test_client
# cleanup test artifacts
with engine.begin() as conn:
conn.execute(text("DELETE FROM metric_result WHERE metric_id IN (SELECT id FROM metric_def WHERE created_by=:uid)"), {"uid": TEST_USER_ID})
conn.execute(text("DELETE FROM metric_job_run WHERE metric_id IN (SELECT id FROM metric_def WHERE created_by=:uid)"), {"uid": TEST_USER_ID})
conn.execute(text("DELETE FROM metric_schedule WHERE metric_id IN (SELECT id FROM metric_def WHERE created_by=:uid)"), {"uid": TEST_USER_ID})
conn.execute(text("DELETE FROM metric_def WHERE created_by=:uid"), {"uid": TEST_USER_ID})
db.get_engine.cache_clear()
def test_metric_crud_and_schedule_mysql(client: TestClient) -> None:
code = f"metric_{random.randint(1000,9999)}"
create_payload = {
"metric_code": code,
"metric_name": "订单数",
"biz_domain": "order",
"biz_desc": "订单总数",
"base_sql": "select count(*) as order_cnt from orders",
"time_grain": "DAY",
"dim_binding": ["dt"],
"update_strategy": "FULL",
"metric_aliases": ["订单量"],
"created_by": TEST_USER_ID,
}
resp = client.post("/api/v1/metrics", json=create_payload)
assert resp.status_code == 200, resp.text
metric = resp.json()
metric_id = metric["id"]
assert metric["metric_code"] == code
# Update metric
resp = client.post(f"/api/v1/metrics/{metric_id}", json={"metric_name": "订单数-更新", "is_active": False})
assert resp.status_code == 200
assert resp.json()["is_active"] is False
# Get metric
resp = client.get(f"/api/v1/metrics/{metric_id}")
assert resp.status_code == 200
assert resp.json()["metric_name"] == "订单数-更新"
# Create schedule
resp = client.post(
"/api/v1/metric-schedules",
json={"metric_id": metric_id, "cron_expr": "0 2 * * *", "priority": 5, "enabled": True},
)
assert resp.status_code == 200, resp.text
schedule = resp.json()
schedule_id = schedule["id"]
# Update schedule
resp = client.post(f"/api/v1/metric-schedules/{schedule_id}", json={"enabled": False, "retry_times": 1})
assert resp.status_code == 200
assert resp.json()["enabled"] is False
# List schedules for metric
resp = client.get(f"/api/v1/metrics/{metric_id}/schedules")
assert resp.status_code == 200
assert any(s["id"] == schedule_id for s in resp.json())
def test_metric_runs_and_results_mysql(client: TestClient) -> None:
code = f"gmv_{random.randint(1000,9999)}"
metric_id = client.post(
"/api/v1/metrics",
json={
"metric_code": code,
"metric_name": "GMV",
"biz_domain": "order",
"base_sql": "select sum(pay_amount) as gmv from orders",
"time_grain": "DAY",
"dim_binding": ["dt"],
"update_strategy": "FULL",
"created_by": TEST_USER_ID,
},
).json()["id"]
# Trigger run
resp = client.post(
"/api/v1/metric-runs/trigger",
json={
"metric_id": metric_id,
"triggered_by": "API",
"data_time_from": (datetime.utcnow() - timedelta(days=1)).isoformat(),
"data_time_to": datetime.utcnow().isoformat(),
},
)
assert resp.status_code == 200, resp.text
run = resp.json()
run_id = run["id"]
assert run["status"] == "RUNNING"
# List runs
resp = client.get("/api/v1/metric-runs", params={"metric_id": metric_id})
assert resp.status_code == 200
assert any(r["id"] == run_id for r in resp.json())
# Get run
resp = client.get(f"/api/v1/metric-runs/{run_id}")
assert resp.status_code == 200
# Write results
now = datetime.utcnow()
resp = client.post(
f"/api/v1/metric-results/{metric_id}",
json={
"metric_id": metric_id,
"results": [
{"stat_time": (now - timedelta(days=1)).isoformat(), "metric_value": 123.45, "data_version": run_id},
{"stat_time": now.isoformat(), "metric_value": 234.56, "data_version": run_id},
],
},
)
assert resp.status_code == 200, resp.text
assert resp.json()["inserted"] == 2
# Query results
resp = client.get("/api/v1/metric-results", params={"metric_id": metric_id})
assert resp.status_code == 200
results = resp.json()
assert len(results) >= 2
# Latest result
resp = client.get("/api/v1/metric-results/latest", params={"metric_id": metric_id})
assert resp.status_code == 200
latest = resp.json()
assert float(latest["metric_value"]) in {123.45, 234.56}
if __name__ == "__main__":
import pytest as _pytest
raise SystemExit(_pytest.main([__file__]))

View File

@ -0,0 +1,156 @@
from __future__ import annotations
import json
from datetime import datetime
import httpx
import pytest
from sqlalchemy import create_engine, text
from app.services.table_snippet import ingest_snippet_rag_from_db
def _setup_sqlite_engine():
engine = create_engine("sqlite://")
with engine.begin() as conn:
conn.execute(
text(
"""
CREATE TABLE action_results (
id INTEGER PRIMARY KEY AUTOINCREMENT,
table_id INTEGER,
version_ts INTEGER,
action_type TEXT,
status TEXT,
snippet_json TEXT,
snippet_alias_json TEXT,
updated_at TEXT
)
"""
)
)
conn.execute(
text(
"""
CREATE TABLE rag_snippet (
rag_item_id INTEGER PRIMARY KEY,
action_result_id INTEGER NOT NULL,
workspace_id INTEGER,
table_id INTEGER,
version_ts INTEGER,
snippet_id TEXT,
rag_text TEXT,
merged_json TEXT,
updated_at TEXT
)
"""
)
)
return engine
def _insert_action_row(engine, payload: dict) -> None:
with engine.begin() as conn:
conn.execute(
text(
"""
INSERT INTO action_results (table_id, version_ts, action_type, status, snippet_json, snippet_alias_json, updated_at)
VALUES (:table_id, :version_ts, :action_type, :status, :snippet_json, :snippet_alias_json, :updated_at)
"""
),
{
"table_id": payload["table_id"],
"version_ts": payload["version_ts"],
"action_type": payload["action_type"],
"status": payload.get("status", "success"),
"snippet_json": json.dumps(payload.get("snippet_json"), ensure_ascii=False)
if payload.get("snippet_json") is not None
else None,
"snippet_alias_json": json.dumps(payload.get("snippet_alias_json"), ensure_ascii=False)
if payload.get("snippet_alias_json") is not None
else None,
"updated_at": payload.get("updated_at") or datetime.utcnow().isoformat(),
},
)
class _StubRagClient:
def __init__(self) -> None:
self.received = None
async def add_batch(self, _client, items):
self.received = items
return {"count": len(items)}
@pytest.mark.asyncio
async def test_ingest_snippet_rag_from_db_persists_and_calls_rag_client() -> None:
engine = _setup_sqlite_engine()
table_id = 321
version_ts = 20240102000000
snippet_payload = [
{
"id": "snpt_topn",
"title": "TopN",
"aliases": [{"text": "站点水表排行前N", "tone": "中性"}],
"keywords": ["TopN", "站点"],
}
]
alias_payload = [
{
"id": "snpt_topn",
"aliases": [
{"text": "站点水表排行前N", "tone": "中性"},
{"text": "按站点水表TopN", "tone": "专业"},
],
"keywords": ["TopN", "排行"],
"intent_tags": ["topn", "aggregate"],
},
{
"id": "snpt_extra",
"aliases": [{"text": "额外别名"}],
"keywords": ["extra"],
},
]
_insert_action_row(
engine,
{
"table_id": table_id,
"version_ts": version_ts,
"action_type": "snippet_alias",
"snippet_json": snippet_payload,
"snippet_alias_json": alias_payload,
"updated_at": "2024-01-02T00:00:00",
},
)
rag_stub = _StubRagClient()
async with httpx.AsyncClient() as client:
rag_ids = await ingest_snippet_rag_from_db(
table_id=table_id,
version_ts=version_ts,
workspace_id=99,
rag_item_type="SNIPPET",
client=client,
engine=engine,
rag_client=rag_stub,
)
assert rag_stub.received is not None
assert len(rag_stub.received) == 2 # includes alias-only row
assert len(rag_ids) == 2
with engine.connect() as conn:
rows = list(
conn.execute(
text("SELECT snippet_id, action_result_id, rag_text, merged_json FROM rag_snippet ORDER BY snippet_id")
)
)
assert {row[0] for row in rows} == {"snpt_extra", "snpt_topn"}
assert all(row[1] is not None for row in rows)
topn_row = next(row for row in rows if row[0] == "snpt_topn")
assert "TopN" in topn_row[2]
assert "按站点水表TopN" in topn_row[2]
assert "排行" in topn_row[2]

View File

@ -0,0 +1,213 @@
from __future__ import annotations
import json
import os
import random
from datetime import datetime, timedelta
from typing import List
from pathlib import Path
import sys
import pytest
from sqlalchemy import text
from sqlalchemy.engine import Engine
from sqlalchemy.exc import SQLAlchemyError
# Ensure the project root is importable when running directly via python.
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from app import db
from app.main import create_app
from app.services.table_snippet import merge_snippet_records_from_db
DEFAULT_MYSQL_URL = "mysql+pymysql://root:12345678@127.0.0.1:3306/data-ge?charset=utf8mb4"
@pytest.fixture()
def mysql_engine() -> Engine:
mysql_url = os.getenv("TEST_DATABASE_URL", DEFAULT_MYSQL_URL)
os.environ["DATABASE_URL"] = mysql_url
db.get_engine.cache_clear()
engine = db.get_engine()
try:
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
exists = conn.execute(text("SHOW TABLES LIKE 'action_results'")).scalar()
if not exists:
pytest.skip("action_results table not found in test database.")
except SQLAlchemyError:
pytest.skip(f"Cannot connect to MySQL at {mysql_url}")
return engine
def _insert_action_row(
engine: Engine,
*,
table_id: int,
version_ts: int,
action_type: str,
status: str = "success",
snippet_json: List[dict] | None = None,
snippet_alias_json: List[dict] | None = None,
updated_at: datetime | None = None,
) -> None:
snippet_json_str = json.dumps(snippet_json, ensure_ascii=False) if snippet_json is not None else None
snippet_alias_json_str = (
json.dumps(snippet_alias_json, ensure_ascii=False) if snippet_alias_json is not None else None
)
with engine.begin() as conn:
conn.execute(
text(
"""
INSERT INTO action_results (
table_id, version_ts, action_type, status,
callback_url, table_schema_version_id, table_schema,
snippet_json, snippet_alias_json, updated_at
) VALUES (
:table_id, :version_ts, :action_type, :status,
:callback_url, :table_schema_version_id, :table_schema,
:snippet_json, :snippet_alias_json, :updated_at
)
ON DUPLICATE KEY UPDATE
status=VALUES(status),
snippet_json=VALUES(snippet_json),
snippet_alias_json=VALUES(snippet_alias_json),
updated_at=VALUES(updated_at)
"""
),
{
"table_id": table_id,
"version_ts": version_ts,
"action_type": action_type,
"status": status,
"callback_url": "http://localhost/test-callback",
"table_schema_version_id": "1",
"table_schema": json.dumps({}, ensure_ascii=False),
"snippet_json": snippet_json_str,
"snippet_alias_json": snippet_alias_json_str,
"updated_at": updated_at or datetime.utcnow(),
},
)
def _cleanup(engine: Engine, table_id: int, version_ts: int) -> None:
with engine.begin() as conn:
conn.execute(
text("DELETE FROM action_results WHERE table_id=:table_id AND version_ts=:version_ts"),
{"table_id": table_id, "version_ts": version_ts},
)
def test_merge_prefers_alias_row_and_appends_alias_only_entries(mysql_engine: Engine) -> None:
table_id = 990000000 + random.randint(1, 9999)
version_ts = int(datetime.utcnow().strftime("%Y%m%d%H%M%S"))
alias_updated = datetime(2024, 1, 2, 0, 0, 0)
snippet_payload = [
{
"id": "snpt_topn",
"aliases": [{"text": "站点水表排行前N", "tone": "中性"}],
"keywords": ["TopN", "站点"],
}
]
alias_payload = [
{
"id": "snpt_topn",
"aliases": [
{"text": "站点水表排行前N", "tone": "中性"},
{"text": "按站点水表TopN", "tone": "专业"},
],
"keywords": ["TopN", "排行"],
"intent_tags": ["topn", "aggregate"],
},
{
"id": "snpt_extra",
"aliases": [{"text": "额外别名"}],
"keywords": ["extra"],
},
]
_insert_action_row(
mysql_engine,
table_id=table_id,
version_ts=version_ts,
action_type="snippet_alias",
snippet_json=snippet_payload,
snippet_alias_json=alias_payload,
updated_at=alias_updated,
)
try:
merged = merge_snippet_records_from_db(table_id, version_ts, engine=mysql_engine)
assert len(merged) == 2
topn = next(item for item in merged if item["id"] == "snpt_topn")
assert topn["source"] == "snippet"
assert topn["updated_at_from_action"] == alias_updated
assert {a["text"] for a in topn["aliases"]} == {"站点水表排行前N", "按站点水表TopN"}
assert set(topn["keywords"]) == {"TopN", "站点", "排行"}
assert set(topn["intent_tags"]) == {"topn", "aggregate"}
alias_only = next(item for item in merged if item["source"] == "alias_only")
assert alias_only["id"] == "snpt_extra"
assert alias_only["aliases"][0]["text"] == "额外别名"
finally:
_cleanup(mysql_engine, table_id, version_ts)
def test_merge_falls_back_to_snippet_row_when_alias_row_missing_snippet_json(mysql_engine: Engine) -> None:
table_id = 991000000 + random.randint(1, 9999)
version_ts = int((datetime.utcnow() + timedelta(seconds=1)).strftime("%Y%m%d%H%M%S"))
alias_updated = datetime(2024, 1, 3, 0, 0, 0)
alias_payload = [
{
"id": "snpt_quality",
"aliases": [{"text": "质量检查"}],
"keywords": ["quality"],
}
]
snippet_payload = [
{
"id": "snpt_quality",
"title": "质量检查",
"keywords": ["data-quality"],
"aliases": [{"text": "质量检查"}],
}
]
_insert_action_row(
mysql_engine,
table_id=table_id,
version_ts=version_ts,
action_type="snippet_alias",
snippet_json=None,
snippet_alias_json=alias_payload,
updated_at=alias_updated,
)
_insert_action_row(
mysql_engine,
table_id=table_id,
version_ts=version_ts,
action_type="snippet",
snippet_json=snippet_payload,
snippet_alias_json=None,
updated_at=datetime(2024, 1, 2, 0, 0, 0),
)
try:
merged = merge_snippet_records_from_db(table_id, version_ts, engine=mysql_engine)
assert len(merged) == 1
record = merged[0]
assert record["id"] == "snpt_quality"
assert record["source"] == "snippet"
assert record["updated_at_from_action"] == alias_updated
assert set(record["keywords"]) == {"data-quality", "quality"}
assert {a["text"] for a in record["aliases"]} == {"质量检查"}
finally:
_cleanup(mysql_engine, table_id, version_ts)