rag snippet生成入库和写rag
This commit is contained in:
3
.env
3
.env
@ -16,6 +16,9 @@ DEFAULT_IMPORT_MODEL=deepseek:deepseek-chat
|
|||||||
# Service configuration
|
# Service configuration
|
||||||
IMPORT_GATEWAY_BASE_URL=http://localhost:8000
|
IMPORT_GATEWAY_BASE_URL=http://localhost:8000
|
||||||
|
|
||||||
|
# prod nbackend base url
|
||||||
|
NBACKEND_BASE_URL=https://chatbi.agentcarrier.cn/chatbi/api
|
||||||
|
|
||||||
# HTTP client configuration
|
# HTTP client configuration
|
||||||
HTTP_CLIENT_TIMEOUT=120
|
HTTP_CLIENT_TIMEOUT=120
|
||||||
HTTP_CLIENT_TRUST_ENV=false
|
HTTP_CLIENT_TRUST_ENV=false
|
||||||
|
|||||||
34
app/main.py
34
app/main.py
@ -16,6 +16,8 @@ from fastapi.responses import JSONResponse
|
|||||||
|
|
||||||
from app.exceptions import ProviderAPICallError, ProviderConfigurationError
|
from app.exceptions import ProviderAPICallError, ProviderConfigurationError
|
||||||
from app.models import (
|
from app.models import (
|
||||||
|
ActionStatus,
|
||||||
|
ActionType,
|
||||||
DataImportAnalysisJobAck,
|
DataImportAnalysisJobAck,
|
||||||
DataImportAnalysisJobRequest,
|
DataImportAnalysisJobRequest,
|
||||||
LLMRequest,
|
LLMRequest,
|
||||||
@ -25,10 +27,11 @@ from app.models import (
|
|||||||
TableSnippetUpsertRequest,
|
TableSnippetUpsertRequest,
|
||||||
TableSnippetUpsertResponse,
|
TableSnippetUpsertResponse,
|
||||||
)
|
)
|
||||||
|
from app.routers import chat_router, metrics_router
|
||||||
from app.services import LLMGateway
|
from app.services import LLMGateway
|
||||||
from app.services.import_analysis import process_import_analysis_job
|
from app.services.import_analysis import process_import_analysis_job
|
||||||
from app.services.table_profiling import process_table_profiling_job
|
from app.services.table_profiling import process_table_profiling_job
|
||||||
from app.services.table_snippet import upsert_action_result
|
from app.services.table_snippet import ingest_snippet_rag_from_db, upsert_action_result
|
||||||
|
|
||||||
|
|
||||||
def _ensure_log_directories(config: dict[str, Any]) -> None:
|
def _ensure_log_directories(config: dict[str, Any]) -> None:
|
||||||
@ -135,6 +138,9 @@ def create_app() -> FastAPI:
|
|||||||
version="0.1.0",
|
version="0.1.0",
|
||||||
lifespan=lifespan,
|
lifespan=lifespan,
|
||||||
)
|
)
|
||||||
|
# Chat/metric management APIs
|
||||||
|
application.include_router(chat_router)
|
||||||
|
application.include_router(metrics_router)
|
||||||
|
|
||||||
@application.exception_handler(RequestValidationError)
|
@application.exception_handler(RequestValidationError)
|
||||||
async def request_validation_exception_handler(
|
async def request_validation_exception_handler(
|
||||||
@ -230,11 +236,12 @@ def create_app() -> FastAPI:
|
|||||||
)
|
)
|
||||||
async def upsert_table_snippet(
|
async def upsert_table_snippet(
|
||||||
payload: TableSnippetUpsertRequest,
|
payload: TableSnippetUpsertRequest,
|
||||||
|
client: httpx.AsyncClient = Depends(get_http_client),
|
||||||
) -> TableSnippetUpsertResponse:
|
) -> TableSnippetUpsertResponse:
|
||||||
request_copy = payload.model_copy(deep=True)
|
request_copy = payload.model_copy(deep=True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await asyncio.to_thread(upsert_action_result, request_copy)
|
response = await asyncio.to_thread(upsert_action_result, request_copy)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Failed to upsert table snippet: table_id=%s version_ts=%s action_type=%s",
|
"Failed to upsert table snippet: table_id=%s version_ts=%s action_type=%s",
|
||||||
@ -244,6 +251,29 @@ def create_app() -> FastAPI:
|
|||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||||
|
else:
|
||||||
|
if (
|
||||||
|
payload.action_type == ActionType.SNIPPET_ALIAS
|
||||||
|
and payload.status == ActionStatus.SUCCESS
|
||||||
|
and payload.rag_workspace_id is not None
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
await ingest_snippet_rag_from_db(
|
||||||
|
table_id=payload.table_id,
|
||||||
|
version_ts=payload.version_ts,
|
||||||
|
workspace_id=payload.rag_workspace_id,
|
||||||
|
rag_item_type=payload.rag_item_type or "SNIPPET",
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Failed to ingest snippet RAG artifacts",
|
||||||
|
extra={
|
||||||
|
"table_id": payload.table_id,
|
||||||
|
"version_ts": payload.version_ts,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
@application.post("/__mock__/import-callback")
|
@application.post("/__mock__/import-callback")
|
||||||
async def mock_import_callback(payload: dict[str, Any]) -> dict[str, str]:
|
async def mock_import_callback(payload: dict[str, Any]) -> dict[str, str]:
|
||||||
|
|||||||
@ -247,6 +247,16 @@ class TableSnippetUpsertRequest(BaseModel):
|
|||||||
ge=0,
|
ge=0,
|
||||||
description="Version timestamp aligned with the pipeline (yyyyMMddHHmmss as integer).",
|
description="Version timestamp aligned with the pipeline (yyyyMMddHHmmss as integer).",
|
||||||
)
|
)
|
||||||
|
rag_workspace_id: Optional[int] = Field(
|
||||||
|
None,
|
||||||
|
ge=0,
|
||||||
|
description="Optional workspace identifier for RAG ingestion; when provided and action_type=snippet_alias "
|
||||||
|
"with status=success, merged snippets will be written to rag_snippet and pushed to RAG.",
|
||||||
|
)
|
||||||
|
rag_item_type: Optional[str] = Field(
|
||||||
|
"SNIPPET",
|
||||||
|
description="Optional RAG item type used when pushing snippets to RAG. Defaults to 'SNIPPET'.",
|
||||||
|
)
|
||||||
action_type: ActionType = Field(..., description="Pipeline action type for this record.")
|
action_type: ActionType = Field(..., description="Pipeline action type for this record.")
|
||||||
status: ActionStatus = Field(
|
status: ActionStatus = Field(
|
||||||
ActionStatus.SUCCESS, description="Execution status for the action."
|
ActionStatus.SUCCESS, description="Execution status for the action."
|
||||||
|
|||||||
46
app/schemas/rag.py
Normal file
46
app/schemas/rag.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
|
||||||
|
class RagItemPayload(BaseModel):
|
||||||
|
"""Payload for creating or updating a single RAG item."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(populate_by_name=True, extra="ignore")
|
||||||
|
|
||||||
|
id: int = Field(..., description="Unique identifier for the RAG item.")
|
||||||
|
workspace_id: int = Field(..., alias="workspaceId", description="Workspace identifier.")
|
||||||
|
name: str = Field(..., description="Readable name of the item.")
|
||||||
|
embedding_data: str = Field(..., alias="embeddingData", description="Serialized embedding payload.")
|
||||||
|
type: str = Field(..., description='Item type, e.g. "METRIC".')
|
||||||
|
|
||||||
|
|
||||||
|
class RagDeleteRequest(BaseModel):
|
||||||
|
"""Payload for deleting a single RAG item."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(populate_by_name=True, extra="ignore")
|
||||||
|
|
||||||
|
id: int = Field(..., description="Identifier of the item to delete.")
|
||||||
|
type: str = Field(..., description="Item type matching the stored record.")
|
||||||
|
|
||||||
|
|
||||||
|
class RagRetrieveRequest(BaseModel):
|
||||||
|
"""Payload for retrieving RAG items by semantic query."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(populate_by_name=True, extra="ignore")
|
||||||
|
|
||||||
|
query: str = Field(..., description="Search query text.")
|
||||||
|
num: int = Field(..., description="Number of items to return.")
|
||||||
|
workspace_id: int = Field(..., alias="workspaceId", description="Workspace scope for the search.")
|
||||||
|
type: str = Field(..., description="Item type to search, e.g. METRIC.")
|
||||||
|
|
||||||
|
|
||||||
|
class RagRetrieveResponse(BaseModel):
|
||||||
|
"""Generic RAG retrieval response wrapper."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
data: List[Any] = Field(default_factory=list, description="Retrieved items.")
|
||||||
|
|
||||||
83
app/services/rag_client.py
Normal file
83
app/services/rag_client.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Sequence
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from app.exceptions import ProviderAPICallError
|
||||||
|
from app.schemas.rag import RagDeleteRequest, RagItemPayload, RagRetrieveRequest
|
||||||
|
from app.settings import RAG_API_AUTH_TOKEN, RAG_API_BASE_URL
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RagAPIClient:
|
||||||
|
"""Thin async client wrapper around the RAG endpoints described in doc/rag-api.md."""
|
||||||
|
|
||||||
|
def __init__(self, *, base_url: str | None = None, auth_token: str | None = None) -> None:
|
||||||
|
resolved_base = base_url or RAG_API_BASE_URL
|
||||||
|
self._base_url = resolved_base.rstrip("/")
|
||||||
|
self._auth_token = auth_token or RAG_API_AUTH_TOKEN
|
||||||
|
|
||||||
|
def _headers(self) -> dict[str, str]:
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
if self._auth_token:
|
||||||
|
headers["Authorization"] = f"Bearer {self._auth_token}"
|
||||||
|
return headers
|
||||||
|
|
||||||
|
async def _post(
|
||||||
|
self,
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
path: str,
|
||||||
|
payload: Any,
|
||||||
|
) -> Any:
|
||||||
|
url = f"{self._base_url}{path}"
|
||||||
|
try:
|
||||||
|
response = await client.post(url, json=payload, headers=self._headers())
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as exc:
|
||||||
|
status_code = exc.response.status_code if exc.response else None
|
||||||
|
response_text = exc.response.text if exc.response else ""
|
||||||
|
logger.error(
|
||||||
|
"RAG API responded with an error (%s) for %s: %s",
|
||||||
|
status_code,
|
||||||
|
url,
|
||||||
|
response_text,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise ProviderAPICallError(
|
||||||
|
"RAG API call failed.",
|
||||||
|
status_code=status_code,
|
||||||
|
response_text=response_text,
|
||||||
|
) from exc
|
||||||
|
except httpx.HTTPError as exc:
|
||||||
|
logger.error("Transport error calling RAG API %s: %s", url, exc, exc_info=True)
|
||||||
|
raise ProviderAPICallError(f"RAG API call failed: {exc}") from exc
|
||||||
|
|
||||||
|
try:
|
||||||
|
return response.json()
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("RAG API returned non-JSON response for %s; returning raw text.", url)
|
||||||
|
return {"raw": response.text}
|
||||||
|
|
||||||
|
async def add(self, client: httpx.AsyncClient, payload: RagItemPayload) -> Any:
|
||||||
|
body = payload.model_dump(by_alias=True, exclude_none=True)
|
||||||
|
return await self._post(client, "/rag/add", body)
|
||||||
|
|
||||||
|
async def add_batch(self, client: httpx.AsyncClient, items: Sequence[RagItemPayload]) -> Any:
|
||||||
|
body = [item.model_dump(by_alias=True, exclude_none=True) for item in items]
|
||||||
|
return await self._post(client, "/rag/addBatch", body)
|
||||||
|
|
||||||
|
async def update(self, client: httpx.AsyncClient, payload: RagItemPayload) -> Any:
|
||||||
|
body = payload.model_dump(by_alias=True, exclude_none=True)
|
||||||
|
return await self._post(client, "/rag/update", body)
|
||||||
|
|
||||||
|
async def delete(self, client: httpx.AsyncClient, payload: RagDeleteRequest) -> Any:
|
||||||
|
body = payload.model_dump(by_alias=True, exclude_none=True)
|
||||||
|
return await self._post(client, "/rag/delete", body)
|
||||||
|
|
||||||
|
async def retrieve(self, client: httpx.AsyncClient, payload: RagRetrieveRequest) -> Any:
|
||||||
|
body = payload.model_dump(by_alias=True, exclude_none=True)
|
||||||
|
return await self._post(client, "/rag/retrieve", body)
|
||||||
@ -1,19 +1,19 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Tuple
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
||||||
from app.db import get_engine
|
from app.db import get_engine
|
||||||
from app.models import (
|
from app.models import ActionType, TableSnippetUpsertRequest, TableSnippetUpsertResponse
|
||||||
ActionType,
|
from app.schemas.rag import RagItemPayload
|
||||||
TableSnippetUpsertRequest,
|
from app.services.rag_client import RagAPIClient
|
||||||
TableSnippetUpsertResponse,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -46,6 +46,7 @@ def _prepare_model_params(params: Dict[str, Any] | None) -> str | None:
|
|||||||
|
|
||||||
|
|
||||||
def _collect_common_columns(request: TableSnippetUpsertRequest) -> Dict[str, Any]:
|
def _collect_common_columns(request: TableSnippetUpsertRequest) -> Dict[str, Any]:
|
||||||
|
# Build the base column set shared by all action types; action-specific fields are populated later.
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Collecting common columns for table_id=%s version_ts=%s action_type=%s",
|
"Collecting common columns for table_id=%s version_ts=%s action_type=%s",
|
||||||
request.table_id,
|
request.table_id,
|
||||||
@ -215,3 +216,405 @@ def upsert_action_result(request: TableSnippetUpsertRequest) -> TableSnippetUpse
|
|||||||
status=request.status,
|
status=request.status,
|
||||||
updated=updated,
|
updated=updated,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_json_field(value: Any) -> Any:
|
||||||
|
"""Decode JSON columns that may be returned as str/bytes/dicts/lists."""
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
if isinstance(value, (dict, list)):
|
||||||
|
return value
|
||||||
|
if isinstance(value, (bytes, bytearray)):
|
||||||
|
try:
|
||||||
|
value = value.decode("utf-8")
|
||||||
|
except Exception: # pragma: no cover - defensive
|
||||||
|
return None
|
||||||
|
if isinstance(value, str):
|
||||||
|
try:
|
||||||
|
return json.loads(value)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning("Failed to decode JSON field: %s", value)
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_json_array(value: Any) -> List[Any]:
|
||||||
|
decoded = _decode_json_field(value)
|
||||||
|
return decoded if isinstance(decoded, list) else []
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_action_payload(
|
||||||
|
engine: Engine, table_id: int, version_ts: int, action_type: ActionType
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
sql = text(
|
||||||
|
"""
|
||||||
|
SELECT id AS action_result_id, snippet_json, snippet_alias_json, updated_at, status
|
||||||
|
FROM action_results
|
||||||
|
WHERE table_id = :table_id
|
||||||
|
AND version_ts = :version_ts
|
||||||
|
AND action_type = :action_type
|
||||||
|
AND status IN ('success', 'partial')
|
||||||
|
ORDER BY CASE status WHEN 'success' THEN 0 ELSE 1 END, updated_at DESC
|
||||||
|
LIMIT 1
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
with engine.connect() as conn:
|
||||||
|
row = conn.execute(
|
||||||
|
sql,
|
||||||
|
{
|
||||||
|
"table_id": table_id,
|
||||||
|
"version_ts": version_ts,
|
||||||
|
"action_type": action_type.value,
|
||||||
|
},
|
||||||
|
).mappings().first()
|
||||||
|
return dict(row) if row else None
|
||||||
|
|
||||||
|
|
||||||
|
def _load_snippet_sources(
|
||||||
|
engine: Engine, table_id: int, version_ts: int
|
||||||
|
) -> Tuple[List[Any], List[Any], Optional[datetime], Optional[int], Optional[int]]:
|
||||||
|
alias_row = _fetch_action_payload(engine, table_id, version_ts, ActionType.SNIPPET_ALIAS)
|
||||||
|
snippet_row = _fetch_action_payload(engine, table_id, version_ts, ActionType.SNIPPET)
|
||||||
|
|
||||||
|
snippet_json = _coerce_json_array(alias_row.get("snippet_json") if alias_row else None)
|
||||||
|
alias_json = _coerce_json_array(alias_row.get("snippet_alias_json") if alias_row else None)
|
||||||
|
updated_at: Optional[datetime] = alias_row.get("updated_at") if alias_row else None
|
||||||
|
alias_action_id: Optional[int] = alias_row.get("action_result_id") if alias_row else None
|
||||||
|
snippet_action_id: Optional[int] = snippet_row.get("action_result_id") if snippet_row else None
|
||||||
|
|
||||||
|
if not snippet_json and snippet_row:
|
||||||
|
snippet_json = _coerce_json_array(snippet_row.get("snippet_json"))
|
||||||
|
if updated_at is None:
|
||||||
|
updated_at = snippet_row.get("updated_at")
|
||||||
|
if alias_action_id is None:
|
||||||
|
alias_action_id = snippet_action_id
|
||||||
|
|
||||||
|
if not updated_at and alias_row:
|
||||||
|
updated_at = alias_row.get("updated_at")
|
||||||
|
|
||||||
|
return snippet_json, alias_json, updated_at, alias_action_id, snippet_action_id
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_aliases(raw_aliases: Any) -> List[Dict[str, Any]]:
|
||||||
|
aliases: List[Dict[str, Any]] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
if not raw_aliases:
|
||||||
|
return aliases
|
||||||
|
if not isinstance(raw_aliases, list):
|
||||||
|
return aliases
|
||||||
|
for item in raw_aliases:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
text_val = item.get("text")
|
||||||
|
if not text_val or text_val in seen:
|
||||||
|
continue
|
||||||
|
seen.add(text_val)
|
||||||
|
aliases.append({"text": text_val, "tone": item.get("tone")})
|
||||||
|
elif isinstance(item, str):
|
||||||
|
if item in seen:
|
||||||
|
continue
|
||||||
|
seen.add(item)
|
||||||
|
aliases.append({"text": item})
|
||||||
|
return aliases
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_str_list(values: Any) -> List[str]:
|
||||||
|
if not values:
|
||||||
|
return []
|
||||||
|
if not isinstance(values, list):
|
||||||
|
return []
|
||||||
|
seen: set[str] = set()
|
||||||
|
normalised: List[str] = []
|
||||||
|
for val in values:
|
||||||
|
if not isinstance(val, str):
|
||||||
|
continue
|
||||||
|
if val in seen:
|
||||||
|
continue
|
||||||
|
seen.add(val)
|
||||||
|
normalised.append(val)
|
||||||
|
return normalised
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_alias_lists(primary: List[Dict[str, Any]], secondary: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||||
|
merged: List[Dict[str, Any]] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for source in (primary, secondary):
|
||||||
|
for item in source:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
text_val = item.get("text")
|
||||||
|
if not text_val or text_val in seen:
|
||||||
|
continue
|
||||||
|
seen.add(text_val)
|
||||||
|
merged.append({"text": text_val, "tone": item.get("tone")})
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_str_lists(primary: List[str], secondary: List[str]) -> List[str]:
|
||||||
|
merged: List[str] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for source in (primary, secondary):
|
||||||
|
for item in source:
|
||||||
|
if item in seen:
|
||||||
|
continue
|
||||||
|
seen.add(item)
|
||||||
|
merged.append(item)
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
|
def _build_alias_map(alias_payload: List[Any]) -> Dict[str, Dict[str, Any]]:
|
||||||
|
alias_map: Dict[str, Dict[str, Any]] = {}
|
||||||
|
for item in alias_payload:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
alias_id = item.get("id")
|
||||||
|
if not alias_id:
|
||||||
|
continue
|
||||||
|
existing = alias_map.setdefault(
|
||||||
|
alias_id,
|
||||||
|
{"aliases": [], "keywords": [], "intent_tags": []},
|
||||||
|
)
|
||||||
|
existing["aliases"] = _merge_alias_lists(
|
||||||
|
existing["aliases"], _normalize_aliases(item.get("aliases"))
|
||||||
|
)
|
||||||
|
existing["keywords"] = _merge_str_lists(
|
||||||
|
existing["keywords"], _normalize_str_list(item.get("keywords"))
|
||||||
|
)
|
||||||
|
existing["intent_tags"] = _merge_str_lists(
|
||||||
|
existing["intent_tags"], _normalize_str_list(item.get("intent_tags"))
|
||||||
|
)
|
||||||
|
return alias_map
|
||||||
|
|
||||||
|
|
||||||
|
def merge_snippet_records_from_db(
|
||||||
|
table_id: int,
|
||||||
|
version_ts: int,
|
||||||
|
*,
|
||||||
|
engine: Optional[Engine] = None,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Load snippet + snippet_alias JSON from action_results after snippet_alias is stored,
|
||||||
|
then merge into a unified snippet object list ready for downstream RAG.
|
||||||
|
"""
|
||||||
|
engine = engine or get_engine()
|
||||||
|
snippets, aliases, updated_at, alias_action_id, snippet_action_id = _load_snippet_sources(
|
||||||
|
engine, table_id, version_ts
|
||||||
|
)
|
||||||
|
alias_map = _build_alias_map(aliases)
|
||||||
|
|
||||||
|
merged: List[Dict[str, Any]] = []
|
||||||
|
seen_ids: set[str] = set()
|
||||||
|
|
||||||
|
for snippet in snippets:
|
||||||
|
if not isinstance(snippet, dict):
|
||||||
|
continue
|
||||||
|
snippet_id = snippet.get("id")
|
||||||
|
if not snippet_id:
|
||||||
|
continue
|
||||||
|
alias_info = alias_map.get(snippet_id)
|
||||||
|
record = dict(snippet)
|
||||||
|
record_aliases = _normalize_aliases(record.get("aliases"))
|
||||||
|
record_keywords = _normalize_str_list(record.get("keywords"))
|
||||||
|
record_intents = _normalize_str_list(record.get("intent_tags"))
|
||||||
|
|
||||||
|
if alias_info:
|
||||||
|
record_aliases = _merge_alias_lists(record_aliases, alias_info["aliases"])
|
||||||
|
record_keywords = _merge_str_lists(record_keywords, alias_info["keywords"])
|
||||||
|
record_intents = _merge_str_lists(record_intents, alias_info["intent_tags"])
|
||||||
|
|
||||||
|
record["aliases"] = record_aliases
|
||||||
|
record["keywords"] = record_keywords
|
||||||
|
record["intent_tags"] = record_intents
|
||||||
|
record["table_id"] = table_id
|
||||||
|
record["version_ts"] = version_ts
|
||||||
|
record["updated_at_from_action"] = updated_at
|
||||||
|
record["source"] = "snippet"
|
||||||
|
record["action_result_id"] = alias_action_id or snippet_action_id
|
||||||
|
merged.append(record)
|
||||||
|
seen_ids.add(snippet_id)
|
||||||
|
|
||||||
|
for alias_id, alias_info in alias_map.items():
|
||||||
|
if alias_id in seen_ids:
|
||||||
|
continue
|
||||||
|
merged.append(
|
||||||
|
{
|
||||||
|
"id": alias_id,
|
||||||
|
"aliases": alias_info["aliases"],
|
||||||
|
"keywords": alias_info["keywords"],
|
||||||
|
"intent_tags": alias_info["intent_tags"],
|
||||||
|
"table_id": table_id,
|
||||||
|
"version_ts": version_ts,
|
||||||
|
"updated_at_from_action": updated_at,
|
||||||
|
"source": "alias_only",
|
||||||
|
"action_result_id": alias_action_id or snippet_action_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
|
def _stable_rag_item_id(table_id: int, version_ts: int, snippet_id: str) -> int:
|
||||||
|
digest = hashlib.md5(f"{table_id}:{version_ts}:{snippet_id}".encode("utf-8")).hexdigest()
|
||||||
|
return int(digest[:16], 16) % 9_000_000_000_000_000_000
|
||||||
|
|
||||||
|
|
||||||
|
def _build_rag_text(snippet: Dict[str, Any]) -> str:
|
||||||
|
# Deterministic text concatenation for embedding input.
|
||||||
|
parts: List[str] = []
|
||||||
|
|
||||||
|
def _add(label: str, value: Any) -> None:
|
||||||
|
if value is None:
|
||||||
|
return
|
||||||
|
if isinstance(value, list):
|
||||||
|
value = ", ".join([str(v) for v in value if v])
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
value = json.dumps(value, ensure_ascii=False)
|
||||||
|
if value:
|
||||||
|
parts.append(f"{label}: {value}")
|
||||||
|
|
||||||
|
_add("Title", snippet.get("title") or snippet.get("id"))
|
||||||
|
_add("Description", snippet.get("desc"))
|
||||||
|
_add("Business", snippet.get("business_caliber"))
|
||||||
|
_add("Type", snippet.get("type"))
|
||||||
|
_add("Examples", snippet.get("examples") or [])
|
||||||
|
_add("Aliases", [a.get("text") for a in snippet.get("aliases") or [] if isinstance(a, dict)])
|
||||||
|
_add("Keywords", snippet.get("keywords") or [])
|
||||||
|
_add("IntentTags", snippet.get("intent_tags") or [])
|
||||||
|
_add("Applicability", snippet.get("applicability"))
|
||||||
|
_add("DialectSQL", snippet.get("dialect_sql"))
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_rag_payloads(
|
||||||
|
snippets: List[Dict[str, Any]],
|
||||||
|
table_id: int,
|
||||||
|
version_ts: int,
|
||||||
|
workspace_id: int,
|
||||||
|
rag_item_type: str = "SNIPPET",
|
||||||
|
) -> Tuple[List[Dict[str, Any]], List[RagItemPayload]]:
|
||||||
|
rows: List[Dict[str, Any]] = []
|
||||||
|
payloads: List[RagItemPayload] = []
|
||||||
|
now = datetime.utcnow()
|
||||||
|
|
||||||
|
for snippet in snippets:
|
||||||
|
snippet_id = snippet.get("id")
|
||||||
|
if not snippet_id:
|
||||||
|
continue
|
||||||
|
action_result_id = snippet.get("action_result_id")
|
||||||
|
if action_result_id is None:
|
||||||
|
logger.warning(
|
||||||
|
"Skipping snippet without action_result_id for RAG ingestion (table_id=%s version_ts=%s snippet_id=%s)",
|
||||||
|
table_id,
|
||||||
|
version_ts,
|
||||||
|
snippet_id,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
rag_item_id = _stable_rag_item_id(table_id, version_ts, snippet_id)
|
||||||
|
rag_text = _build_rag_text(snippet)
|
||||||
|
merged_json = json.dumps(snippet, ensure_ascii=False)
|
||||||
|
updated_at_raw = snippet.get("updated_at_from_action") or now
|
||||||
|
if isinstance(updated_at_raw, str):
|
||||||
|
try:
|
||||||
|
updated_at = datetime.fromisoformat(updated_at_raw)
|
||||||
|
except ValueError:
|
||||||
|
updated_at = now
|
||||||
|
else:
|
||||||
|
updated_at = updated_at_raw if isinstance(updated_at_raw, datetime) else now
|
||||||
|
|
||||||
|
row = {
|
||||||
|
"rag_item_id": rag_item_id,
|
||||||
|
"workspace_id": workspace_id,
|
||||||
|
"table_id": table_id,
|
||||||
|
"version_ts": version_ts,
|
||||||
|
"action_result_id": action_result_id,
|
||||||
|
"snippet_id": snippet_id,
|
||||||
|
"rag_text": rag_text,
|
||||||
|
"merged_json": merged_json,
|
||||||
|
"updated_at": updated_at,
|
||||||
|
}
|
||||||
|
rows.append(row)
|
||||||
|
|
||||||
|
payloads.append(
|
||||||
|
RagItemPayload(
|
||||||
|
id=rag_item_id,
|
||||||
|
workspaceId=workspace_id,
|
||||||
|
name=snippet.get("title") or snippet_id,
|
||||||
|
embeddingData=rag_text,
|
||||||
|
type=rag_item_type or "SNIPPET",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return rows, payloads
|
||||||
|
|
||||||
|
|
||||||
|
def _upsert_rag_snippet_rows(engine: Engine, rows: Sequence[Dict[str, Any]]) -> None:
|
||||||
|
if not rows:
|
||||||
|
return
|
||||||
|
delete_sql = text("DELETE FROM rag_snippet WHERE rag_item_id=:rag_item_id")
|
||||||
|
insert_sql = text(
|
||||||
|
"""
|
||||||
|
INSERT INTO rag_snippet (
|
||||||
|
rag_item_id,
|
||||||
|
workspace_id,
|
||||||
|
table_id,
|
||||||
|
version_ts,
|
||||||
|
action_result_id,
|
||||||
|
snippet_id,
|
||||||
|
rag_text,
|
||||||
|
merged_json,
|
||||||
|
updated_at
|
||||||
|
) VALUES (
|
||||||
|
:rag_item_id,
|
||||||
|
:workspace_id,
|
||||||
|
:table_id,
|
||||||
|
:version_ts,
|
||||||
|
:action_result_id,
|
||||||
|
:snippet_id,
|
||||||
|
:rag_text,
|
||||||
|
:merged_json,
|
||||||
|
:updated_at
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
with engine.begin() as conn:
|
||||||
|
for row in rows:
|
||||||
|
conn.execute(delete_sql, row)
|
||||||
|
conn.execute(insert_sql, row)
|
||||||
|
|
||||||
|
|
||||||
|
async def ingest_snippet_rag_from_db(
|
||||||
|
table_id: int,
|
||||||
|
version_ts: int,
|
||||||
|
*,
|
||||||
|
workspace_id: int,
|
||||||
|
rag_item_type: str = "SNIPPET",
|
||||||
|
client,
|
||||||
|
engine: Optional[Engine] = None,
|
||||||
|
rag_client: Optional[RagAPIClient] = None,
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Merge snippet + alias JSON from action_results, persist to rag_snippet, then push to RAG via addBatch.
|
||||||
|
Returns list of rag_item_id ingested.
|
||||||
|
"""
|
||||||
|
engine = engine or get_engine()
|
||||||
|
snippets = merge_snippet_records_from_db(table_id, version_ts, engine=engine)
|
||||||
|
if not snippets:
|
||||||
|
logger.info(
|
||||||
|
"No snippets available for RAG ingestion (table_id=%s version_ts=%s)",
|
||||||
|
table_id,
|
||||||
|
version_ts,
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
rows, payloads = _prepare_rag_payloads(
|
||||||
|
snippets,
|
||||||
|
table_id=table_id,
|
||||||
|
version_ts=version_ts,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
rag_item_type=rag_item_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
_upsert_rag_snippet_rows(engine, rows)
|
||||||
|
|
||||||
|
rag_client = rag_client or RagAPIClient()
|
||||||
|
await rag_client.add_batch(client, payloads)
|
||||||
|
return [row["rag_item_id"] for row in rows]
|
||||||
|
|||||||
57
doc/rag-api.md
Normal file
57
doc/rag-api.md
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
#添加RAG
|
||||||
|
curl --location --request POST 'http://127.0.0.1:8000/rag/add' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--header 'Authorization: Bearer ' \
|
||||||
|
--data-raw '{
|
||||||
|
"id": 0,
|
||||||
|
"workspaceId": 0,
|
||||||
|
"name": "string",
|
||||||
|
"embeddingData": "string",
|
||||||
|
"type": "METRIC"
|
||||||
|
}'
|
||||||
|
|
||||||
|
#批量添加RAG
|
||||||
|
curl --location --request POST 'http://127.0.0.1:8000/rag/addBatch' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--header 'Authorization: Bearer ' \
|
||||||
|
--data-raw '[
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"workspaceId": 0,
|
||||||
|
"name": "string",
|
||||||
|
"embeddingData": "string",
|
||||||
|
"type": "METRIC"
|
||||||
|
}
|
||||||
|
]'
|
||||||
|
|
||||||
|
#更新RAG
|
||||||
|
curl --location --request POST 'http://127.0.0.1:8000/rag/update' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--header 'Authorization: Bearer ' \
|
||||||
|
--data-raw '{
|
||||||
|
"id": 0,
|
||||||
|
"workspaceId": 0,
|
||||||
|
"name": "string",
|
||||||
|
"embeddingData": "string",
|
||||||
|
"type": "METRIC"
|
||||||
|
}'
|
||||||
|
|
||||||
|
#删除RAG
|
||||||
|
curl --location --request POST 'http://127.0.0.1:8000/rag/delete' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--header 'Authorization: Bearer ' \
|
||||||
|
--data-raw '{
|
||||||
|
"id": 0,
|
||||||
|
"type": "METRIC"
|
||||||
|
}'
|
||||||
|
|
||||||
|
#检索RAG
|
||||||
|
curl --location --request POST 'http://127.0.0.1:8000/rag/retrieve' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--header 'Authorization: Bearer ' \
|
||||||
|
--data-raw '{
|
||||||
|
"query": "string",
|
||||||
|
"num": 0,
|
||||||
|
"workspaceId": 0,
|
||||||
|
"type": "METRIC"
|
||||||
|
}'
|
||||||
15
file/tableschema/rag_snippet.sql
Normal file
15
file/tableschema/rag_snippet.sql
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
CREATE TABLE `rag_snippet` (
|
||||||
|
`rag_item_id` bigint NOT NULL COMMENT 'RAG item id (stable hash of table/version/snippet_id)',
|
||||||
|
`workspace_id` bigint NOT NULL COMMENT 'RAG workspace scope',
|
||||||
|
`table_id` bigint NOT NULL COMMENT '来源表ID',
|
||||||
|
`version_ts` bigint NOT NULL COMMENT '表版本号',
|
||||||
|
`action_result_id` bigint NOT NULL COMMENT '来源 action_results 主键ID(snippet_alias 或 snippet 行)',
|
||||||
|
`snippet_id` varchar(255) COLLATE utf8mb4_bin NOT NULL COMMENT '原始 snippet id',
|
||||||
|
`rag_text` text COLLATE utf8mb4_bin NOT NULL COMMENT '用于向量化的拼接文本',
|
||||||
|
`merged_json` json NOT NULL COMMENT '合并后的 snippet 对象',
|
||||||
|
`updated_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||||
|
PRIMARY KEY (`rag_item_id`),
|
||||||
|
KEY `idx_action_result` (`action_result_id`),
|
||||||
|
KEY `idx_workspace` (`workspace_id`),
|
||||||
|
KEY `idx_table_version` (`table_id`,`version_ts`)
|
||||||
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin COMMENT='RAG snippet 索引缓存';
|
||||||
207
test/test_metrics_api_mysql.py
Normal file
207
test/test_metrics_api_mysql.py
Normal file
@ -0,0 +1,207 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Generator, List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
||||||
|
# Ensure project root on path for direct execution
|
||||||
|
ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
import sys
|
||||||
|
if str(ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(ROOT))
|
||||||
|
|
||||||
|
from app import db
|
||||||
|
from app.main import create_app
|
||||||
|
|
||||||
|
|
||||||
|
TEST_USER_ID = 98765
|
||||||
|
#SCHEMA_PATH = Path("file/tableschema/metrics.sql")
|
||||||
|
DEFAULT_MYSQL_URL = "mysql+pymysql://root:12345678@127.0.0.1:3306/data-ge?charset=utf8mb4"
|
||||||
|
|
||||||
|
|
||||||
|
# def _run_sql_script(engine, sql_text: str) -> None:
|
||||||
|
# """Execute semicolon-terminated SQL statements sequentially."""
|
||||||
|
# statements: List[str] = []
|
||||||
|
# buffer: List[str] = []
|
||||||
|
# for line in sql_text.splitlines():
|
||||||
|
# stripped = line.strip()
|
||||||
|
# if not stripped or stripped.startswith("--"):
|
||||||
|
# continue
|
||||||
|
# buffer.append(line)
|
||||||
|
# if stripped.endswith(";"):
|
||||||
|
# statements.append("\n".join(buffer).rstrip(";"))
|
||||||
|
# buffer = []
|
||||||
|
# if buffer:
|
||||||
|
# statements.append("\n".join(buffer))
|
||||||
|
# with engine.begin() as conn:
|
||||||
|
# for stmt in statements:
|
||||||
|
# conn.execute(text(stmt))
|
||||||
|
|
||||||
|
|
||||||
|
# def _ensure_metric_schema(engine) -> None:
|
||||||
|
# if not SCHEMA_PATH.exists():
|
||||||
|
# pytest.skip("metrics.sql schema file not found.")
|
||||||
|
# raw_sql = SCHEMA_PATH.read_text(encoding="utf-8")
|
||||||
|
# raw_sql = raw_sql.replace("CREATE TABLE metric_def", "CREATE TABLE IF NOT EXISTS metric_def")
|
||||||
|
# raw_sql = raw_sql.replace("CREATE TABLE metric_schedule", "CREATE TABLE IF NOT EXISTS metric_schedule")
|
||||||
|
# raw_sql = raw_sql.replace("CREATE TABLE metric_job_run", "CREATE TABLE IF NOT EXISTS metric_job_run")
|
||||||
|
# raw_sql = raw_sql.replace("CREATE TABLE metric_result", "CREATE TABLE IF NOT EXISTS metric_result")
|
||||||
|
# _run_sql_script(engine, raw_sql)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def client() -> Generator[TestClient, None, None]:
|
||||||
|
mysql_url = os.getenv("TEST_DATABASE_URL", DEFAULT_MYSQL_URL)
|
||||||
|
os.environ["DATABASE_URL"] = mysql_url
|
||||||
|
db.get_engine.cache_clear()
|
||||||
|
engine = db.get_engine()
|
||||||
|
try:
|
||||||
|
with engine.connect() as conn:
|
||||||
|
conn.execute(text("SELECT 1"))
|
||||||
|
except SQLAlchemyError:
|
||||||
|
pytest.skip(f"Cannot connect to MySQL at {mysql_url}")
|
||||||
|
|
||||||
|
#_ensure_metric_schema(engine)
|
||||||
|
|
||||||
|
app = create_app()
|
||||||
|
with TestClient(app) as test_client:
|
||||||
|
yield test_client
|
||||||
|
|
||||||
|
# cleanup test artifacts
|
||||||
|
with engine.begin() as conn:
|
||||||
|
conn.execute(text("DELETE FROM metric_result WHERE metric_id IN (SELECT id FROM metric_def WHERE created_by=:uid)"), {"uid": TEST_USER_ID})
|
||||||
|
conn.execute(text("DELETE FROM metric_job_run WHERE metric_id IN (SELECT id FROM metric_def WHERE created_by=:uid)"), {"uid": TEST_USER_ID})
|
||||||
|
conn.execute(text("DELETE FROM metric_schedule WHERE metric_id IN (SELECT id FROM metric_def WHERE created_by=:uid)"), {"uid": TEST_USER_ID})
|
||||||
|
conn.execute(text("DELETE FROM metric_def WHERE created_by=:uid"), {"uid": TEST_USER_ID})
|
||||||
|
db.get_engine.cache_clear()
|
||||||
|
|
||||||
|
|
||||||
|
def test_metric_crud_and_schedule_mysql(client: TestClient) -> None:
|
||||||
|
code = f"metric_{random.randint(1000,9999)}"
|
||||||
|
create_payload = {
|
||||||
|
"metric_code": code,
|
||||||
|
"metric_name": "订单数",
|
||||||
|
"biz_domain": "order",
|
||||||
|
"biz_desc": "订单总数",
|
||||||
|
"base_sql": "select count(*) as order_cnt from orders",
|
||||||
|
"time_grain": "DAY",
|
||||||
|
"dim_binding": ["dt"],
|
||||||
|
"update_strategy": "FULL",
|
||||||
|
"metric_aliases": ["订单量"],
|
||||||
|
"created_by": TEST_USER_ID,
|
||||||
|
}
|
||||||
|
resp = client.post("/api/v1/metrics", json=create_payload)
|
||||||
|
assert resp.status_code == 200, resp.text
|
||||||
|
metric = resp.json()
|
||||||
|
metric_id = metric["id"]
|
||||||
|
assert metric["metric_code"] == code
|
||||||
|
|
||||||
|
# Update metric
|
||||||
|
resp = client.post(f"/api/v1/metrics/{metric_id}", json={"metric_name": "订单数-更新", "is_active": False})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["is_active"] is False
|
||||||
|
|
||||||
|
# Get metric
|
||||||
|
resp = client.get(f"/api/v1/metrics/{metric_id}")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["metric_name"] == "订单数-更新"
|
||||||
|
|
||||||
|
# Create schedule
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/metric-schedules",
|
||||||
|
json={"metric_id": metric_id, "cron_expr": "0 2 * * *", "priority": 5, "enabled": True},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200, resp.text
|
||||||
|
schedule = resp.json()
|
||||||
|
schedule_id = schedule["id"]
|
||||||
|
|
||||||
|
# Update schedule
|
||||||
|
resp = client.post(f"/api/v1/metric-schedules/{schedule_id}", json={"enabled": False, "retry_times": 1})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["enabled"] is False
|
||||||
|
|
||||||
|
# List schedules for metric
|
||||||
|
resp = client.get(f"/api/v1/metrics/{metric_id}/schedules")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert any(s["id"] == schedule_id for s in resp.json())
|
||||||
|
|
||||||
|
|
||||||
|
def test_metric_runs_and_results_mysql(client: TestClient) -> None:
|
||||||
|
code = f"gmv_{random.randint(1000,9999)}"
|
||||||
|
metric_id = client.post(
|
||||||
|
"/api/v1/metrics",
|
||||||
|
json={
|
||||||
|
"metric_code": code,
|
||||||
|
"metric_name": "GMV",
|
||||||
|
"biz_domain": "order",
|
||||||
|
"base_sql": "select sum(pay_amount) as gmv from orders",
|
||||||
|
"time_grain": "DAY",
|
||||||
|
"dim_binding": ["dt"],
|
||||||
|
"update_strategy": "FULL",
|
||||||
|
"created_by": TEST_USER_ID,
|
||||||
|
},
|
||||||
|
).json()["id"]
|
||||||
|
|
||||||
|
# Trigger run
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/metric-runs/trigger",
|
||||||
|
json={
|
||||||
|
"metric_id": metric_id,
|
||||||
|
"triggered_by": "API",
|
||||||
|
"data_time_from": (datetime.utcnow() - timedelta(days=1)).isoformat(),
|
||||||
|
"data_time_to": datetime.utcnow().isoformat(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200, resp.text
|
||||||
|
run = resp.json()
|
||||||
|
run_id = run["id"]
|
||||||
|
assert run["status"] == "RUNNING"
|
||||||
|
|
||||||
|
# List runs
|
||||||
|
resp = client.get("/api/v1/metric-runs", params={"metric_id": metric_id})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert any(r["id"] == run_id for r in resp.json())
|
||||||
|
|
||||||
|
# Get run
|
||||||
|
resp = client.get(f"/api/v1/metric-runs/{run_id}")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
# Write results
|
||||||
|
now = datetime.utcnow()
|
||||||
|
resp = client.post(
|
||||||
|
f"/api/v1/metric-results/{metric_id}",
|
||||||
|
json={
|
||||||
|
"metric_id": metric_id,
|
||||||
|
"results": [
|
||||||
|
{"stat_time": (now - timedelta(days=1)).isoformat(), "metric_value": 123.45, "data_version": run_id},
|
||||||
|
{"stat_time": now.isoformat(), "metric_value": 234.56, "data_version": run_id},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200, resp.text
|
||||||
|
assert resp.json()["inserted"] == 2
|
||||||
|
|
||||||
|
# Query results
|
||||||
|
resp = client.get("/api/v1/metric-results", params={"metric_id": metric_id})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
results = resp.json()
|
||||||
|
assert len(results) >= 2
|
||||||
|
|
||||||
|
# Latest result
|
||||||
|
resp = client.get("/api/v1/metric-results/latest", params={"metric_id": metric_id})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
latest = resp.json()
|
||||||
|
assert float(latest["metric_value"]) in {123.45, 234.56}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import pytest as _pytest
|
||||||
|
|
||||||
|
raise SystemExit(_pytest.main([__file__]))
|
||||||
156
test/test_snippet_rag_ingest.py
Normal file
156
test/test_snippet_rag_ingest.py
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import create_engine, text
|
||||||
|
|
||||||
|
from app.services.table_snippet import ingest_snippet_rag_from_db
|
||||||
|
|
||||||
|
|
||||||
|
def _setup_sqlite_engine():
|
||||||
|
engine = create_engine("sqlite://")
|
||||||
|
with engine.begin() as conn:
|
||||||
|
conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
CREATE TABLE action_results (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
table_id INTEGER,
|
||||||
|
version_ts INTEGER,
|
||||||
|
action_type TEXT,
|
||||||
|
status TEXT,
|
||||||
|
snippet_json TEXT,
|
||||||
|
snippet_alias_json TEXT,
|
||||||
|
updated_at TEXT
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
CREATE TABLE rag_snippet (
|
||||||
|
rag_item_id INTEGER PRIMARY KEY,
|
||||||
|
action_result_id INTEGER NOT NULL,
|
||||||
|
workspace_id INTEGER,
|
||||||
|
table_id INTEGER,
|
||||||
|
version_ts INTEGER,
|
||||||
|
snippet_id TEXT,
|
||||||
|
rag_text TEXT,
|
||||||
|
merged_json TEXT,
|
||||||
|
updated_at TEXT
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return engine
|
||||||
|
|
||||||
|
|
||||||
|
def _insert_action_row(engine, payload: dict) -> None:
|
||||||
|
with engine.begin() as conn:
|
||||||
|
conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO action_results (table_id, version_ts, action_type, status, snippet_json, snippet_alias_json, updated_at)
|
||||||
|
VALUES (:table_id, :version_ts, :action_type, :status, :snippet_json, :snippet_alias_json, :updated_at)
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"table_id": payload["table_id"],
|
||||||
|
"version_ts": payload["version_ts"],
|
||||||
|
"action_type": payload["action_type"],
|
||||||
|
"status": payload.get("status", "success"),
|
||||||
|
"snippet_json": json.dumps(payload.get("snippet_json"), ensure_ascii=False)
|
||||||
|
if payload.get("snippet_json") is not None
|
||||||
|
else None,
|
||||||
|
"snippet_alias_json": json.dumps(payload.get("snippet_alias_json"), ensure_ascii=False)
|
||||||
|
if payload.get("snippet_alias_json") is not None
|
||||||
|
else None,
|
||||||
|
"updated_at": payload.get("updated_at") or datetime.utcnow().isoformat(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _StubRagClient:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.received = None
|
||||||
|
|
||||||
|
async def add_batch(self, _client, items):
|
||||||
|
self.received = items
|
||||||
|
return {"count": len(items)}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ingest_snippet_rag_from_db_persists_and_calls_rag_client() -> None:
|
||||||
|
engine = _setup_sqlite_engine()
|
||||||
|
table_id = 321
|
||||||
|
version_ts = 20240102000000
|
||||||
|
|
||||||
|
snippet_payload = [
|
||||||
|
{
|
||||||
|
"id": "snpt_topn",
|
||||||
|
"title": "TopN",
|
||||||
|
"aliases": [{"text": "站点水表排行前N", "tone": "中性"}],
|
||||||
|
"keywords": ["TopN", "站点"],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
alias_payload = [
|
||||||
|
{
|
||||||
|
"id": "snpt_topn",
|
||||||
|
"aliases": [
|
||||||
|
{"text": "站点水表排行前N", "tone": "中性"},
|
||||||
|
{"text": "按站点水表TopN", "tone": "专业"},
|
||||||
|
],
|
||||||
|
"keywords": ["TopN", "排行"],
|
||||||
|
"intent_tags": ["topn", "aggregate"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "snpt_extra",
|
||||||
|
"aliases": [{"text": "额外别名"}],
|
||||||
|
"keywords": ["extra"],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
_insert_action_row(
|
||||||
|
engine,
|
||||||
|
{
|
||||||
|
"table_id": table_id,
|
||||||
|
"version_ts": version_ts,
|
||||||
|
"action_type": "snippet_alias",
|
||||||
|
"snippet_json": snippet_payload,
|
||||||
|
"snippet_alias_json": alias_payload,
|
||||||
|
"updated_at": "2024-01-02T00:00:00",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
rag_stub = _StubRagClient()
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
rag_ids = await ingest_snippet_rag_from_db(
|
||||||
|
table_id=table_id,
|
||||||
|
version_ts=version_ts,
|
||||||
|
workspace_id=99,
|
||||||
|
rag_item_type="SNIPPET",
|
||||||
|
client=client,
|
||||||
|
engine=engine,
|
||||||
|
rag_client=rag_stub,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert rag_stub.received is not None
|
||||||
|
assert len(rag_stub.received) == 2 # includes alias-only row
|
||||||
|
assert len(rag_ids) == 2
|
||||||
|
|
||||||
|
with engine.connect() as conn:
|
||||||
|
rows = list(
|
||||||
|
conn.execute(
|
||||||
|
text("SELECT snippet_id, action_result_id, rag_text, merged_json FROM rag_snippet ORDER BY snippet_id")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert {row[0] for row in rows} == {"snpt_extra", "snpt_topn"}
|
||||||
|
assert all(row[1] is not None for row in rows)
|
||||||
|
topn_row = next(row for row in rows if row[0] == "snpt_topn")
|
||||||
|
assert "TopN" in topn_row[2]
|
||||||
|
assert "按站点水表TopN" in topn_row[2]
|
||||||
|
assert "排行" in topn_row[2]
|
||||||
213
test/test_table_snippet_merge.py
Normal file
213
test/test_table_snippet_merge.py
Normal file
@ -0,0 +1,213 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import List
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
||||||
|
# Ensure the project root is importable when running directly via python.
|
||||||
|
ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
if str(ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(ROOT))
|
||||||
|
|
||||||
|
from app import db
|
||||||
|
from app.main import create_app
|
||||||
|
|
||||||
|
|
||||||
|
from app.services.table_snippet import merge_snippet_records_from_db
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_MYSQL_URL = "mysql+pymysql://root:12345678@127.0.0.1:3306/data-ge?charset=utf8mb4"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def mysql_engine() -> Engine:
|
||||||
|
mysql_url = os.getenv("TEST_DATABASE_URL", DEFAULT_MYSQL_URL)
|
||||||
|
os.environ["DATABASE_URL"] = mysql_url
|
||||||
|
db.get_engine.cache_clear()
|
||||||
|
engine = db.get_engine()
|
||||||
|
try:
|
||||||
|
with engine.connect() as conn:
|
||||||
|
conn.execute(text("SELECT 1"))
|
||||||
|
exists = conn.execute(text("SHOW TABLES LIKE 'action_results'")).scalar()
|
||||||
|
if not exists:
|
||||||
|
pytest.skip("action_results table not found in test database.")
|
||||||
|
except SQLAlchemyError:
|
||||||
|
pytest.skip(f"Cannot connect to MySQL at {mysql_url}")
|
||||||
|
return engine
|
||||||
|
|
||||||
|
|
||||||
|
def _insert_action_row(
|
||||||
|
engine: Engine,
|
||||||
|
*,
|
||||||
|
table_id: int,
|
||||||
|
version_ts: int,
|
||||||
|
action_type: str,
|
||||||
|
status: str = "success",
|
||||||
|
snippet_json: List[dict] | None = None,
|
||||||
|
snippet_alias_json: List[dict] | None = None,
|
||||||
|
updated_at: datetime | None = None,
|
||||||
|
) -> None:
|
||||||
|
snippet_json_str = json.dumps(snippet_json, ensure_ascii=False) if snippet_json is not None else None
|
||||||
|
snippet_alias_json_str = (
|
||||||
|
json.dumps(snippet_alias_json, ensure_ascii=False) if snippet_alias_json is not None else None
|
||||||
|
)
|
||||||
|
with engine.begin() as conn:
|
||||||
|
conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO action_results (
|
||||||
|
table_id, version_ts, action_type, status,
|
||||||
|
callback_url, table_schema_version_id, table_schema,
|
||||||
|
snippet_json, snippet_alias_json, updated_at
|
||||||
|
) VALUES (
|
||||||
|
:table_id, :version_ts, :action_type, :status,
|
||||||
|
:callback_url, :table_schema_version_id, :table_schema,
|
||||||
|
:snippet_json, :snippet_alias_json, :updated_at
|
||||||
|
)
|
||||||
|
ON DUPLICATE KEY UPDATE
|
||||||
|
status=VALUES(status),
|
||||||
|
snippet_json=VALUES(snippet_json),
|
||||||
|
snippet_alias_json=VALUES(snippet_alias_json),
|
||||||
|
updated_at=VALUES(updated_at)
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"table_id": table_id,
|
||||||
|
"version_ts": version_ts,
|
||||||
|
"action_type": action_type,
|
||||||
|
"status": status,
|
||||||
|
"callback_url": "http://localhost/test-callback",
|
||||||
|
"table_schema_version_id": "1",
|
||||||
|
"table_schema": json.dumps({}, ensure_ascii=False),
|
||||||
|
"snippet_json": snippet_json_str,
|
||||||
|
"snippet_alias_json": snippet_alias_json_str,
|
||||||
|
"updated_at": updated_at or datetime.utcnow(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _cleanup(engine: Engine, table_id: int, version_ts: int) -> None:
|
||||||
|
with engine.begin() as conn:
|
||||||
|
conn.execute(
|
||||||
|
text("DELETE FROM action_results WHERE table_id=:table_id AND version_ts=:version_ts"),
|
||||||
|
{"table_id": table_id, "version_ts": version_ts},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_prefers_alias_row_and_appends_alias_only_entries(mysql_engine: Engine) -> None:
|
||||||
|
table_id = 990000000 + random.randint(1, 9999)
|
||||||
|
version_ts = int(datetime.utcnow().strftime("%Y%m%d%H%M%S"))
|
||||||
|
alias_updated = datetime(2024, 1, 2, 0, 0, 0)
|
||||||
|
|
||||||
|
snippet_payload = [
|
||||||
|
{
|
||||||
|
"id": "snpt_topn",
|
||||||
|
"aliases": [{"text": "站点水表排行前N", "tone": "中性"}],
|
||||||
|
"keywords": ["TopN", "站点"],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
alias_payload = [
|
||||||
|
{
|
||||||
|
"id": "snpt_topn",
|
||||||
|
"aliases": [
|
||||||
|
{"text": "站点水表排行前N", "tone": "中性"},
|
||||||
|
{"text": "按站点水表TopN", "tone": "专业"},
|
||||||
|
],
|
||||||
|
"keywords": ["TopN", "排行"],
|
||||||
|
"intent_tags": ["topn", "aggregate"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "snpt_extra",
|
||||||
|
"aliases": [{"text": "额外别名"}],
|
||||||
|
"keywords": ["extra"],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
_insert_action_row(
|
||||||
|
mysql_engine,
|
||||||
|
table_id=table_id,
|
||||||
|
version_ts=version_ts,
|
||||||
|
action_type="snippet_alias",
|
||||||
|
snippet_json=snippet_payload,
|
||||||
|
snippet_alias_json=alias_payload,
|
||||||
|
updated_at=alias_updated,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
merged = merge_snippet_records_from_db(table_id, version_ts, engine=mysql_engine)
|
||||||
|
assert len(merged) == 2
|
||||||
|
topn = next(item for item in merged if item["id"] == "snpt_topn")
|
||||||
|
assert topn["source"] == "snippet"
|
||||||
|
assert topn["updated_at_from_action"] == alias_updated
|
||||||
|
assert {a["text"] for a in topn["aliases"]} == {"站点水表排行前N", "按站点水表TopN"}
|
||||||
|
assert set(topn["keywords"]) == {"TopN", "站点", "排行"}
|
||||||
|
assert set(topn["intent_tags"]) == {"topn", "aggregate"}
|
||||||
|
|
||||||
|
alias_only = next(item for item in merged if item["source"] == "alias_only")
|
||||||
|
assert alias_only["id"] == "snpt_extra"
|
||||||
|
assert alias_only["aliases"][0]["text"] == "额外别名"
|
||||||
|
finally:
|
||||||
|
_cleanup(mysql_engine, table_id, version_ts)
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_falls_back_to_snippet_row_when_alias_row_missing_snippet_json(mysql_engine: Engine) -> None:
|
||||||
|
table_id = 991000000 + random.randint(1, 9999)
|
||||||
|
version_ts = int((datetime.utcnow() + timedelta(seconds=1)).strftime("%Y%m%d%H%M%S"))
|
||||||
|
|
||||||
|
alias_updated = datetime(2024, 1, 3, 0, 0, 0)
|
||||||
|
alias_payload = [
|
||||||
|
{
|
||||||
|
"id": "snpt_quality",
|
||||||
|
"aliases": [{"text": "质量检查"}],
|
||||||
|
"keywords": ["quality"],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
snippet_payload = [
|
||||||
|
{
|
||||||
|
"id": "snpt_quality",
|
||||||
|
"title": "质量检查",
|
||||||
|
"keywords": ["data-quality"],
|
||||||
|
"aliases": [{"text": "质量检查"}],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
_insert_action_row(
|
||||||
|
mysql_engine,
|
||||||
|
table_id=table_id,
|
||||||
|
version_ts=version_ts,
|
||||||
|
action_type="snippet_alias",
|
||||||
|
snippet_json=None,
|
||||||
|
snippet_alias_json=alias_payload,
|
||||||
|
updated_at=alias_updated,
|
||||||
|
)
|
||||||
|
_insert_action_row(
|
||||||
|
mysql_engine,
|
||||||
|
table_id=table_id,
|
||||||
|
version_ts=version_ts,
|
||||||
|
action_type="snippet",
|
||||||
|
snippet_json=snippet_payload,
|
||||||
|
snippet_alias_json=None,
|
||||||
|
updated_at=datetime(2024, 1, 2, 0, 0, 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
merged = merge_snippet_records_from_db(table_id, version_ts, engine=mysql_engine)
|
||||||
|
|
||||||
|
assert len(merged) == 1
|
||||||
|
record = merged[0]
|
||||||
|
assert record["id"] == "snpt_quality"
|
||||||
|
assert record["source"] == "snippet"
|
||||||
|
assert record["updated_at_from_action"] == alias_updated
|
||||||
|
assert set(record["keywords"]) == {"data-quality", "quality"}
|
||||||
|
assert {a["text"] for a in record["aliases"]} == {"质量检查"}
|
||||||
|
finally:
|
||||||
|
_cleanup(mysql_engine, table_id, version_ts)
|
||||||
Reference in New Issue
Block a user