# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import logging
from pathlib import Path
import shutil
import tempfile

from torchvision.datasets import MNIST

TEMPDIR = tempfile.gettempdir()


def setup_cached_mnist():
    done, tentatives = False, 0
    while not done and tentatives < 5:
        # Monkey patch the resource URLs to work around a possible blacklist
        MNIST.mirrors = ["https://github.com/blefaudeux/mnist_dataset/raw/main/"] + MNIST.mirrors

        # This will automatically skip the download if the dataset is already there, and check the checksum
        try:
            _ = MNIST(transform=None, download=True, root=TEMPDIR)
            done = True
        except RuntimeError as e:
            logging.warning(e)
            mnist_root = Path(TEMPDIR + "/MNIST")
            # Corrupted data, erase and restart
            shutil.rmtree(str(mnist_root))

        tentatives += 1

    if done is False:
        logging.error("Could not download MNIST dataset")
        exit(-1)
    else:
        logging.info("Dataset downloaded")
