diff --git a/.env b/.env index 97c6958..995ee87 100644 --- a/.env +++ b/.env @@ -16,6 +16,9 @@ DEFAULT_IMPORT_MODEL=deepseek:deepseek-chat # Service configuration IMPORT_GATEWAY_BASE_URL=http://localhost:8000 +# prod nbackend base url +NBACKEND_BASE_URL=https://chatbi.agentcarrier.cn/chatbi/api + # HTTP client configuration HTTP_CLIENT_TIMEOUT=120 HTTP_CLIENT_TRUST_ENV=false diff --git a/app/main.py b/app/main.py index 7b7d55b..1eee626 100644 --- a/app/main.py +++ b/app/main.py @@ -16,6 +16,8 @@ from fastapi.responses import JSONResponse from app.exceptions import ProviderAPICallError, ProviderConfigurationError from app.models import ( + ActionStatus, + ActionType, DataImportAnalysisJobAck, DataImportAnalysisJobRequest, LLMRequest, @@ -25,10 +27,11 @@ from app.models import ( TableSnippetUpsertRequest, TableSnippetUpsertResponse, ) +from app.routers import chat_router, metrics_router from app.services import LLMGateway from app.services.import_analysis import process_import_analysis_job from app.services.table_profiling import process_table_profiling_job -from app.services.table_snippet import upsert_action_result +from app.services.table_snippet import ingest_snippet_rag_from_db, upsert_action_result def _ensure_log_directories(config: dict[str, Any]) -> None: @@ -135,6 +138,9 @@ def create_app() -> FastAPI: version="0.1.0", lifespan=lifespan, ) + # Chat/metric management APIs + application.include_router(chat_router) + application.include_router(metrics_router) @application.exception_handler(RequestValidationError) async def request_validation_exception_handler( @@ -230,11 +236,12 @@ def create_app() -> FastAPI: ) async def upsert_table_snippet( payload: TableSnippetUpsertRequest, + client: httpx.AsyncClient = Depends(get_http_client), ) -> TableSnippetUpsertResponse: request_copy = payload.model_copy(deep=True) try: - return await asyncio.to_thread(upsert_action_result, request_copy) + response = await asyncio.to_thread(upsert_action_result, request_copy) except Exception as exc: logger.error( "Failed to upsert table snippet: table_id=%s version_ts=%s action_type=%s", @@ -244,6 +251,29 @@ def create_app() -> FastAPI: exc_info=True, ) raise HTTPException(status_code=500, detail=str(exc)) from exc + else: + if ( + payload.action_type == ActionType.SNIPPET_ALIAS + and payload.status == ActionStatus.SUCCESS + and payload.rag_workspace_id is not None + ): + try: + await ingest_snippet_rag_from_db( + table_id=payload.table_id, + version_ts=payload.version_ts, + workspace_id=payload.rag_workspace_id, + rag_item_type=payload.rag_item_type or "SNIPPET", + client=client, + ) + except Exception: + logger.exception( + "Failed to ingest snippet RAG artifacts", + extra={ + "table_id": payload.table_id, + "version_ts": payload.version_ts, + }, + ) + return response @application.post("/__mock__/import-callback") async def mock_import_callback(payload: dict[str, Any]) -> dict[str, str]: diff --git a/app/models.py b/app/models.py index 4cf2c6d..7405597 100644 --- a/app/models.py +++ b/app/models.py @@ -247,6 +247,16 @@ class TableSnippetUpsertRequest(BaseModel): ge=0, description="Version timestamp aligned with the pipeline (yyyyMMddHHmmss as integer).", ) + rag_workspace_id: Optional[int] = Field( + None, + ge=0, + description="Optional workspace identifier for RAG ingestion; when provided and action_type=snippet_alias " + "with status=success, merged snippets will be written to rag_snippet and pushed to RAG.", + ) + rag_item_type: Optional[str] = Field( + "SNIPPET", + description="Optional RAG item type used when pushing snippets to RAG. Defaults to 'SNIPPET'.", + ) action_type: ActionType = Field(..., description="Pipeline action type for this record.") status: ActionStatus = Field( ActionStatus.SUCCESS, description="Execution status for the action." diff --git a/app/schemas/rag.py b/app/schemas/rag.py new file mode 100644 index 0000000..6399700 --- /dev/null +++ b/app/schemas/rag.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from typing import Any, List + +from pydantic import BaseModel, ConfigDict, Field + + +class RagItemPayload(BaseModel): + """Payload for creating or updating a single RAG item.""" + + model_config = ConfigDict(populate_by_name=True, extra="ignore") + + id: int = Field(..., description="Unique identifier for the RAG item.") + workspace_id: int = Field(..., alias="workspaceId", description="Workspace identifier.") + name: str = Field(..., description="Readable name of the item.") + embedding_data: str = Field(..., alias="embeddingData", description="Serialized embedding payload.") + type: str = Field(..., description='Item type, e.g. "METRIC".') + + +class RagDeleteRequest(BaseModel): + """Payload for deleting a single RAG item.""" + + model_config = ConfigDict(populate_by_name=True, extra="ignore") + + id: int = Field(..., description="Identifier of the item to delete.") + type: str = Field(..., description="Item type matching the stored record.") + + +class RagRetrieveRequest(BaseModel): + """Payload for retrieving RAG items by semantic query.""" + + model_config = ConfigDict(populate_by_name=True, extra="ignore") + + query: str = Field(..., description="Search query text.") + num: int = Field(..., description="Number of items to return.") + workspace_id: int = Field(..., alias="workspaceId", description="Workspace scope for the search.") + type: str = Field(..., description="Item type to search, e.g. METRIC.") + + +class RagRetrieveResponse(BaseModel): + """Generic RAG retrieval response wrapper.""" + + model_config = ConfigDict(extra="allow") + + data: List[Any] = Field(default_factory=list, description="Retrieved items.") + diff --git a/app/services/rag_client.py b/app/services/rag_client.py new file mode 100644 index 0000000..056dcf2 --- /dev/null +++ b/app/services/rag_client.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +import logging +from typing import Any, Sequence + +import httpx + +from app.exceptions import ProviderAPICallError +from app.schemas.rag import RagDeleteRequest, RagItemPayload, RagRetrieveRequest +from app.settings import RAG_API_AUTH_TOKEN, RAG_API_BASE_URL + + +logger = logging.getLogger(__name__) + + +class RagAPIClient: + """Thin async client wrapper around the RAG endpoints described in doc/rag-api.md.""" + + def __init__(self, *, base_url: str | None = None, auth_token: str | None = None) -> None: + resolved_base = base_url or RAG_API_BASE_URL + self._base_url = resolved_base.rstrip("/") + self._auth_token = auth_token or RAG_API_AUTH_TOKEN + + def _headers(self) -> dict[str, str]: + headers = {"Content-Type": "application/json"} + if self._auth_token: + headers["Authorization"] = f"Bearer {self._auth_token}" + return headers + + async def _post( + self, + client: httpx.AsyncClient, + path: str, + payload: Any, + ) -> Any: + url = f"{self._base_url}{path}" + try: + response = await client.post(url, json=payload, headers=self._headers()) + response.raise_for_status() + except httpx.HTTPStatusError as exc: + status_code = exc.response.status_code if exc.response else None + response_text = exc.response.text if exc.response else "" + logger.error( + "RAG API responded with an error (%s) for %s: %s", + status_code, + url, + response_text, + exc_info=True, + ) + raise ProviderAPICallError( + "RAG API call failed.", + status_code=status_code, + response_text=response_text, + ) from exc + except httpx.HTTPError as exc: + logger.error("Transport error calling RAG API %s: %s", url, exc, exc_info=True) + raise ProviderAPICallError(f"RAG API call failed: {exc}") from exc + + try: + return response.json() + except ValueError: + logger.warning("RAG API returned non-JSON response for %s; returning raw text.", url) + return {"raw": response.text} + + async def add(self, client: httpx.AsyncClient, payload: RagItemPayload) -> Any: + body = payload.model_dump(by_alias=True, exclude_none=True) + return await self._post(client, "/rag/add", body) + + async def add_batch(self, client: httpx.AsyncClient, items: Sequence[RagItemPayload]) -> Any: + body = [item.model_dump(by_alias=True, exclude_none=True) for item in items] + return await self._post(client, "/rag/addBatch", body) + + async def update(self, client: httpx.AsyncClient, payload: RagItemPayload) -> Any: + body = payload.model_dump(by_alias=True, exclude_none=True) + return await self._post(client, "/rag/update", body) + + async def delete(self, client: httpx.AsyncClient, payload: RagDeleteRequest) -> Any: + body = payload.model_dump(by_alias=True, exclude_none=True) + return await self._post(client, "/rag/delete", body) + + async def retrieve(self, client: httpx.AsyncClient, payload: RagRetrieveRequest) -> Any: + body = payload.model_dump(by_alias=True, exclude_none=True) + return await self._post(client, "/rag/retrieve", body) diff --git a/app/services/table_snippet.py b/app/services/table_snippet.py index a97436d..345d296 100644 --- a/app/services/table_snippet.py +++ b/app/services/table_snippet.py @@ -1,19 +1,19 @@ from __future__ import annotations +import hashlib import json import logging -from typing import Any, Dict, Tuple +from datetime import datetime +from typing import Any, Dict, List, Optional, Sequence, Tuple from sqlalchemy import text from sqlalchemy.engine import Engine from sqlalchemy.exc import SQLAlchemyError from app.db import get_engine -from app.models import ( - ActionType, - TableSnippetUpsertRequest, - TableSnippetUpsertResponse, -) +from app.models import ActionType, TableSnippetUpsertRequest, TableSnippetUpsertResponse +from app.schemas.rag import RagItemPayload +from app.services.rag_client import RagAPIClient logger = logging.getLogger(__name__) @@ -46,6 +46,7 @@ def _prepare_model_params(params: Dict[str, Any] | None) -> str | None: def _collect_common_columns(request: TableSnippetUpsertRequest) -> Dict[str, Any]: + # Build the base column set shared by all action types; action-specific fields are populated later. logger.debug( "Collecting common columns for table_id=%s version_ts=%s action_type=%s", request.table_id, @@ -215,3 +216,405 @@ def upsert_action_result(request: TableSnippetUpsertRequest) -> TableSnippetUpse status=request.status, updated=updated, ) + + +def _decode_json_field(value: Any) -> Any: + """Decode JSON columns that may be returned as str/bytes/dicts/lists.""" + if value is None: + return None + if isinstance(value, (dict, list)): + return value + if isinstance(value, (bytes, bytearray)): + try: + value = value.decode("utf-8") + except Exception: # pragma: no cover - defensive + return None + if isinstance(value, str): + try: + return json.loads(value) + except json.JSONDecodeError: + logger.warning("Failed to decode JSON field: %s", value) + return None + return None + + +def _coerce_json_array(value: Any) -> List[Any]: + decoded = _decode_json_field(value) + return decoded if isinstance(decoded, list) else [] + + +def _fetch_action_payload( + engine: Engine, table_id: int, version_ts: int, action_type: ActionType +) -> Optional[Dict[str, Any]]: + sql = text( + """ + SELECT id AS action_result_id, snippet_json, snippet_alias_json, updated_at, status + FROM action_results + WHERE table_id = :table_id + AND version_ts = :version_ts + AND action_type = :action_type + AND status IN ('success', 'partial') + ORDER BY CASE status WHEN 'success' THEN 0 ELSE 1 END, updated_at DESC + LIMIT 1 + """ + ) + with engine.connect() as conn: + row = conn.execute( + sql, + { + "table_id": table_id, + "version_ts": version_ts, + "action_type": action_type.value, + }, + ).mappings().first() + return dict(row) if row else None + + +def _load_snippet_sources( + engine: Engine, table_id: int, version_ts: int +) -> Tuple[List[Any], List[Any], Optional[datetime], Optional[int], Optional[int]]: + alias_row = _fetch_action_payload(engine, table_id, version_ts, ActionType.SNIPPET_ALIAS) + snippet_row = _fetch_action_payload(engine, table_id, version_ts, ActionType.SNIPPET) + + snippet_json = _coerce_json_array(alias_row.get("snippet_json") if alias_row else None) + alias_json = _coerce_json_array(alias_row.get("snippet_alias_json") if alias_row else None) + updated_at: Optional[datetime] = alias_row.get("updated_at") if alias_row else None + alias_action_id: Optional[int] = alias_row.get("action_result_id") if alias_row else None + snippet_action_id: Optional[int] = snippet_row.get("action_result_id") if snippet_row else None + + if not snippet_json and snippet_row: + snippet_json = _coerce_json_array(snippet_row.get("snippet_json")) + if updated_at is None: + updated_at = snippet_row.get("updated_at") + if alias_action_id is None: + alias_action_id = snippet_action_id + + if not updated_at and alias_row: + updated_at = alias_row.get("updated_at") + + return snippet_json, alias_json, updated_at, alias_action_id, snippet_action_id + + +def _normalize_aliases(raw_aliases: Any) -> List[Dict[str, Any]]: + aliases: List[Dict[str, Any]] = [] + seen: set[str] = set() + if not raw_aliases: + return aliases + if not isinstance(raw_aliases, list): + return aliases + for item in raw_aliases: + if isinstance(item, dict): + text_val = item.get("text") + if not text_val or text_val in seen: + continue + seen.add(text_val) + aliases.append({"text": text_val, "tone": item.get("tone")}) + elif isinstance(item, str): + if item in seen: + continue + seen.add(item) + aliases.append({"text": item}) + return aliases + + +def _normalize_str_list(values: Any) -> List[str]: + if not values: + return [] + if not isinstance(values, list): + return [] + seen: set[str] = set() + normalised: List[str] = [] + for val in values: + if not isinstance(val, str): + continue + if val in seen: + continue + seen.add(val) + normalised.append(val) + return normalised + + +def _merge_alias_lists(primary: List[Dict[str, Any]], secondary: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + merged: List[Dict[str, Any]] = [] + seen: set[str] = set() + for source in (primary, secondary): + for item in source: + if not isinstance(item, dict): + continue + text_val = item.get("text") + if not text_val or text_val in seen: + continue + seen.add(text_val) + merged.append({"text": text_val, "tone": item.get("tone")}) + return merged + + +def _merge_str_lists(primary: List[str], secondary: List[str]) -> List[str]: + merged: List[str] = [] + seen: set[str] = set() + for source in (primary, secondary): + for item in source: + if item in seen: + continue + seen.add(item) + merged.append(item) + return merged + + +def _build_alias_map(alias_payload: List[Any]) -> Dict[str, Dict[str, Any]]: + alias_map: Dict[str, Dict[str, Any]] = {} + for item in alias_payload: + if not isinstance(item, dict): + continue + alias_id = item.get("id") + if not alias_id: + continue + existing = alias_map.setdefault( + alias_id, + {"aliases": [], "keywords": [], "intent_tags": []}, + ) + existing["aliases"] = _merge_alias_lists( + existing["aliases"], _normalize_aliases(item.get("aliases")) + ) + existing["keywords"] = _merge_str_lists( + existing["keywords"], _normalize_str_list(item.get("keywords")) + ) + existing["intent_tags"] = _merge_str_lists( + existing["intent_tags"], _normalize_str_list(item.get("intent_tags")) + ) + return alias_map + + +def merge_snippet_records_from_db( + table_id: int, + version_ts: int, + *, + engine: Optional[Engine] = None, +) -> List[Dict[str, Any]]: + """ + Load snippet + snippet_alias JSON from action_results after snippet_alias is stored, + then merge into a unified snippet object list ready for downstream RAG. + """ + engine = engine or get_engine() + snippets, aliases, updated_at, alias_action_id, snippet_action_id = _load_snippet_sources( + engine, table_id, version_ts + ) + alias_map = _build_alias_map(aliases) + + merged: List[Dict[str, Any]] = [] + seen_ids: set[str] = set() + + for snippet in snippets: + if not isinstance(snippet, dict): + continue + snippet_id = snippet.get("id") + if not snippet_id: + continue + alias_info = alias_map.get(snippet_id) + record = dict(snippet) + record_aliases = _normalize_aliases(record.get("aliases")) + record_keywords = _normalize_str_list(record.get("keywords")) + record_intents = _normalize_str_list(record.get("intent_tags")) + + if alias_info: + record_aliases = _merge_alias_lists(record_aliases, alias_info["aliases"]) + record_keywords = _merge_str_lists(record_keywords, alias_info["keywords"]) + record_intents = _merge_str_lists(record_intents, alias_info["intent_tags"]) + + record["aliases"] = record_aliases + record["keywords"] = record_keywords + record["intent_tags"] = record_intents + record["table_id"] = table_id + record["version_ts"] = version_ts + record["updated_at_from_action"] = updated_at + record["source"] = "snippet" + record["action_result_id"] = alias_action_id or snippet_action_id + merged.append(record) + seen_ids.add(snippet_id) + + for alias_id, alias_info in alias_map.items(): + if alias_id in seen_ids: + continue + merged.append( + { + "id": alias_id, + "aliases": alias_info["aliases"], + "keywords": alias_info["keywords"], + "intent_tags": alias_info["intent_tags"], + "table_id": table_id, + "version_ts": version_ts, + "updated_at_from_action": updated_at, + "source": "alias_only", + "action_result_id": alias_action_id or snippet_action_id, + } + ) + + return merged + + +def _stable_rag_item_id(table_id: int, version_ts: int, snippet_id: str) -> int: + digest = hashlib.md5(f"{table_id}:{version_ts}:{snippet_id}".encode("utf-8")).hexdigest() + return int(digest[:16], 16) % 9_000_000_000_000_000_000 + + +def _build_rag_text(snippet: Dict[str, Any]) -> str: + # Deterministic text concatenation for embedding input. + parts: List[str] = [] + + def _add(label: str, value: Any) -> None: + if value is None: + return + if isinstance(value, list): + value = ", ".join([str(v) for v in value if v]) + elif isinstance(value, dict): + value = json.dumps(value, ensure_ascii=False) + if value: + parts.append(f"{label}: {value}") + + _add("Title", snippet.get("title") or snippet.get("id")) + _add("Description", snippet.get("desc")) + _add("Business", snippet.get("business_caliber")) + _add("Type", snippet.get("type")) + _add("Examples", snippet.get("examples") or []) + _add("Aliases", [a.get("text") for a in snippet.get("aliases") or [] if isinstance(a, dict)]) + _add("Keywords", snippet.get("keywords") or []) + _add("IntentTags", snippet.get("intent_tags") or []) + _add("Applicability", snippet.get("applicability")) + _add("DialectSQL", snippet.get("dialect_sql")) + return "\n".join(parts) + + +def _prepare_rag_payloads( + snippets: List[Dict[str, Any]], + table_id: int, + version_ts: int, + workspace_id: int, + rag_item_type: str = "SNIPPET", +) -> Tuple[List[Dict[str, Any]], List[RagItemPayload]]: + rows: List[Dict[str, Any]] = [] + payloads: List[RagItemPayload] = [] + now = datetime.utcnow() + + for snippet in snippets: + snippet_id = snippet.get("id") + if not snippet_id: + continue + action_result_id = snippet.get("action_result_id") + if action_result_id is None: + logger.warning( + "Skipping snippet without action_result_id for RAG ingestion (table_id=%s version_ts=%s snippet_id=%s)", + table_id, + version_ts, + snippet_id, + ) + continue + rag_item_id = _stable_rag_item_id(table_id, version_ts, snippet_id) + rag_text = _build_rag_text(snippet) + merged_json = json.dumps(snippet, ensure_ascii=False) + updated_at_raw = snippet.get("updated_at_from_action") or now + if isinstance(updated_at_raw, str): + try: + updated_at = datetime.fromisoformat(updated_at_raw) + except ValueError: + updated_at = now + else: + updated_at = updated_at_raw if isinstance(updated_at_raw, datetime) else now + + row = { + "rag_item_id": rag_item_id, + "workspace_id": workspace_id, + "table_id": table_id, + "version_ts": version_ts, + "action_result_id": action_result_id, + "snippet_id": snippet_id, + "rag_text": rag_text, + "merged_json": merged_json, + "updated_at": updated_at, + } + rows.append(row) + + payloads.append( + RagItemPayload( + id=rag_item_id, + workspaceId=workspace_id, + name=snippet.get("title") or snippet_id, + embeddingData=rag_text, + type=rag_item_type or "SNIPPET", + ) + ) + + return rows, payloads + + +def _upsert_rag_snippet_rows(engine: Engine, rows: Sequence[Dict[str, Any]]) -> None: + if not rows: + return + delete_sql = text("DELETE FROM rag_snippet WHERE rag_item_id=:rag_item_id") + insert_sql = text( + """ + INSERT INTO rag_snippet ( + rag_item_id, + workspace_id, + table_id, + version_ts, + action_result_id, + snippet_id, + rag_text, + merged_json, + updated_at + ) VALUES ( + :rag_item_id, + :workspace_id, + :table_id, + :version_ts, + :action_result_id, + :snippet_id, + :rag_text, + :merged_json, + :updated_at + ) + """ + ) + with engine.begin() as conn: + for row in rows: + conn.execute(delete_sql, row) + conn.execute(insert_sql, row) + + +async def ingest_snippet_rag_from_db( + table_id: int, + version_ts: int, + *, + workspace_id: int, + rag_item_type: str = "SNIPPET", + client, + engine: Optional[Engine] = None, + rag_client: Optional[RagAPIClient] = None, +) -> List[int]: + """ + Merge snippet + alias JSON from action_results, persist to rag_snippet, then push to RAG via addBatch. + Returns list of rag_item_id ingested. + """ + engine = engine or get_engine() + snippets = merge_snippet_records_from_db(table_id, version_ts, engine=engine) + if not snippets: + logger.info( + "No snippets available for RAG ingestion (table_id=%s version_ts=%s)", + table_id, + version_ts, + ) + return [] + + rows, payloads = _prepare_rag_payloads( + snippets, + table_id=table_id, + version_ts=version_ts, + workspace_id=workspace_id, + rag_item_type=rag_item_type, + ) + + _upsert_rag_snippet_rows(engine, rows) + + rag_client = rag_client or RagAPIClient() + await rag_client.add_batch(client, payloads) + return [row["rag_item_id"] for row in rows] diff --git a/doc/rag-api.md b/doc/rag-api.md new file mode 100644 index 0000000..5da7984 --- /dev/null +++ b/doc/rag-api.md @@ -0,0 +1,57 @@ +#添加RAG +curl --location --request POST 'http://127.0.0.1:8000/rag/add' \ +--header 'Content-Type: application/json' \ +--header 'Authorization: Bearer ' \ +--data-raw '{ + "id": 0, + "workspaceId": 0, + "name": "string", + "embeddingData": "string", + "type": "METRIC" +}' + +#批量添加RAG +curl --location --request POST 'http://127.0.0.1:8000/rag/addBatch' \ +--header 'Content-Type: application/json' \ +--header 'Authorization: Bearer ' \ +--data-raw '[ + { + "id": 0, + "workspaceId": 0, + "name": "string", + "embeddingData": "string", + "type": "METRIC" + } +]' + +#更新RAG +curl --location --request POST 'http://127.0.0.1:8000/rag/update' \ +--header 'Content-Type: application/json' \ +--header 'Authorization: Bearer ' \ +--data-raw '{ + "id": 0, + "workspaceId": 0, + "name": "string", + "embeddingData": "string", + "type": "METRIC" +}' + +#删除RAG +curl --location --request POST 'http://127.0.0.1:8000/rag/delete' \ +--header 'Content-Type: application/json' \ +--header 'Authorization: Bearer ' \ +--data-raw '{ + "id": 0, + "type": "METRIC" +}' + +#检索RAG +curl --location --request POST 'http://127.0.0.1:8000/rag/retrieve' \ +--header 'Content-Type: application/json' \ +--header 'Authorization: Bearer ' \ +--data-raw '{ + "query": "string", + "num": 0, + "workspaceId": 0, + "type": "METRIC" +}' \ No newline at end of file diff --git a/file/tableschema/rag_snippet.sql b/file/tableschema/rag_snippet.sql new file mode 100644 index 0000000..e6f3c5a --- /dev/null +++ b/file/tableschema/rag_snippet.sql @@ -0,0 +1,15 @@ +CREATE TABLE `rag_snippet` ( + `rag_item_id` bigint NOT NULL COMMENT 'RAG item id (stable hash of table/version/snippet_id)', + `workspace_id` bigint NOT NULL COMMENT 'RAG workspace scope', + `table_id` bigint NOT NULL COMMENT '来源表ID', + `version_ts` bigint NOT NULL COMMENT '表版本号', + `action_result_id` bigint NOT NULL COMMENT '来源 action_results 主键ID(snippet_alias 或 snippet 行)', + `snippet_id` varchar(255) COLLATE utf8mb4_bin NOT NULL COMMENT '原始 snippet id', + `rag_text` text COLLATE utf8mb4_bin NOT NULL COMMENT '用于向量化的拼接文本', + `merged_json` json NOT NULL COMMENT '合并后的 snippet 对象', + `updated_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`rag_item_id`), + KEY `idx_action_result` (`action_result_id`), + KEY `idx_workspace` (`workspace_id`), + KEY `idx_table_version` (`table_id`,`version_ts`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin COMMENT='RAG snippet 索引缓存'; diff --git a/test/test_metrics_api_mysql.py b/test/test_metrics_api_mysql.py new file mode 100644 index 0000000..520a4cf --- /dev/null +++ b/test/test_metrics_api_mysql.py @@ -0,0 +1,207 @@ +from __future__ import annotations + +import os +import random +from datetime import datetime, timedelta +from pathlib import Path +from typing import Generator, List + +import pytest +from fastapi.testclient import TestClient +from sqlalchemy import text +from sqlalchemy.exc import SQLAlchemyError + +# Ensure project root on path for direct execution +ROOT = Path(__file__).resolve().parents[1] +import sys +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + +from app import db +from app.main import create_app + + +TEST_USER_ID = 98765 +#SCHEMA_PATH = Path("file/tableschema/metrics.sql") +DEFAULT_MYSQL_URL = "mysql+pymysql://root:12345678@127.0.0.1:3306/data-ge?charset=utf8mb4" + + +# def _run_sql_script(engine, sql_text: str) -> None: +# """Execute semicolon-terminated SQL statements sequentially.""" +# statements: List[str] = [] +# buffer: List[str] = [] +# for line in sql_text.splitlines(): +# stripped = line.strip() +# if not stripped or stripped.startswith("--"): +# continue +# buffer.append(line) +# if stripped.endswith(";"): +# statements.append("\n".join(buffer).rstrip(";")) +# buffer = [] +# if buffer: +# statements.append("\n".join(buffer)) +# with engine.begin() as conn: +# for stmt in statements: +# conn.execute(text(stmt)) + + +# def _ensure_metric_schema(engine) -> None: +# if not SCHEMA_PATH.exists(): +# pytest.skip("metrics.sql schema file not found.") +# raw_sql = SCHEMA_PATH.read_text(encoding="utf-8") +# raw_sql = raw_sql.replace("CREATE TABLE metric_def", "CREATE TABLE IF NOT EXISTS metric_def") +# raw_sql = raw_sql.replace("CREATE TABLE metric_schedule", "CREATE TABLE IF NOT EXISTS metric_schedule") +# raw_sql = raw_sql.replace("CREATE TABLE metric_job_run", "CREATE TABLE IF NOT EXISTS metric_job_run") +# raw_sql = raw_sql.replace("CREATE TABLE metric_result", "CREATE TABLE IF NOT EXISTS metric_result") +# _run_sql_script(engine, raw_sql) + + +@pytest.fixture(scope="module") +def client() -> Generator[TestClient, None, None]: + mysql_url = os.getenv("TEST_DATABASE_URL", DEFAULT_MYSQL_URL) + os.environ["DATABASE_URL"] = mysql_url + db.get_engine.cache_clear() + engine = db.get_engine() + try: + with engine.connect() as conn: + conn.execute(text("SELECT 1")) + except SQLAlchemyError: + pytest.skip(f"Cannot connect to MySQL at {mysql_url}") + + #_ensure_metric_schema(engine) + + app = create_app() + with TestClient(app) as test_client: + yield test_client + + # cleanup test artifacts + with engine.begin() as conn: + conn.execute(text("DELETE FROM metric_result WHERE metric_id IN (SELECT id FROM metric_def WHERE created_by=:uid)"), {"uid": TEST_USER_ID}) + conn.execute(text("DELETE FROM metric_job_run WHERE metric_id IN (SELECT id FROM metric_def WHERE created_by=:uid)"), {"uid": TEST_USER_ID}) + conn.execute(text("DELETE FROM metric_schedule WHERE metric_id IN (SELECT id FROM metric_def WHERE created_by=:uid)"), {"uid": TEST_USER_ID}) + conn.execute(text("DELETE FROM metric_def WHERE created_by=:uid"), {"uid": TEST_USER_ID}) + db.get_engine.cache_clear() + + +def test_metric_crud_and_schedule_mysql(client: TestClient) -> None: + code = f"metric_{random.randint(1000,9999)}" + create_payload = { + "metric_code": code, + "metric_name": "订单数", + "biz_domain": "order", + "biz_desc": "订单总数", + "base_sql": "select count(*) as order_cnt from orders", + "time_grain": "DAY", + "dim_binding": ["dt"], + "update_strategy": "FULL", + "metric_aliases": ["订单量"], + "created_by": TEST_USER_ID, + } + resp = client.post("/api/v1/metrics", json=create_payload) + assert resp.status_code == 200, resp.text + metric = resp.json() + metric_id = metric["id"] + assert metric["metric_code"] == code + + # Update metric + resp = client.post(f"/api/v1/metrics/{metric_id}", json={"metric_name": "订单数-更新", "is_active": False}) + assert resp.status_code == 200 + assert resp.json()["is_active"] is False + + # Get metric + resp = client.get(f"/api/v1/metrics/{metric_id}") + assert resp.status_code == 200 + assert resp.json()["metric_name"] == "订单数-更新" + + # Create schedule + resp = client.post( + "/api/v1/metric-schedules", + json={"metric_id": metric_id, "cron_expr": "0 2 * * *", "priority": 5, "enabled": True}, + ) + assert resp.status_code == 200, resp.text + schedule = resp.json() + schedule_id = schedule["id"] + + # Update schedule + resp = client.post(f"/api/v1/metric-schedules/{schedule_id}", json={"enabled": False, "retry_times": 1}) + assert resp.status_code == 200 + assert resp.json()["enabled"] is False + + # List schedules for metric + resp = client.get(f"/api/v1/metrics/{metric_id}/schedules") + assert resp.status_code == 200 + assert any(s["id"] == schedule_id for s in resp.json()) + + +def test_metric_runs_and_results_mysql(client: TestClient) -> None: + code = f"gmv_{random.randint(1000,9999)}" + metric_id = client.post( + "/api/v1/metrics", + json={ + "metric_code": code, + "metric_name": "GMV", + "biz_domain": "order", + "base_sql": "select sum(pay_amount) as gmv from orders", + "time_grain": "DAY", + "dim_binding": ["dt"], + "update_strategy": "FULL", + "created_by": TEST_USER_ID, + }, + ).json()["id"] + + # Trigger run + resp = client.post( + "/api/v1/metric-runs/trigger", + json={ + "metric_id": metric_id, + "triggered_by": "API", + "data_time_from": (datetime.utcnow() - timedelta(days=1)).isoformat(), + "data_time_to": datetime.utcnow().isoformat(), + }, + ) + assert resp.status_code == 200, resp.text + run = resp.json() + run_id = run["id"] + assert run["status"] == "RUNNING" + + # List runs + resp = client.get("/api/v1/metric-runs", params={"metric_id": metric_id}) + assert resp.status_code == 200 + assert any(r["id"] == run_id for r in resp.json()) + + # Get run + resp = client.get(f"/api/v1/metric-runs/{run_id}") + assert resp.status_code == 200 + + # Write results + now = datetime.utcnow() + resp = client.post( + f"/api/v1/metric-results/{metric_id}", + json={ + "metric_id": metric_id, + "results": [ + {"stat_time": (now - timedelta(days=1)).isoformat(), "metric_value": 123.45, "data_version": run_id}, + {"stat_time": now.isoformat(), "metric_value": 234.56, "data_version": run_id}, + ], + }, + ) + assert resp.status_code == 200, resp.text + assert resp.json()["inserted"] == 2 + + # Query results + resp = client.get("/api/v1/metric-results", params={"metric_id": metric_id}) + assert resp.status_code == 200 + results = resp.json() + assert len(results) >= 2 + + # Latest result + resp = client.get("/api/v1/metric-results/latest", params={"metric_id": metric_id}) + assert resp.status_code == 200 + latest = resp.json() + assert float(latest["metric_value"]) in {123.45, 234.56} + + +if __name__ == "__main__": + import pytest as _pytest + + raise SystemExit(_pytest.main([__file__])) diff --git a/test/test_snippet_rag_ingest.py b/test/test_snippet_rag_ingest.py new file mode 100644 index 0000000..1668ead --- /dev/null +++ b/test/test_snippet_rag_ingest.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +import json +from datetime import datetime + +import httpx +import pytest +from sqlalchemy import create_engine, text + +from app.services.table_snippet import ingest_snippet_rag_from_db + + +def _setup_sqlite_engine(): + engine = create_engine("sqlite://") + with engine.begin() as conn: + conn.execute( + text( + """ + CREATE TABLE action_results ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + table_id INTEGER, + version_ts INTEGER, + action_type TEXT, + status TEXT, + snippet_json TEXT, + snippet_alias_json TEXT, + updated_at TEXT + ) + """ + ) + ) + conn.execute( + text( + """ + CREATE TABLE rag_snippet ( + rag_item_id INTEGER PRIMARY KEY, + action_result_id INTEGER NOT NULL, + workspace_id INTEGER, + table_id INTEGER, + version_ts INTEGER, + snippet_id TEXT, + rag_text TEXT, + merged_json TEXT, + updated_at TEXT + ) + """ + ) + ) + return engine + + +def _insert_action_row(engine, payload: dict) -> None: + with engine.begin() as conn: + conn.execute( + text( + """ + INSERT INTO action_results (table_id, version_ts, action_type, status, snippet_json, snippet_alias_json, updated_at) + VALUES (:table_id, :version_ts, :action_type, :status, :snippet_json, :snippet_alias_json, :updated_at) + """ + ), + { + "table_id": payload["table_id"], + "version_ts": payload["version_ts"], + "action_type": payload["action_type"], + "status": payload.get("status", "success"), + "snippet_json": json.dumps(payload.get("snippet_json"), ensure_ascii=False) + if payload.get("snippet_json") is not None + else None, + "snippet_alias_json": json.dumps(payload.get("snippet_alias_json"), ensure_ascii=False) + if payload.get("snippet_alias_json") is not None + else None, + "updated_at": payload.get("updated_at") or datetime.utcnow().isoformat(), + }, + ) + + +class _StubRagClient: + def __init__(self) -> None: + self.received = None + + async def add_batch(self, _client, items): + self.received = items + return {"count": len(items)} + + +@pytest.mark.asyncio +async def test_ingest_snippet_rag_from_db_persists_and_calls_rag_client() -> None: + engine = _setup_sqlite_engine() + table_id = 321 + version_ts = 20240102000000 + + snippet_payload = [ + { + "id": "snpt_topn", + "title": "TopN", + "aliases": [{"text": "站点水表排行前N", "tone": "中性"}], + "keywords": ["TopN", "站点"], + } + ] + alias_payload = [ + { + "id": "snpt_topn", + "aliases": [ + {"text": "站点水表排行前N", "tone": "中性"}, + {"text": "按站点水表TopN", "tone": "专业"}, + ], + "keywords": ["TopN", "排行"], + "intent_tags": ["topn", "aggregate"], + }, + { + "id": "snpt_extra", + "aliases": [{"text": "额外别名"}], + "keywords": ["extra"], + }, + ] + + _insert_action_row( + engine, + { + "table_id": table_id, + "version_ts": version_ts, + "action_type": "snippet_alias", + "snippet_json": snippet_payload, + "snippet_alias_json": alias_payload, + "updated_at": "2024-01-02T00:00:00", + }, + ) + + rag_stub = _StubRagClient() + async with httpx.AsyncClient() as client: + rag_ids = await ingest_snippet_rag_from_db( + table_id=table_id, + version_ts=version_ts, + workspace_id=99, + rag_item_type="SNIPPET", + client=client, + engine=engine, + rag_client=rag_stub, + ) + + assert rag_stub.received is not None + assert len(rag_stub.received) == 2 # includes alias-only row + assert len(rag_ids) == 2 + + with engine.connect() as conn: + rows = list( + conn.execute( + text("SELECT snippet_id, action_result_id, rag_text, merged_json FROM rag_snippet ORDER BY snippet_id") + ) + ) + assert {row[0] for row in rows} == {"snpt_extra", "snpt_topn"} + assert all(row[1] is not None for row in rows) + topn_row = next(row for row in rows if row[0] == "snpt_topn") + assert "TopN" in topn_row[2] + assert "按站点水表TopN" in topn_row[2] + assert "排行" in topn_row[2] diff --git a/test/test_table_snippet_merge.py b/test/test_table_snippet_merge.py new file mode 100644 index 0000000..b0c1a3e --- /dev/null +++ b/test/test_table_snippet_merge.py @@ -0,0 +1,213 @@ +from __future__ import annotations + +import json +import os +import random +from datetime import datetime, timedelta +from typing import List +from pathlib import Path + +import sys +import pytest +from sqlalchemy import text +from sqlalchemy.engine import Engine +from sqlalchemy.exc import SQLAlchemyError + +# Ensure the project root is importable when running directly via python. +ROOT = Path(__file__).resolve().parents[1] +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + +from app import db +from app.main import create_app + + +from app.services.table_snippet import merge_snippet_records_from_db + + +DEFAULT_MYSQL_URL = "mysql+pymysql://root:12345678@127.0.0.1:3306/data-ge?charset=utf8mb4" + + +@pytest.fixture() +def mysql_engine() -> Engine: + mysql_url = os.getenv("TEST_DATABASE_URL", DEFAULT_MYSQL_URL) + os.environ["DATABASE_URL"] = mysql_url + db.get_engine.cache_clear() + engine = db.get_engine() + try: + with engine.connect() as conn: + conn.execute(text("SELECT 1")) + exists = conn.execute(text("SHOW TABLES LIKE 'action_results'")).scalar() + if not exists: + pytest.skip("action_results table not found in test database.") + except SQLAlchemyError: + pytest.skip(f"Cannot connect to MySQL at {mysql_url}") + return engine + + +def _insert_action_row( + engine: Engine, + *, + table_id: int, + version_ts: int, + action_type: str, + status: str = "success", + snippet_json: List[dict] | None = None, + snippet_alias_json: List[dict] | None = None, + updated_at: datetime | None = None, +) -> None: + snippet_json_str = json.dumps(snippet_json, ensure_ascii=False) if snippet_json is not None else None + snippet_alias_json_str = ( + json.dumps(snippet_alias_json, ensure_ascii=False) if snippet_alias_json is not None else None + ) + with engine.begin() as conn: + conn.execute( + text( + """ + INSERT INTO action_results ( + table_id, version_ts, action_type, status, + callback_url, table_schema_version_id, table_schema, + snippet_json, snippet_alias_json, updated_at + ) VALUES ( + :table_id, :version_ts, :action_type, :status, + :callback_url, :table_schema_version_id, :table_schema, + :snippet_json, :snippet_alias_json, :updated_at + ) + ON DUPLICATE KEY UPDATE + status=VALUES(status), + snippet_json=VALUES(snippet_json), + snippet_alias_json=VALUES(snippet_alias_json), + updated_at=VALUES(updated_at) + """ + ), + { + "table_id": table_id, + "version_ts": version_ts, + "action_type": action_type, + "status": status, + "callback_url": "http://localhost/test-callback", + "table_schema_version_id": "1", + "table_schema": json.dumps({}, ensure_ascii=False), + "snippet_json": snippet_json_str, + "snippet_alias_json": snippet_alias_json_str, + "updated_at": updated_at or datetime.utcnow(), + }, + ) + + +def _cleanup(engine: Engine, table_id: int, version_ts: int) -> None: + with engine.begin() as conn: + conn.execute( + text("DELETE FROM action_results WHERE table_id=:table_id AND version_ts=:version_ts"), + {"table_id": table_id, "version_ts": version_ts}, + ) + + +def test_merge_prefers_alias_row_and_appends_alias_only_entries(mysql_engine: Engine) -> None: + table_id = 990000000 + random.randint(1, 9999) + version_ts = int(datetime.utcnow().strftime("%Y%m%d%H%M%S")) + alias_updated = datetime(2024, 1, 2, 0, 0, 0) + + snippet_payload = [ + { + "id": "snpt_topn", + "aliases": [{"text": "站点水表排行前N", "tone": "中性"}], + "keywords": ["TopN", "站点"], + } + ] + alias_payload = [ + { + "id": "snpt_topn", + "aliases": [ + {"text": "站点水表排行前N", "tone": "中性"}, + {"text": "按站点水表TopN", "tone": "专业"}, + ], + "keywords": ["TopN", "排行"], + "intent_tags": ["topn", "aggregate"], + }, + { + "id": "snpt_extra", + "aliases": [{"text": "额外别名"}], + "keywords": ["extra"], + }, + ] + + _insert_action_row( + mysql_engine, + table_id=table_id, + version_ts=version_ts, + action_type="snippet_alias", + snippet_json=snippet_payload, + snippet_alias_json=alias_payload, + updated_at=alias_updated, + ) + + try: + merged = merge_snippet_records_from_db(table_id, version_ts, engine=mysql_engine) + assert len(merged) == 2 + topn = next(item for item in merged if item["id"] == "snpt_topn") + assert topn["source"] == "snippet" + assert topn["updated_at_from_action"] == alias_updated + assert {a["text"] for a in topn["aliases"]} == {"站点水表排行前N", "按站点水表TopN"} + assert set(topn["keywords"]) == {"TopN", "站点", "排行"} + assert set(topn["intent_tags"]) == {"topn", "aggregate"} + + alias_only = next(item for item in merged if item["source"] == "alias_only") + assert alias_only["id"] == "snpt_extra" + assert alias_only["aliases"][0]["text"] == "额外别名" + finally: + _cleanup(mysql_engine, table_id, version_ts) + + +def test_merge_falls_back_to_snippet_row_when_alias_row_missing_snippet_json(mysql_engine: Engine) -> None: + table_id = 991000000 + random.randint(1, 9999) + version_ts = int((datetime.utcnow() + timedelta(seconds=1)).strftime("%Y%m%d%H%M%S")) + + alias_updated = datetime(2024, 1, 3, 0, 0, 0) + alias_payload = [ + { + "id": "snpt_quality", + "aliases": [{"text": "质量检查"}], + "keywords": ["quality"], + } + ] + snippet_payload = [ + { + "id": "snpt_quality", + "title": "质量检查", + "keywords": ["data-quality"], + "aliases": [{"text": "质量检查"}], + } + ] + + _insert_action_row( + mysql_engine, + table_id=table_id, + version_ts=version_ts, + action_type="snippet_alias", + snippet_json=None, + snippet_alias_json=alias_payload, + updated_at=alias_updated, + ) + _insert_action_row( + mysql_engine, + table_id=table_id, + version_ts=version_ts, + action_type="snippet", + snippet_json=snippet_payload, + snippet_alias_json=None, + updated_at=datetime(2024, 1, 2, 0, 0, 0), + ) + + try: + merged = merge_snippet_records_from_db(table_id, version_ts, engine=mysql_engine) + + assert len(merged) == 1 + record = merged[0] + assert record["id"] == "snpt_quality" + assert record["source"] == "snippet" + assert record["updated_at_from_action"] == alias_updated + assert set(record["keywords"]) == {"data-quality", "quality"} + assert {a["text"] for a in record["aliases"]} == {"质量检查"} + finally: + _cleanup(mysql_engine, table_id, version_ts)