Files
data-ge/scripts/huggingface_download.py
2025-11-03 00:20:00 +08:00

227 lines
7.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import logging
import os
from typing import Dict, Iterable, List, Optional
import datasets
from datasets import DownloadConfig
from huggingface_hub import snapshot_download
# 批量下载 Hugging Face 上的数据集和模型
# 支持通过命令行参数配置代理和下载参数如超时和重试次数支持批量循环下载存储到file目录下dataset和model子目录
def _parse_id_list(values: Iterable[str]) -> List[str]:
"""将多次传入以及逗号分隔的标识整理为列表."""
ids: List[str] = []
for value in values:
value = value.strip()
if not value:
continue
if "," in value:
ids.extend(v.strip() for v in value.split(",") if v.strip())
else:
ids.append(value)
return ids
def _parse_proxy_args(proxy_args: Iterable[str]) -> Dict[str, str]:
"""解析命令行传入的代理设置,格式 scheme=url."""
proxies: Dict[str, str] = {}
for item in proxy_args:
raw = item.strip()
if not raw:
continue
if "=" not in raw:
logging.warning("代理参数 %s 缺少 '=' 分隔符,将忽略该项", raw)
continue
key, value = raw.split("=", 1)
key = key.strip()
value = value.strip()
if not key or not value:
logging.warning("代理参数 %s 解析失败,将忽略该项", raw)
continue
proxies[key] = value
return proxies
def _sanitize_dir_name(name: str) -> str:
return name.replace("/", "__")
def _ensure_dirs(root_dir: str) -> Dict[str, str]:
paths = {
"dataset": os.path.join(root_dir, "dataset"),
"model": os.path.join(root_dir, "model"),
}
for path in paths.values():
os.makedirs(path, exist_ok=True)
return paths
def _build_download_config(cache_dir: str, retries: Optional[int], proxies: Dict[str, str]) -> DownloadConfig:
config_kwargs = {"cache_dir": cache_dir}
if retries is not None:
config_kwargs["max_retries"] = retries
if proxies:
config_kwargs["proxies"] = proxies
return DownloadConfig(**config_kwargs)
def _apply_timeout(timeout: Optional[float]) -> None:
if timeout is None:
return
str_timeout = str(timeout)
os.environ.setdefault("HF_DATASETS_HTTP_TIMEOUT", str_timeout)
os.environ.setdefault("HF_HUB_HTTP_TIMEOUT", str_timeout)
def _resolve_log_level(level_name: str) -> int:
if isinstance(level_name, int):
return level_name
upper_name = str(level_name).upper()
return getattr(logging, upper_name, logging.INFO)
def _build_argument_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="批量下载 Hugging Face 数据集和模型并存储到指定目录。"
)
parser.add_argument(
"-d",
"--dataset",
action="append",
default=[],
help="要下载的数据集 ID可重复使用或传入逗号分隔列表。",
)
parser.add_argument(
"-m",
"--model",
action="append",
default=[],
help="要下载的模型 ID可重复使用或传入逗号分隔列表。",
)
parser.add_argument(
"-r",
"--root",
default="file",
help="存储根目录,默认 file。",
)
parser.add_argument(
"--retries",
type=int,
default=None,
help="失败后的重试次数,默认不重试。",
)
parser.add_argument(
"--timeout",
type=float,
default=None,
help="HTTP 超时时间(秒),默认跟随库设置。",
)
parser.add_argument(
"-p",
"--proxy",
action="append",
default=[],
help="代理设置,格式 scheme=url可多次传入例如 --proxy http=http://127.0.0.1:7890",
)
parser.add_argument(
"--log-level",
default="INFO",
help="日志级别,默认 INFO。",
)
return parser
def download_datasets(dataset_ids: Iterable[str], root_dir: str, retries: Optional[int], proxies: Dict[str, str]) -> None:
if not dataset_ids:
return
cache_dir = root_dir
download_config = _build_download_config(cache_dir, retries, proxies)
for dataset_id in dataset_ids:
try:
logging.info("开始下载数据集 %s", dataset_id)
# 使用 load_dataset 触发缓存下载
dataset = datasets.load_dataset(
dataset_id,
cache_dir=cache_dir,
download_config=download_config,
download_mode="reuse_cache_if_exists",
)
target_path = os.path.join(root_dir, _sanitize_dir_name(dataset_id))
dataset.save_to_disk(target_path)
logging.info("数据集 %s 下载完成,存储于 %s", dataset_id, target_path)
except Exception as exc: # pylint: disable=broad-except
logging.error("下载数据集 %s 失败: %s", dataset_id, exc)
def download_models(
model_ids: Iterable[str],
target_dir: str,
retries: Optional[int],
proxies: Dict[str, str],
timeout: Optional[float],
) -> None:
if not model_ids:
return
max_attempts = (retries or 0) + 1
hub_kwargs = {
"local_dir": target_dir,
"local_dir_use_symlinks": False,
"max_workers": os.cpu_count() or 4,
}
if proxies:
hub_kwargs["proxies"] = proxies
if timeout is not None:
hub_kwargs["timeout"] = timeout
for model_id in model_ids:
attempt = 0
while attempt < max_attempts:
attempt += 1
try:
logging.info("开始下载模型 %s (尝试 %s/%s)", model_id, attempt, max_attempts)
snapshot_download(
repo_id=model_id,
**hub_kwargs,
)
logging.info("模型 %s 下载完成,存储于 %s", model_id, target_dir)
break
except Exception as exc: # pylint: disable=broad-except
logging.error("下载模型 %s 失败: %s", model_id, exc)
if attempt >= max_attempts:
logging.error("模型 %s 在重试后仍未成功下载", model_id)
def main() -> None:
parser = _build_argument_parser()
args = parser.parse_args()
logging.basicConfig(
level=_resolve_log_level(args.log_level),
format="%(asctime)s - %(levelname)s - %(message)s",
)
dataset_ids = _parse_id_list(args.dataset)
model_ids = _parse_id_list(args.model)
retries = args.retries
timeout = args.timeout
proxies = _parse_proxy_args(args.proxy)
_apply_timeout(timeout)
if not dataset_ids and not model_ids:
logging.warning(
"未配置任何数据集或模型,"
"请通过参数 --dataset / --model 指定 Hugging Face ID"
)
return
dirs = _ensure_dirs(args.root)
download_datasets(dataset_ids, dirs["dataset"], retries, proxies)
download_models(model_ids, dirs["model"], retries, proxies, timeout)
if __name__ == "__main__":
main()