file和demo
This commit is contained in:
10001
file/dataset/ecommerce_orders_clean.csv
Normal file
10001
file/dataset/ecommerce_orders_clean.csv
Normal file
File diff suppressed because it is too large
Load Diff
226
scripts/huggingface_download.py
Normal file
226
scripts/huggingface_download.py
Normal file
@ -0,0 +1,226 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
|
||||
import datasets
|
||||
from datasets import DownloadConfig
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# 批量下载 Hugging Face 上的数据集和模型
|
||||
# 支持通过命令行参数配置代理和下载参数,如超时和重试次数,支持批量循环下载,存储到file目录下dataset和model子目录
|
||||
|
||||
|
||||
def _parse_id_list(values: Iterable[str]) -> List[str]:
|
||||
"""将多次传入以及逗号分隔的标识整理为列表."""
|
||||
ids: List[str] = []
|
||||
for value in values:
|
||||
value = value.strip()
|
||||
if not value:
|
||||
continue
|
||||
if "," in value:
|
||||
ids.extend(v.strip() for v in value.split(",") if v.strip())
|
||||
else:
|
||||
ids.append(value)
|
||||
return ids
|
||||
|
||||
|
||||
def _parse_proxy_args(proxy_args: Iterable[str]) -> Dict[str, str]:
|
||||
"""解析命令行传入的代理设置,格式 scheme=url."""
|
||||
proxies: Dict[str, str] = {}
|
||||
for item in proxy_args:
|
||||
raw = item.strip()
|
||||
if not raw:
|
||||
continue
|
||||
if "=" not in raw:
|
||||
logging.warning("代理参数 %s 缺少 '=' 分隔符,将忽略该项", raw)
|
||||
continue
|
||||
key, value = raw.split("=", 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
if not key or not value:
|
||||
logging.warning("代理参数 %s 解析失败,将忽略该项", raw)
|
||||
continue
|
||||
proxies[key] = value
|
||||
return proxies
|
||||
|
||||
|
||||
def _sanitize_dir_name(name: str) -> str:
|
||||
return name.replace("/", "__")
|
||||
|
||||
|
||||
def _ensure_dirs(root_dir: str) -> Dict[str, str]:
|
||||
paths = {
|
||||
"dataset": os.path.join(root_dir, "dataset"),
|
||||
"model": os.path.join(root_dir, "model"),
|
||||
}
|
||||
for path in paths.values():
|
||||
os.makedirs(path, exist_ok=True)
|
||||
return paths
|
||||
|
||||
|
||||
def _build_download_config(cache_dir: str, retries: Optional[int], proxies: Dict[str, str]) -> DownloadConfig:
|
||||
config_kwargs = {"cache_dir": cache_dir}
|
||||
if retries is not None:
|
||||
config_kwargs["max_retries"] = retries
|
||||
if proxies:
|
||||
config_kwargs["proxies"] = proxies
|
||||
return DownloadConfig(**config_kwargs)
|
||||
|
||||
|
||||
def _apply_timeout(timeout: Optional[float]) -> None:
|
||||
if timeout is None:
|
||||
return
|
||||
str_timeout = str(timeout)
|
||||
os.environ.setdefault("HF_DATASETS_HTTP_TIMEOUT", str_timeout)
|
||||
os.environ.setdefault("HF_HUB_HTTP_TIMEOUT", str_timeout)
|
||||
|
||||
|
||||
def _resolve_log_level(level_name: str) -> int:
|
||||
if isinstance(level_name, int):
|
||||
return level_name
|
||||
upper_name = str(level_name).upper()
|
||||
return getattr(logging, upper_name, logging.INFO)
|
||||
|
||||
|
||||
def _build_argument_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="批量下载 Hugging Face 数据集和模型并存储到指定目录。"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--dataset",
|
||||
action="append",
|
||||
default=[],
|
||||
help="要下载的数据集 ID,可重复使用或传入逗号分隔列表。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--model",
|
||||
action="append",
|
||||
default=[],
|
||||
help="要下载的模型 ID,可重复使用或传入逗号分隔列表。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--root",
|
||||
default="file",
|
||||
help="存储根目录,默认 file。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--retries",
|
||||
type=int,
|
||||
default=None,
|
||||
help="失败后的重试次数,默认不重试。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=float,
|
||||
default=None,
|
||||
help="HTTP 超时时间(秒),默认跟随库设置。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--proxy",
|
||||
action="append",
|
||||
default=[],
|
||||
help="代理设置,格式 scheme=url,可多次传入,例如 --proxy http=http://127.0.0.1:7890",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
default="INFO",
|
||||
help="日志级别,默认 INFO。",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def download_datasets(dataset_ids: Iterable[str], root_dir: str, retries: Optional[int], proxies: Dict[str, str]) -> None:
|
||||
if not dataset_ids:
|
||||
return
|
||||
cache_dir = root_dir
|
||||
download_config = _build_download_config(cache_dir, retries, proxies)
|
||||
for dataset_id in dataset_ids:
|
||||
try:
|
||||
logging.info("开始下载数据集 %s", dataset_id)
|
||||
# 使用 load_dataset 触发缓存下载
|
||||
dataset = datasets.load_dataset(
|
||||
dataset_id,
|
||||
cache_dir=cache_dir,
|
||||
download_config=download_config,
|
||||
download_mode="reuse_cache_if_exists",
|
||||
)
|
||||
target_path = os.path.join(root_dir, _sanitize_dir_name(dataset_id))
|
||||
dataset.save_to_disk(target_path)
|
||||
logging.info("数据集 %s 下载完成,存储于 %s", dataset_id, target_path)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
logging.error("下载数据集 %s 失败: %s", dataset_id, exc)
|
||||
|
||||
|
||||
def download_models(
|
||||
model_ids: Iterable[str],
|
||||
target_dir: str,
|
||||
retries: Optional[int],
|
||||
proxies: Dict[str, str],
|
||||
timeout: Optional[float],
|
||||
) -> None:
|
||||
if not model_ids:
|
||||
return
|
||||
max_attempts = (retries or 0) + 1
|
||||
hub_kwargs = {
|
||||
"local_dir": target_dir,
|
||||
"local_dir_use_symlinks": False,
|
||||
"max_workers": os.cpu_count() or 4,
|
||||
}
|
||||
if proxies:
|
||||
hub_kwargs["proxies"] = proxies
|
||||
if timeout is not None:
|
||||
hub_kwargs["timeout"] = timeout
|
||||
for model_id in model_ids:
|
||||
attempt = 0
|
||||
while attempt < max_attempts:
|
||||
attempt += 1
|
||||
try:
|
||||
logging.info("开始下载模型 %s (尝试 %s/%s)", model_id, attempt, max_attempts)
|
||||
snapshot_download(
|
||||
repo_id=model_id,
|
||||
**hub_kwargs,
|
||||
)
|
||||
logging.info("模型 %s 下载完成,存储于 %s", model_id, target_dir)
|
||||
break
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
logging.error("下载模型 %s 失败: %s", model_id, exc)
|
||||
if attempt >= max_attempts:
|
||||
logging.error("模型 %s 在重试后仍未成功下载", model_id)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = _build_argument_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
level=_resolve_log_level(args.log_level),
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
dataset_ids = _parse_id_list(args.dataset)
|
||||
model_ids = _parse_id_list(args.model)
|
||||
retries = args.retries
|
||||
timeout = args.timeout
|
||||
proxies = _parse_proxy_args(args.proxy)
|
||||
_apply_timeout(timeout)
|
||||
|
||||
if not dataset_ids and not model_ids:
|
||||
logging.warning(
|
||||
"未配置任何数据集或模型,"
|
||||
"请通过参数 --dataset / --model 指定 Hugging Face ID"
|
||||
)
|
||||
return
|
||||
|
||||
dirs = _ensure_dirs(args.root)
|
||||
|
||||
download_datasets(dataset_ids, dirs["dataset"], retries, proxies)
|
||||
download_models(model_ids, dirs["model"], retries, proxies, timeout)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
80
scripts/table_snippet_demo.py
Normal file
80
scripts/table_snippet_demo.py
Normal file
@ -0,0 +1,80 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def build_demo_payload() -> Dict[str, Any]:
|
||||
now = datetime.utcnow()
|
||||
started_at = now.replace(microsecond=0).isoformat() + "Z"
|
||||
finished_at = now.replace(microsecond=0).isoformat() + "Z"
|
||||
return {
|
||||
"table_id": 42,
|
||||
"version_ts": 20251101200000,
|
||||
"action_type": "snippet",
|
||||
"status": "success",
|
||||
"callback_url": "http://localhost:9999/dummy-callback",
|
||||
"table_schema_version_id": 7,
|
||||
"table_schema": {
|
||||
"columns": [
|
||||
{"name": "order_id", "type": "bigint"},
|
||||
{"name": "order_dt", "type": "date"},
|
||||
{"name": "gmv", "type": "decimal(18,2)"},
|
||||
]
|
||||
},
|
||||
"result_json": [
|
||||
{
|
||||
"id": "snpt_daily_gmv",
|
||||
"title": "按日GMV",
|
||||
"desc": "统计每日GMV总额",
|
||||
"type": "trend",
|
||||
"dialect_sql": {
|
||||
"mysql": "SELECT order_dt, SUM(gmv) AS total_gmv FROM orders GROUP BY order_dt ORDER BY order_dt"
|
||||
},
|
||||
}
|
||||
],
|
||||
"result_summary_json": {"total_snippets": 1},
|
||||
"html_report_url": None,
|
||||
"error_code": None,
|
||||
"error_message": None,
|
||||
"started_at": started_at,
|
||||
"finished_at": finished_at,
|
||||
"duration_ms": 1234,
|
||||
"result_checksum": "demo-checksum",
|
||||
}
|
||||
|
||||
|
||||
def main() -> int:
|
||||
base_url = os.getenv("TABLE_SNIPPET_DEMO_BASE_URL", "http://localhost:8000")
|
||||
endpoint = f"{base_url.rstrip('/')}/v1/table/snippet"
|
||||
payload = build_demo_payload()
|
||||
|
||||
print(f"POST {endpoint}")
|
||||
print(json.dumps(payload, ensure_ascii=False, indent=2))
|
||||
|
||||
try:
|
||||
response = requests.post(endpoint, json=payload, timeout=30)
|
||||
except requests.RequestException as exc:
|
||||
print(f"Request failed: {exc}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
print(f"\nStatus: {response.status_code}")
|
||||
|
||||
try:
|
||||
data = response.json()
|
||||
print("Response JSON:")
|
||||
print(json.dumps(data, ensure_ascii=False, indent=2))
|
||||
except ValueError:
|
||||
print("Response Text:")
|
||||
print(response.text)
|
||||
|
||||
return 0 if response.ok else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
Reference in New Issue
Block a user