214 lines
7.2 KiB
Python
214 lines
7.2 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
import random
|
|
from datetime import datetime, timedelta
|
|
from typing import List
|
|
from pathlib import Path
|
|
|
|
import sys
|
|
import pytest
|
|
from sqlalchemy import text
|
|
from sqlalchemy.engine import Engine
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
|
|
# Ensure the project root is importable when running directly via python.
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
if str(ROOT) not in sys.path:
|
|
sys.path.insert(0, str(ROOT))
|
|
|
|
from app import db
|
|
from app.main import create_app
|
|
|
|
|
|
from app.services.table_snippet import merge_snippet_records_from_db
|
|
|
|
|
|
DEFAULT_MYSQL_URL = "mysql+pymysql://root:12345678@127.0.0.1:3306/data-ge?charset=utf8mb4"
|
|
|
|
|
|
@pytest.fixture()
|
|
def mysql_engine() -> Engine:
|
|
mysql_url = os.getenv("TEST_DATABASE_URL", DEFAULT_MYSQL_URL)
|
|
os.environ["DATABASE_URL"] = mysql_url
|
|
db.get_engine.cache_clear()
|
|
engine = db.get_engine()
|
|
try:
|
|
with engine.connect() as conn:
|
|
conn.execute(text("SELECT 1"))
|
|
exists = conn.execute(text("SHOW TABLES LIKE 'action_results'")).scalar()
|
|
if not exists:
|
|
pytest.skip("action_results table not found in test database.")
|
|
except SQLAlchemyError:
|
|
pytest.skip(f"Cannot connect to MySQL at {mysql_url}")
|
|
return engine
|
|
|
|
|
|
def _insert_action_row(
|
|
engine: Engine,
|
|
*,
|
|
table_id: int,
|
|
version_ts: int,
|
|
action_type: str,
|
|
status: str = "success",
|
|
snippet_json: List[dict] | None = None,
|
|
snippet_alias_json: List[dict] | None = None,
|
|
updated_at: datetime | None = None,
|
|
) -> None:
|
|
snippet_json_str = json.dumps(snippet_json, ensure_ascii=False) if snippet_json is not None else None
|
|
snippet_alias_json_str = (
|
|
json.dumps(snippet_alias_json, ensure_ascii=False) if snippet_alias_json is not None else None
|
|
)
|
|
with engine.begin() as conn:
|
|
conn.execute(
|
|
text(
|
|
"""
|
|
INSERT INTO action_results (
|
|
table_id, version_ts, action_type, status,
|
|
callback_url, table_schema_version_id, table_schema,
|
|
snippet_json, snippet_alias_json, updated_at
|
|
) VALUES (
|
|
:table_id, :version_ts, :action_type, :status,
|
|
:callback_url, :table_schema_version_id, :table_schema,
|
|
:snippet_json, :snippet_alias_json, :updated_at
|
|
)
|
|
ON DUPLICATE KEY UPDATE
|
|
status=VALUES(status),
|
|
snippet_json=VALUES(snippet_json),
|
|
snippet_alias_json=VALUES(snippet_alias_json),
|
|
updated_at=VALUES(updated_at)
|
|
"""
|
|
),
|
|
{
|
|
"table_id": table_id,
|
|
"version_ts": version_ts,
|
|
"action_type": action_type,
|
|
"status": status,
|
|
"callback_url": "http://localhost/test-callback",
|
|
"table_schema_version_id": "1",
|
|
"table_schema": json.dumps({}, ensure_ascii=False),
|
|
"snippet_json": snippet_json_str,
|
|
"snippet_alias_json": snippet_alias_json_str,
|
|
"updated_at": updated_at or datetime.utcnow(),
|
|
},
|
|
)
|
|
|
|
|
|
def _cleanup(engine: Engine, table_id: int, version_ts: int) -> None:
|
|
with engine.begin() as conn:
|
|
conn.execute(
|
|
text("DELETE FROM action_results WHERE table_id=:table_id AND version_ts=:version_ts"),
|
|
{"table_id": table_id, "version_ts": version_ts},
|
|
)
|
|
|
|
|
|
def test_merge_prefers_alias_row_and_appends_alias_only_entries(mysql_engine: Engine) -> None:
|
|
table_id = 990000000 + random.randint(1, 9999)
|
|
version_ts = int(datetime.utcnow().strftime("%Y%m%d%H%M%S"))
|
|
alias_updated = datetime(2024, 1, 2, 0, 0, 0)
|
|
|
|
snippet_payload = [
|
|
{
|
|
"id": "snpt_topn",
|
|
"aliases": [{"text": "站点水表排行前N", "tone": "中性"}],
|
|
"keywords": ["TopN", "站点"],
|
|
}
|
|
]
|
|
alias_payload = [
|
|
{
|
|
"id": "snpt_topn",
|
|
"aliases": [
|
|
{"text": "站点水表排行前N", "tone": "中性"},
|
|
{"text": "按站点水表TopN", "tone": "专业"},
|
|
],
|
|
"keywords": ["TopN", "排行"],
|
|
"intent_tags": ["topn", "aggregate"],
|
|
},
|
|
{
|
|
"id": "snpt_extra",
|
|
"aliases": [{"text": "额外别名"}],
|
|
"keywords": ["extra"],
|
|
},
|
|
]
|
|
|
|
_insert_action_row(
|
|
mysql_engine,
|
|
table_id=table_id,
|
|
version_ts=version_ts,
|
|
action_type="snippet_alias",
|
|
snippet_json=snippet_payload,
|
|
snippet_alias_json=alias_payload,
|
|
updated_at=alias_updated,
|
|
)
|
|
|
|
try:
|
|
merged = merge_snippet_records_from_db(table_id, version_ts, engine=mysql_engine)
|
|
assert len(merged) == 2
|
|
topn = next(item for item in merged if item["id"] == "snpt_topn")
|
|
assert topn["source"] == "snippet"
|
|
assert topn["updated_at_from_action"] == alias_updated
|
|
assert {a["text"] for a in topn["aliases"]} == {"站点水表排行前N", "按站点水表TopN"}
|
|
assert set(topn["keywords"]) == {"TopN", "站点", "排行"}
|
|
assert set(topn["intent_tags"]) == {"topn", "aggregate"}
|
|
|
|
alias_only = next(item for item in merged if item["source"] == "alias_only")
|
|
assert alias_only["id"] == "snpt_extra"
|
|
assert alias_only["aliases"][0]["text"] == "额外别名"
|
|
finally:
|
|
_cleanup(mysql_engine, table_id, version_ts)
|
|
|
|
|
|
def test_merge_falls_back_to_snippet_row_when_alias_row_missing_snippet_json(mysql_engine: Engine) -> None:
|
|
table_id = 991000000 + random.randint(1, 9999)
|
|
version_ts = int((datetime.utcnow() + timedelta(seconds=1)).strftime("%Y%m%d%H%M%S"))
|
|
|
|
alias_updated = datetime(2024, 1, 3, 0, 0, 0)
|
|
alias_payload = [
|
|
{
|
|
"id": "snpt_quality",
|
|
"aliases": [{"text": "质量检查"}],
|
|
"keywords": ["quality"],
|
|
}
|
|
]
|
|
snippet_payload = [
|
|
{
|
|
"id": "snpt_quality",
|
|
"title": "质量检查",
|
|
"keywords": ["data-quality"],
|
|
"aliases": [{"text": "质量检查"}],
|
|
}
|
|
]
|
|
|
|
_insert_action_row(
|
|
mysql_engine,
|
|
table_id=table_id,
|
|
version_ts=version_ts,
|
|
action_type="snippet_alias",
|
|
snippet_json=None,
|
|
snippet_alias_json=alias_payload,
|
|
updated_at=alias_updated,
|
|
)
|
|
_insert_action_row(
|
|
mysql_engine,
|
|
table_id=table_id,
|
|
version_ts=version_ts,
|
|
action_type="snippet",
|
|
snippet_json=snippet_payload,
|
|
snippet_alias_json=None,
|
|
updated_at=datetime(2024, 1, 2, 0, 0, 0),
|
|
)
|
|
|
|
try:
|
|
merged = merge_snippet_records_from_db(table_id, version_ts, engine=mysql_engine)
|
|
|
|
assert len(merged) == 1
|
|
record = merged[0]
|
|
assert record["id"] == "snpt_quality"
|
|
assert record["source"] == "snippet"
|
|
assert record["updated_at_from_action"] == alias_updated
|
|
assert set(record["keywords"]) == {"data-quality", "quality"}
|
|
assert {a["text"] for a in record["aliases"]} == {"质量检查"}
|
|
finally:
|
|
_cleanup(mysql_engine, table_id, version_ts)
|