import os
import glob
import hashlib
import torch
import torchvision
from torchvision.datasets.utils import download_and_extract_archive
from PIL import Image, ImageFile

def get_dataset(args):
    if args.dataset == 'MNIST':
        train_loader, test_loader, inv_transform = get_mnist(args.B)
        img_size = [1,28,28]
        num_classes = 10
    if args.dataset == 'CIFAR10':
        train_loader, test_loader, inv_transform = get_cifar10(args.B)
        img_size = [3,32,32]
        num_classes = 10
    if args.dataset == 'TinyImgNet':
        train_loader, test_loader, inv_transform = get_tinyImgNet(args.B)
        img_size = [3,64,64]
        num_classes = 200
    return train_loader, test_loader, img_size, num_classes, inv_transform

def get_mnist(B):
    train_loader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST('./data', train=True, download=True,
                                 transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                     (0.1307,), (0.3081,))
                                 ])),
        batch_size=B, shuffle=True)

    test_loader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST('./data', train=False, download=True,
                                 transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                     (0.1307,), (0.3081,))
                                 ])),
        batch_size=B, shuffle=False)


    inv_transform = torchvision.transforms.Compose([ 
        torchvision.transforms.Normalize(mean = [ 0., ], std = [ 1/0.3081, ]), 
        torchvision.transforms.Normalize(mean = [ -0.1307, ], std = [ 1.0, ] ),
        torchvision.transforms.Normalize(mean = [ 0., ], std = [ 1/255.0, ]),
    ])

    return train_loader, test_loader, inv_transform



def get_tinyImgNet(B):
    train_loader = torch.utils.data.DataLoader(
            TinyImageNet(root='./data', split='train', download=True,
                                    transform=torchvision.transforms.Compose([
                                        torchvision.transforms.ToTensor(),
                                        torchvision.transforms.Normalize(
                                            (0.4789886474609375, 0.4457630515098572, 0.3944724500179291), (0.27698642015457153, 0.2690644860267639, 0.2820819020271301))
                                        ])),
        batch_size=B, shuffle=True)

    test_loader = torch.utils.data.DataLoader(
            TinyImageNet(root='./data', split='val', download=True,
                                    transform=torchvision.transforms.Compose([
                                        torchvision.transforms.ToTensor(),
                                        torchvision.transforms.Normalize(
                                            (0.4789886474609375, 0.4457630515098572, 0.3944724500179291), (0.27698642015457153, 0.2690644860267639, 0.2820819020271301))
                                        ])),
        batch_size=B, shuffle=False)


    inv_transform = torchvision.transforms.Compose([ 
        torchvision.transforms.Normalize(mean = [ 0., 0., 0. ], std = [ 1/0.27698642015457153, 1/0.2690644860267639, 1/0.2820819020271301 ]), 
        torchvision.transforms.Normalize(mean = [ -0.4789886474609375, -0.4457630515098572, -0.3944724500179291 ], std = [ 1.0, 1.0, 1.0 ] ),
        torchvision.transforms.Normalize(mean = [ 0., 0., 0. ], std = [ 1/255.0, 1/255.0, 1/255.0 ]),
    ])

    return train_loader, test_loader, inv_transform

def get_cifar10(B):
    train_loader = torch.utils.data.DataLoader(
        torchvision.datasets.CIFAR10('./data', train=True, download=True,
                                 transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                     (0.4914672374725342,0.4822617471218109, 0.4467701315879822), (0.24703224003314972,0.24348513782024384,0.26158785820007324))
                                 ])),
        batch_size=B, shuffle=True)

    test_loader = torch.utils.data.DataLoader(
        torchvision.datasets.CIFAR10('./data', train=False, download=True,
                                 transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                     (0.4914672374725342,0.4822617471218109, 0.4467701315879822), (0.24703224003314972,0.24348513782024384,0.26158785820007324))
                                 ])),
        batch_size=B, shuffle=False)

    inv_transform = torchvision.transforms.Compose([ 
        torchvision.transforms.Normalize(mean = [ 0., 0., 0. ], std = [ 1/0.24703224003314972, 1/0.24348513782024384, 1/0.26158785820007324 ]), 
        torchvision.transforms.Normalize(mean = [ -0.4914672374725342, -0.4822617471218109, -0.4467701315879822 ], std = [ 1.0, 1.0, 1.0 ] ),
        torchvision.transforms.Normalize(mean = [ 0., 0., 0. ], std = [ 1/255.0, 1/255.0, 1/255.0 ]),
    ])

    return train_loader, test_loader, inv_transform

#Taken from https://github.com/JonasGeiping/breaching/blob/main/breaching/cases/data/datasets_vision.py
class TinyImageNet(torch.utils.data.Dataset):
    """Tiny ImageNet data set available from `http://cs231n.stanford.edu/tiny-imagenet-200.zip`.

    This is a TinyImageNet variant to the code of Meng Lee, mnicnc404 / Date: 2018/06/04
    References:
        - https://stanford.edu/~shervine/blog/pytorch-how-to-generate-data-parallel.html
    Parameters
    ----------
    root: string
        Root directory including `train`, `test` and `val` subdirectories.
    split: string
        Indicating which split to return as a data set.
        Valid option: [`train`, `test`, `val`]
    transform: torchvision.transforms
        A (series) of valid transformation(s).
    cached: bool
        Set to True if there is enough memory (about 5G) and want to minimize disk IO overhead.
    download: bool
        Set to true to automatically download the dataset in to the root folder.
    """

    EXTENSION = "JPEG"
    NUM_IMAGES_PER_CLASS = 500
    CLASS_LIST_FILE = "wnids.txt"
    VAL_ANNOTATION_FILE = "val_annotations.txt"
    CLASSES = "words.txt"

    url = "http://cs231n.stanford.edu/tiny-imagenet-200.zip"
    archive = "tiny-imagenet-200.zip"
    folder = "tiny-imagenet-200"
    train_md5 = "2fe62ea2ae3f40a2722ae2027690bf55"
    val_md5 = "224a3620e1e130875249eab1f72707fc"
    test_md5 = "46af9c5f5b834252831e15951e7ab11f"

    def __init__(self, root, split="train", transform=None, target_transform=None, cached=True, download=True):
        """Init with split, transform, target_transform."""
        self.root = os.path.expanduser(root)
        self.split = split
        self.transform = transform
        self.target_transform = target_transform
        self.cached = cached

        self.split_dir = os.path.join(root, self.folder, self.split)
        self.image_paths = sorted(
            glob.iglob(os.path.join(self.split_dir, "**", "*.%s" % self.EXTENSION), recursive=True)
        )
        self.labels = {}  # fname - label number mapping

        if download:
            self.download()

        self._parse_labels()

        if self.cached:
            self._build_cache()

    def _check_integrity(self):
        """This only checks if all files are there."""
        string_rep = "".join(self.image_paths).encode("utf-8")
        hash = hashlib.md5(string_rep)
        if self.split == "train":
            return hash.hexdigest() == self.train_md5
        elif self.split == "val":
            return hash.hexdigest() == self.val_md5
        else:
            return hash.hexdigest() == self.test_md5

    def download(self):
        if self._check_integrity():
            print("Files already downloaded and verified")
            return
        download_and_extract_archive(self.url, self.root, filename=self.archive)

    def _parse_labels(self):
        with open(os.path.join(self.root, self.folder, self.CLASS_LIST_FILE), "r") as fp:
            self.label_texts = sorted([text.strip() for text in fp.readlines()])
        self.label_text_to_number = {text: i for i, text in enumerate(self.label_texts)}

        if self.split == "train":
            for label_text, i in self.label_text_to_number.items():
                for cnt in range(self.NUM_IMAGES_PER_CLASS):
                    self.labels["%s_%d.%s" % (label_text, cnt, self.EXTENSION)] = i
        elif self.split == "val":
            with open(os.path.join(self.split_dir, self.VAL_ANNOTATION_FILE), "r") as fp:
                for line in fp.readlines():
                    terms = line.split("\t")
                    file_name, label_text = terms[0], terms[1]
                    self.labels[file_name] = self.label_text_to_number[label_text]

        # Build class names
        label_text_to_word = dict()
        with open(os.path.join(self.root, self.folder, self.CLASSES), "r") as file:
            for line in file:
                label_text, word = line.split("\t")
                label_text_to_word[label_text] = word.split(",")[0].rstrip("\n")
        self.classes = [label_text_to_word[label] for label in self.label_texts]

        # Prepare index - label mapping
        self.targets = [self.labels[os.path.basename(file_path)] for file_path in self.image_paths]

    def _build_cache(self):
        """Cache images in RAM."""
        self.cache = []
        for index in range(len(self)):
            img = Image.open(self.image_paths[index])
            img = img.convert("RGB")
            self.cache.append(img)

    def __len__(self):
        """Return length via image paths."""
        return len(self.image_paths)

    def __getitem__(self, index):
        """Return image, label."""
        if self.cached:
            img = self.cache[index]
        else:
            img = Image.open(self.image_paths[index])
            img = img.convert("RGB")
        target = self.targets[index]

        img = self.transform(img) if self.transform else img
        target = self.target_transform(target) if self.target_transform else target
        if self.split == "test":
            return img, None
        else:
            return img, target

