import nninfo
import numpy as np
from numpy.random import Philox, Generator
import torch.utils.data
import copy
import torchvision
import torchvision.transforms as transforms
from abc import ABC, abstractmethod
import os.path
from os import path

SUBSET_SYMBOL = "/"  # <dataset>/<subset>/<subsubset> etc.

log = nninfo.log.get_logger(__name__)


class TaskManager(nninfo.exp_comp.ExperimentComponent):
    """
    Helper class, handles a task with one 'full_set' dataset and several subsets,
    which are then used by Trainer, Tester or Evaluation.
    """

    def __init__(self, task_id=None, reload=False, component_dir=None, **kwargs):
        """
        Creates a new instance of TaskManager. Loads the
        dataset according to task_id.

        Keyword Args:
            task_id (str): one of the pre-implemented Tasks:
                'tishby_dat', 'andreas_dat', 'fake_dat' etc.
            reload (bool): whether to reload the

        Passes additional arguments on to dataset creation if given.
        """

        super().__init__()

        self._task = None
        self._dataset = None
        self._kwargs = kwargs
        self._save_data = False

        if not reload:
            self._init_datasets(task_id, **kwargs)
        else:
            self._reload_datasets(component_dir)

    def _init_datasets(self, task_id, **kwargs):
        """
        Args:
            task_id (str): one of the pre-implemented Tasks:
                'tishby_dat', 'andreas_dat', 'fake_dat' etc.
        """

        task_kwargs = kwargs["task_kwargs"] if "task_kwargs" in kwargs else dict()
        self._task = Task.from_id(task_id, **task_kwargs)

        if self._task.finite:
            self._dataset = CachedDataset(self._task, "full_set")
        else:
            seed = kwargs["seed"] if "seed" in kwargs else None
            n_samples = kwargs["n_samples"]
            self._dataset = LazyDataset(self._task, "full_set", n_samples, seed)

    def save(self, component_dir):
        """
        Saves indices and all relevant settings
        for the task's dataset and subsets to the task subdirectory.
        This ensures later usage of the task to use identical samples.

        If save_data is true, the data itself is also stored.
        """

        def get_branch(dataset):
            branch_dict = dict()
            for subset in dataset.children:
                branch_dict[subset.name] = [
                    get_branch(subset),
                    subset.indices.tolist(),
                ]
            return branch_dict

        task_saver = nninfo.file_io.FileManager(component_dir, write=True)

        output_dict = {
            "task_id": self._task.task_id,
            "kwarg_dict": self._kwargs,
            "datasets_tree": get_branch(self._dataset),
            "save_data": self._save_data,
            "task_kwargs": self._task.kwargs,
        }

        if isinstance(self._dataset, LazyDataset):
            output_dict["n_samples"] = len(self._dataset)
            output_dict["seed"] = self._dataset.seed

        if self._save_data:
            xy = np.array(self._dataset)
            x_data, y_data = xy[:, 0], xy[:, 1]
            log.info("Saving input x_data and output y_data to npy.")
            task_saver.write(x_data, "x_data.npy")
            task_saver.write(y_data, "y_data.npy")

        log.info("Saving task structure to json.")
        task_saver.write(output_dict, "task.json")

    def _reload_datasets(self, component_dir):
        """
        Is called when an experiment from a different session is loaded or a different,
        previously used task needs to be loaded.

        Sets all the necessary task settings and reloads the data.
        """

        def load_branch(parent_dataset, branch_dict):
            for subset_name, value in branch_dict.items():
                children_branch_dict = value[0]
                indices = np.array(value[1])
                parent_dataset.create_subset(subset_name, indices)
                load_branch(parent_dataset._children[-1], children_branch_dict)

        task_loader = nninfo.file_io.FileManager(component_dir, read=True)
        d = task_loader.read("task.json")

        task_id = d["task_id"] if "task_id" in d else d["data_key"]  # legacy support
        self._kwargs = d["kwarg_dict"]
        datasets_tree = d["datasets_tree"]
        task_kwargs = d["task_kwargs"] if "task_kwargs" in d else dict()

        self._task = Task.from_id(task_id, **task_kwargs)

        if d["save_data"]:
            # If saved data exists, load as cached dataset
            self._save_data = True
            self._dataset = CachedDataset(self._task, "full_set", component_dir)
        elif self._task.finite:
            self._dataset = CachedDataset(self._task, "full_set")
        else:
            seed = d["seed"]
            n_samples = self._kwargs["n_samples"]
            self._dataset = LazyDataset(self._task, "full_set", n_samples, seed)

        load_branch(self._dataset, datasets_tree)

    def get_binning_limits(self, label):
        if label == "X":
            return self._task.x_limits
        elif label == "Y":
            return self._task.y_limits
        else:
            raise AttributeError

    def get_input_output_dimensions(self):
        return self._task.x_dim, self._task.y_dim

    def __getitem__(self, dataset_name):
        """
        Finds the dataset by the given name in the dataset tree
        """
        return self._dataset.find(dataset_name)

    def __str__(self, level=0):
        ret = self._dataset.__str__()
        return ret

    @property
    def save_data(self):
        return self._save_data

    @save_data.setter
    def save_data(self, save_data):
        self._save_data = save_data


class DataSet(torch.utils.data.Dataset):
    def __init__(self, task, name):
        self._task = task
        self._name = name
        self._children = []

    def __str__(self, level=0):
        """
        Recursive function that allows for printing of the Dataset Tree / Subset Branch.

        Args:
            level (int): level of branch

        Returns:
            str: Representation of this branch (Tree).
        """
        ret = "\t" * level + self.__repr__() + "\n"
        for child in self._children:
            ret += child.__str__(level=level + 1)
        return ret

    def __repr__(self):
        return (
            self._name
            + ": \t"
            + str(len(self))
            + " elements."
            + ("(lazy)" if isinstance(self, LazyDataset) else "")
            + ("(cached)" if isinstance(self, CachedDataset) else "")
        )

    def find(self, dataset_name):
        """
        Depth-first search for dataset_name in the dataset tree
        """
        if self._name == dataset_name:
            return self
        else:
            for chld in self._children:
                result = chld.find(dataset_name)
                if not result is None:
                    return result
            return None

    def create_subset(self, name, keep_idx):
        child = SubSet(self, name, keep_idx)
        self._children.append(child)

    def train_test_val_random_split(self, train_len, test_len, val_len):
        total_len = train_len + test_len + val_len
        if len(self) != total_len:
            print(
                "Split can only be performed if the subdatasets total"
                "length matches with the length of the dataset"
            )
            return [None, None, None]
        else:
            # create random indices
            indices = np.random.permutation(total_len)
            train_idx = indices[:train_len]
            test_idx = indices[train_len : train_len + test_len]
            val_idx = indices[train_len + test_len :]

            train_name = self._name + SUBSET_SYMBOL + "train"
            test_name = self._name + SUBSET_SYMBOL + "test"
            val_name = self._name + SUBSET_SYMBOL + "val"

            # create subsets
            self.create_subset(train_name, train_idx)
            self.create_subset(test_name, test_idx)
            self.create_subset(val_name, val_idx)
            return [train_name, test_name, val_name]

    def train_test_val_sequential_split(self, train_len, test_len, val_len):
        assert train_len + test_len + val_len == len(self), 'Split can only be performed if the subsets comprise the whole set'

        train_name = self._name + SUBSET_SYMBOL + "train"
        test_name = self._name + SUBSET_SYMBOL + "test"
        val_name = self._name + SUBSET_SYMBOL + "val"

        self.create_subset(train_name, np.array(range(0, train_len)))
        self.create_subset(test_name, np.array(range(train_len, train_len + test_len)))
        self.create_subset(val_len, np.array(range(train_len + test_len, train_len + test_len + val_len)))

    def one_class_split(self):
        return self._class_wise_split("one_class")

    def all_but_one_class_split(self):
        return self._class_wise_split("all_but_one_class")

    def multiple_class_split(self, class_list):
        return self._class_wise_split("multiple", class_list=class_list)

    def _class_wise_split(self, method, class_list=None):
        dataset_labels_np = self._y# if self is CachedDataset else np.array(self)[:, 1]
        # TODO: Test this for one-hot-labels
        classes = np.unique(dataset_labels_np)
        return_name_list = []
        if method == "one_class":
            for cls_idx, cls in enumerate(classes):
                keep_idx = np.where(dataset_labels_np == cls)[0]
                cls_name = self._name + SUBSET_SYMBOL + "class_" + str(cls_idx)
                self.create_subset(cls_name, keep_idx)
                return_name_list.append(cls_name)
        elif method == "all_but_one_class":
            for cls_idx, cls in enumerate(classes):
                keep_idx = np.where(dataset_labels_np != cls)[0]
                cls_name = self._name + SUBSET_SYMBOL + "not_class_" + str(cls_idx)
                self.create_subset(cls_name, keep_idx)
                return_name_list.append(cls_name)
        elif method == "multiple":
            keep_idx_list = []
            cls_name = self._name + SUBSET_SYMBOL + "multiple"
            for el in class_list:
                temp = np.where(dataset_labels_np == el)[0].tolist()
                print(type(temp))
                keep_idx_list.extend(temp)
                cls_name += "_" + str(el)
            keep_idx = np.array(keep_idx_list)
            self.create_subset(cls_name, keep_idx)
        else:
            raise NotImplementedError
        return return_name_list

    @property
    def name(self):
        return self._name

    @property
    def task(self):
        return self._task

    @property
    def children(self):
        return self._children


class SubSet(DataSet):
    """
    SubSet of data, that is directly connected to the original 'full_set', and takes its samples
    from there. Never change the order of the parent Dataset!
    """

    def __init__(self, parent, name, indices):
        super().__init__(parent.task, name)
        self._parent = parent
        self._indices = indices

    def __getitem__(self, idx):
        return self.parent[self._indices[idx]]

    def __len__(self):
        return len(self._indices)

    @property
    def parent(self):
        return self._parent

    @property
    def indices(self):
        return self._indices


class CachedDataset(DataSet):
    def __init__(self, task, name, component_dir=None):
        super().__init__(task, name)
        x, y = self._task.load_samples(component_dir)
        self._x, self._y = x, y

    def __len__(self):
        return self._x.shape[0]

    def __getitem__(self, idx):
        return self._x[idx], self._y[idx]


class LazyDataset(DataSet):
    def __init__(self, task, name, length, seed=None, condition=None):
        super().__init__(task, name)
        assert not task.finite
        self._len = length

        if not seed is None:
            self._seed = seed
        else:
            seed_generator = Generator(Philox())
            self._seed = int(
                seed_generator.integers(np.iinfo(np.int64).max)
            )  # Has to be int to be json serializable

        self._philox = Philox(seed)
        self._philox_state = self._philox.state
        self._rng = Generator(self._philox)

        self._condition = condition

    def extend(self, n_samples):
        """
        Get a copy of the dataset with a different length
        """
        name_ext = self._name + "_ext"
        return LazyDataset(self._task, name_ext, n_samples, self._seed)

    def condition(self, n_samples, condition):
        """
        Get a copy of the dataset with a condition
        """
        name_cond = self._name + "_cond"
        return LazyDataset(
            self._task, name_cond, n_samples, self._seed, condition=condition
        )

    @property
    def seed(self):
        return self._seed

    def __len__(self):
        return self._len

    def __getitem__(self, idx):
        if idx >= self._len:
            raise IndexError

        self._philox_state["state"]["counter"][-1] = idx
        self._philox.state = self._philox_state
        return self._task.generate_sample(self._rng, condition=self._condition)

def binary_encode_label(y, bits):
    """Encode the label tensor y as a tensor of bits."""
    mask = 2**torch.arange(bits).to(y.device, y.dtype)
    return y.unsqueeze(-1).bitwise_and(mask).ne(0).byte()

def quaternary_encode_label(y, quits):
    """Encode the label tensor y as a tensor of quits."""
    binary = binary_encode_label(y, 2*quits)
    return binary[:, ::2] + 2 * binary[:, 1::2]

def octal_encode_label(y, octs):
    """Encode the label tensor y as a tensor of octs."""
    quaternary = quaternary_encode_label(y, 2*octs)
    return quaternary[:, ::2] + 4 * quaternary[:, 1::2]

class Task(ABC):
    def __init__(self, **kwargs):
        self._kwargs = kwargs

    @property
    def kwargs(self):
        return self._kwargs

    @property
    @abstractmethod
    def finite(self):
        raise NotImplementedError

    @property
    @abstractmethod
    def task_id(self):
        raise NotImplementedError

    @property
    @abstractmethod
    def x_limits(self):
        raise NotImplementedError

    @property
    @abstractmethod
    def y_limits(self):
        raise NotImplementedError

    @property
    @abstractmethod
    def x_dim(self):
        """
        Returns the dimension of the feature component of a single sample
        """
        raise NotImplementedError

    @property
    @abstractmethod
    def y_dim(self):
        """
        Returns the dimension of the label component of a single sample
        """
        raise NotImplementedError

    @classmethod
    def from_id(cls, task_id, **kwargs):
        if task_id == "xor_dat":
            return XorTask(**kwargs)
        elif task_id == "tishby_dat":
            return TishbyTask(**kwargs)
        elif task_id == "parity":
            return ParityTask(**kwargs)
        elif task_id == "andreas_dat":
            return AndreasTask(**kwargs)
        elif task_id == "fake_dat":
            return FakeTask(**kwargs)
        elif task_id == "checkerboard":
            return CheckerboardTask(**kwargs)
        elif task_id == "mnist_1d_dat":
            return Mnist1DTask(**kwargs)
        elif task_id == "mnist_binary_dat":
            return MnistBinaryTask(**kwargs)
        elif task_id == "mnist8_binary_dat":
            return MnistEightBinaryTask(**kwargs)
        elif task_id == "emnist_1d_dat":
            return EMnist1DTask(**kwargs)    
        elif task_id == "mnist_reduced_dat":
            return ReducedMnistTask(**kwargs)
        elif task_id == "XorTaskMissInfo_dat":
            return XorTaskMissInfo(**kwargs)
        elif task_id == "fashion_mnist_1d_dat":
            return FashionMnistTask(**kwargs)
        elif task_id == "combined_mnist_1d_dat":
            return CombinedMnistTask(**kwargs)
        elif task_id == "combined_mnist_binary_dat":
            return CombinedMnistBinaryTask(**kwargs)
        elif task_id == "combined_mnist_quaternary_dat":
            return CombinedMnistQuaternaryTask(**kwargs)
        elif task_id == "combined_mnist_octal_dat":
            return CombinedMnistOctalTask(**kwargs)
        elif task_id == "cifar10_1d_dat":
            return CIFAR10Task(**kwargs)
        else:
            raise NotImplementedError('Task "{}" not found.'.format(task_id))

    def generate_sample(self, rng, condition=None):
        raise NotImplementedError(
            "Finite-sample tasks do not support the generation of samples."
        )

    def load_samples(self, component_dir):
        task_loader = nninfo.file_io.FileManager(component_dir, read=True)
        x_data = task_loader.read("x_data.npy")
        y_data = task_loader.read("y_data.npy")
        return torch.tensor(x_data, dtype=torch.float), torch.tensor(y_data,dtype=torch.float)


class TishbyTask(Task):
    @property
    def finite(self):
        return True

    @property
    def task_id(self):
        return "tishby_dat"

    @property
    def x_limits(self):
        return "binary"

    @property
    def y_limits(self):
        return "binary"

    @property
    def x_dim(self):
        return 12

    @property
    def y_dim(self):
        return 1

    def load_samples(self, component_dir):
        self._data_location = "../data/Tishby_2017/"
        dataset_reader = nninfo.file_io.FileManager(self._data_location, read=True)
        data_dict = dataset_reader.read("var_u.mat")
        x = data_dict["F"]
        y = data_dict["y"].T
        return torch.tensor(x, dtype=torch.float), torch.tensor(y,dtype=torch.float)


class XorTask(Task):
    @property
    def finite(self):
        return False

    @property
    def task_id(self):
        return "xor_dat"

    @property
    def x_limits(self):
        return (0, 1)

    @property
    def y_limits(self):
        return "binary"

    @property
    def x_dim(self):
        return 2

    @property
    def y_dim(self):
        return 1

    def generate_sample(self, rng, condition=None):
        x = rng.random(2, dtype=np.float32)
        y = (x[0] >= 0.5) ^ (x[1] >= 0.5)
        return x, torch.tensor([y], dtype=torch.float)

class CheckerboardTask(Task):
    @property
    def finite(self):
        return False

    @property
    def task_id(self):
        return "checkerboard"

    @property
    def x_limits(self):
        return (0, 1)

    @property
    def y_limits(self):
        return "binary"

    @property
    def x_dim(self):
        return 2

    @property
    def y_dim(self):
        return 1

    def generate_sample(self, rng, condition=None):
        size = self._kwargs['size']
        x = rng.random(2, dtype=np.float32)
        y = (int(x[0] * size[0]) + int(x[1] * size[1])) % 2
        return x, torch.tensor([y], dtype=torch.float)

    

class XorTaskMissInfo(Task):
    @property
    def finite(self):
        return False

    @property
    def task_id(self):
        return "XorTaskMissInfo_dat"

    @property
    def x_limits(self):
        return (0, 1)

    @property
    def y_limits(self):
        return "binary"

    @property
    def x_dim(self):
        return 3

    @property
    def y_dim(self):
        return 1

    def generate_sample(self, rng, condition=None):
        x = rng.random(3, dtype=np.float32)
        y = (x[0] >= 0.5) ^ (x[1] >= 0.5)
        return x, torch.tensor([y], dtype=torch.float)
    
    
class ParityTask(Task):
    @property
    def finite(self):
        return False

    @property
    def task_id(self):
        return "parity"

    @property
    def x_limits(self):
        return (0, 1) if self._kwargs["continuous"] else "binary"

    @property
    def y_limits(self):
        return "binary"

    @property
    def x_dim(self):
        return self._kwargs["n_bits"]

    @property
    def y_dim(self):
        return 1

    def generate_sample(self, rng, condition=None):
        n_bits = self._kwargs["n_bits"]

        if self._kwargs["continuous"]:
            x = rng.random(size=n_bits, dtype=np.float32)
            y = (x >= 0.5).sum() % 2
        else:
            x = rng.integers(2, size=n_bits)
            y = x.sum() % 2

        return torch.tensor(x, dtype=torch.float), torch.tensor([y], dtype=torch.float)


class RecMajorityTask(Task):
    def __init__(self, **kwargs):
        """
        Expected kwargs:
            voter_list: list of numbers of voters for each recursive layer
        """
        super().__init__(self, kwargs)

        assert "voter_list" in kwargs

    def task_id(self):
        return "rec_maj"

    def x_limits(self):
        return "binary"

    def y_limits(self):
        return "binary"

    def x_dim(self):
        return self._kwargs["voter_list"][0]

    def y_dim(self):
        return 1

    def generate_sample(self, rng):
        x = rng.integers(1, size=10)


def generate_andreas_data(n_input_bits=12, thresh=None, noise_level=0.0):
    # difficulty level of the task
    problem_size = n_input_bits
    # create random numbers, write them in bits
    n_samples = 2 ** n_input_bits
    d = np.linspace(0, n_samples - 1, n_samples, endpoint=True, dtype=np.uint32)[
        :, np.newaxis
    ]
    bits = np.unpackbits(np.flip(d.view("uint8")), axis=1)
    features = bits[:, bits.shape[1] - problem_size :]


class AndreasTask(Task):
    @property
    def finite(self):
        return True

    @property
    def task_id(self):
        return "andreas_dat"

    @property
    def x_limits(self):
        return "binary"

    @property
    def y_limits(self):
        return "binary"

    @property
    def x_dim(self):
        if "x_dim" in self._kwargs:
            x_dim = self._kwargs["x_dim"]
        else:
            x_dim = 12
        return x_dim

    @property
    def y_dim(self):
        return 1

    def load_samples(self, component_dir):
        n_bits = self.x_dim
        x = create_all_possible_n_bit_configurations(n_bits)

        # create x_dim different vectors in a 2D plane,
        # pointing to equally spaced positions on a circle around center
        # with radius 1;
        # each vector is later corresponding to one of the bits in a sample
        vectors = np.zeros((self.x_dim, 2))
        for i in range(n_bits):
            vectors[-i, 0] = np.sin(i * 2 * np.pi / n_bits)  # dimension 1
            vectors[-i, 1] = np.cos(i * 2 * np.pi / n_bits)  # dimension 2

        # calculates length of the vectors corresponding to 1 sample
        d1_comp = np.sum(np.multiply(x, vectors[:, 0]), axis=1)
        d2_comp = np.sum(np.multiply(x, vectors[:, 1]), axis=1)
        lengths = np.sqrt(np.square(d1_comp) + np.square(d2_comp))
        # possibly one could add noise to the lengths vector here

        # set up threshold rule: everything inside the circle is class 0,
        # everything outside is class 1
        thresh = np.median(lengths)
        y_bool = lengths > thresh
        y = np.array(y_bool * 1)[:, np.newaxis]
        return torch.tensor(x, dtype=torch.float), torch.tensor(y,dtype=torch.float)


class FakeTask(Task):
    @property
    def finite(self):
        return True

    @property
    def task_id(self):
        return "fake_dat"

    @property
    def x_limits(self):
        return "binary"

    @property
    def y_limits(self):
        return "binary"

    @property
    def x_dim(self):
        if "x_dim" in self._kwargs:
            x_dim = self._kwargs["x_dim"]
        else:
            x_dim = 12
        return x_dim

    @property
    def y_dim(self):
        return 1

    def load_samples(self, component_dir):
        n_bits = self.x_dim
        x = create_all_possible_n_bit_configurations(n_bits)

        # effectively setting y with x_0=0 to 1
        y = np.zeros(x.shape[0], dtype=np.int)
        y[int(x.shape[0] / 2) :] = y[int(x.shape[0] / 2) :] + 1
        y = y[:, np.newaxis]
        return torch.tensor(x, dtype=torch.float), torch.tensor(y,dtype=torch.float)


class Mnist1DTask(Task):
    @property
    def finite(self):
        return True

    @property
    def task_id(self):
        return "mnist_1d_dat"

    @property
    def x_limits(self):
        return (0, 1)

    @property
    def y_limits(self):
        return "binary"

    @property
    def x_dim(self):
        return 784

    @property
    def y_dim(self):
        return 10


    def load_samples(self, component_dir):
        mnist = torchvision.datasets.MNIST(root="../", download=True, train=True)
        mnist_test = torchvision.datasets.MNIST(root="../", download=True, train=False)
        qmnist_test = torchvision.datasets.QMNIST(root="../", what='test50k', download=True, train=False)
        
        x = torch.cat([mnist.data, mnist_test.data, qmnist_test.data]).reshape(-1, 784) / 256.
        y = torch.cat([mnist.targets, mnist_test.targets, qmnist_test.targets[:,0]])

        return x.type(torch.float32), y.type(torch.long)


class MnistBinaryTask(Task):
    @property
    def finite(self):
        return True

    @property
    def task_id(self):
        return "mnist_binary_dat"

    @property
    def x_limits(self):
        return (0, 1)

    @property
    def y_limits(self):
        return "binary"

    @property
    def x_dim(self):
        return 784

    @property
    def y_dim(self):
        return 4

    def load_samples(self, component_dir):
        mnist = torchvision.datasets.MNIST(
            root="../", download=True, train=True)
        mnist_test = torchvision.datasets.MNIST(
            root="../", download=True, train=False)
        qmnist_test = torchvision.datasets.QMNIST(
            root="../", what='test50k', download=True, train=False)

        x = torch.cat([mnist.data, mnist_test.data,
                       qmnist_test.data]).reshape(-1, 784) / 256.
        y = torch.cat([mnist.targets, mnist_test.targets,
                       qmnist_test.targets[:, 0]])

        y_binary = self.binary(y, 4)

        return x.type(torch.float32), y_binary.type(torch.float32)

    def binary(self, x, bits):
        mask = 2**torch.arange(bits).to(x.device, x.dtype)
        return x.unsqueeze(-1).bitwise_and(mask).ne(0).byte()


class FashionMnistTask(Task):
    @property
    def finite(self):
        return True

    @property
    def task_id(self):
        return "fashion_mnist_1d_dat"

    @property
    def x_limits(self):
        return (0, 1)

    @property
    def y_limits(self):
        return "binary"

    @property
    def x_dim(self):
        return 784

    @property
    def y_dim(self):
        return 10

    def load_samples(self, component_dir):
        mnist = torchvision.datasets.FashionMNIST(
            root="../", download=True, train=True)
        mnist_test = torchvision.datasets.FashionMNIST(
            root="../", download=True, train=False)

        x = torch.cat([mnist.data, mnist_test.data]).reshape(-1, 784) / 256.
        y = torch.cat([mnist.targets, mnist_test.targets])

        return x.type(torch.float32), y.type(torch.long)


class CombinedMnistTask(Task):
    @property
    def finite(self):
        return True

    @property
    def task_id(self):
        return "combined_mnist_1d_dat"

    @property
    def x_limits(self):
        return (0, 1)

    @property
    def y_limits(self):
        return "binary"

    @property
    def x_dim(self):
        return 784

    @property
    def y_dim(self):
        return 20

    def load_samples(self, component_dir):
        mnist = torchvision.datasets.MNIST(
            root="../", download=True, train=True)
        mnist_test = torchvision.datasets.MNIST(
            root="../", download=True, train=False)
        
        fmnist = torchvision.datasets.FashionMNIST(
            root="../", download=True, train=True)
        fmnist_test = torchvision.datasets.FashionMNIST(
            root="../", download=True, train=False)


        x = torch.cat([mnist.data, fmnist.data, mnist_test.data,
                       fmnist_test.data]).reshape(-1, 784) / 256.
        y = torch.cat([mnist.targets, fmnist.targets + 10, mnist_test.targets,
                       fmnist_test.targets + 10])

        return x.type(torch.float32), y.type(torch.long)

class CombinedMnistBinaryTask(Task):
    @property
    def finite(self):
        return True

    @property
    def task_id(self):
        return "combined_mnist_binary_dat"

    @property
    def x_limits(self):
        return (0, 1)

    @property
    def y_limits(self):
        return "binary"

    @property
    def x_dim(self):
        return 784

    @property
    def y_dim(self):
        return 5

    def load_samples(self, component_dir):
        mnist = torchvision.datasets.MNIST(
            root="../", download=True, train=True)
        mnist_test = torchvision.datasets.MNIST(
            root="../", download=True, train=False)
        
        fmnist = torchvision.datasets.FashionMNIST(
            root="../", download=True, train=True)
        fmnist_test = torchvision.datasets.FashionMNIST(
            root="../", download=True, train=False)

        x = torch.cat([mnist.data, fmnist.data, mnist_test.data,
                       fmnist_test.data]).reshape(-1, 784) / 256.
        y = torch.cat([mnist.targets, fmnist.targets + 10, mnist_test.targets,
                       fmnist_test.targets + 10])

        y_binary = binary_encode_label(y, 5)

        return x.type(torch.float32), y_binary.type(torch.float32)

class CombinedMnistQuaternaryTask(Task):
    @property
    def finite(self):
        return True

    @property
    def task_id(self):
        return "combined_mnist_quaternary_dat"

    @property
    def x_limits(self):
        return (0, 1)

    @property
    def y_limits(self):
        return "binary"

    @property
    def x_dim(self):
        return 784

    @property
    def y_dim(self):
        return 3

    def load_samples(self, component_dir):
        mnist = torchvision.datasets.MNIST(
            root="../", download=True, train=True)
        mnist_test = torchvision.datasets.MNIST(
            root="../", download=True, train=False)
        
        fmnist = torchvision.datasets.FashionMNIST(
            root="../", download=True, train=True)
        fmnist_test = torchvision.datasets.FashionMNIST(
            root="../", download=True, train=False)

        x = torch.cat([mnist.data, fmnist.data, mnist_test.data,
                       fmnist_test.data]).reshape(-1, 784) / 256.
        y = torch.cat([mnist.targets, fmnist.targets + 10, mnist_test.targets,
                       fmnist_test.targets + 10])

        y_binary = quaternary_encode_label(y, 3)

        return x.type(torch.float32), y_binary.type(torch.float32)

class CombinedMnistOctalTask(Task):
    @property
    def finite(self):
        return True

    @property
    def task_id(self):
        return "combined_mnist_octal_dat"

    @property
    def x_limits(self):
        return (0, 1)

    @property
    def y_limits(self):
        return "binary"

    @property
    def x_dim(self):
        return 784

    @property
    def y_dim(self):
        return 2

    def load_samples(self, component_dir):
        mnist = torchvision.datasets.MNIST(
            root="../", download=True, train=True)
        mnist_test = torchvision.datasets.MNIST(
            root="../", download=True, train=False)
        
        fmnist = torchvision.datasets.FashionMNIST(
            root="../", download=True, train=True)
        fmnist_test = torchvision.datasets.FashionMNIST(
            root="../", download=True, train=False)

        x = torch.cat([mnist.data, fmnist.data, mnist_test.data,
                       fmnist_test.data]).reshape(-1, 784) / 256.
        y = torch.cat([mnist.targets, fmnist.targets + 10, mnist_test.targets,
                       fmnist_test.targets + 10])

        y_binary = octal_encode_label(y, 2)

        return x.type(torch.float32), y_binary.type(torch.float32)
class MnistEightBinaryTask(Task):
    """Mnist task but only with digits from 0-7
    
    len train: 48200
    len test: 48275
    
    """

    @property
    def finite(self):
        return True

    @property
    def task_id(self):
        return "mnist8_binary_dat"

    @property
    def x_limits(self):
        return (0, 1)

    @property
    def y_limits(self):
        return "binary"

    @property
    def x_dim(self):
        return 784

    @property
    def y_dim(self):
        return 3

    def load_samples(self, component_dir):
        mnist = torchvision.datasets.MNIST(
            root="../", download=True, train=True)
        mnist_test = torchvision.datasets.MNIST(
            root="../", download=True, train=False)
        qmnist_test = torchvision.datasets.QMNIST(
            root="../", what='test50k', download=True, train=False)

        x = torch.cat([mnist.data, mnist_test.data,
                       qmnist_test.data]).reshape(-1, 784) / 256.
        y = torch.cat([mnist.targets, mnist_test.targets,
                       qmnist_test.targets[:, 0]])

        # Filter out digits 8 and 9
        filter_mask = (y != 8) & (y != 9)
        y = y[filter_mask]
        x = x[filter_mask]

        y_binary = self.binary(y, 3)

        return x.type(torch.float32), y_binary.type(torch.float32)

    def binary(self, x, bits):
        mask = 2**torch.arange(bits).to(x.device, x.dtype)
        return x.unsqueeze(-1).bitwise_and(mask).ne(0).byte()

class EMnist1DTask(Task):
    @property
    def finite(self):
        return True

    @property
    def task_id(self):
        return "emnist_1d_dat"

    @property
    def x_limits(self):
        return (0, 1)

    @property
    def y_limits(self):
        return "binary"

    @property
    def x_dim(self):
        return 784

    @property
    def y_dim(self):
        return 10


    def load_samples(self, component_dir):
        emnist = torchvision.datasets.EMNIST(root=".../", split='digits', download=True, train=True)
        emnist_test = torchvision.datasets.EMNIST(root="../", split='digits', download=True, train=False)
        
        x = torch.cat([torch.transpose(emnist.data, 1, 2), torch.transpose(emnist_test.data, 1, 2)]).reshape(-1, 784) / 256
        y = torch.cat([emnist.targets, emnist_test.targets])

        return x.type(torch.float32), y.type(torch.long)

class ReducedMnistTask(Task):

    @property
    def finite(self):
        return True

    @property
    def task_id(self):
        return "mnist_reduced_dat"

    @property
    def x_limits(self):
        return (0, 1)

    @property
    def y_limits(self):
        return "binary"

    @property
    def x_dim(self):
        return 784

    @property
    def y_dim(self):
        return 10


    def load_samples(self, component_dir):
        mnist = torchvision.datasets.MNIST(root="../", download=True, train=True)
        mnist_test = torchvision.datasets.MNIST(root="../", download=True, train=False)
        
        x = torch.cat([mnist.train_data, mnist_test.test_data]).reshape(-1, 784)
        y = torch.cat([mnist.train_labels, mnist_test.test_labels])

        # Apply filter
        f = (y == 3) | (y == 6) | (y == 8) | (y == 9)
        x = x[f]
        y = y[f]

        return x.type(torch.float32), y.type(torch.long)

class CIFAR10Task(Task):
    @property
    def finite(self):
        return True

    @property
    def task_id(self):
        return "cifar10_1d_dat"

    @property
    def x_limits(self):
        return (0, 1)

    @property
    def y_limits(self):
        return "binary"

    @property
    def x_dim(self):
        return 32 * 32 * 3

    @property
    def y_dim(self):
        return 10

    def load_samples(self, component_dir):
        
        cifar10_train = torchvision.datasets.CIFAR10(root="../", download=True, train=True)
        cifar10_test = torchvision.datasets.CIFAR10(root="../", download=True, train=False)
        x_train = torch.from_numpy(cifar10_train.data) / 1.
        x_test = torch.from_numpy(cifar10_test.data) / 1.
        x = torch.cat([x_train, x_test])
        x = x - torch.mean(x_train, axis=[0,1,2])
        x = x / torch.std(x_train, axis=[0,1,2])
        
        x = torch.transpose(x, 1, 3)

        y = torch.cat([torch.tensor(cifar10_train.targets), torch.tensor(cifar10_test.targets)])

        return x.type(torch.float32), y.type(torch.long)

def create_all_possible_n_bit_configurations(n_bits):
    n_samples = 2 ** n_bits
    # create all integer values
    x_int = np.linspace(0, n_samples - 1, n_samples, endpoint=True, dtype=np.uint32)[
        :, np.newaxis
    ]
    # unpack integer values into bits
    x_bit = np.unpackbits(np.flip(x_int.view("uint8")), axis=1)
    # cut bits to x_dim dimensions
    return x_bit[:, x_bit.shape[1] - n_bits :]
