import argparse import logging import os from typing import Dict, Iterable, List, Optional import datasets from datasets import DownloadConfig from huggingface_hub import snapshot_download # 批量下载 Hugging Face 上的数据集和模型 # 支持通过命令行参数配置代理和下载参数,如超时和重试次数,支持批量循环下载,存储到file目录下dataset和model子目录 def _parse_id_list(values: Iterable[str]) -> List[str]: """将多次传入以及逗号分隔的标识整理为列表.""" ids: List[str] = [] for value in values: value = value.strip() if not value: continue if "," in value: ids.extend(v.strip() for v in value.split(",") if v.strip()) else: ids.append(value) return ids def _parse_proxy_args(proxy_args: Iterable[str]) -> Dict[str, str]: """解析命令行传入的代理设置,格式 scheme=url.""" proxies: Dict[str, str] = {} for item in proxy_args: raw = item.strip() if not raw: continue if "=" not in raw: logging.warning("代理参数 %s 缺少 '=' 分隔符,将忽略该项", raw) continue key, value = raw.split("=", 1) key = key.strip() value = value.strip() if not key or not value: logging.warning("代理参数 %s 解析失败,将忽略该项", raw) continue proxies[key] = value return proxies def _sanitize_dir_name(name: str) -> str: return name.replace("/", "__") def _ensure_dirs(root_dir: str) -> Dict[str, str]: paths = { "dataset": os.path.join(root_dir, "dataset"), "model": os.path.join(root_dir, "model"), } for path in paths.values(): os.makedirs(path, exist_ok=True) return paths def _build_download_config(cache_dir: str, retries: Optional[int], proxies: Dict[str, str]) -> DownloadConfig: config_kwargs = {"cache_dir": cache_dir} if retries is not None: config_kwargs["max_retries"] = retries if proxies: config_kwargs["proxies"] = proxies return DownloadConfig(**config_kwargs) def _apply_timeout(timeout: Optional[float]) -> None: if timeout is None: return str_timeout = str(timeout) os.environ.setdefault("HF_DATASETS_HTTP_TIMEOUT", str_timeout) os.environ.setdefault("HF_HUB_HTTP_TIMEOUT", str_timeout) def _resolve_log_level(level_name: str) -> int: if isinstance(level_name, int): return level_name upper_name = str(level_name).upper() return getattr(logging, upper_name, logging.INFO) def _build_argument_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="批量下载 Hugging Face 数据集和模型并存储到指定目录。" ) parser.add_argument( "-d", "--dataset", action="append", default=[], help="要下载的数据集 ID,可重复使用或传入逗号分隔列表。", ) parser.add_argument( "-m", "--model", action="append", default=[], help="要下载的模型 ID,可重复使用或传入逗号分隔列表。", ) parser.add_argument( "-r", "--root", default="file", help="存储根目录,默认 file。", ) parser.add_argument( "--retries", type=int, default=None, help="失败后的重试次数,默认不重试。", ) parser.add_argument( "--timeout", type=float, default=None, help="HTTP 超时时间(秒),默认跟随库设置。", ) parser.add_argument( "-p", "--proxy", action="append", default=[], help="代理设置,格式 scheme=url,可多次传入,例如 --proxy http=http://127.0.0.1:7890", ) parser.add_argument( "--log-level", default="INFO", help="日志级别,默认 INFO。", ) return parser def download_datasets(dataset_ids: Iterable[str], root_dir: str, retries: Optional[int], proxies: Dict[str, str]) -> None: if not dataset_ids: return cache_dir = root_dir download_config = _build_download_config(cache_dir, retries, proxies) for dataset_id in dataset_ids: try: logging.info("开始下载数据集 %s", dataset_id) # 使用 load_dataset 触发缓存下载 dataset = datasets.load_dataset( dataset_id, cache_dir=cache_dir, download_config=download_config, download_mode="reuse_cache_if_exists", ) target_path = os.path.join(root_dir, _sanitize_dir_name(dataset_id)) dataset.save_to_disk(target_path) logging.info("数据集 %s 下载完成,存储于 %s", dataset_id, target_path) except Exception as exc: # pylint: disable=broad-except logging.error("下载数据集 %s 失败: %s", dataset_id, exc) def download_models( model_ids: Iterable[str], target_dir: str, retries: Optional[int], proxies: Dict[str, str], timeout: Optional[float], ) -> None: if not model_ids: return max_attempts = (retries or 0) + 1 hub_kwargs = { "local_dir": target_dir, "local_dir_use_symlinks": False, "max_workers": os.cpu_count() or 4, } if proxies: hub_kwargs["proxies"] = proxies if timeout is not None: hub_kwargs["timeout"] = timeout for model_id in model_ids: attempt = 0 while attempt < max_attempts: attempt += 1 try: logging.info("开始下载模型 %s (尝试 %s/%s)", model_id, attempt, max_attempts) snapshot_download( repo_id=model_id, **hub_kwargs, ) logging.info("模型 %s 下载完成,存储于 %s", model_id, target_dir) break except Exception as exc: # pylint: disable=broad-except logging.error("下载模型 %s 失败: %s", model_id, exc) if attempt >= max_attempts: logging.error("模型 %s 在重试后仍未成功下载", model_id) def main() -> None: parser = _build_argument_parser() args = parser.parse_args() logging.basicConfig( level=_resolve_log_level(args.log_level), format="%(asctime)s - %(levelname)s - %(message)s", ) dataset_ids = _parse_id_list(args.dataset) model_ids = _parse_id_list(args.model) retries = args.retries timeout = args.timeout proxies = _parse_proxy_args(args.proxy) _apply_timeout(timeout) if not dataset_ids and not model_ids: logging.warning( "未配置任何数据集或模型," "请通过参数 --dataset / --model 指定 Hugging Face ID" ) return dirs = _ensure_dirs(args.root) download_datasets(dataset_ids, dirs["dataset"], retries, proxies) download_models(model_ids, dirs["model"], retries, proxies, timeout) if __name__ == "__main__": main()