from __future__ import annotations import os import random from datetime import datetime, timedelta from pathlib import Path from typing import Generator, List import pytest from fastapi.testclient import TestClient from sqlalchemy import text from sqlalchemy.exc import SQLAlchemyError # Ensure project root on path for direct execution ROOT = Path(__file__).resolve().parents[1] import sys if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from app import db from app.main import create_app TEST_USER_ID = 98765 #SCHEMA_PATH = Path("file/tableschema/metrics.sql") DEFAULT_MYSQL_URL = "mysql+pymysql://root:12345678@127.0.0.1:3306/data-ge?charset=utf8mb4" # def _run_sql_script(engine, sql_text: str) -> None: # """Execute semicolon-terminated SQL statements sequentially.""" # statements: List[str] = [] # buffer: List[str] = [] # for line in sql_text.splitlines(): # stripped = line.strip() # if not stripped or stripped.startswith("--"): # continue # buffer.append(line) # if stripped.endswith(";"): # statements.append("\n".join(buffer).rstrip(";")) # buffer = [] # if buffer: # statements.append("\n".join(buffer)) # with engine.begin() as conn: # for stmt in statements: # conn.execute(text(stmt)) # def _ensure_metric_schema(engine) -> None: # if not SCHEMA_PATH.exists(): # pytest.skip("metrics.sql schema file not found.") # raw_sql = SCHEMA_PATH.read_text(encoding="utf-8") # raw_sql = raw_sql.replace("CREATE TABLE metric_def", "CREATE TABLE IF NOT EXISTS metric_def") # raw_sql = raw_sql.replace("CREATE TABLE metric_schedule", "CREATE TABLE IF NOT EXISTS metric_schedule") # raw_sql = raw_sql.replace("CREATE TABLE metric_job_run", "CREATE TABLE IF NOT EXISTS metric_job_run") # raw_sql = raw_sql.replace("CREATE TABLE metric_result", "CREATE TABLE IF NOT EXISTS metric_result") # _run_sql_script(engine, raw_sql) @pytest.fixture(scope="module") def client() -> Generator[TestClient, None, None]: mysql_url = os.getenv("TEST_DATABASE_URL", DEFAULT_MYSQL_URL) os.environ["DATABASE_URL"] = mysql_url db.get_engine.cache_clear() engine = db.get_engine() try: with engine.connect() as conn: conn.execute(text("SELECT 1")) except SQLAlchemyError: pytest.skip(f"Cannot connect to MySQL at {mysql_url}") #_ensure_metric_schema(engine) app = create_app() with TestClient(app) as test_client: yield test_client # cleanup test artifacts with engine.begin() as conn: conn.execute(text("DELETE FROM metric_result WHERE metric_id IN (SELECT id FROM metric_def WHERE created_by=:uid)"), {"uid": TEST_USER_ID}) conn.execute(text("DELETE FROM metric_job_run WHERE metric_id IN (SELECT id FROM metric_def WHERE created_by=:uid)"), {"uid": TEST_USER_ID}) conn.execute(text("DELETE FROM metric_schedule WHERE metric_id IN (SELECT id FROM metric_def WHERE created_by=:uid)"), {"uid": TEST_USER_ID}) conn.execute(text("DELETE FROM metric_def WHERE created_by=:uid"), {"uid": TEST_USER_ID}) db.get_engine.cache_clear() def test_metric_crud_and_schedule_mysql(client: TestClient) -> None: code = f"metric_{random.randint(1000,9999)}" create_payload = { "metric_code": code, "metric_name": "订单数", "biz_domain": "order", "biz_desc": "订单总数", "base_sql": "select count(*) as order_cnt from orders", "time_grain": "DAY", "dim_binding": ["dt"], "update_strategy": "FULL", "metric_aliases": ["订单量"], "created_by": TEST_USER_ID, } resp = client.post("/api/v1/metrics", json=create_payload) assert resp.status_code == 200, resp.text metric = resp.json() metric_id = metric["id"] assert metric["metric_code"] == code # Update metric resp = client.post(f"/api/v1/metrics/{metric_id}", json={"metric_name": "订单数-更新", "is_active": False}) assert resp.status_code == 200 assert resp.json()["is_active"] is False # Get metric resp = client.get(f"/api/v1/metrics/{metric_id}") assert resp.status_code == 200 assert resp.json()["metric_name"] == "订单数-更新" # Create schedule resp = client.post( "/api/v1/metric-schedules", json={"metric_id": metric_id, "cron_expr": "0 2 * * *", "priority": 5, "enabled": True}, ) assert resp.status_code == 200, resp.text schedule = resp.json() schedule_id = schedule["id"] # Update schedule resp = client.post(f"/api/v1/metric-schedules/{schedule_id}", json={"enabled": False, "retry_times": 1}) assert resp.status_code == 200 assert resp.json()["enabled"] is False # List schedules for metric resp = client.get(f"/api/v1/metrics/{metric_id}/schedules") assert resp.status_code == 200 assert any(s["id"] == schedule_id for s in resp.json()) def test_metric_runs_and_results_mysql(client: TestClient) -> None: code = f"gmv_{random.randint(1000,9999)}" metric_id = client.post( "/api/v1/metrics", json={ "metric_code": code, "metric_name": "GMV", "biz_domain": "order", "base_sql": "select sum(pay_amount) as gmv from orders", "time_grain": "DAY", "dim_binding": ["dt"], "update_strategy": "FULL", "created_by": TEST_USER_ID, }, ).json()["id"] # Trigger run resp = client.post( "/api/v1/metric-runs/trigger", json={ "metric_id": metric_id, "triggered_by": "API", "data_time_from": (datetime.utcnow() - timedelta(days=1)).isoformat(), "data_time_to": datetime.utcnow().isoformat(), }, ) assert resp.status_code == 200, resp.text run = resp.json() run_id = run["id"] assert run["status"] == "RUNNING" # List runs resp = client.get("/api/v1/metric-runs", params={"metric_id": metric_id}) assert resp.status_code == 200 assert any(r["id"] == run_id for r in resp.json()) # Get run resp = client.get(f"/api/v1/metric-runs/{run_id}") assert resp.status_code == 200 # Write results now = datetime.utcnow() resp = client.post( f"/api/v1/metric-results/{metric_id}", json={ "metric_id": metric_id, "results": [ {"stat_time": (now - timedelta(days=1)).isoformat(), "metric_value": 123.45, "data_version": run_id}, {"stat_time": now.isoformat(), "metric_value": 234.56, "data_version": run_id}, ], }, ) assert resp.status_code == 200, resp.text assert resp.json()["inserted"] == 2 # Query results resp = client.get("/api/v1/metric-results", params={"metric_id": metric_id}) assert resp.status_code == 200 results = resp.json() assert len(results) >= 2 # Latest result resp = client.get("/api/v1/metric-results/latest", params={"metric_id": metric_id}) assert resp.status_code == 200 latest = resp.json() assert float(latest["metric_value"]) in {123.45, 234.56} if __name__ == "__main__": import pytest as _pytest raise SystemExit(_pytest.main([__file__]))