rag snippet生成入库和写rag

This commit is contained in:
zhaoawd
2025-12-09 00:15:22 +08:00
parent ebd79b75bd
commit 3218e51bad
11 changed files with 1231 additions and 8 deletions

3
.env
View File

@ -16,6 +16,9 @@ DEFAULT_IMPORT_MODEL=deepseek:deepseek-chat
# Service configuration # Service configuration
IMPORT_GATEWAY_BASE_URL=http://localhost:8000 IMPORT_GATEWAY_BASE_URL=http://localhost:8000
# prod nbackend base url
NBACKEND_BASE_URL=https://chatbi.agentcarrier.cn/chatbi/api
# HTTP client configuration # HTTP client configuration
HTTP_CLIENT_TIMEOUT=120 HTTP_CLIENT_TIMEOUT=120
HTTP_CLIENT_TRUST_ENV=false HTTP_CLIENT_TRUST_ENV=false

View File

@ -16,6 +16,8 @@ from fastapi.responses import JSONResponse
from app.exceptions import ProviderAPICallError, ProviderConfigurationError from app.exceptions import ProviderAPICallError, ProviderConfigurationError
from app.models import ( from app.models import (
ActionStatus,
ActionType,
DataImportAnalysisJobAck, DataImportAnalysisJobAck,
DataImportAnalysisJobRequest, DataImportAnalysisJobRequest,
LLMRequest, LLMRequest,
@ -25,10 +27,11 @@ from app.models import (
TableSnippetUpsertRequest, TableSnippetUpsertRequest,
TableSnippetUpsertResponse, TableSnippetUpsertResponse,
) )
from app.routers import chat_router, metrics_router
from app.services import LLMGateway from app.services import LLMGateway
from app.services.import_analysis import process_import_analysis_job from app.services.import_analysis import process_import_analysis_job
from app.services.table_profiling import process_table_profiling_job from app.services.table_profiling import process_table_profiling_job
from app.services.table_snippet import upsert_action_result from app.services.table_snippet import ingest_snippet_rag_from_db, upsert_action_result
def _ensure_log_directories(config: dict[str, Any]) -> None: def _ensure_log_directories(config: dict[str, Any]) -> None:
@ -135,6 +138,9 @@ def create_app() -> FastAPI:
version="0.1.0", version="0.1.0",
lifespan=lifespan, lifespan=lifespan,
) )
# Chat/metric management APIs
application.include_router(chat_router)
application.include_router(metrics_router)
@application.exception_handler(RequestValidationError) @application.exception_handler(RequestValidationError)
async def request_validation_exception_handler( async def request_validation_exception_handler(
@ -230,11 +236,12 @@ def create_app() -> FastAPI:
) )
async def upsert_table_snippet( async def upsert_table_snippet(
payload: TableSnippetUpsertRequest, payload: TableSnippetUpsertRequest,
client: httpx.AsyncClient = Depends(get_http_client),
) -> TableSnippetUpsertResponse: ) -> TableSnippetUpsertResponse:
request_copy = payload.model_copy(deep=True) request_copy = payload.model_copy(deep=True)
try: try:
return await asyncio.to_thread(upsert_action_result, request_copy) response = await asyncio.to_thread(upsert_action_result, request_copy)
except Exception as exc: except Exception as exc:
logger.error( logger.error(
"Failed to upsert table snippet: table_id=%s version_ts=%s action_type=%s", "Failed to upsert table snippet: table_id=%s version_ts=%s action_type=%s",
@ -244,6 +251,29 @@ def create_app() -> FastAPI:
exc_info=True, exc_info=True,
) )
raise HTTPException(status_code=500, detail=str(exc)) from exc raise HTTPException(status_code=500, detail=str(exc)) from exc
else:
if (
payload.action_type == ActionType.SNIPPET_ALIAS
and payload.status == ActionStatus.SUCCESS
and payload.rag_workspace_id is not None
):
try:
await ingest_snippet_rag_from_db(
table_id=payload.table_id,
version_ts=payload.version_ts,
workspace_id=payload.rag_workspace_id,
rag_item_type=payload.rag_item_type or "SNIPPET",
client=client,
)
except Exception:
logger.exception(
"Failed to ingest snippet RAG artifacts",
extra={
"table_id": payload.table_id,
"version_ts": payload.version_ts,
},
)
return response
@application.post("/__mock__/import-callback") @application.post("/__mock__/import-callback")
async def mock_import_callback(payload: dict[str, Any]) -> dict[str, str]: async def mock_import_callback(payload: dict[str, Any]) -> dict[str, str]:

View File

@ -247,6 +247,16 @@ class TableSnippetUpsertRequest(BaseModel):
ge=0, ge=0,
description="Version timestamp aligned with the pipeline (yyyyMMddHHmmss as integer).", description="Version timestamp aligned with the pipeline (yyyyMMddHHmmss as integer).",
) )
rag_workspace_id: Optional[int] = Field(
None,
ge=0,
description="Optional workspace identifier for RAG ingestion; when provided and action_type=snippet_alias "
"with status=success, merged snippets will be written to rag_snippet and pushed to RAG.",
)
rag_item_type: Optional[str] = Field(
"SNIPPET",
description="Optional RAG item type used when pushing snippets to RAG. Defaults to 'SNIPPET'.",
)
action_type: ActionType = Field(..., description="Pipeline action type for this record.") action_type: ActionType = Field(..., description="Pipeline action type for this record.")
status: ActionStatus = Field( status: ActionStatus = Field(
ActionStatus.SUCCESS, description="Execution status for the action." ActionStatus.SUCCESS, description="Execution status for the action."

46
app/schemas/rag.py Normal file
View File

@ -0,0 +1,46 @@
from __future__ import annotations
from typing import Any, List
from pydantic import BaseModel, ConfigDict, Field
class RagItemPayload(BaseModel):
"""Payload for creating or updating a single RAG item."""
model_config = ConfigDict(populate_by_name=True, extra="ignore")
id: int = Field(..., description="Unique identifier for the RAG item.")
workspace_id: int = Field(..., alias="workspaceId", description="Workspace identifier.")
name: str = Field(..., description="Readable name of the item.")
embedding_data: str = Field(..., alias="embeddingData", description="Serialized embedding payload.")
type: str = Field(..., description='Item type, e.g. "METRIC".')
class RagDeleteRequest(BaseModel):
"""Payload for deleting a single RAG item."""
model_config = ConfigDict(populate_by_name=True, extra="ignore")
id: int = Field(..., description="Identifier of the item to delete.")
type: str = Field(..., description="Item type matching the stored record.")
class RagRetrieveRequest(BaseModel):
"""Payload for retrieving RAG items by semantic query."""
model_config = ConfigDict(populate_by_name=True, extra="ignore")
query: str = Field(..., description="Search query text.")
num: int = Field(..., description="Number of items to return.")
workspace_id: int = Field(..., alias="workspaceId", description="Workspace scope for the search.")
type: str = Field(..., description="Item type to search, e.g. METRIC.")
class RagRetrieveResponse(BaseModel):
"""Generic RAG retrieval response wrapper."""
model_config = ConfigDict(extra="allow")
data: List[Any] = Field(default_factory=list, description="Retrieved items.")

View File

@ -0,0 +1,83 @@
from __future__ import annotations
import logging
from typing import Any, Sequence
import httpx
from app.exceptions import ProviderAPICallError
from app.schemas.rag import RagDeleteRequest, RagItemPayload, RagRetrieveRequest
from app.settings import RAG_API_AUTH_TOKEN, RAG_API_BASE_URL
logger = logging.getLogger(__name__)
class RagAPIClient:
"""Thin async client wrapper around the RAG endpoints described in doc/rag-api.md."""
def __init__(self, *, base_url: str | None = None, auth_token: str | None = None) -> None:
resolved_base = base_url or RAG_API_BASE_URL
self._base_url = resolved_base.rstrip("/")
self._auth_token = auth_token or RAG_API_AUTH_TOKEN
def _headers(self) -> dict[str, str]:
headers = {"Content-Type": "application/json"}
if self._auth_token:
headers["Authorization"] = f"Bearer {self._auth_token}"
return headers
async def _post(
self,
client: httpx.AsyncClient,
path: str,
payload: Any,
) -> Any:
url = f"{self._base_url}{path}"
try:
response = await client.post(url, json=payload, headers=self._headers())
response.raise_for_status()
except httpx.HTTPStatusError as exc:
status_code = exc.response.status_code if exc.response else None
response_text = exc.response.text if exc.response else ""
logger.error(
"RAG API responded with an error (%s) for %s: %s",
status_code,
url,
response_text,
exc_info=True,
)
raise ProviderAPICallError(
"RAG API call failed.",
status_code=status_code,
response_text=response_text,
) from exc
except httpx.HTTPError as exc:
logger.error("Transport error calling RAG API %s: %s", url, exc, exc_info=True)
raise ProviderAPICallError(f"RAG API call failed: {exc}") from exc
try:
return response.json()
except ValueError:
logger.warning("RAG API returned non-JSON response for %s; returning raw text.", url)
return {"raw": response.text}
async def add(self, client: httpx.AsyncClient, payload: RagItemPayload) -> Any:
body = payload.model_dump(by_alias=True, exclude_none=True)
return await self._post(client, "/rag/add", body)
async def add_batch(self, client: httpx.AsyncClient, items: Sequence[RagItemPayload]) -> Any:
body = [item.model_dump(by_alias=True, exclude_none=True) for item in items]
return await self._post(client, "/rag/addBatch", body)
async def update(self, client: httpx.AsyncClient, payload: RagItemPayload) -> Any:
body = payload.model_dump(by_alias=True, exclude_none=True)
return await self._post(client, "/rag/update", body)
async def delete(self, client: httpx.AsyncClient, payload: RagDeleteRequest) -> Any:
body = payload.model_dump(by_alias=True, exclude_none=True)
return await self._post(client, "/rag/delete", body)
async def retrieve(self, client: httpx.AsyncClient, payload: RagRetrieveRequest) -> Any:
body = payload.model_dump(by_alias=True, exclude_none=True)
return await self._post(client, "/rag/retrieve", body)

View File

@ -1,19 +1,19 @@
from __future__ import annotations from __future__ import annotations
import hashlib
import json import json
import logging import logging
from typing import Any, Dict, Tuple from datetime import datetime
from typing import Any, Dict, List, Optional, Sequence, Tuple
from sqlalchemy import text from sqlalchemy import text
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from app.db import get_engine from app.db import get_engine
from app.models import ( from app.models import ActionType, TableSnippetUpsertRequest, TableSnippetUpsertResponse
ActionType, from app.schemas.rag import RagItemPayload
TableSnippetUpsertRequest, from app.services.rag_client import RagAPIClient
TableSnippetUpsertResponse,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -46,6 +46,7 @@ def _prepare_model_params(params: Dict[str, Any] | None) -> str | None:
def _collect_common_columns(request: TableSnippetUpsertRequest) -> Dict[str, Any]: def _collect_common_columns(request: TableSnippetUpsertRequest) -> Dict[str, Any]:
# Build the base column set shared by all action types; action-specific fields are populated later.
logger.debug( logger.debug(
"Collecting common columns for table_id=%s version_ts=%s action_type=%s", "Collecting common columns for table_id=%s version_ts=%s action_type=%s",
request.table_id, request.table_id,
@ -215,3 +216,405 @@ def upsert_action_result(request: TableSnippetUpsertRequest) -> TableSnippetUpse
status=request.status, status=request.status,
updated=updated, updated=updated,
) )
def _decode_json_field(value: Any) -> Any:
"""Decode JSON columns that may be returned as str/bytes/dicts/lists."""
if value is None:
return None
if isinstance(value, (dict, list)):
return value
if isinstance(value, (bytes, bytearray)):
try:
value = value.decode("utf-8")
except Exception: # pragma: no cover - defensive
return None
if isinstance(value, str):
try:
return json.loads(value)
except json.JSONDecodeError:
logger.warning("Failed to decode JSON field: %s", value)
return None
return None
def _coerce_json_array(value: Any) -> List[Any]:
decoded = _decode_json_field(value)
return decoded if isinstance(decoded, list) else []
def _fetch_action_payload(
engine: Engine, table_id: int, version_ts: int, action_type: ActionType
) -> Optional[Dict[str, Any]]:
sql = text(
"""
SELECT id AS action_result_id, snippet_json, snippet_alias_json, updated_at, status
FROM action_results
WHERE table_id = :table_id
AND version_ts = :version_ts
AND action_type = :action_type
AND status IN ('success', 'partial')
ORDER BY CASE status WHEN 'success' THEN 0 ELSE 1 END, updated_at DESC
LIMIT 1
"""
)
with engine.connect() as conn:
row = conn.execute(
sql,
{
"table_id": table_id,
"version_ts": version_ts,
"action_type": action_type.value,
},
).mappings().first()
return dict(row) if row else None
def _load_snippet_sources(
engine: Engine, table_id: int, version_ts: int
) -> Tuple[List[Any], List[Any], Optional[datetime], Optional[int], Optional[int]]:
alias_row = _fetch_action_payload(engine, table_id, version_ts, ActionType.SNIPPET_ALIAS)
snippet_row = _fetch_action_payload(engine, table_id, version_ts, ActionType.SNIPPET)
snippet_json = _coerce_json_array(alias_row.get("snippet_json") if alias_row else None)
alias_json = _coerce_json_array(alias_row.get("snippet_alias_json") if alias_row else None)
updated_at: Optional[datetime] = alias_row.get("updated_at") if alias_row else None
alias_action_id: Optional[int] = alias_row.get("action_result_id") if alias_row else None
snippet_action_id: Optional[int] = snippet_row.get("action_result_id") if snippet_row else None
if not snippet_json and snippet_row:
snippet_json = _coerce_json_array(snippet_row.get("snippet_json"))
if updated_at is None:
updated_at = snippet_row.get("updated_at")
if alias_action_id is None:
alias_action_id = snippet_action_id
if not updated_at and alias_row:
updated_at = alias_row.get("updated_at")
return snippet_json, alias_json, updated_at, alias_action_id, snippet_action_id
def _normalize_aliases(raw_aliases: Any) -> List[Dict[str, Any]]:
aliases: List[Dict[str, Any]] = []
seen: set[str] = set()
if not raw_aliases:
return aliases
if not isinstance(raw_aliases, list):
return aliases
for item in raw_aliases:
if isinstance(item, dict):
text_val = item.get("text")
if not text_val or text_val in seen:
continue
seen.add(text_val)
aliases.append({"text": text_val, "tone": item.get("tone")})
elif isinstance(item, str):
if item in seen:
continue
seen.add(item)
aliases.append({"text": item})
return aliases
def _normalize_str_list(values: Any) -> List[str]:
if not values:
return []
if not isinstance(values, list):
return []
seen: set[str] = set()
normalised: List[str] = []
for val in values:
if not isinstance(val, str):
continue
if val in seen:
continue
seen.add(val)
normalised.append(val)
return normalised
def _merge_alias_lists(primary: List[Dict[str, Any]], secondary: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
merged: List[Dict[str, Any]] = []
seen: set[str] = set()
for source in (primary, secondary):
for item in source:
if not isinstance(item, dict):
continue
text_val = item.get("text")
if not text_val or text_val in seen:
continue
seen.add(text_val)
merged.append({"text": text_val, "tone": item.get("tone")})
return merged
def _merge_str_lists(primary: List[str], secondary: List[str]) -> List[str]:
merged: List[str] = []
seen: set[str] = set()
for source in (primary, secondary):
for item in source:
if item in seen:
continue
seen.add(item)
merged.append(item)
return merged
def _build_alias_map(alias_payload: List[Any]) -> Dict[str, Dict[str, Any]]:
alias_map: Dict[str, Dict[str, Any]] = {}
for item in alias_payload:
if not isinstance(item, dict):
continue
alias_id = item.get("id")
if not alias_id:
continue
existing = alias_map.setdefault(
alias_id,
{"aliases": [], "keywords": [], "intent_tags": []},
)
existing["aliases"] = _merge_alias_lists(
existing["aliases"], _normalize_aliases(item.get("aliases"))
)
existing["keywords"] = _merge_str_lists(
existing["keywords"], _normalize_str_list(item.get("keywords"))
)
existing["intent_tags"] = _merge_str_lists(
existing["intent_tags"], _normalize_str_list(item.get("intent_tags"))
)
return alias_map
def merge_snippet_records_from_db(
table_id: int,
version_ts: int,
*,
engine: Optional[Engine] = None,
) -> List[Dict[str, Any]]:
"""
Load snippet + snippet_alias JSON from action_results after snippet_alias is stored,
then merge into a unified snippet object list ready for downstream RAG.
"""
engine = engine or get_engine()
snippets, aliases, updated_at, alias_action_id, snippet_action_id = _load_snippet_sources(
engine, table_id, version_ts
)
alias_map = _build_alias_map(aliases)
merged: List[Dict[str, Any]] = []
seen_ids: set[str] = set()
for snippet in snippets:
if not isinstance(snippet, dict):
continue
snippet_id = snippet.get("id")
if not snippet_id:
continue
alias_info = alias_map.get(snippet_id)
record = dict(snippet)
record_aliases = _normalize_aliases(record.get("aliases"))
record_keywords = _normalize_str_list(record.get("keywords"))
record_intents = _normalize_str_list(record.get("intent_tags"))
if alias_info:
record_aliases = _merge_alias_lists(record_aliases, alias_info["aliases"])
record_keywords = _merge_str_lists(record_keywords, alias_info["keywords"])
record_intents = _merge_str_lists(record_intents, alias_info["intent_tags"])
record["aliases"] = record_aliases
record["keywords"] = record_keywords
record["intent_tags"] = record_intents
record["table_id"] = table_id
record["version_ts"] = version_ts
record["updated_at_from_action"] = updated_at
record["source"] = "snippet"
record["action_result_id"] = alias_action_id or snippet_action_id
merged.append(record)
seen_ids.add(snippet_id)
for alias_id, alias_info in alias_map.items():
if alias_id in seen_ids:
continue
merged.append(
{
"id": alias_id,
"aliases": alias_info["aliases"],
"keywords": alias_info["keywords"],
"intent_tags": alias_info["intent_tags"],
"table_id": table_id,
"version_ts": version_ts,
"updated_at_from_action": updated_at,
"source": "alias_only",
"action_result_id": alias_action_id or snippet_action_id,
}
)
return merged
def _stable_rag_item_id(table_id: int, version_ts: int, snippet_id: str) -> int:
digest = hashlib.md5(f"{table_id}:{version_ts}:{snippet_id}".encode("utf-8")).hexdigest()
return int(digest[:16], 16) % 9_000_000_000_000_000_000
def _build_rag_text(snippet: Dict[str, Any]) -> str:
# Deterministic text concatenation for embedding input.
parts: List[str] = []
def _add(label: str, value: Any) -> None:
if value is None:
return
if isinstance(value, list):
value = ", ".join([str(v) for v in value if v])
elif isinstance(value, dict):
value = json.dumps(value, ensure_ascii=False)
if value:
parts.append(f"{label}: {value}")
_add("Title", snippet.get("title") or snippet.get("id"))
_add("Description", snippet.get("desc"))
_add("Business", snippet.get("business_caliber"))
_add("Type", snippet.get("type"))
_add("Examples", snippet.get("examples") or [])
_add("Aliases", [a.get("text") for a in snippet.get("aliases") or [] if isinstance(a, dict)])
_add("Keywords", snippet.get("keywords") or [])
_add("IntentTags", snippet.get("intent_tags") or [])
_add("Applicability", snippet.get("applicability"))
_add("DialectSQL", snippet.get("dialect_sql"))
return "\n".join(parts)
def _prepare_rag_payloads(
snippets: List[Dict[str, Any]],
table_id: int,
version_ts: int,
workspace_id: int,
rag_item_type: str = "SNIPPET",
) -> Tuple[List[Dict[str, Any]], List[RagItemPayload]]:
rows: List[Dict[str, Any]] = []
payloads: List[RagItemPayload] = []
now = datetime.utcnow()
for snippet in snippets:
snippet_id = snippet.get("id")
if not snippet_id:
continue
action_result_id = snippet.get("action_result_id")
if action_result_id is None:
logger.warning(
"Skipping snippet without action_result_id for RAG ingestion (table_id=%s version_ts=%s snippet_id=%s)",
table_id,
version_ts,
snippet_id,
)
continue
rag_item_id = _stable_rag_item_id(table_id, version_ts, snippet_id)
rag_text = _build_rag_text(snippet)
merged_json = json.dumps(snippet, ensure_ascii=False)
updated_at_raw = snippet.get("updated_at_from_action") or now
if isinstance(updated_at_raw, str):
try:
updated_at = datetime.fromisoformat(updated_at_raw)
except ValueError:
updated_at = now
else:
updated_at = updated_at_raw if isinstance(updated_at_raw, datetime) else now
row = {
"rag_item_id": rag_item_id,
"workspace_id": workspace_id,
"table_id": table_id,
"version_ts": version_ts,
"action_result_id": action_result_id,
"snippet_id": snippet_id,
"rag_text": rag_text,
"merged_json": merged_json,
"updated_at": updated_at,
}
rows.append(row)
payloads.append(
RagItemPayload(
id=rag_item_id,
workspaceId=workspace_id,
name=snippet.get("title") or snippet_id,
embeddingData=rag_text,
type=rag_item_type or "SNIPPET",
)
)
return rows, payloads
def _upsert_rag_snippet_rows(engine: Engine, rows: Sequence[Dict[str, Any]]) -> None:
if not rows:
return
delete_sql = text("DELETE FROM rag_snippet WHERE rag_item_id=:rag_item_id")
insert_sql = text(
"""
INSERT INTO rag_snippet (
rag_item_id,
workspace_id,
table_id,
version_ts,
action_result_id,
snippet_id,
rag_text,
merged_json,
updated_at
) VALUES (
:rag_item_id,
:workspace_id,
:table_id,
:version_ts,
:action_result_id,
:snippet_id,
:rag_text,
:merged_json,
:updated_at
)
"""
)
with engine.begin() as conn:
for row in rows:
conn.execute(delete_sql, row)
conn.execute(insert_sql, row)
async def ingest_snippet_rag_from_db(
table_id: int,
version_ts: int,
*,
workspace_id: int,
rag_item_type: str = "SNIPPET",
client,
engine: Optional[Engine] = None,
rag_client: Optional[RagAPIClient] = None,
) -> List[int]:
"""
Merge snippet + alias JSON from action_results, persist to rag_snippet, then push to RAG via addBatch.
Returns list of rag_item_id ingested.
"""
engine = engine or get_engine()
snippets = merge_snippet_records_from_db(table_id, version_ts, engine=engine)
if not snippets:
logger.info(
"No snippets available for RAG ingestion (table_id=%s version_ts=%s)",
table_id,
version_ts,
)
return []
rows, payloads = _prepare_rag_payloads(
snippets,
table_id=table_id,
version_ts=version_ts,
workspace_id=workspace_id,
rag_item_type=rag_item_type,
)
_upsert_rag_snippet_rows(engine, rows)
rag_client = rag_client or RagAPIClient()
await rag_client.add_batch(client, payloads)
return [row["rag_item_id"] for row in rows]

57
doc/rag-api.md Normal file
View File

@ -0,0 +1,57 @@
#添加RAG
curl --location --request POST 'http://127.0.0.1:8000/rag/add' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer ' \
--data-raw '{
"id": 0,
"workspaceId": 0,
"name": "string",
"embeddingData": "string",
"type": "METRIC"
}'
#批量添加RAG
curl --location --request POST 'http://127.0.0.1:8000/rag/addBatch' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer ' \
--data-raw '[
{
"id": 0,
"workspaceId": 0,
"name": "string",
"embeddingData": "string",
"type": "METRIC"
}
]'
#更新RAG
curl --location --request POST 'http://127.0.0.1:8000/rag/update' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer ' \
--data-raw '{
"id": 0,
"workspaceId": 0,
"name": "string",
"embeddingData": "string",
"type": "METRIC"
}'
#删除RAG
curl --location --request POST 'http://127.0.0.1:8000/rag/delete' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer ' \
--data-raw '{
"id": 0,
"type": "METRIC"
}'
#检索RAG
curl --location --request POST 'http://127.0.0.1:8000/rag/retrieve' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer ' \
--data-raw '{
"query": "string",
"num": 0,
"workspaceId": 0,
"type": "METRIC"
}'

View File

@ -0,0 +1,15 @@
CREATE TABLE `rag_snippet` (
`rag_item_id` bigint NOT NULL COMMENT 'RAG item id (stable hash of table/version/snippet_id)',
`workspace_id` bigint NOT NULL COMMENT 'RAG workspace scope',
`table_id` bigint NOT NULL COMMENT '来源表ID',
`version_ts` bigint NOT NULL COMMENT '表版本号',
`action_result_id` bigint NOT NULL COMMENT '来源 action_results 主键IDsnippet_alias 或 snippet 行)',
`snippet_id` varchar(255) COLLATE utf8mb4_bin NOT NULL COMMENT '原始 snippet id',
`rag_text` text COLLATE utf8mb4_bin NOT NULL COMMENT '用于向量化的拼接文本',
`merged_json` json NOT NULL COMMENT '合并后的 snippet 对象',
`updated_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
PRIMARY KEY (`rag_item_id`),
KEY `idx_action_result` (`action_result_id`),
KEY `idx_workspace` (`workspace_id`),
KEY `idx_table_version` (`table_id`,`version_ts`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin COMMENT='RAG snippet 索引缓存';

View File

@ -0,0 +1,207 @@
from __future__ import annotations
import os
import random
from datetime import datetime, timedelta
from pathlib import Path
from typing import Generator, List
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import text
from sqlalchemy.exc import SQLAlchemyError
# Ensure project root on path for direct execution
ROOT = Path(__file__).resolve().parents[1]
import sys
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from app import db
from app.main import create_app
TEST_USER_ID = 98765
#SCHEMA_PATH = Path("file/tableschema/metrics.sql")
DEFAULT_MYSQL_URL = "mysql+pymysql://root:12345678@127.0.0.1:3306/data-ge?charset=utf8mb4"
# def _run_sql_script(engine, sql_text: str) -> None:
# """Execute semicolon-terminated SQL statements sequentially."""
# statements: List[str] = []
# buffer: List[str] = []
# for line in sql_text.splitlines():
# stripped = line.strip()
# if not stripped or stripped.startswith("--"):
# continue
# buffer.append(line)
# if stripped.endswith(";"):
# statements.append("\n".join(buffer).rstrip(";"))
# buffer = []
# if buffer:
# statements.append("\n".join(buffer))
# with engine.begin() as conn:
# for stmt in statements:
# conn.execute(text(stmt))
# def _ensure_metric_schema(engine) -> None:
# if not SCHEMA_PATH.exists():
# pytest.skip("metrics.sql schema file not found.")
# raw_sql = SCHEMA_PATH.read_text(encoding="utf-8")
# raw_sql = raw_sql.replace("CREATE TABLE metric_def", "CREATE TABLE IF NOT EXISTS metric_def")
# raw_sql = raw_sql.replace("CREATE TABLE metric_schedule", "CREATE TABLE IF NOT EXISTS metric_schedule")
# raw_sql = raw_sql.replace("CREATE TABLE metric_job_run", "CREATE TABLE IF NOT EXISTS metric_job_run")
# raw_sql = raw_sql.replace("CREATE TABLE metric_result", "CREATE TABLE IF NOT EXISTS metric_result")
# _run_sql_script(engine, raw_sql)
@pytest.fixture(scope="module")
def client() -> Generator[TestClient, None, None]:
mysql_url = os.getenv("TEST_DATABASE_URL", DEFAULT_MYSQL_URL)
os.environ["DATABASE_URL"] = mysql_url
db.get_engine.cache_clear()
engine = db.get_engine()
try:
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
except SQLAlchemyError:
pytest.skip(f"Cannot connect to MySQL at {mysql_url}")
#_ensure_metric_schema(engine)
app = create_app()
with TestClient(app) as test_client:
yield test_client
# cleanup test artifacts
with engine.begin() as conn:
conn.execute(text("DELETE FROM metric_result WHERE metric_id IN (SELECT id FROM metric_def WHERE created_by=:uid)"), {"uid": TEST_USER_ID})
conn.execute(text("DELETE FROM metric_job_run WHERE metric_id IN (SELECT id FROM metric_def WHERE created_by=:uid)"), {"uid": TEST_USER_ID})
conn.execute(text("DELETE FROM metric_schedule WHERE metric_id IN (SELECT id FROM metric_def WHERE created_by=:uid)"), {"uid": TEST_USER_ID})
conn.execute(text("DELETE FROM metric_def WHERE created_by=:uid"), {"uid": TEST_USER_ID})
db.get_engine.cache_clear()
def test_metric_crud_and_schedule_mysql(client: TestClient) -> None:
code = f"metric_{random.randint(1000,9999)}"
create_payload = {
"metric_code": code,
"metric_name": "订单数",
"biz_domain": "order",
"biz_desc": "订单总数",
"base_sql": "select count(*) as order_cnt from orders",
"time_grain": "DAY",
"dim_binding": ["dt"],
"update_strategy": "FULL",
"metric_aliases": ["订单量"],
"created_by": TEST_USER_ID,
}
resp = client.post("/api/v1/metrics", json=create_payload)
assert resp.status_code == 200, resp.text
metric = resp.json()
metric_id = metric["id"]
assert metric["metric_code"] == code
# Update metric
resp = client.post(f"/api/v1/metrics/{metric_id}", json={"metric_name": "订单数-更新", "is_active": False})
assert resp.status_code == 200
assert resp.json()["is_active"] is False
# Get metric
resp = client.get(f"/api/v1/metrics/{metric_id}")
assert resp.status_code == 200
assert resp.json()["metric_name"] == "订单数-更新"
# Create schedule
resp = client.post(
"/api/v1/metric-schedules",
json={"metric_id": metric_id, "cron_expr": "0 2 * * *", "priority": 5, "enabled": True},
)
assert resp.status_code == 200, resp.text
schedule = resp.json()
schedule_id = schedule["id"]
# Update schedule
resp = client.post(f"/api/v1/metric-schedules/{schedule_id}", json={"enabled": False, "retry_times": 1})
assert resp.status_code == 200
assert resp.json()["enabled"] is False
# List schedules for metric
resp = client.get(f"/api/v1/metrics/{metric_id}/schedules")
assert resp.status_code == 200
assert any(s["id"] == schedule_id for s in resp.json())
def test_metric_runs_and_results_mysql(client: TestClient) -> None:
code = f"gmv_{random.randint(1000,9999)}"
metric_id = client.post(
"/api/v1/metrics",
json={
"metric_code": code,
"metric_name": "GMV",
"biz_domain": "order",
"base_sql": "select sum(pay_amount) as gmv from orders",
"time_grain": "DAY",
"dim_binding": ["dt"],
"update_strategy": "FULL",
"created_by": TEST_USER_ID,
},
).json()["id"]
# Trigger run
resp = client.post(
"/api/v1/metric-runs/trigger",
json={
"metric_id": metric_id,
"triggered_by": "API",
"data_time_from": (datetime.utcnow() - timedelta(days=1)).isoformat(),
"data_time_to": datetime.utcnow().isoformat(),
},
)
assert resp.status_code == 200, resp.text
run = resp.json()
run_id = run["id"]
assert run["status"] == "RUNNING"
# List runs
resp = client.get("/api/v1/metric-runs", params={"metric_id": metric_id})
assert resp.status_code == 200
assert any(r["id"] == run_id for r in resp.json())
# Get run
resp = client.get(f"/api/v1/metric-runs/{run_id}")
assert resp.status_code == 200
# Write results
now = datetime.utcnow()
resp = client.post(
f"/api/v1/metric-results/{metric_id}",
json={
"metric_id": metric_id,
"results": [
{"stat_time": (now - timedelta(days=1)).isoformat(), "metric_value": 123.45, "data_version": run_id},
{"stat_time": now.isoformat(), "metric_value": 234.56, "data_version": run_id},
],
},
)
assert resp.status_code == 200, resp.text
assert resp.json()["inserted"] == 2
# Query results
resp = client.get("/api/v1/metric-results", params={"metric_id": metric_id})
assert resp.status_code == 200
results = resp.json()
assert len(results) >= 2
# Latest result
resp = client.get("/api/v1/metric-results/latest", params={"metric_id": metric_id})
assert resp.status_code == 200
latest = resp.json()
assert float(latest["metric_value"]) in {123.45, 234.56}
if __name__ == "__main__":
import pytest as _pytest
raise SystemExit(_pytest.main([__file__]))

View File

@ -0,0 +1,156 @@
from __future__ import annotations
import json
from datetime import datetime
import httpx
import pytest
from sqlalchemy import create_engine, text
from app.services.table_snippet import ingest_snippet_rag_from_db
def _setup_sqlite_engine():
engine = create_engine("sqlite://")
with engine.begin() as conn:
conn.execute(
text(
"""
CREATE TABLE action_results (
id INTEGER PRIMARY KEY AUTOINCREMENT,
table_id INTEGER,
version_ts INTEGER,
action_type TEXT,
status TEXT,
snippet_json TEXT,
snippet_alias_json TEXT,
updated_at TEXT
)
"""
)
)
conn.execute(
text(
"""
CREATE TABLE rag_snippet (
rag_item_id INTEGER PRIMARY KEY,
action_result_id INTEGER NOT NULL,
workspace_id INTEGER,
table_id INTEGER,
version_ts INTEGER,
snippet_id TEXT,
rag_text TEXT,
merged_json TEXT,
updated_at TEXT
)
"""
)
)
return engine
def _insert_action_row(engine, payload: dict) -> None:
with engine.begin() as conn:
conn.execute(
text(
"""
INSERT INTO action_results (table_id, version_ts, action_type, status, snippet_json, snippet_alias_json, updated_at)
VALUES (:table_id, :version_ts, :action_type, :status, :snippet_json, :snippet_alias_json, :updated_at)
"""
),
{
"table_id": payload["table_id"],
"version_ts": payload["version_ts"],
"action_type": payload["action_type"],
"status": payload.get("status", "success"),
"snippet_json": json.dumps(payload.get("snippet_json"), ensure_ascii=False)
if payload.get("snippet_json") is not None
else None,
"snippet_alias_json": json.dumps(payload.get("snippet_alias_json"), ensure_ascii=False)
if payload.get("snippet_alias_json") is not None
else None,
"updated_at": payload.get("updated_at") or datetime.utcnow().isoformat(),
},
)
class _StubRagClient:
def __init__(self) -> None:
self.received = None
async def add_batch(self, _client, items):
self.received = items
return {"count": len(items)}
@pytest.mark.asyncio
async def test_ingest_snippet_rag_from_db_persists_and_calls_rag_client() -> None:
engine = _setup_sqlite_engine()
table_id = 321
version_ts = 20240102000000
snippet_payload = [
{
"id": "snpt_topn",
"title": "TopN",
"aliases": [{"text": "站点水表排行前N", "tone": "中性"}],
"keywords": ["TopN", "站点"],
}
]
alias_payload = [
{
"id": "snpt_topn",
"aliases": [
{"text": "站点水表排行前N", "tone": "中性"},
{"text": "按站点水表TopN", "tone": "专业"},
],
"keywords": ["TopN", "排行"],
"intent_tags": ["topn", "aggregate"],
},
{
"id": "snpt_extra",
"aliases": [{"text": "额外别名"}],
"keywords": ["extra"],
},
]
_insert_action_row(
engine,
{
"table_id": table_id,
"version_ts": version_ts,
"action_type": "snippet_alias",
"snippet_json": snippet_payload,
"snippet_alias_json": alias_payload,
"updated_at": "2024-01-02T00:00:00",
},
)
rag_stub = _StubRagClient()
async with httpx.AsyncClient() as client:
rag_ids = await ingest_snippet_rag_from_db(
table_id=table_id,
version_ts=version_ts,
workspace_id=99,
rag_item_type="SNIPPET",
client=client,
engine=engine,
rag_client=rag_stub,
)
assert rag_stub.received is not None
assert len(rag_stub.received) == 2 # includes alias-only row
assert len(rag_ids) == 2
with engine.connect() as conn:
rows = list(
conn.execute(
text("SELECT snippet_id, action_result_id, rag_text, merged_json FROM rag_snippet ORDER BY snippet_id")
)
)
assert {row[0] for row in rows} == {"snpt_extra", "snpt_topn"}
assert all(row[1] is not None for row in rows)
topn_row = next(row for row in rows if row[0] == "snpt_topn")
assert "TopN" in topn_row[2]
assert "按站点水表TopN" in topn_row[2]
assert "排行" in topn_row[2]

View File

@ -0,0 +1,213 @@
from __future__ import annotations
import json
import os
import random
from datetime import datetime, timedelta
from typing import List
from pathlib import Path
import sys
import pytest
from sqlalchemy import text
from sqlalchemy.engine import Engine
from sqlalchemy.exc import SQLAlchemyError
# Ensure the project root is importable when running directly via python.
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from app import db
from app.main import create_app
from app.services.table_snippet import merge_snippet_records_from_db
DEFAULT_MYSQL_URL = "mysql+pymysql://root:12345678@127.0.0.1:3306/data-ge?charset=utf8mb4"
@pytest.fixture()
def mysql_engine() -> Engine:
mysql_url = os.getenv("TEST_DATABASE_URL", DEFAULT_MYSQL_URL)
os.environ["DATABASE_URL"] = mysql_url
db.get_engine.cache_clear()
engine = db.get_engine()
try:
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
exists = conn.execute(text("SHOW TABLES LIKE 'action_results'")).scalar()
if not exists:
pytest.skip("action_results table not found in test database.")
except SQLAlchemyError:
pytest.skip(f"Cannot connect to MySQL at {mysql_url}")
return engine
def _insert_action_row(
engine: Engine,
*,
table_id: int,
version_ts: int,
action_type: str,
status: str = "success",
snippet_json: List[dict] | None = None,
snippet_alias_json: List[dict] | None = None,
updated_at: datetime | None = None,
) -> None:
snippet_json_str = json.dumps(snippet_json, ensure_ascii=False) if snippet_json is not None else None
snippet_alias_json_str = (
json.dumps(snippet_alias_json, ensure_ascii=False) if snippet_alias_json is not None else None
)
with engine.begin() as conn:
conn.execute(
text(
"""
INSERT INTO action_results (
table_id, version_ts, action_type, status,
callback_url, table_schema_version_id, table_schema,
snippet_json, snippet_alias_json, updated_at
) VALUES (
:table_id, :version_ts, :action_type, :status,
:callback_url, :table_schema_version_id, :table_schema,
:snippet_json, :snippet_alias_json, :updated_at
)
ON DUPLICATE KEY UPDATE
status=VALUES(status),
snippet_json=VALUES(snippet_json),
snippet_alias_json=VALUES(snippet_alias_json),
updated_at=VALUES(updated_at)
"""
),
{
"table_id": table_id,
"version_ts": version_ts,
"action_type": action_type,
"status": status,
"callback_url": "http://localhost/test-callback",
"table_schema_version_id": "1",
"table_schema": json.dumps({}, ensure_ascii=False),
"snippet_json": snippet_json_str,
"snippet_alias_json": snippet_alias_json_str,
"updated_at": updated_at or datetime.utcnow(),
},
)
def _cleanup(engine: Engine, table_id: int, version_ts: int) -> None:
with engine.begin() as conn:
conn.execute(
text("DELETE FROM action_results WHERE table_id=:table_id AND version_ts=:version_ts"),
{"table_id": table_id, "version_ts": version_ts},
)
def test_merge_prefers_alias_row_and_appends_alias_only_entries(mysql_engine: Engine) -> None:
table_id = 990000000 + random.randint(1, 9999)
version_ts = int(datetime.utcnow().strftime("%Y%m%d%H%M%S"))
alias_updated = datetime(2024, 1, 2, 0, 0, 0)
snippet_payload = [
{
"id": "snpt_topn",
"aliases": [{"text": "站点水表排行前N", "tone": "中性"}],
"keywords": ["TopN", "站点"],
}
]
alias_payload = [
{
"id": "snpt_topn",
"aliases": [
{"text": "站点水表排行前N", "tone": "中性"},
{"text": "按站点水表TopN", "tone": "专业"},
],
"keywords": ["TopN", "排行"],
"intent_tags": ["topn", "aggregate"],
},
{
"id": "snpt_extra",
"aliases": [{"text": "额外别名"}],
"keywords": ["extra"],
},
]
_insert_action_row(
mysql_engine,
table_id=table_id,
version_ts=version_ts,
action_type="snippet_alias",
snippet_json=snippet_payload,
snippet_alias_json=alias_payload,
updated_at=alias_updated,
)
try:
merged = merge_snippet_records_from_db(table_id, version_ts, engine=mysql_engine)
assert len(merged) == 2
topn = next(item for item in merged if item["id"] == "snpt_topn")
assert topn["source"] == "snippet"
assert topn["updated_at_from_action"] == alias_updated
assert {a["text"] for a in topn["aliases"]} == {"站点水表排行前N", "按站点水表TopN"}
assert set(topn["keywords"]) == {"TopN", "站点", "排行"}
assert set(topn["intent_tags"]) == {"topn", "aggregate"}
alias_only = next(item for item in merged if item["source"] == "alias_only")
assert alias_only["id"] == "snpt_extra"
assert alias_only["aliases"][0]["text"] == "额外别名"
finally:
_cleanup(mysql_engine, table_id, version_ts)
def test_merge_falls_back_to_snippet_row_when_alias_row_missing_snippet_json(mysql_engine: Engine) -> None:
table_id = 991000000 + random.randint(1, 9999)
version_ts = int((datetime.utcnow() + timedelta(seconds=1)).strftime("%Y%m%d%H%M%S"))
alias_updated = datetime(2024, 1, 3, 0, 0, 0)
alias_payload = [
{
"id": "snpt_quality",
"aliases": [{"text": "质量检查"}],
"keywords": ["quality"],
}
]
snippet_payload = [
{
"id": "snpt_quality",
"title": "质量检查",
"keywords": ["data-quality"],
"aliases": [{"text": "质量检查"}],
}
]
_insert_action_row(
mysql_engine,
table_id=table_id,
version_ts=version_ts,
action_type="snippet_alias",
snippet_json=None,
snippet_alias_json=alias_payload,
updated_at=alias_updated,
)
_insert_action_row(
mysql_engine,
table_id=table_id,
version_ts=version_ts,
action_type="snippet",
snippet_json=snippet_payload,
snippet_alias_json=None,
updated_at=datetime(2024, 1, 2, 0, 0, 0),
)
try:
merged = merge_snippet_records_from_db(table_id, version_ts, engine=mysql_engine)
assert len(merged) == 1
record = merged[0]
assert record["id"] == "snpt_quality"
assert record["source"] == "snippet"
assert record["updated_at_from_action"] == alias_updated
assert set(record["keywords"]) == {"data-quality", "quality"}
assert {a["text"] for a in record["aliases"]} == {"质量检查"}
finally:
_cleanup(mysql_engine, table_id, version_ts)