Files
data-ge/test/test_metrics_api_mysql.py
2025-12-09 00:15:22 +08:00

208 lines
7.2 KiB
Python

from __future__ import annotations
import os
import random
from datetime import datetime, timedelta
from pathlib import Path
from typing import Generator, List
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import text
from sqlalchemy.exc import SQLAlchemyError
# Ensure project root on path for direct execution
ROOT = Path(__file__).resolve().parents[1]
import sys
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from app import db
from app.main import create_app
TEST_USER_ID = 98765
#SCHEMA_PATH = Path("file/tableschema/metrics.sql")
DEFAULT_MYSQL_URL = "mysql+pymysql://root:12345678@127.0.0.1:3306/data-ge?charset=utf8mb4"
# def _run_sql_script(engine, sql_text: str) -> None:
# """Execute semicolon-terminated SQL statements sequentially."""
# statements: List[str] = []
# buffer: List[str] = []
# for line in sql_text.splitlines():
# stripped = line.strip()
# if not stripped or stripped.startswith("--"):
# continue
# buffer.append(line)
# if stripped.endswith(";"):
# statements.append("\n".join(buffer).rstrip(";"))
# buffer = []
# if buffer:
# statements.append("\n".join(buffer))
# with engine.begin() as conn:
# for stmt in statements:
# conn.execute(text(stmt))
# def _ensure_metric_schema(engine) -> None:
# if not SCHEMA_PATH.exists():
# pytest.skip("metrics.sql schema file not found.")
# raw_sql = SCHEMA_PATH.read_text(encoding="utf-8")
# raw_sql = raw_sql.replace("CREATE TABLE metric_def", "CREATE TABLE IF NOT EXISTS metric_def")
# raw_sql = raw_sql.replace("CREATE TABLE metric_schedule", "CREATE TABLE IF NOT EXISTS metric_schedule")
# raw_sql = raw_sql.replace("CREATE TABLE metric_job_run", "CREATE TABLE IF NOT EXISTS metric_job_run")
# raw_sql = raw_sql.replace("CREATE TABLE metric_result", "CREATE TABLE IF NOT EXISTS metric_result")
# _run_sql_script(engine, raw_sql)
@pytest.fixture(scope="module")
def client() -> Generator[TestClient, None, None]:
mysql_url = os.getenv("TEST_DATABASE_URL", DEFAULT_MYSQL_URL)
os.environ["DATABASE_URL"] = mysql_url
db.get_engine.cache_clear()
engine = db.get_engine()
try:
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
except SQLAlchemyError:
pytest.skip(f"Cannot connect to MySQL at {mysql_url}")
#_ensure_metric_schema(engine)
app = create_app()
with TestClient(app) as test_client:
yield test_client
# cleanup test artifacts
with engine.begin() as conn:
conn.execute(text("DELETE FROM metric_result WHERE metric_id IN (SELECT id FROM metric_def WHERE created_by=:uid)"), {"uid": TEST_USER_ID})
conn.execute(text("DELETE FROM metric_job_run WHERE metric_id IN (SELECT id FROM metric_def WHERE created_by=:uid)"), {"uid": TEST_USER_ID})
conn.execute(text("DELETE FROM metric_schedule WHERE metric_id IN (SELECT id FROM metric_def WHERE created_by=:uid)"), {"uid": TEST_USER_ID})
conn.execute(text("DELETE FROM metric_def WHERE created_by=:uid"), {"uid": TEST_USER_ID})
db.get_engine.cache_clear()
def test_metric_crud_and_schedule_mysql(client: TestClient) -> None:
code = f"metric_{random.randint(1000,9999)}"
create_payload = {
"metric_code": code,
"metric_name": "订单数",
"biz_domain": "order",
"biz_desc": "订单总数",
"base_sql": "select count(*) as order_cnt from orders",
"time_grain": "DAY",
"dim_binding": ["dt"],
"update_strategy": "FULL",
"metric_aliases": ["订单量"],
"created_by": TEST_USER_ID,
}
resp = client.post("/api/v1/metrics", json=create_payload)
assert resp.status_code == 200, resp.text
metric = resp.json()
metric_id = metric["id"]
assert metric["metric_code"] == code
# Update metric
resp = client.post(f"/api/v1/metrics/{metric_id}", json={"metric_name": "订单数-更新", "is_active": False})
assert resp.status_code == 200
assert resp.json()["is_active"] is False
# Get metric
resp = client.get(f"/api/v1/metrics/{metric_id}")
assert resp.status_code == 200
assert resp.json()["metric_name"] == "订单数-更新"
# Create schedule
resp = client.post(
"/api/v1/metric-schedules",
json={"metric_id": metric_id, "cron_expr": "0 2 * * *", "priority": 5, "enabled": True},
)
assert resp.status_code == 200, resp.text
schedule = resp.json()
schedule_id = schedule["id"]
# Update schedule
resp = client.post(f"/api/v1/metric-schedules/{schedule_id}", json={"enabled": False, "retry_times": 1})
assert resp.status_code == 200
assert resp.json()["enabled"] is False
# List schedules for metric
resp = client.get(f"/api/v1/metrics/{metric_id}/schedules")
assert resp.status_code == 200
assert any(s["id"] == schedule_id for s in resp.json())
def test_metric_runs_and_results_mysql(client: TestClient) -> None:
code = f"gmv_{random.randint(1000,9999)}"
metric_id = client.post(
"/api/v1/metrics",
json={
"metric_code": code,
"metric_name": "GMV",
"biz_domain": "order",
"base_sql": "select sum(pay_amount) as gmv from orders",
"time_grain": "DAY",
"dim_binding": ["dt"],
"update_strategy": "FULL",
"created_by": TEST_USER_ID,
},
).json()["id"]
# Trigger run
resp = client.post(
"/api/v1/metric-runs/trigger",
json={
"metric_id": metric_id,
"triggered_by": "API",
"data_time_from": (datetime.utcnow() - timedelta(days=1)).isoformat(),
"data_time_to": datetime.utcnow().isoformat(),
},
)
assert resp.status_code == 200, resp.text
run = resp.json()
run_id = run["id"]
assert run["status"] == "RUNNING"
# List runs
resp = client.get("/api/v1/metric-runs", params={"metric_id": metric_id})
assert resp.status_code == 200
assert any(r["id"] == run_id for r in resp.json())
# Get run
resp = client.get(f"/api/v1/metric-runs/{run_id}")
assert resp.status_code == 200
# Write results
now = datetime.utcnow()
resp = client.post(
f"/api/v1/metric-results/{metric_id}",
json={
"metric_id": metric_id,
"results": [
{"stat_time": (now - timedelta(days=1)).isoformat(), "metric_value": 123.45, "data_version": run_id},
{"stat_time": now.isoformat(), "metric_value": 234.56, "data_version": run_id},
],
},
)
assert resp.status_code == 200, resp.text
assert resp.json()["inserted"] == 2
# Query results
resp = client.get("/api/v1/metric-results", params={"metric_id": metric_id})
assert resp.status_code == 200
results = resp.json()
assert len(results) >= 2
# Latest result
resp = client.get("/api/v1/metric-results/latest", params={"metric_id": metric_id})
assert resp.status_code == 200
latest = resp.json()
assert float(latest["metric_value"]) in {123.45, 234.56}
if __name__ == "__main__":
import pytest as _pytest
raise SystemExit(_pytest.main([__file__]))