208 lines
7.2 KiB
Python
208 lines
7.2 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
import random
|
|
from datetime import datetime, timedelta
|
|
from pathlib import Path
|
|
from typing import Generator, List
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
from sqlalchemy import text
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
|
|
# Ensure project root on path for direct execution
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
import sys
|
|
if str(ROOT) not in sys.path:
|
|
sys.path.insert(0, str(ROOT))
|
|
|
|
from app import db
|
|
from app.main import create_app
|
|
|
|
|
|
TEST_USER_ID = 98765
|
|
#SCHEMA_PATH = Path("file/tableschema/metrics.sql")
|
|
DEFAULT_MYSQL_URL = "mysql+pymysql://root:12345678@127.0.0.1:3306/data-ge?charset=utf8mb4"
|
|
|
|
|
|
# def _run_sql_script(engine, sql_text: str) -> None:
|
|
# """Execute semicolon-terminated SQL statements sequentially."""
|
|
# statements: List[str] = []
|
|
# buffer: List[str] = []
|
|
# for line in sql_text.splitlines():
|
|
# stripped = line.strip()
|
|
# if not stripped or stripped.startswith("--"):
|
|
# continue
|
|
# buffer.append(line)
|
|
# if stripped.endswith(";"):
|
|
# statements.append("\n".join(buffer).rstrip(";"))
|
|
# buffer = []
|
|
# if buffer:
|
|
# statements.append("\n".join(buffer))
|
|
# with engine.begin() as conn:
|
|
# for stmt in statements:
|
|
# conn.execute(text(stmt))
|
|
|
|
|
|
# def _ensure_metric_schema(engine) -> None:
|
|
# if not SCHEMA_PATH.exists():
|
|
# pytest.skip("metrics.sql schema file not found.")
|
|
# raw_sql = SCHEMA_PATH.read_text(encoding="utf-8")
|
|
# raw_sql = raw_sql.replace("CREATE TABLE metric_def", "CREATE TABLE IF NOT EXISTS metric_def")
|
|
# raw_sql = raw_sql.replace("CREATE TABLE metric_schedule", "CREATE TABLE IF NOT EXISTS metric_schedule")
|
|
# raw_sql = raw_sql.replace("CREATE TABLE metric_job_run", "CREATE TABLE IF NOT EXISTS metric_job_run")
|
|
# raw_sql = raw_sql.replace("CREATE TABLE metric_result", "CREATE TABLE IF NOT EXISTS metric_result")
|
|
# _run_sql_script(engine, raw_sql)
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def client() -> Generator[TestClient, None, None]:
|
|
mysql_url = os.getenv("TEST_DATABASE_URL", DEFAULT_MYSQL_URL)
|
|
os.environ["DATABASE_URL"] = mysql_url
|
|
db.get_engine.cache_clear()
|
|
engine = db.get_engine()
|
|
try:
|
|
with engine.connect() as conn:
|
|
conn.execute(text("SELECT 1"))
|
|
except SQLAlchemyError:
|
|
pytest.skip(f"Cannot connect to MySQL at {mysql_url}")
|
|
|
|
#_ensure_metric_schema(engine)
|
|
|
|
app = create_app()
|
|
with TestClient(app) as test_client:
|
|
yield test_client
|
|
|
|
# cleanup test artifacts
|
|
with engine.begin() as conn:
|
|
conn.execute(text("DELETE FROM metric_result WHERE metric_id IN (SELECT id FROM metric_def WHERE created_by=:uid)"), {"uid": TEST_USER_ID})
|
|
conn.execute(text("DELETE FROM metric_job_run WHERE metric_id IN (SELECT id FROM metric_def WHERE created_by=:uid)"), {"uid": TEST_USER_ID})
|
|
conn.execute(text("DELETE FROM metric_schedule WHERE metric_id IN (SELECT id FROM metric_def WHERE created_by=:uid)"), {"uid": TEST_USER_ID})
|
|
conn.execute(text("DELETE FROM metric_def WHERE created_by=:uid"), {"uid": TEST_USER_ID})
|
|
db.get_engine.cache_clear()
|
|
|
|
|
|
def test_metric_crud_and_schedule_mysql(client: TestClient) -> None:
|
|
code = f"metric_{random.randint(1000,9999)}"
|
|
create_payload = {
|
|
"metric_code": code,
|
|
"metric_name": "订单数",
|
|
"biz_domain": "order",
|
|
"biz_desc": "订单总数",
|
|
"base_sql": "select count(*) as order_cnt from orders",
|
|
"time_grain": "DAY",
|
|
"dim_binding": ["dt"],
|
|
"update_strategy": "FULL",
|
|
"metric_aliases": ["订单量"],
|
|
"created_by": TEST_USER_ID,
|
|
}
|
|
resp = client.post("/api/v1/metrics", json=create_payload)
|
|
assert resp.status_code == 200, resp.text
|
|
metric = resp.json()
|
|
metric_id = metric["id"]
|
|
assert metric["metric_code"] == code
|
|
|
|
# Update metric
|
|
resp = client.post(f"/api/v1/metrics/{metric_id}", json={"metric_name": "订单数-更新", "is_active": False})
|
|
assert resp.status_code == 200
|
|
assert resp.json()["is_active"] is False
|
|
|
|
# Get metric
|
|
resp = client.get(f"/api/v1/metrics/{metric_id}")
|
|
assert resp.status_code == 200
|
|
assert resp.json()["metric_name"] == "订单数-更新"
|
|
|
|
# Create schedule
|
|
resp = client.post(
|
|
"/api/v1/metric-schedules",
|
|
json={"metric_id": metric_id, "cron_expr": "0 2 * * *", "priority": 5, "enabled": True},
|
|
)
|
|
assert resp.status_code == 200, resp.text
|
|
schedule = resp.json()
|
|
schedule_id = schedule["id"]
|
|
|
|
# Update schedule
|
|
resp = client.post(f"/api/v1/metric-schedules/{schedule_id}", json={"enabled": False, "retry_times": 1})
|
|
assert resp.status_code == 200
|
|
assert resp.json()["enabled"] is False
|
|
|
|
# List schedules for metric
|
|
resp = client.get(f"/api/v1/metrics/{metric_id}/schedules")
|
|
assert resp.status_code == 200
|
|
assert any(s["id"] == schedule_id for s in resp.json())
|
|
|
|
|
|
def test_metric_runs_and_results_mysql(client: TestClient) -> None:
|
|
code = f"gmv_{random.randint(1000,9999)}"
|
|
metric_id = client.post(
|
|
"/api/v1/metrics",
|
|
json={
|
|
"metric_code": code,
|
|
"metric_name": "GMV",
|
|
"biz_domain": "order",
|
|
"base_sql": "select sum(pay_amount) as gmv from orders",
|
|
"time_grain": "DAY",
|
|
"dim_binding": ["dt"],
|
|
"update_strategy": "FULL",
|
|
"created_by": TEST_USER_ID,
|
|
},
|
|
).json()["id"]
|
|
|
|
# Trigger run
|
|
resp = client.post(
|
|
"/api/v1/metric-runs/trigger",
|
|
json={
|
|
"metric_id": metric_id,
|
|
"triggered_by": "API",
|
|
"data_time_from": (datetime.utcnow() - timedelta(days=1)).isoformat(),
|
|
"data_time_to": datetime.utcnow().isoformat(),
|
|
},
|
|
)
|
|
assert resp.status_code == 200, resp.text
|
|
run = resp.json()
|
|
run_id = run["id"]
|
|
assert run["status"] == "RUNNING"
|
|
|
|
# List runs
|
|
resp = client.get("/api/v1/metric-runs", params={"metric_id": metric_id})
|
|
assert resp.status_code == 200
|
|
assert any(r["id"] == run_id for r in resp.json())
|
|
|
|
# Get run
|
|
resp = client.get(f"/api/v1/metric-runs/{run_id}")
|
|
assert resp.status_code == 200
|
|
|
|
# Write results
|
|
now = datetime.utcnow()
|
|
resp = client.post(
|
|
f"/api/v1/metric-results/{metric_id}",
|
|
json={
|
|
"metric_id": metric_id,
|
|
"results": [
|
|
{"stat_time": (now - timedelta(days=1)).isoformat(), "metric_value": 123.45, "data_version": run_id},
|
|
{"stat_time": now.isoformat(), "metric_value": 234.56, "data_version": run_id},
|
|
],
|
|
},
|
|
)
|
|
assert resp.status_code == 200, resp.text
|
|
assert resp.json()["inserted"] == 2
|
|
|
|
# Query results
|
|
resp = client.get("/api/v1/metric-results", params={"metric_id": metric_id})
|
|
assert resp.status_code == 200
|
|
results = resp.json()
|
|
assert len(results) >= 2
|
|
|
|
# Latest result
|
|
resp = client.get("/api/v1/metric-results/latest", params={"metric_id": metric_id})
|
|
assert resp.status_code == 200
|
|
latest = resp.json()
|
|
assert float(latest["metric_value"]) in {123.45, 234.56}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import pytest as _pytest
|
|
|
|
raise SystemExit(_pytest.main([__file__]))
|