file和demo
This commit is contained in:
10001
file/dataset/ecommerce_orders_clean.csv
Normal file
10001
file/dataset/ecommerce_orders_clean.csv
Normal file
File diff suppressed because it is too large
Load Diff
226
scripts/huggingface_download.py
Normal file
226
scripts/huggingface_download.py
Normal file
@ -0,0 +1,226 @@
|
|||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Dict, Iterable, List, Optional
|
||||||
|
|
||||||
|
import datasets
|
||||||
|
from datasets import DownloadConfig
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
# 批量下载 Hugging Face 上的数据集和模型
|
||||||
|
# 支持通过命令行参数配置代理和下载参数,如超时和重试次数,支持批量循环下载,存储到file目录下dataset和model子目录
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_id_list(values: Iterable[str]) -> List[str]:
|
||||||
|
"""将多次传入以及逗号分隔的标识整理为列表."""
|
||||||
|
ids: List[str] = []
|
||||||
|
for value in values:
|
||||||
|
value = value.strip()
|
||||||
|
if not value:
|
||||||
|
continue
|
||||||
|
if "," in value:
|
||||||
|
ids.extend(v.strip() for v in value.split(",") if v.strip())
|
||||||
|
else:
|
||||||
|
ids.append(value)
|
||||||
|
return ids
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_proxy_args(proxy_args: Iterable[str]) -> Dict[str, str]:
|
||||||
|
"""解析命令行传入的代理设置,格式 scheme=url."""
|
||||||
|
proxies: Dict[str, str] = {}
|
||||||
|
for item in proxy_args:
|
||||||
|
raw = item.strip()
|
||||||
|
if not raw:
|
||||||
|
continue
|
||||||
|
if "=" not in raw:
|
||||||
|
logging.warning("代理参数 %s 缺少 '=' 分隔符,将忽略该项", raw)
|
||||||
|
continue
|
||||||
|
key, value = raw.split("=", 1)
|
||||||
|
key = key.strip()
|
||||||
|
value = value.strip()
|
||||||
|
if not key or not value:
|
||||||
|
logging.warning("代理参数 %s 解析失败,将忽略该项", raw)
|
||||||
|
continue
|
||||||
|
proxies[key] = value
|
||||||
|
return proxies
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_dir_name(name: str) -> str:
|
||||||
|
return name.replace("/", "__")
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_dirs(root_dir: str) -> Dict[str, str]:
|
||||||
|
paths = {
|
||||||
|
"dataset": os.path.join(root_dir, "dataset"),
|
||||||
|
"model": os.path.join(root_dir, "model"),
|
||||||
|
}
|
||||||
|
for path in paths.values():
|
||||||
|
os.makedirs(path, exist_ok=True)
|
||||||
|
return paths
|
||||||
|
|
||||||
|
|
||||||
|
def _build_download_config(cache_dir: str, retries: Optional[int], proxies: Dict[str, str]) -> DownloadConfig:
|
||||||
|
config_kwargs = {"cache_dir": cache_dir}
|
||||||
|
if retries is not None:
|
||||||
|
config_kwargs["max_retries"] = retries
|
||||||
|
if proxies:
|
||||||
|
config_kwargs["proxies"] = proxies
|
||||||
|
return DownloadConfig(**config_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_timeout(timeout: Optional[float]) -> None:
|
||||||
|
if timeout is None:
|
||||||
|
return
|
||||||
|
str_timeout = str(timeout)
|
||||||
|
os.environ.setdefault("HF_DATASETS_HTTP_TIMEOUT", str_timeout)
|
||||||
|
os.environ.setdefault("HF_HUB_HTTP_TIMEOUT", str_timeout)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_log_level(level_name: str) -> int:
|
||||||
|
if isinstance(level_name, int):
|
||||||
|
return level_name
|
||||||
|
upper_name = str(level_name).upper()
|
||||||
|
return getattr(logging, upper_name, logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_argument_parser() -> argparse.ArgumentParser:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="批量下载 Hugging Face 数据集和模型并存储到指定目录。"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-d",
|
||||||
|
"--dataset",
|
||||||
|
action="append",
|
||||||
|
default=[],
|
||||||
|
help="要下载的数据集 ID,可重复使用或传入逗号分隔列表。",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-m",
|
||||||
|
"--model",
|
||||||
|
action="append",
|
||||||
|
default=[],
|
||||||
|
help="要下载的模型 ID,可重复使用或传入逗号分隔列表。",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-r",
|
||||||
|
"--root",
|
||||||
|
default="file",
|
||||||
|
help="存储根目录,默认 file。",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--retries",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="失败后的重试次数,默认不重试。",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--timeout",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="HTTP 超时时间(秒),默认跟随库设置。",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-p",
|
||||||
|
"--proxy",
|
||||||
|
action="append",
|
||||||
|
default=[],
|
||||||
|
help="代理设置,格式 scheme=url,可多次传入,例如 --proxy http=http://127.0.0.1:7890",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--log-level",
|
||||||
|
default="INFO",
|
||||||
|
help="日志级别,默认 INFO。",
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def download_datasets(dataset_ids: Iterable[str], root_dir: str, retries: Optional[int], proxies: Dict[str, str]) -> None:
|
||||||
|
if not dataset_ids:
|
||||||
|
return
|
||||||
|
cache_dir = root_dir
|
||||||
|
download_config = _build_download_config(cache_dir, retries, proxies)
|
||||||
|
for dataset_id in dataset_ids:
|
||||||
|
try:
|
||||||
|
logging.info("开始下载数据集 %s", dataset_id)
|
||||||
|
# 使用 load_dataset 触发缓存下载
|
||||||
|
dataset = datasets.load_dataset(
|
||||||
|
dataset_id,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
download_config=download_config,
|
||||||
|
download_mode="reuse_cache_if_exists",
|
||||||
|
)
|
||||||
|
target_path = os.path.join(root_dir, _sanitize_dir_name(dataset_id))
|
||||||
|
dataset.save_to_disk(target_path)
|
||||||
|
logging.info("数据集 %s 下载完成,存储于 %s", dataset_id, target_path)
|
||||||
|
except Exception as exc: # pylint: disable=broad-except
|
||||||
|
logging.error("下载数据集 %s 失败: %s", dataset_id, exc)
|
||||||
|
|
||||||
|
|
||||||
|
def download_models(
|
||||||
|
model_ids: Iterable[str],
|
||||||
|
target_dir: str,
|
||||||
|
retries: Optional[int],
|
||||||
|
proxies: Dict[str, str],
|
||||||
|
timeout: Optional[float],
|
||||||
|
) -> None:
|
||||||
|
if not model_ids:
|
||||||
|
return
|
||||||
|
max_attempts = (retries or 0) + 1
|
||||||
|
hub_kwargs = {
|
||||||
|
"local_dir": target_dir,
|
||||||
|
"local_dir_use_symlinks": False,
|
||||||
|
"max_workers": os.cpu_count() or 4,
|
||||||
|
}
|
||||||
|
if proxies:
|
||||||
|
hub_kwargs["proxies"] = proxies
|
||||||
|
if timeout is not None:
|
||||||
|
hub_kwargs["timeout"] = timeout
|
||||||
|
for model_id in model_ids:
|
||||||
|
attempt = 0
|
||||||
|
while attempt < max_attempts:
|
||||||
|
attempt += 1
|
||||||
|
try:
|
||||||
|
logging.info("开始下载模型 %s (尝试 %s/%s)", model_id, attempt, max_attempts)
|
||||||
|
snapshot_download(
|
||||||
|
repo_id=model_id,
|
||||||
|
**hub_kwargs,
|
||||||
|
)
|
||||||
|
logging.info("模型 %s 下载完成,存储于 %s", model_id, target_dir)
|
||||||
|
break
|
||||||
|
except Exception as exc: # pylint: disable=broad-except
|
||||||
|
logging.error("下载模型 %s 失败: %s", model_id, exc)
|
||||||
|
if attempt >= max_attempts:
|
||||||
|
logging.error("模型 %s 在重试后仍未成功下载", model_id)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = _build_argument_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=_resolve_log_level(args.log_level),
|
||||||
|
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_ids = _parse_id_list(args.dataset)
|
||||||
|
model_ids = _parse_id_list(args.model)
|
||||||
|
retries = args.retries
|
||||||
|
timeout = args.timeout
|
||||||
|
proxies = _parse_proxy_args(args.proxy)
|
||||||
|
_apply_timeout(timeout)
|
||||||
|
|
||||||
|
if not dataset_ids and not model_ids:
|
||||||
|
logging.warning(
|
||||||
|
"未配置任何数据集或模型,"
|
||||||
|
"请通过参数 --dataset / --model 指定 Hugging Face ID"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
dirs = _ensure_dirs(args.root)
|
||||||
|
|
||||||
|
download_datasets(dataset_ids, dirs["dataset"], retries, proxies)
|
||||||
|
download_models(model_ids, dirs["model"], retries, proxies, timeout)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
80
scripts/table_snippet_demo.py
Normal file
80
scripts/table_snippet_demo.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
def build_demo_payload() -> Dict[str, Any]:
|
||||||
|
now = datetime.utcnow()
|
||||||
|
started_at = now.replace(microsecond=0).isoformat() + "Z"
|
||||||
|
finished_at = now.replace(microsecond=0).isoformat() + "Z"
|
||||||
|
return {
|
||||||
|
"table_id": 42,
|
||||||
|
"version_ts": 20251101200000,
|
||||||
|
"action_type": "snippet",
|
||||||
|
"status": "success",
|
||||||
|
"callback_url": "http://localhost:9999/dummy-callback",
|
||||||
|
"table_schema_version_id": 7,
|
||||||
|
"table_schema": {
|
||||||
|
"columns": [
|
||||||
|
{"name": "order_id", "type": "bigint"},
|
||||||
|
{"name": "order_dt", "type": "date"},
|
||||||
|
{"name": "gmv", "type": "decimal(18,2)"},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"result_json": [
|
||||||
|
{
|
||||||
|
"id": "snpt_daily_gmv",
|
||||||
|
"title": "按日GMV",
|
||||||
|
"desc": "统计每日GMV总额",
|
||||||
|
"type": "trend",
|
||||||
|
"dialect_sql": {
|
||||||
|
"mysql": "SELECT order_dt, SUM(gmv) AS total_gmv FROM orders GROUP BY order_dt ORDER BY order_dt"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"result_summary_json": {"total_snippets": 1},
|
||||||
|
"html_report_url": None,
|
||||||
|
"error_code": None,
|
||||||
|
"error_message": None,
|
||||||
|
"started_at": started_at,
|
||||||
|
"finished_at": finished_at,
|
||||||
|
"duration_ms": 1234,
|
||||||
|
"result_checksum": "demo-checksum",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
base_url = os.getenv("TABLE_SNIPPET_DEMO_BASE_URL", "http://localhost:8000")
|
||||||
|
endpoint = f"{base_url.rstrip('/')}/v1/table/snippet"
|
||||||
|
payload = build_demo_payload()
|
||||||
|
|
||||||
|
print(f"POST {endpoint}")
|
||||||
|
print(json.dumps(payload, ensure_ascii=False, indent=2))
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(endpoint, json=payload, timeout=30)
|
||||||
|
except requests.RequestException as exc:
|
||||||
|
print(f"Request failed: {exc}", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
print(f"\nStatus: {response.status_code}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = response.json()
|
||||||
|
print("Response JSON:")
|
||||||
|
print(json.dumps(data, ensure_ascii=False, indent=2))
|
||||||
|
except ValueError:
|
||||||
|
print("Response Text:")
|
||||||
|
print(response.text)
|
||||||
|
|
||||||
|
return 0 if response.ok else 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
Reference in New Issue
Block a user