"""
MNIST
_____________________________________________________
"""
import codecs
import os
import os.path
import shutil
import string
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple
from urllib.error import URLError

import numpy as np
import torch
import torchvision
from PIL import Image
from torchvision.datasets.utils import (
    check_integrity,
    download_and_extract_archive,
    extract_archive,
    verify_str_arg,
)
from torchvision.datasets.vision import VisionDataset

from torch.utils.data import Dataset

import os
from typing import IO, Any, Callable, Dict, List, Optional, Tuple, Union
from urllib.error import URLError

import torch
from torchvision.datasets.mnist import MNIST, VisionDataset
from torchvision.datasets.utils import download_url, extract_archive, verify_str_arg
from torchvision.transforms import Compose, Normalize, ToTensor, Resize

from src.utils.utils import load_gc_torch_files

# Cell

MNIST_NORMALIZATION = Normalize((0.1307,), (0.3081,))

# based on torchvision.datasets.mnist.py (https://github.com/pytorch/vision/blob/37eb37a836fbc2c26197dfaf76d2a3f4f39f15df/torchvision/datasets/mnist.py)
class indices_AmbiguousMNIST(VisionDataset):
    """
    Ambiguous-MNIST Dataset
    Please cite:
        @article{mukhoti2021deterministic,
          title={Deterministic Neural Networks with Appropriate Inductive Biases Capture Epistemic and Aleatoric Uncertainty},
          author={Mukhoti, Jishnu and Kirsch, Andreas and van Amersfoort, Joost and Torr, Philip HS and Gal, Yarin},
          journal={arXiv preprint arXiv:2102.11582},
          year={2021}
        }
    Args:
        root (string): Root directory of dataset where ``MNIST/processed/training.pt``
            and  ``MNIST/processed/test.pt`` exist.
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        normalize (bool, optional): Normalize the samples.
        device: Device to use (pass `num_workers=0, pin_memory=False` to the DataLoader for max throughput)
    """

    mirrors = ["http://github.com/BlackHC/ddu_dirty_mnist/releases/download/data-v1.0.0/"]

    resources = dict(
        data=("amnist_samples.pt", "4f7865093b1d28e34019847fab917722"),
        targets=("amnist_labels.pt", "3bfc055a9f91a76d8d493e8b898c3c95"),
    )

    def __init__(
        self,
        root: str,
        *,
        train: bool = True,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = True,
        normalize: bool = True,
        noise_stddev=0.05,
        device=None,
    ):
        super().__init__(root, transform=transform, target_transform=target_transform)

        self.train = train  # training set or test set

        if download:
            self.download()

        self.data = torch.load(self.resource_path("data"), map_location=device)
        if normalize:
            self.data = self.data.sub_(0.1307).div_(0.3081)

        self.targets = torch.load(self.resource_path("targets"), map_location=device)

        # Each sample has `num_multi_labels` many labels.
        num_multi_labels = self.targets.shape[1]

        # Flatten the multi-label dataset into a single-label dataset with samples repeated x `num_multi_labels` many times
        self.data = self.data.expand(-1, num_multi_labels, 28, 28).reshape(-1, 1, 28, 28)
        self.targets = self.targets.reshape(-1)

        data_range = slice(None, 60000) if self.train else slice(60000, None)
        self.data = self.data[data_range]

        if noise_stddev > 0.0:
            self.data += torch.randn_like(self.data) * noise_stddev

        self.targets = self.targets[data_range]
        self.sequence = np.arange(len(self.data))

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.targets[index]
        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return index, img, target


    def __len__(self) -> int:
        return len(self.sequence)

    @property
    def data_folder(self) -> str:
        return os.path.join(self.root, self.__class__.__name__)

    def resource_path(self, name):
        return os.path.join(self.data_folder, self.resources[name][0])

    def _check_exists(self) -> bool:
        return all(os.path.exists(self.resource_path(name)) for name in self.resources)

    def download(self) -> None:
        """Download the data if it doesn't exist in data_folder already."""

        if self._check_exists():
            return

        os.makedirs(self.data_folder, exist_ok=True)

        # download files
        for filename, md5 in self.resources.values():
            for mirror in self.mirrors:
                url = "{}{}".format(mirror, filename)
                try:
                    print("Downloading {}".format(url))
                    download_url(url, root=self.data_folder, filename=filename, md5=md5)
                except URLError as error:
                    print("Failed to download (trying next):\n{}".format(error))
                    continue
                except:
                    raise
                finally:
                    print()
                break
            else:
                raise RuntimeError("Error downloading {}".format(filename))

        print("Done!")

class indices_TMNIST(VisionDataset):
    """`tMNIST <https://github.com/mcaandewiel/tMNIST-PyTorch>`_ Dataset."""

    @property
    def train_labels(self):
        warnings.warn("train_labels has been renamed targets")
        return self.targets

    @property
    def test_labels(self):
        warnings.warn("test_labels has been renamed targets")
        return self.targets

    @property
    def train_data(self):
        warnings.warn("train_data has been renamed data")
        return self.data

    @property
    def test_data(self):
        warnings.warn("test_data has been renamed data")
        return self.data

    def __init__(
        self,
        root: str,
        train: bool = True,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
    ) -> None:
        super(indices_TMNIST, self).__init__(
            root, transform=transform, target_transform=target_transform
        )
        self.train = train  # training set or test set

        self.root = root
        self.data = self._load_data()
        self.sequence = np.arange(len(self.data))

    def _load_data(self):
        data = np.load(self.root, allow_pickle=True).item()
        self.images = data["images"]
        self.labels = data["labels"]

        return data

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.images[self.sequence[index]].reshape(40, 40, -1), self.labels[self.sequence[index]]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return index, img, target

    def __len__(self) -> int:
        return len(self.labels)

class indices_MNIST(VisionDataset):
    """The standard PyTorch class, modified so that each datapoint is returned
    as (index, sample, target) instead of (sample, target). Index is the global
    index of each datapoint w.r.t. the dataset. What follows is the docstring of
    the standard PyTorch class.

    `MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.

    Args:
        root (string): Root directory of dataset where ``MNIST/processed/training.pt``
            and  ``MNIST/processed/test.pt`` exist.
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """

    mirrors = [
        "http://yann.lecun.com/exdb/mnist/",
        "https://ossci-datasets.s3.amazonaws.com/mnist/",
    ]

    resources = [
        ("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
        ("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
        ("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
        ("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c"),
    ]

    training_file = "training.pt"
    test_file = "test.pt"
    classes = [
        "0 - zero",
        "1 - one",
        "2 - two",
        "3 - three",
        "4 - four",
        "5 - five",
        "6 - six",
        "7 - seven",
        "8 - eight",
        "9 - nine",
    ]

    @property
    def train_labels(self):
        warnings.warn("train_labels has been renamed targets")
        return self.targets

    @property
    def test_labels(self):
        warnings.warn("test_labels has been renamed targets")
        return self.targets

    @property
    def train_data(self):
        warnings.warn("train_data has been renamed data")
        return self.data

    @property
    def test_data(self):
        warnings.warn("test_data has been renamed data")
        return self.data

    def __init__(
        self,
        root: str,
        train: bool = True,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
    ) -> None:
        super(indices_MNIST, self).__init__(
            root, transform=transform, target_transform=target_transform
        )
        self.train = train  # training set or test set

        if self._check_legacy_exist():
            self.data, self.targets = self._load_legacy_data()
            return

        self.download()

        if not self._check_exists():
            raise RuntimeError("Dataset not found even though we tried downloading")

        self.data, self.targets = self._load_data()
        self.sequence = np.arange(len(self.data))

    def _check_legacy_exist(self):
        processed_folder_exists = os.path.exists(self.processed_folder)
        if not processed_folder_exists:
            return False

        return all(
            check_integrity(os.path.join(self.processed_folder, file))
            for file in (self.training_file, self.test_file)
        )

    def _load_legacy_data(self):
        # This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data
        # directly.
        data_file = self.training_file if self.train else self.test_file
        return torch.load(os.path.join(self.processed_folder, data_file))

    def _load_data(self):
        image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"
        data = read_image_file(os.path.join(self.raw_folder, image_file))

        label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte"
        targets = read_label_file(os.path.join(self.raw_folder, label_file))

        return data, targets

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[self.sequence[index]], int(self.targets[self.sequence[index]])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img.numpy(), mode="L")

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return index, img, target

    def __len__(self) -> int:
        return len(self.sequence)

    @property
    def raw_folder(self) -> str:
        return os.path.join(self.root, self.__class__.__name__, "raw")

    @property
    def processed_folder(self) -> str:
        return os.path.join(self.root, self.__class__.__name__, "processed")

    @property
    def class_to_idx(self) -> Dict[str, int]:
        return {_class: i for i, _class in enumerate(self.classes)}

    def _check_exists(self) -> bool:
        return all(
            check_integrity(
                os.path.join(
                    self.raw_folder, os.path.splitext(os.path.basename(url))[0]
                )
            )
            for url, _ in self.resources
        )

    def download(self) -> None:
        """Download the MNIST data if it doesn't exist already."""

        if self._check_exists():
            return

        os.makedirs(self.raw_folder, exist_ok=True)

        for filename, md5 in self.resources:
            for mirror in self.mirrors:
                url = "{}{}".format(mirror, filename)
                try:
                    print("Downloading {}".format(url))
                    download_and_extract_archive(
                        url, download_root=self.raw_folder, filename=filename, md5=md5
                    )
                except URLError as error:
                    print("Failed to download (trying next):\n{}".format(error))
                    continue
                finally:
                    print()
                break
            else:
                raise RuntimeError("Error downloading {}".format(filename))

    def extra_repr(self) -> str:
        return "Split: {}".format("Train" if self.train is True else "Test")

class indices_QMNIST(torchvision.datasets.MNIST):
    """`QMNIST <https://github.com/facebookresearch/qmnist>`_ Dataset.

    Args:
        root (string): Root directory of dataset whose ``processed``
            subdir contains torch binary files with the datasets.
        what (string,optional): Can be 'train', 'test', 'test10k',
            'test50k', or 'nist' for respectively the mnist compatible
            training set, the 60k qmnist testing set, the 10k qmnist
            examples that match the mnist testing set, the 50k
            remaining qmnist testing examples, or all the nist
            digits. The default is to select 'train' or 'test'
            according to the compatibility argument 'train'.
        compat (bool,optional): A boolean that says whether the target
            for each example is class number (for compatibility with
            the MNIST dataloader) or a torch vector containing the
            full qmnist information. Default=True.
        download (bool, optional): If true, downloads the dataset from
            the internet and puts it in root directory. If dataset is
            already downloaded, it is not downloaded again.
        transform (callable, optional): A function/transform that
            takes in an PIL image and returns a transformed
            version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform
            that takes in the target and transforms it.
        train (bool,optional,compatibility): When argument 'what' is
            not specified, this boolean decides whether to load the
            training set ot the testing set.  Default: True.
    """

    subsets = {
        "train": "train",
        "test": "test",
        "test10k": "test",
        "test50k": "test",
        "nist": "nist",
    }
    resources: Dict[str, List[Tuple[str, str]]] = {  # type: ignore[assignment]
        "train": [
            (
                "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz",
                "ed72d4157d28c017586c42bc6afe6370",
            ),
            (
                "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz",
                "0058f8dd561b90ffdd0f734c6a30e5e4",
            ),
        ],
        "test": [
            (
                "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-images-idx3-ubyte.gz",
                "1394631089c404de565df7b7aeaf9412",
            ),
            (
                "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-labels-idx2-int.gz",
                "5b5b05890a5e13444e108efe57b788aa",
            ),
        ],
        "nist": [
            (
                "https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-images-idx3-ubyte.xz",
                "7f124b3b8ab81486c9d8c2749c17f834",
            ),
            (
                "https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-labels-idx2-int.xz",
                "5ed0e788978e45d4a8bd4b7caec3d79d",
            ),
        ],
    }
    classes = [
        "0 - zero",
        "1 - one",
        "2 - two",
        "3 - three",
        "4 - four",
        "5 - five",
        "6 - six",
        "7 - seven",
        "8 - eight",
        "9 - nine",
    ]

    def __init__(
        self,
        root: str,
        what: Optional[str] = None,
        compat: bool = True,
        train: bool = True,
        **kwargs: Any,
    ) -> None:
        if what is None:
            what = "train" if train else "test"
        self.what = verify_str_arg(what, "what", tuple(self.subsets.keys()))
        self.compat = compat
        self.data_file = what + ".pt"
        self.training_file = self.data_file
        self.test_file = self.data_file
        super(indices_QMNIST, self).__init__(root, train, **kwargs)
        self.sequence = np.arange(len(self.data))

    @property
    def images_file(self) -> str:
        (url, _), _ = self.resources[self.subsets[self.what]]
        return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])

    @property
    def labels_file(self) -> str:
        _, (url, _) = self.resources[self.subsets[self.what]]
        return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])

    def _check_exists(self) -> bool:
        return all(
            check_integrity(file) for file in (self.images_file, self.labels_file)
        )

    def _load_data(self):
        data = read_sn3_pascalvincent_tensor(self.images_file)
        assert data.dtype == torch.uint8
        assert data.ndimension() == 3

        targets = read_sn3_pascalvincent_tensor(self.labels_file).long()
        assert targets.ndimension() == 2

        if self.what == "test10k":
            data = data[0:10000, :, :].clone()
            targets = targets[0:10000, :].clone()
        elif self.what == "test50k":
            data = data[10000:, :, :].clone()
            targets = targets[10000:, :].clone()

        return data, targets

    def download(self) -> None:
        """Download the QMNIST data if it doesn't exist already.
        Note that we only download what has been asked for (argument 'what').
        """
        if self._check_exists():
            return

        os.makedirs(self.raw_folder, exist_ok=True)
        split = self.resources[self.subsets[self.what]]

        for url, md5 in split:
            filename = url.rpartition("/")[2]
            file_path = os.path.join(self.raw_folder, filename)
            if not os.path.isfile(file_path):
                download_and_extract_archive(
                    url, self.raw_folder, filename=filename, md5=md5
                )

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        # redefined to handle the compat flag
        img, target = self.data[self.sequence[index]], self.targets[self.sequence[index]]
        img = Image.fromarray(img.numpy(), mode="L")
        if self.transform is not None:
            img = self.transform(img)
        if self.compat:
            target = int(target[0])
        if self.target_transform is not None:
            target = self.target_transform(target)
        return index, img, target

    def extra_repr(self) -> str:
        return "Split: {}".format(self.what)

class indices_infiMNIST(VisionDataset):
    def __init__(
        self,
        root: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
    ) -> None:
        super(indices_infiMNIST, self).__init__(
            root, transform=transform, target_transform=target_transform
        )
        self._data_file = os.path.join(root, "infimnist.npz")
        self.data, self.targets = self._load_data()        
        
        self.n_classes = 10
        self.img_shape = (28, 28)
        
        infimnist = np.arange(100000, 8100000)
        self.data = self.data[infimnist]
        self.targets = self.targets[infimnist]
        
        self.sequence = np.arange(len(self.data))

    def _load_data(self):
        with open(self._data_file, 'rb') as f:
            dic = np.load(f)
            data = dic['x']
            targets = dic['y']
        return data, targets
    
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], int(self.targets[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img, mode="L")

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return index, img, target

    def __len__(self) -> int:
        return len(self.sequence)


def get_int(b: bytes) -> int:
    return int(codecs.encode(b, "hex"), 16)

SN3_PASCALVINCENT_TYPEMAP = {
    8: (torch.uint8, np.uint8, np.uint8),
    9: (torch.int8, np.int8, np.int8),
    11: (torch.int16, np.dtype(">i2"), "i2"),
    12: (torch.int32, np.dtype(">i4"), "i4"),
    13: (torch.float32, np.dtype(">f4"), "f4"),
    14: (torch.float64, np.dtype(">f8"), "f8"),
}

def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tensor:
    """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
    Argument may be a filename, compressed filename, or file object.
    """
    # read
    with open(path, "rb") as f:
        data = f.read()
    # parse
    magic = get_int(data[0:4])
    nd = magic % 256
    ty = magic // 256
    assert 1 <= nd <= 3
    assert 8 <= ty <= 14
    m = SN3_PASCALVINCENT_TYPEMAP[ty]
    s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)]
    parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1)))
    assert parsed.shape[0] == np.prod(s) or not strict
    return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)

def read_label_file(path: str) -> torch.Tensor:
    x = read_sn3_pascalvincent_tensor(path, strict=False)
    assert x.dtype == torch.uint8
    assert x.ndimension() == 1
    return x.long()

def read_image_file(path: str) -> torch.Tensor:
    x = read_sn3_pascalvincent_tensor(path, strict=False)
    assert x.dtype == torch.uint8
    assert x.ndimension() == 3
    return x

"""
CIFAR10
_____________________________________________________

"""

import os
import os.path
import pickle
from typing import Any, Callable, Optional, Tuple

import numpy as np
import pandas as pd
from PIL import Image
from torchvision import transforms
from torchvision.datasets.utils import check_integrity, download_and_extract_archive
from torchvision.datasets.vision import VisionDataset
from torchvision.io import read_image

import ssl

#ensure CIFAR100 can be loaded
ssl._create_default_https_context = ssl._create_unverified_context


class indices_CIFAR10(VisionDataset):
    """The standard PyTorch class, modified so that each datapoint is returned
    as (index, sample, target) instead of (sample, target). Index is the global
    index of each datapoint w.r.t. the dataset. What follows is the docstring of
    the standard PyTorch class.

    `CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
    Args:
        root (string): Root directory of dataset where directory
            ``cifar-10-batches-py`` exists or will be downloaded to.
        train (bool, optional): If True, creates dataset from training set, otherwise
            creates from test set.
        transform (callable, optional): A function/transform that takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """

    base_folder = "cifar-10-batches-py"
    url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
    filename = "cifar-10-python.tar.gz"
    tgz_md5 = "c58f30108f718f92721af3b95e74349a"
    train_list = [
        ["data_batch_1", "c99cafc152244af753f735de768cd75f"],
        ["data_batch_2", "d4bba439e000b95fd0a9bffe97cbabec"],
        ["data_batch_3", "54ebc095f3ab1f0389bbae665268c751"],
        ["data_batch_4", "634d18415352ddfa80567beed471001a"],
        ["data_batch_5", "482c414d41f54cd18b22e5b47cb7c3cb"],
    ]

    test_list = [
        ["test_batch", "40351d587109b95175f43aff81a1287e"],
    ]
    meta = {
        "filename": "batches.meta",
        "key": "label_names",
        "md5": "5ff9c542aee3614f3951f8cda6e48888",
    }

    def __init__(
        self,
        root: str,
        train: bool = True,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        sequence=None,
    ) -> None:

        super(indices_CIFAR10, self).__init__(
            root, transform=transform, target_transform=target_transform
        )

        self.train = train  # training set or test set

        self.download()

        if not self._check_integrity():
            raise RuntimeError(
                "Dataset not found or corrupted even though we tried to download it"
            )

        if self.train:
            downloaded_list = self.train_list
        else:
            downloaded_list = self.test_list

        self.data: Any = []
        self.targets = []

        # now load the picked numpy arrays
        for file_name, checksum in downloaded_list:
            file_path = os.path.join(self.root, self.base_folder, file_name)
            with open(file_path, "rb") as f:
                entry = pickle.load(f, encoding="latin1")
                self.data.append(entry["data"])
                if "labels" in entry:
                    self.targets.extend(entry["labels"])
                else:
                    self.targets.extend(entry["fine_labels"])

        self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
        self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC

        self._load_meta()
        if sequence is not None:
            self.sequence = sequence
        else:
            self.sequence = np.arange(len(self.data))

    def _load_meta(self) -> None:
        path = os.path.join(self.root, self.base_folder, self.meta["filename"])
        if not check_integrity(path, self.meta["md5"]):
            raise RuntimeError("Dataset metadata file not found or corrupted")
        with open(path, "rb") as infile:
            data = pickle.load(infile, encoding="latin1")
            self.classes = data[self.meta["key"]]
        self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[self.sequence[index]], self.targets[self.sequence[index]]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return index, img, target, target #second target is category label for robust setting

    def __len__(self) -> int:
        return len(self.sequence)

    def _check_integrity(self) -> bool:
        root = self.root
        for fentry in self.train_list + self.test_list:
            filename, md5 = fentry[0], fentry[1]
            fpath = os.path.join(root, self.base_folder, filename)
            if not check_integrity(fpath, md5):
                return False
        return True

    def download(self) -> None:
        if self._check_integrity():
            print("Files already downloaded and verified")
            return
        download_and_extract_archive(
            self.url, self.root, filename=self.filename, md5=self.tgz_md5
        )

    def extra_repr(self) -> str:
        return "Split: {}".format("Train" if self.train is True else "Test")


class indices_CIFAR2(indices_CIFAR10):
    """
    CIFAR2 Is a subset of the CIFAR10 dataset with only two classes. This is 
    designed to be a simpler setting than CIFAR10 where the model results can
    be observed.
    
    """
    
    
    def __init__(self,
            root: str,
            train: bool = True,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            sequence=None,
            classes=None) -> None:
                
        #Init CIFAR10 super class:
        super().__init__(root,
                    train=train,
                    transform=transform,
                    target_transform=target_transform,
                    sequence=sequence)
        
        #select subset of classes:
        if classes is None:    
            subset_classes = [1,9]
        else:
            subset_classes = classes
            
        class_to_idx_old = self.class_to_idx
        
        #Convert to np array for bool indexing 
        self.targets = np.array(self.targets)
        
        #Select subset of data and adjust meta data
        self.classes = [_class for i, _class in enumerate(self.classes) if i in subset_classes]
        self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
        
        self.data = self.data[np.isin(self.targets, subset_classes)]
        self.targets = self.targets[np.isin(self.targets, subset_classes)]

        for _class in self.classes:
            old_target_value = class_to_idx_old[_class]
            new_target_value = self.class_to_idx[_class]
                        
            _filter = self.targets == old_target_value
            self.targets[_filter] = new_target_value
        
        
        #Redefine the sequence variable
        if sequence is not None:
            self.sequence = sequence
        else:
            self.sequence = np.arange(len(self.data))
                    

class indices_CIFAR100(indices_CIFAR10):
    """`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.

    This is a subclass of the `CIFAR10` Dataset.
    
    """

    base_folder = "cifar-100-python"
    url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
    filename = "cifar-100-python.tar.gz"
    tgz_md5 = "eb9058c3a382ffc7106e4002c42a8d85"
    train_list = [
        ["train", "16019d7e3df5f24257cddd939b257f8d"],
    ]

    test_list = [
        ["test", "f0ef6b0ae62326f3e7ffdfab6717acfc"],
    ]
    meta = {
        "filename": "meta",
        "key": "fine_label_names",
        "md5": "7973b15100ade9c7d40fb424638fde48",
    }
        
    superclass = {
    0  :['beaver', 'dolphin', 'otter', 'seal', 'whale'],
    1  :['aquarium_fish', 'flatfish', 'ray', 'shark', 'trout'],
    2  :['orchid', 'poppy', 'rose', 'sunflower', 'tulip'],
    3  :['bottle', 'bowl', 'can', 'cup', 'plate'],
    4  :['apple', 'mushroom', 'orange', 'pear', 'sweet_pepper'],
    5  :['clock', 'keyboard', 'lamp', 'telephone', 'television'],
    6  :['bed', 'chair', 'couch', 'table', 'wardrobe'],
    7  :['bee', 'beetle', 'butterfly', 'caterpillar', 'cockroach'],
    8  :['bear', 'leopard', 'lion', 'tiger', 'wolf'],
    9  :['bridge', 'castle', 'house', 'road', 'skyscraper'],
    10 :['cloud', 'forest', 'mountain', 'plain', 'sea'],
    11 :['camel', 'cattle', 'chimpanzee', 'elephant', 'kangaroo'],
    12 :['fox', 'porcupine', 'possum', 'raccoon', 'skunk'],
    13 :['crab', 'lobster', 'snail', 'spider', 'worm'],
    14 :['baby', 'boy', 'girl', 'man', 'woman'],
    15 :['crocodile', 'dinosaur', 'lizard', 'snake', 'turtle'],
    16 :['hamster', 'mouse', 'rabbit', 'shrew', 'squirrel'],
    17 :['maple_tree', 'oak_tree', 'palm_tree', 'pine_tree', 'willow_tree'],
    18 :['bicycle', 'bus', 'motorcycle', 'pickup_truck', 'train'],
    19 :['lawn_mower', 'rocket', 'streetcar', 'tank', 'tractor']}
        
    def __init__(self, *args, **kwargs):
        
        super().__init__(*args, **kwargs)
        self.group_index = None
        self.invert_index()
    
    def invert_index(self):
        
        self.group_index = dict()
        
        for key, value in self.class_to_idx.items():
            grp_index = None

            for grp_key, grp_val in self.superclass.items():                 
                if key in grp_val:
                    
                    grp_index = grp_key
                        
            self.group_index[value] = grp_index
                    
    
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index
        Returns:
            tuple: (index, image, target, category) where target is index of the target class.
        """
        img, target = self.data[self.sequence[index]], self.targets[self.sequence[index]]
        category = self.group_index[target]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return index, img, target, category #second target is category label for robust setting

"""
WATERBIRDS
_____________________________________________________

"""

class indices_WaterBirds(VisionDataset):
        
    base_folder = "waterbird_complete95_forest2water2"
    url = "https://nlp.stanford.edu/data/dro/waterbird_complete95_forest2water2.tar.gz"
    filename = "waterbird_complete95_forest2water2.tar.gz"
   
    
    def __init__(self,
                 root: str,
                 train: bool,
                 dims: tuple, 
                 transform: Optional[Callable] = None,
                 trainset_transform: Optional[Callable] = None,
                 valset_transform: Optional[Callable] = None,
                 target_transform: Optional[Callable] = None,
                 sequence=None,
                 init_loaded: bool = False):
        """
        Reformatted from the CUB Dataset class found at:
        https://github.com/kohpangwei/group_DRO/blob/master/data/cub_dataset.py
        
        Parameters
        ----------
        root : str
            DESCRIPTION.
        train : bool, optional
            DESCRIPTION. The default is True.
        transform : Optional[Callable], optional
            DESCRIPTION. The default is None.
        target_transform : Optional[Callable], optional
            DESCRIPTION. The default is None.
        sequence : TYPE, optional
            DESCRIPTION. The default is None.
        init_loaded: Bool, optional
            When True, load the dataset when the indices_WaterBirds is initialised 
            else when False, load specific data points when __getitem__ method
            is called

        Returns
        -------
        None.

        """
        self.root = root
        self.train = train
        self.dims = dims
        self.sequence = sequence 
        self.init_loaded = init_loaded
        #if transform is None: self.transform = transforms.Resize()
        self.target_transform = target_transform
        self.trainset_transform = trainset_transform
        self.valset_transform = valset_transform      
                
        super(indices_WaterBirds, self).__init__(
            root, transform=transform, target_transform=target_transform
        )
        
        #Base folder defined ^^
        self.data_dir = os.path.join(self.root, self.base_folder)          
        self.download()       
        
        #Read in the metadata:
        self.metadata = pd.read_csv(os.path.join(self.data_dir,
                                                 'metadata.csv'))

        self.targets = self.metadata['y'].values
        self.groups = self.metadata['place'].values
        self.num_classes = 2
        
        self.filename_array = self.metadata['img_filename'].values
        self.split_array = self.metadata['split'].values
        self.split_dict = {
            'train':0,
            'val':1,
            'test':2}
            
        #Process filename_array for train/test        
        if self.train: split_filter = np.isin(self.split_array,[0,1])
        else: split_filter = self.split_array==2
        
        self.filename_array = self.filename_array[split_filter]
        self.targets = list(self.targets[split_filter])
        self.groups = list(self.groups[split_filter])
        self.split = self.split_array[split_filter] #define the splits for the dataset
        
        #Create resize_transform:
        resize_transform = transforms.Resize(self.dims)      
        
        #Initialise dataset on object creation
        self.data = None
        if self.init_loaded:
            self.data = self.load_data(self.filename_array,
                                       resize_transform)
        
        #Filter data to train or test set: 
        if sequence is not None:
            self.sequence = sequence
        else:
            self.sequence = np.arange(len(self.filename_array))
            if self.train: #Find the train val indices split 
                self.train_split_indices = self.sequence[self.split == 0]
                self.val_split_indices = self.sequence[self.split == 1]
                
    def load_data(self, filenames, transforms):
        
        data = list()
        
        if isinstance(filenames, list):
            #Could write this as a parallel dataloader (uncertain what is implemented already by torch)
            for i, file in enumerate(filenames):
                
                data_point = read_image(os.path.join(self.data_dir, file)).float()            
                data.append(transforms(data_point))
                
            data = torch.stack(data)

        else:
                data_point = read_image(os.path.join(self.data_dir, filenames)).float()            
                            
        return data
        
    def assert_index_type(self, idx):
        
        #ensure indices are dtype.long
        if isinstance(idx, np.ndarray):
            idx = idx.astype(np.longlong)
        elif isinstance(idx, torch.Tensor):
            idx = idx.to(dtype=torch.long)
        else:
            idx = torch.tensor(idx, dtype=torch.long)
            
        return idx
    
    
    def __getitem__(self, idx: int):
        
        target, group, split = self.targets[self.sequence[idx]],\
                               self.groups[self.sequence[idx]],\
                               self.split[self.sequence[idx]]
                               
        if self.init_loaded:
            img = self.data[self.sequence[idx]]
        else:
            print('here')
            img = self.load_data(self.filename_array[self.sequence[idx]],
                                 self.transform)
                                  
        assert len(np.unique(split)) == 1, 'indices_waterbirds split should only have one unqiue elements'

        #reduce image to [0,1] interval from {0, ..., 256}:
        #if img.max() > 1.:
        #    img /= 256

        #Apply transforms to images:
        #As we will only sample from val or train or test split:
        #if split == 0:
        #    img = self.trainset_transform(img)
        #elif split == 1:
        #    img = self.valset_transform(img)
        #else:
        #    img = self.transform(img)   
                    
        if self.target_transform is not None:
            target = self.target_transform(target)
            group = self.target_transform(group) #Can specify an exact group transform if needs be...
                
        return idx, img, target, group
    
    def download(self):
        
        if not os.path.exists(self.data_dir):
            print(f'File {self.data_dir} does not exist downloading...')
            download_and_extract_archive(self.url, self.root)

        else:
            print(f'Files {self.data_dir} already downloaded')
        
        
            
    def __len__(self):
        return len(self.filename_array)


"""
ImageFolder
_____________________________________________________

"""

import os
import os.path
from typing import Any, Callable, Dict, List, Optional, Tuple, cast

from PIL import Image
from torchvision.datasets.vision import VisionDataset


def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:
    """Checks if a file is an allowed extension.

    Args:
        filename (string): path to a file
        extensions (tuple of strings): extensions to consider (lowercase)

    Returns:
        bool: True if the filename ends with one of given extensions
    """
    return filename.lower().endswith(extensions)


def is_image_file(filename: str) -> bool:
    """Checks if a file is an allowed image extension.

    Args:
        filename (string): path to a file

    Returns:
        bool: True if the filename ends with a known image extension
    """
    return has_file_allowed_extension(filename, IMG_EXTENSIONS)


def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
    """Finds the class folders in a dataset.

    See :class:`DatasetFolder` for details.
    """
    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
    if not classes:
        raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx


def make_dataset(
    directory: str,
    class_to_idx: Optional[Dict[str, int]] = None,
    extensions: Optional[Tuple[str, ...]] = None,
    is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
    """Generates a list of samples of a form (path_to_sample, class).

    See :class:`DatasetFolder` for details.

    Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
    by default.
    """
    directory = os.path.expanduser(directory)

    if class_to_idx is None:
        _, class_to_idx = find_classes(directory)
    elif not class_to_idx:
        raise ValueError(
            "'class_to_index' must have at least one entry to collect any samples."
        )

    both_none = extensions is None and is_valid_file is None
    both_something = extensions is not None and is_valid_file is not None
    if both_none or both_something:
        raise ValueError(
            "Both extensions and is_valid_file cannot be None or not None at the same time"
        )

    if extensions is not None:

        def is_valid_file(x: str) -> bool:
            return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))

    is_valid_file = cast(Callable[[str], bool], is_valid_file)

    instances = []
    available_classes = set()
    for target_class in sorted(class_to_idx.keys()):
        class_index = class_to_idx[target_class]
        target_dir = os.path.join(directory, target_class)
        if not os.path.isdir(target_dir):
            continue
        for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
            for fname in sorted(fnames):
                if is_valid_file(fname):
                    path = os.path.join(root, fname)
                    item = path, class_index
                    instances.append(item)

                    if target_class not in available_classes:
                        available_classes.add(target_class)

    empty_classes = set(class_to_idx.keys()) - available_classes
    if empty_classes:
        msg = (
            f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
        )
        if extensions is not None:
            msg += f"Supported extensions are: {', '.join(extensions)}"
        raise FileNotFoundError(msg)

    return instances


class indices_DatasetFolder(VisionDataset):
    """The standard PyTorch class, modified so that each datapoint is returned
    as (index, sample, target) instead of (sample, target). Index is the global
    index of each datapoint w.r.t. the dataset. What follows is the docstring of
    the standard PyTorch class.


    A generic data loader.

    This default directory structure can be customized by overriding the
    :meth:`find_classes` method.

    Args:
        root (string): Root directory path.
        loader (callable): A function to load a sample given its path.
        extensions (tuple[string]): A list of allowed extensions.
            both extensions and is_valid_file should not be passed.
        transform (callable, optional): A function/transform that takes in
            a sample and returns a transformed version.
            E.g, ``transforms.RandomCrop`` for images.
        target_transform (callable, optional): A function/transform that takes
            in the target and transforms it.
        is_valid_file (callable, optional): A function that takes path of a file
            and check if the file is a valid file (used to check of corrupt files)
            both extensions and is_valid_file should not be passed.

     Attributes:
        classes (list): List of the class names sorted alphabetically.
        class_to_idx (dict): Dict with items (class_name, class_index).
        samples (list): List of (sample path, class_index) tuples
        targets (list): The class_index value for each image in the dataset
    """

    def __init__(
        self,
        root: str,
        loader: Callable[[str], Any],
        extensions: Optional[Tuple[str, ...]] = None,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        is_valid_file: Optional[Callable[[str], bool]] = None,
    ) -> None:
        super(indices_DatasetFolder, self).__init__(
            root, transform=transform, target_transform=target_transform
        )
        classes, class_to_idx = self.find_classes(self.root)
        samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)

        self.loader = loader
        self.extensions = extensions

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples  # the targets are **also** here!
        self.targets = [s[1] for s in samples]

        self.sequence = np.arange(len(self.samples))

    @staticmethod
    def make_dataset(
        directory: str,
        class_to_idx: Dict[str, int],
        extensions: Optional[Tuple[str, ...]] = None,
        is_valid_file: Optional[Callable[[str], bool]] = None,
    ) -> List[Tuple[str, int]]:
        """Generates a list of samples of a form (path_to_sample, class).

        This can be overridden to e.g. read files from a compressed zip file instead of from the disk.

        Args:
            directory (str): root dataset directory, corresponding to ``self.root``.
            class_to_idx (Dict[str, int]): Dictionary mapping class name to class index.
            extensions (optional): A list of allowed extensions.
                Either extensions or is_valid_file should be passed. Defaults to None.
            is_valid_file (optional): A function that takes path of a file
                and checks if the file is a valid file
                (used to check of corrupt files) both extensions and
                is_valid_file should not be passed. Defaults to None.

        Raises:
            ValueError: In case ``class_to_idx`` is empty.
            ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
            FileNotFoundError: In case no valid file was found for any class.

        Returns:
            List[Tuple[str, int]]: samples of a form (path_to_sample, class)
        """
        if class_to_idx is None:
            # prevent potential bug since make_dataset() would use the class_to_idx logic of the
            # find_classes() function, instead of using that of the find_classes() method, which
            # is potentially overridden and thus could have a different logic.
            raise ValueError("The class_to_idx parameter cannot be None.")
        return make_dataset(
            directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file
        )

    def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
        """Find the class folders in a dataset structured as follows::

            directory/
            ├── class_x
            │   ├── xxx.ext
            │   ├── xxy.ext
            │   └── ...
            │       └── xxz.ext
            └── class_y
                ├── 123.ext
                ├── nsdf3.ext
                └── ...
                └── asd932_.ext

        This method can be overridden to only consider
        a subset of classes, or to adapt to a different dataset directory structure.

        Args:
            directory(str): Root directory path, corresponding to ``self.root``

        Raises:
            FileNotFoundError: If ``dir`` has no class folders.

        Returns:
            (Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index.
        """
        return find_classes(directory)

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, _ = self.samples[self.sequence[index]]
        target = self.targets[
            int(self.sequence[index])
        ]  # get the target from the targets list, rather than from the samples list when we actually get the item. Otherwise our dataset corruption doesn't work
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return index, sample, target, target #for the time being classes will be defined as groups

    def __len__(self) -> int:
        return len(self.sequence)


IMG_EXTENSIONS = (
    ".jpg",
    ".jpeg",
    ".png",
    ".ppm",
    ".bmp",
    ".pgm",
    ".tif",
    ".tiff",
    ".webp",
)


def pil_loader(path: str) -> Image.Image:
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, "rb") as f:
        img = Image.open(f)
        return img.convert("RGB")


def accimage_loader(path: str) -> Any:
    import accimage

    try:
        return accimage.Image(path)
    except IOError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)


def default_loader(path: str) -> Any:
    from torchvision import get_image_backend

    if get_image_backend() == "accimage":
        return accimage_loader(path)
    else:
        return pil_loader(path)


class indices_ImageFolder(indices_DatasetFolder):
    """The standard PyTorch class, modified so that each datapoint is returned
    as (index, sample, target) instead of (sample, target). Index is the global
    index of each datapoint w.r.t. the dataset. What follows is the docstring of
    the standard PyTorch class.

    A generic data loader where the images are arranged in this way by default: ::

        root/dog/xxx.png
        root/dog/xxy.png
        root/dog/[...]/xxz.png

        root/cat/123.png
        root/cat/nsdf3.png
        root/cat/[...]/asd932_.png

    This class inherits from :class:`~torchvision.datasets.DatasetFolder` so
    the same methods can be overridden to customize the dataset.

    Args:
        root (string): Root directory path.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an image given its path.
        is_valid_file (callable, optional): A function that takes path of an Image file
            and check if the file is a valid file (used to check of corrupt files)

     Attributes:
        classes (list): List of the class names sorted alphabetically.
        class_to_idx (dict): Dict with items (class_name, class_index).
        imgs (list): List of (image path, class_index) tuples
    """

    def __init__(
        self,
        root: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        loader: Callable[[str], Any] = default_loader,
        is_valid_file: Optional[Callable[[str], bool]] = None,
    ):
        super(indices_ImageFolder, self).__init__(
            root,
            loader,
            IMG_EXTENSIONS if is_valid_file is None else None,
            transform=transform,
            target_transform=target_transform,
            is_valid_file=is_valid_file,
        )
        self.imgs = self.samples


import warnings
from contextlib import contextmanager
import os
import shutil
import tempfile
from typing import Any, Dict, List, Iterator, Optional, Tuple
import torch

ARCHIVE_META = {
    "train": ("ILSVRC2012_img_train.tar", "1d675b47d978889d74fa0da5fadfb00e"),
    "val": ("ILSVRC2012_img_val.tar", "29b22e2961454d5413ddabcf34fc5622"),
    "devkit": ("ILSVRC2012_devkit_t12.tar.gz", "fa75699e90414af021442c21a62c3abf"),
}

META_FILE = "meta.bin"


class indices_ImageNet(indices_ImageFolder):
    """`ImageNet <http://image-net.org/>`_ 2012 Classification Dataset.

    Args:
        root (string): Root directory of the ImageNet Dataset.
        split (string, optional): The dataset split, supports ``train``, or ``val``.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an image given its path.

     Attributes:
        classes (list): List of the class name tuples.
        class_to_idx (dict): Dict with items (class_name, class_index).
        wnids (list): List of the WordNet IDs.
        wnid_to_idx (dict): Dict with items (wordnet_id, class_index).
        imgs (list): List of (image path, class_index) tuples
        targets (list): The class_index value for each image in the dataset
    """

    def __init__(
        self,
        root: str,
        split: str = "train",
        download: Optional[str] = None,
        **kwargs: Any,
    ) -> None:
        if download is True:
            msg = (
                "The dataset is no longer publicly accessible. You need to "
                "download the archives externally and place them in the root "
                "directory."
            )
            raise RuntimeError(msg)
        elif download is False:
            msg = (
                "The use of the download flag is deprecated, since the dataset "
                "is no longer publicly accessible."
            )
            warnings.warn(msg, RuntimeWarning)

        root = self.root = os.path.expanduser(root)
        self.split = verify_str_arg(split, "split", ("train", "val"))

        self.parse_archives()
        wnid_to_classes = load_meta_file(self.root)[0]

        super(indices_ImageNet, self).__init__(self.split_folder, **kwargs)
        self.root = root

        self.wnids = self.classes
        self.wnid_to_idx = self.class_to_idx
        self.classes = [wnid_to_classes[wnid] for wnid in self.wnids]
        self.class_to_idx = {
            cls: idx for idx, clss in enumerate(self.classes) for cls in clss
        }

    def parse_archives(self) -> None:
        if not check_integrity(os.path.join(self.root, META_FILE)):
            parse_devkit_archive(self.root)

        if not os.path.isdir(self.split_folder):
            if self.split == "train":
                parse_train_archive(self.root)
            elif self.split == "val":
                parse_val_archive(self.root)

    @property
    def split_folder(self) -> str:
        return os.path.join(self.root, self.split)

    def extra_repr(self) -> str:
        return "Split: {split}".format(**self.__dict__)


def load_meta_file(
    root: str, file: Optional[str] = None
) -> Tuple[Dict[str, str], List[str]]:
    if file is None:
        file = META_FILE
    file = os.path.join(root, file)

    if check_integrity(file):
        return torch.load(file)
    else:
        msg = (
            "The meta file {} is not present in the root directory or is corrupted. "
            "This file is automatically created by the ImageNet dataset."
        )
        raise RuntimeError(msg.format(file, root))


def _verify_archive(root: str, file: str, md5: str) -> None:
    if not check_integrity(os.path.join(root, file), md5):
        msg = (
            "The archive {} is not present in the root directory or is corrupted. "
            "You need to download it externally and place it in {}."
        )
        raise RuntimeError(msg.format(file, root))


def parse_devkit_archive(root: str, file: Optional[str] = None) -> None:
    """Parse the devkit archive of the ImageNet2012 classification dataset and save
    the meta information in a binary file.

    Args:
        root (str): Root directory containing the devkit archive
        file (str, optional): Name of devkit archive. Defaults to
            'ILSVRC2012_devkit_t12.tar.gz'
    """
    import scipy.io as sio

    def parse_meta_mat(devkit_root: str) -> Tuple[Dict[int, str], Dict[str, str]]:
        metafile = os.path.join(devkit_root, "data", "meta.mat")
        meta = sio.loadmat(metafile, squeeze_me=True)["synsets"]
        nums_children = list(zip(*meta))[4]
        meta = [
            meta[idx]
            for idx, num_children in enumerate(nums_children)
            if num_children == 0
        ]
        idcs, wnids, classes = list(zip(*meta))[:3]
        classes = [tuple(clss.split(", ")) for clss in classes]
        idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
        wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)}
        return idx_to_wnid, wnid_to_classes

    def parse_val_groundtruth_txt(devkit_root: str) -> List[int]:
        file = os.path.join(
            devkit_root, "data", "ILSVRC2012_validation_ground_truth.txt"
        )
        with open(file, "r") as txtfh:
            val_idcs = txtfh.readlines()
        return [int(val_idx) for val_idx in val_idcs]

    @contextmanager
    def get_tmp_dir() -> Iterator[str]:
        tmp_dir = tempfile.mkdtemp()
        try:
            yield tmp_dir
        finally:
            shutil.rmtree(tmp_dir)

    archive_meta = ARCHIVE_META["devkit"]
    if file is None:
        file = archive_meta[0]
    md5 = archive_meta[1]

    _verify_archive(root, file, md5)

    with get_tmp_dir() as tmp_dir:
        extract_archive(os.path.join(root, file), tmp_dir)

        devkit_root = os.path.join(tmp_dir, "ILSVRC2012_devkit_t12")
        idx_to_wnid, wnid_to_classes = parse_meta_mat(devkit_root)
        val_idcs = parse_val_groundtruth_txt(devkit_root)
        val_wnids = [idx_to_wnid[idx] for idx in val_idcs]

        torch.save((wnid_to_classes, val_wnids), os.path.join(root, META_FILE))


def parse_train_archive(
    root: str, file: Optional[str] = None, folder: str = "train"
) -> None:
    """Parse the train images archive of the ImageNet2012 classification dataset and
    prepare it for usage with the ImageNet dataset.

    Args:
        root (str): Root directory containing the train images archive
        file (str, optional): Name of train images archive. Defaults to
            'ILSVRC2012_img_train.tar'
        folder (str, optional): Optional name for train images folder. Defaults to
            'train'
    """
    archive_meta = ARCHIVE_META["train"]
    if file is None:
        file = archive_meta[0]
    md5 = archive_meta[1]

    _verify_archive(root, file, md5)

    train_root = os.path.join(root, folder)
    extract_archive(os.path.join(root, file), train_root)

    archives = [os.path.join(train_root, archive) for archive in os.listdir(train_root)]
    for archive in archives:
        extract_archive(archive, os.path.splitext(archive)[0], remove_finished=True)


def parse_val_archive(
    root: str,
    file: Optional[str] = None,
    wnids: Optional[List[str]] = None,
    folder: str = "val",
) -> None:
    """Parse the validation images archive of the ImageNet2012 classification dataset
    and prepare it for usage with the ImageNet dataset.

    Args:
        root (str): Root directory containing the validation images archive
        file (str, optional): Name of validation images archive. Defaults to
            'ILSVRC2012_img_val.tar'
        wnids (list, optional): List of WordNet IDs of the validation images. If None
            is given, the IDs are loaded from the meta file in the root directory
        folder (str, optional): Optional name for validation images folder. Defaults to
            'val'
    """
    archive_meta = ARCHIVE_META["val"]
    if file is None:
        file = archive_meta[0]
    md5 = archive_meta[1]
    if wnids is None:
        wnids = load_meta_file(root)[1]

    _verify_archive(root, file, md5)

    val_root = os.path.join(root, folder)
    extract_archive(os.path.join(root, file), val_root)

    images = sorted([os.path.join(val_root, image) for image in os.listdir(val_root)])

    for wnid in set(wnids):
        os.mkdir(os.path.join(val_root, wnid))

    for wnid, img_file in zip(wnids, images):
        shutil.move(img_file, os.path.join(val_root, wnid, os.path.basename(img_file)))


"""
---------------------------------FUSION DATASET -------------------------------
"""

#%%
import pandas as pd
from torch.utils.data import Dataset

class indices_Fusion(Dataset):
    
    """Face Landmarks dataset."""

    def __init__(self, root: str,
                     train: bool = True,
                     transform:None = None,
                     target_transform:None = None,
                     sequence=None):

        self.root = root 
        
        #Setup to download from current root directory:
        dataset = pd.read_pickle(root)
        
        #Remove regression labels:
        columns = [col for col in dataset.columns\
                   if col not in ["efiitg_gb","efetem_gb","efeetg_gb"]]
            
        dataset = dataset[columns]
        
        #Process true labels:
        unstable_filter = (dataset['ITG'] == 1) |\
                          (dataset['ETG'] == 1) |\
                         (dataset['TEM'] == 1)
                          
        stable_filter = (dataset['ITG'] == 0) & \
                        (dataset['TEM'] == 0) & \
                        (dataset['ETG'] == 0)

        #Remove overlapping label cases:
        dataset = dataset[unstable_filter|stable_filter]
        
        dataset['label'] = 0 
        dataset.loc[dataset['ITG'] == 1, 'label'] = 1
        dataset.loc[dataset['ETG'] == 1, 'label'] = 2
        dataset.loc[dataset['TEM'] == 1, 'label'] = 3
        
        self.targets = torch.tensor(dataset['label'].values)
        self.dataset = torch.tensor(
            dataset.loc[:, dataset.columns != 'label'].values)
        

    def __len__(self):
        return len(self.dataset)
    
    def download(self):
        #We need to set this up to get the dataset from google
        pass

    def __getitem__(self, index):
        
        data = self.dataset[index]
        target = self.targets[index]

        return index, data, target, target
