227 lines
7.0 KiB
Python
227 lines
7.0 KiB
Python
import argparse
|
||
import logging
|
||
import os
|
||
from typing import Dict, Iterable, List, Optional
|
||
|
||
import datasets
|
||
from datasets import DownloadConfig
|
||
from huggingface_hub import snapshot_download
|
||
|
||
# 批量下载 Hugging Face 上的数据集和模型
|
||
# 支持通过命令行参数配置代理和下载参数,如超时和重试次数,支持批量循环下载,存储到file目录下dataset和model子目录
|
||
|
||
|
||
def _parse_id_list(values: Iterable[str]) -> List[str]:
|
||
"""将多次传入以及逗号分隔的标识整理为列表."""
|
||
ids: List[str] = []
|
||
for value in values:
|
||
value = value.strip()
|
||
if not value:
|
||
continue
|
||
if "," in value:
|
||
ids.extend(v.strip() for v in value.split(",") if v.strip())
|
||
else:
|
||
ids.append(value)
|
||
return ids
|
||
|
||
|
||
def _parse_proxy_args(proxy_args: Iterable[str]) -> Dict[str, str]:
|
||
"""解析命令行传入的代理设置,格式 scheme=url."""
|
||
proxies: Dict[str, str] = {}
|
||
for item in proxy_args:
|
||
raw = item.strip()
|
||
if not raw:
|
||
continue
|
||
if "=" not in raw:
|
||
logging.warning("代理参数 %s 缺少 '=' 分隔符,将忽略该项", raw)
|
||
continue
|
||
key, value = raw.split("=", 1)
|
||
key = key.strip()
|
||
value = value.strip()
|
||
if not key or not value:
|
||
logging.warning("代理参数 %s 解析失败,将忽略该项", raw)
|
||
continue
|
||
proxies[key] = value
|
||
return proxies
|
||
|
||
|
||
def _sanitize_dir_name(name: str) -> str:
|
||
return name.replace("/", "__")
|
||
|
||
|
||
def _ensure_dirs(root_dir: str) -> Dict[str, str]:
|
||
paths = {
|
||
"dataset": os.path.join(root_dir, "dataset"),
|
||
"model": os.path.join(root_dir, "model"),
|
||
}
|
||
for path in paths.values():
|
||
os.makedirs(path, exist_ok=True)
|
||
return paths
|
||
|
||
|
||
def _build_download_config(cache_dir: str, retries: Optional[int], proxies: Dict[str, str]) -> DownloadConfig:
|
||
config_kwargs = {"cache_dir": cache_dir}
|
||
if retries is not None:
|
||
config_kwargs["max_retries"] = retries
|
||
if proxies:
|
||
config_kwargs["proxies"] = proxies
|
||
return DownloadConfig(**config_kwargs)
|
||
|
||
|
||
def _apply_timeout(timeout: Optional[float]) -> None:
|
||
if timeout is None:
|
||
return
|
||
str_timeout = str(timeout)
|
||
os.environ.setdefault("HF_DATASETS_HTTP_TIMEOUT", str_timeout)
|
||
os.environ.setdefault("HF_HUB_HTTP_TIMEOUT", str_timeout)
|
||
|
||
|
||
def _resolve_log_level(level_name: str) -> int:
|
||
if isinstance(level_name, int):
|
||
return level_name
|
||
upper_name = str(level_name).upper()
|
||
return getattr(logging, upper_name, logging.INFO)
|
||
|
||
|
||
def _build_argument_parser() -> argparse.ArgumentParser:
|
||
parser = argparse.ArgumentParser(
|
||
description="批量下载 Hugging Face 数据集和模型并存储到指定目录。"
|
||
)
|
||
parser.add_argument(
|
||
"-d",
|
||
"--dataset",
|
||
action="append",
|
||
default=[],
|
||
help="要下载的数据集 ID,可重复使用或传入逗号分隔列表。",
|
||
)
|
||
parser.add_argument(
|
||
"-m",
|
||
"--model",
|
||
action="append",
|
||
default=[],
|
||
help="要下载的模型 ID,可重复使用或传入逗号分隔列表。",
|
||
)
|
||
parser.add_argument(
|
||
"-r",
|
||
"--root",
|
||
default="file",
|
||
help="存储根目录,默认 file。",
|
||
)
|
||
parser.add_argument(
|
||
"--retries",
|
||
type=int,
|
||
default=None,
|
||
help="失败后的重试次数,默认不重试。",
|
||
)
|
||
parser.add_argument(
|
||
"--timeout",
|
||
type=float,
|
||
default=None,
|
||
help="HTTP 超时时间(秒),默认跟随库设置。",
|
||
)
|
||
parser.add_argument(
|
||
"-p",
|
||
"--proxy",
|
||
action="append",
|
||
default=[],
|
||
help="代理设置,格式 scheme=url,可多次传入,例如 --proxy http=http://127.0.0.1:7890",
|
||
)
|
||
parser.add_argument(
|
||
"--log-level",
|
||
default="INFO",
|
||
help="日志级别,默认 INFO。",
|
||
)
|
||
return parser
|
||
|
||
|
||
def download_datasets(dataset_ids: Iterable[str], root_dir: str, retries: Optional[int], proxies: Dict[str, str]) -> None:
|
||
if not dataset_ids:
|
||
return
|
||
cache_dir = root_dir
|
||
download_config = _build_download_config(cache_dir, retries, proxies)
|
||
for dataset_id in dataset_ids:
|
||
try:
|
||
logging.info("开始下载数据集 %s", dataset_id)
|
||
# 使用 load_dataset 触发缓存下载
|
||
dataset = datasets.load_dataset(
|
||
dataset_id,
|
||
cache_dir=cache_dir,
|
||
download_config=download_config,
|
||
download_mode="reuse_cache_if_exists",
|
||
)
|
||
target_path = os.path.join(root_dir, _sanitize_dir_name(dataset_id))
|
||
dataset.save_to_disk(target_path)
|
||
logging.info("数据集 %s 下载完成,存储于 %s", dataset_id, target_path)
|
||
except Exception as exc: # pylint: disable=broad-except
|
||
logging.error("下载数据集 %s 失败: %s", dataset_id, exc)
|
||
|
||
|
||
def download_models(
|
||
model_ids: Iterable[str],
|
||
target_dir: str,
|
||
retries: Optional[int],
|
||
proxies: Dict[str, str],
|
||
timeout: Optional[float],
|
||
) -> None:
|
||
if not model_ids:
|
||
return
|
||
max_attempts = (retries or 0) + 1
|
||
hub_kwargs = {
|
||
"local_dir": target_dir,
|
||
"local_dir_use_symlinks": False,
|
||
"max_workers": os.cpu_count() or 4,
|
||
}
|
||
if proxies:
|
||
hub_kwargs["proxies"] = proxies
|
||
if timeout is not None:
|
||
hub_kwargs["timeout"] = timeout
|
||
for model_id in model_ids:
|
||
attempt = 0
|
||
while attempt < max_attempts:
|
||
attempt += 1
|
||
try:
|
||
logging.info("开始下载模型 %s (尝试 %s/%s)", model_id, attempt, max_attempts)
|
||
snapshot_download(
|
||
repo_id=model_id,
|
||
**hub_kwargs,
|
||
)
|
||
logging.info("模型 %s 下载完成,存储于 %s", model_id, target_dir)
|
||
break
|
||
except Exception as exc: # pylint: disable=broad-except
|
||
logging.error("下载模型 %s 失败: %s", model_id, exc)
|
||
if attempt >= max_attempts:
|
||
logging.error("模型 %s 在重试后仍未成功下载", model_id)
|
||
|
||
|
||
def main() -> None:
|
||
parser = _build_argument_parser()
|
||
args = parser.parse_args()
|
||
|
||
logging.basicConfig(
|
||
level=_resolve_log_level(args.log_level),
|
||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||
)
|
||
|
||
dataset_ids = _parse_id_list(args.dataset)
|
||
model_ids = _parse_id_list(args.model)
|
||
retries = args.retries
|
||
timeout = args.timeout
|
||
proxies = _parse_proxy_args(args.proxy)
|
||
_apply_timeout(timeout)
|
||
|
||
if not dataset_ids and not model_ids:
|
||
logging.warning(
|
||
"未配置任何数据集或模型,"
|
||
"请通过参数 --dataset / --model 指定 Hugging Face ID"
|
||
)
|
||
return
|
||
|
||
dirs = _ensure_dirs(args.root)
|
||
|
||
download_datasets(dataset_ids, dirs["dataset"], retries, proxies)
|
||
download_models(model_ids, dirs["model"], retries, proxies, timeout)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|