""" Download models for Huggingface Transformers

This is useful for when you want to pre-download models
on an inference server to save time (you can download
one model while using another for batch inference)

You can specify one or more models separated by commas.
Usage:
    pdm run scripts/download_models.py --models=bigcode/starcoderbase-1b,bigcode/starcoderbase
"""

import argparse
import huggingface_hub
from transformers import __version__ as TRANSFORMERS_VERSION
import logging

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def main(args):
    models = args.models.split(",")
    downloaded = []
    for model in models:
        logger.info("Downloading model:", model)
        try:
            repo_files = huggingface_hub.list_repo_files(model)
            for filename in repo_files:
                huggingface_hub.hf_hub_download(
                    model,
                    filename,
                    library_name="transformers",
                    library_version=TRANSFORMERS_VERSION,
                    cache_dir=args.cache_dir,
                )
            downloaded.append(model)
        except Exception as e:
            logger.error(f"Failed to download {model}: {e}")
    logger.info("Downloaded models: %s", downloaded)
    failures = list(set(models) - set(downloaded))
    if failures:
        logger.warning("Failed to download models: %s", failures)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--models",
        type=str,
        required=True,
        help="""One or more models to download. 
        This can be either the checkpoint names or regex patterns to match against models in the registry.""",
    )
    parser.add_argument(
        "--cache-dir", type=str, default=None, help="Cache dir for models"
    )
    args = parser.parse_args()
    main(args)
