import argparse
import os

import gdown

URLS = {
    "checkpoints": {
        "ViT-B-32": "16xmoeMqNf1JlICqjMUEgHiH6mHn0G1f0",
        "ViT-B-16": "1dmLqIwCYPqO0qsNEuflPhC5-_QodQHDv",
        "ViT-L-14": "1QFBcz79RqXUAkEVlYOItt6VJIcf1Ppq6",
    },
    "tall_masks": {
        "ViT-B-32": "1jpqsurrAdD5bn9i7pRBsTq4E3ayPb0YK",
        "ViT-B-16": "1jYNsdeFz6vlwIl5s4T48zeTm7FdwTSL9",
        "ViT-L-14": "16GVDMhpScmM3zeyfubjmNRFLiFKkBE7p",
    },
}

parser = argparse.ArgumentParser("Download checkpoints for Vision Transformer models")
parser.add_argument(
    "--model",
    type=str,
    required=True,
    help="Model type to download",
    choices=["ViT-B-32", "ViT-B-16", "ViT-L-14"],
)

parser.add_argument(
    "--kind",
    type=str,
    required=True,
    help="Kind of download: checkpoints refer to single-task fine-tuned models and tall_masks for the per-task binary"
    + " masks generated by the 'TALL-masks' method.",
    choices=["checkpoints", "tall_masks"],
)


if __name__ == "__main__":
    args = parser.parse_args()
    url = URLS[args.kind][args.model]
    url = "https://drive.google.com/drive/u/1/folders/" + url
    gdown.download_folder(url, output=os.path.join(args.kind, args.model), quiet=False)
