import argparse
import os
import zipfile

import gdown

benchmarks_dict = {
    "bimcv": [
        "bimcv",
        "ct",
        "xraybone",
        "actmed",
        "mnist",
        "cifar10",
        "texture",
        "tin",
    ],
    "mnist": [
        "mnist",
        "notmnist",
        "fashionmnist",
        "texture",
        "cifar10",
        "tin",
        "places365",
        "cinic10",
    ],
    "cifar-10": [
        "cifar10",
        "cifar100",
        "tin",
        "mnist",
        "svhn",
        "texture",
        "places365",
        "tin597",
    ],
    "cifar-100": [
        "cifar100",
        "cifar10",
        "tin",
        "svhn",
        "texture",
        "places365",
        "tin597",
    ],
    "imagenet-200": [
        "imagenet_1k",
        "ssb_hard",
        "ninco",
        "inaturalist",
        "texture",
        "openimage_o",
        "imagenet_v2",
        "imagenet_c",
        "imagenet_r",
    ],
    "imagenet-1k": [
        "imagenet_1k",
        "ssb_hard",
        "ninco",
        "inaturalist",
        "texture",
        "openimage_o",
        "imagenet_v2",
        "imagenet_c",
        "imagenet_r",
    ],
    "misc": [
        "cifar10c",
        "fractals_and_fvis",
        "usps",
        "imagenet10",
        "hannover",
        # 'imagenet200_cae', 'imagenet200_edsr', 'imagenet200_stylized'
    ],
}

dir_dict = {
    "images_classic/": [
        "cifar100",
        "tin",
        "tin597",
        "svhn",
        "cinic10",
        "imagenet10",
        "mnist",
        "fashionmnist",
        "cifar10",
        "cifar100c",
        "places365",
        "cifar10c",
        "fractals_and_fvis",
        "usps",
        "texture",
        "notmnist",
    ],
    "images_largescale/": [
        "imagenet_1k",
        "species_sub",
        "ssb_hard",
        "ninco",
        "inaturalist",
        "places",
        "sun",
        "openimage_o",
        "imagenet_v2",
        "imagenet_c",
        "imagenet_r",
        # 'imagenet200_cae', 'imagenet200_edsr', 'imagenet200_stylized'
    ],
    "images_medical/": ["actmed", "bimcv", "ct", "hannover", "xraybone"],
}

download_id_dict = {
    "osr": "1L9MpK9QZq-o-JrFHrfo5lM4-FsFPk0e9",
    "mnist_lenet": "13mEvYF9rVIuch8u0RVDLf_JMOk3PAYCj",
    "cifar10_res18": "1rPEScK7TFjBn_W_frO-8RSPwIG6_x0fJ",
    "cifar100_res18": "1OOf88A48yXFw4fSU02XQT-3OQKf31Csy",
    "imagenet_res50": "1tgY_PsfkazLDyI1pniDMDEehntBhFyF3",
    "cifar10_res18_v1.5": "1byGeYxM_PlLjT72wZsMQvP6popJeWBgt",
    "cifar100_res18_v1.5": "1s-1oNrRtmA0pGefxXJOUVRYpaoAML0C-",
    "imagenet200_res18_v1.5": "1ddVmwc8zmzSjdLUO84EuV4Gz1c7vhIAs",
    "imagenet_res50_v1.5": "15PdDMNRfnJ7f2oxW6lI-Ge4QJJH3Z0Fy",
    "benchmark_imglist": "1XKzBdWCqg3vPoj-D32YixJyJJ0hL63gP",
    "usps": "1KhbWhlFlpFjEIb4wpvW0s9jmXXsHonVl",
    "cifar100": "1PGKheHUsf29leJPPGuXqzLBMwl8qMF8_",
    "cifar10": "1Co32RiiWe16lTaiOU6JMMnyUYS41IlO1",
    "cifar10c": "170DU_ficWWmbh6O2wqELxK9jxRiGhlJH",
    "cinic10": "190gdcfbvSGbrRK6ZVlJgg5BqqED6H_nn",
    "svhn": "1DQfc11HOtB1nEwqS4pWUFp8vtQ3DczvI",
    "fashionmnist": "1nVObxjUBmVpZ6M0PPlcspsMMYHidUMfa",
    "cifar100c": "1MnETiQh9RTxJin2EHeSoIAJA28FRonHx",
    "mnist": "1CCHAGWqA1KJTFFswuF9cbhmB-j98Y1Sb",
    "fractals_and_fvis": "1EZP8RGOP-XbMsKex3r-BGI5F1WAP_PJ3",
    "tin": "1PZ-ixyx52U989IKsMA2OT-24fToTrelC",
    "tin597": "1R0d8zBcUxWNXz6CPXanobniiIfQbpKzn",
    "texture": "1OSz1m3hHfVWbRdmMwKbUzoU8Hg9UKcam",
    "imagenet10": "1qRKp-HCLkmfiWwR-PXthN7-2dxIQVKxP",
    "notmnist": "16ueghlyzunbksnc_ccPgEAloRW9pKO-K",
    "places365": "1Ec-LRSTf6u5vEctKX9vRp9OA6tqnJ0Ay",
    "places": "1fZ8TbPC4JGqUCm-VtvrmkYxqRNp2PoB3",
    "sun": "1ISK0STxWzWmg-_uUr4RQ8GSLFW7TZiKp",
    "species_sub": "1-JCxDx__iFMExkYRMylnGJYTPvyuX6aq",
    "imagenet_1k": "1i1ipLDFARR-JZ9argXd2-0a6DXwVhXEj",
    "ssb_hard": "1PzkA-WGG8Z18h0ooL_pDdz9cO-DCIouE",
    "ninco": "1Z82cmvIB0eghTehxOGP5VTdLt7OD3nk6",
    "imagenet_v2": "1akg2IiE22HcbvTBpwXQoD7tgfPCdkoho",
    "imagenet_r": "1EzjMN2gq-bVV7lg-MEAdeuBuz-7jbGYU",
    "imagenet_c": "1JeXL9YH4BO8gCJ631c5BHbaSsl-lekHt",
    "imagenet_o": "1S9cFV7fGvJCcka220-pIO9JPZL1p1V8w",
    "openimage_o": "1VUFXnB_z70uHfdgJG2E_pjYOcEgqM7tE",
    "inaturalist": "1zfLfMvoUD0CUlKNnkk7LgxZZBnTBipdj",
    "actmed": "1tibxL_wt6b3BjliPaQ2qjH54Wo4ZXWYb",
    "ct": "1k5OYN4inaGgivJBJ5L8pHlopQSVnhQ36",
    "hannover": "1NmqBDlcA1dZQKOvgcILG0U1Tm6RP0s2N",
    "xraybone": "1ZzO3y1-V_IeksJXEvEfBYKRoQLLvPYe9",
    "bimcv": "1nAA45V6e0s5FAq2BJsj9QH5omoihb7MZ",
}


def require_download(filename, path):
    for item in os.listdir(path):
        if (
            item.startswith(filename)
            or filename.startswith(item)
            or path.endswith(filename)
        ):
            return False

    else:
        print(filename + " needs download:")
        return True


def download_dataset(dataset, args):
    for key in dir_dict.keys():
        if dataset in dir_dict[key]:
            store_path = os.path.join(args.save_dir[0], key, dataset)
            if not os.path.exists(store_path):
                os.makedirs(store_path)
            break
    else:
        print("Invalid dataset detected {}".format(dataset))
        return

    if require_download(dataset, store_path):
        print(store_path)
        if not store_path.endswith("/"):
            store_path = store_path + "/"
        gdown.download(id=download_id_dict[dataset], output=store_path)

        file_path = os.path.join(store_path, dataset + ".zip")
        with zipfile.ZipFile(file_path, "r") as zip_file:
            zip_file.extractall(store_path)
        os.remove(file_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Download datasets and checkpoints")
    parser.add_argument(
        "--contents", nargs="+", default=["checkpoints"]
    )  # ['datasets', 'checkpoints'])
    parser.add_argument("--datasets", nargs="+", default=["default"])
    parser.add_argument("--checkpoints", nargs="+", default=["ood_v1.5"])
    parser.add_argument("--save_dir", nargs="+", default=["./data", "./results"])
    parser.add_argument("--dataset_mode", default="default")
    args = parser.parse_args()

    if args.datasets[0] == "default":
        args.datasets = ["mnist", "cifar-10", "cifar-100"]
    elif args.datasets[0] == "ood_v1.5":
        args.datasets = ["cifar-10", "cifar-100", "imagenet-200", "imagenet-1k"]
    elif args.datasets[0] == "all":
        args.datasets = list(benchmarks_dict.keys())

    if args.checkpoints[0] == "ood":
        args.checkpoints = [
            "mnist_lenet",
            "cifar10_res18",
            "cifar100_res18",
            "imagenet_res50",
        ]
    elif args.checkpoints[0] == "ood_v1.5":
        args.checkpoints = [
            "cifar10_res18_v1.5",
            "cifar100_res18_v1.5",
            "imagenet200_res18_v1.5",
            "imagenet_res50_v1.5",
        ]
    elif args.checkpoints[0] == "all":
        args.checkpoints = [
            "mnist_lenet",
            "cifar10_res18",
            "cifar100_res18",
            "imagenet_res50",
            "osr",
        ]

    for content in args.contents:
        if content == "datasets":
            store_path = args.save_dir[0]
            if not store_path.endswith("/"):
                store_path = store_path + "/"
            if not os.path.exists(os.path.join(store_path, "benchmark_imglist")):
                gdown.download(
                    id=download_id_dict["benchmark_imglist"], output=store_path
                )
                file_path = os.path.join(args.save_dir[0], "benchmark_imglist.zip")
                with zipfile.ZipFile(file_path, "r") as zip_file:
                    zip_file.extractall(store_path)
                os.remove(file_path)

            if args.dataset_mode == "default" or args.dataset_mode == "benchmark":
                for benchmark in args.datasets:
                    for dataset in benchmarks_dict[benchmark]:
                        download_dataset(dataset, args)

            if args.dataset_mode == "dataset":
                for dataset in args.datasets:
                    download_dataset(dataset, args)

        elif content == "checkpoints":
            if "v1.5" in args.checkpoints[0]:
                store_path = args.save_dir[1]
            else:
                store_path = os.path.join(args.save_dir[1], "checkpoints/")
            if not os.path.exists(store_path):
                os.makedirs(store_path)

            if not store_path.endswith("/"):
                store_path = store_path + "/"

            for checkpoint in args.checkpoints:
                if require_download(checkpoint, store_path):
                    gdown.download(id=download_id_dict[checkpoint], output=store_path)
                    file_path = os.path.join(store_path, checkpoint + ".zip")
                    with zipfile.ZipFile(file_path, "r") as zip_file:
                        zip_file.extractall(store_path)
                    os.remove(file_path)
