import torch as to
import torchvision as tv
from torch.utils.data import Dataset
from torch.nn.functional import one_hot

class DatasetFromList(Dataset):
    """Generates a dataset from a list"""

    def __init__(self, list):
        """
        Arguments:
            list (list): List of length 'samples'. Each element is one sample.
            If supervised dataset, each sample is a dictionary: {'x': <input>, 'y': <output>}
            If unsupervised dataset, each sample is a data point
        """
        self.list = list

    def __len__(self):
        return len(self.list)

    def __getitem__(self, idx):
        if to.is_tensor(idx):
            idx = idx.tolist()

        sample = self.list[idx]

        return sample


def get_mnist_images_dataset(one_hot_output=True, cluster=True):

    if cluster:
        root = '/home/mila/e/ezekiel.williams/bias_learning/bias-learning'
    else:
        root = '/Users/ezekielwilliams/documents/py_code/bias-learning/datasets'
    transform = tv.transforms.Compose([
    # you can add other transformations in this list
                            tv.transforms.ToTensor(),
                            tv.transforms.Normalize((0.1307,), (0.3081,)),
                            tv.transforms.Lambda(lambda x: to.flatten(x))
                            ])
    if one_hot_output:
        target_transform=tv.transforms.Compose([
                                    tv.transforms.Lambda(lambda x:to.Tensor([x]).to(to.int64)),
                                    tv.transforms.Lambda(lambda x:one_hot(x,10)),
                                    tv.transforms.Lambda(lambda x: to.flatten(x).to(to.float))])
    else:
        target_transform=tv.transforms.Compose([
                                    tv.transforms.Lambda(lambda x:to.Tensor([x]).to(to.int64)),
                                    tv.transforms.Lambda(lambda x:to.flatten(x).to(to.float))])

    mnist_data = tv.datasets.MNIST(root, train=True, download=False, transform=transform, target_transform=target_transform)
    
    return mnist_data