import torch
import torchvision
from torchvision import transforms
import numpy as np





###### transformations

class FlattenImage(torch.nn.Module):
    def __init__(self):
        '''
        Allows you to flatten an input to 
        1D. This is useful in pytorch
        transforms when loading data.
        
        '''
        super(FlattenImage, self).__init__()

    def forward(self, x):
        return x.reshape(-1)






###### datasets

def get_mnist(path, return_targets=False):
    '''
    Function to get the mnist data from pytorch with
    some transformations first.

    The returned MNIST data will be flattened.

    Arguments
    ---------

    - ```path```: ```str```:
        The path that the data is located or will be saved.
        This should be a directory containing ````MNIST```.
    
    - ```return_targets```: ```bool```, (optional):
        Whether to return the targets along with the 
        datasets.
        Defaults to ```False```.

    Returns
    ---------

        - ```train_mnist```: ```torch.utils.data.Dataset```

        - ```test_mnist```: ```torch.utils.data.Dataset```

        - If ```return_targets=True```:
            - ```train_mnist_targets```: ```torch.tensor```
            - ```test_mnist_targets```: ```torch.tensor```

    '''
    transform_images = transforms.Compose([
                            transforms.PILToTensor(),
                            transforms.ConvertImageDtype(torch.float),
                            transforms.Normalize(mean=0, std=1),
                            FlattenImage(),
                            ])

    train_mnist = torchvision.datasets.MNIST(root=path, 
                                                download=True, 
                                                train=True,
                                                transform=transform_images)

    test_mnist = torchvision.datasets.MNIST(root=path, 
                                                    download=True, 
                                                    train=False,
                                                    transform=transform_images)
    if return_targets:
        train_mnist_targets = torch.tensor(np.asarray(train_mnist.targets).astype(int))
        test_mnist_targets = torch.tensor(np.asarray(test_mnist.targets).astype(int))

        return train_mnist, test_mnist, train_mnist_targets, test_mnist_targets

    return train_mnist, test_mnist



def get_fmnist(path, return_targets=False):
    '''
    Function to get the FMNIST data from pytorch with
    some transformations first.

    The returned FMNIST data will be flattened.

    Arguments
    ---------

    - ```path```: ```str```:
        The path that the data is located or will be saved.
        This should be a directory containing ````FashionMNIST```.
    
    - ```return_targets```: ```bool```, (optional):
        Whether to return the targets along with the 
        datasets.
        Defaults to ```False```.

    Returns
    ---------

        - ```train_fmnist```: ```torch.utils.data.Dataset```

        - ```test_fmnist```: ```torch.utils.data.Dataset```

        - If ```return_targets=True```:
            - ```train_fmnist_targets```: ```torch.tensor```
            - ```test_fmnist_targets```: ```torch.tensor```

    '''
    transform_images = transforms.Compose([
                            transforms.PILToTensor(),
                            transforms.ConvertImageDtype(torch.float),
                            transforms.Normalize(mean=0, std=1),
                            FlattenImage(),
                            ])

    train_fmnist = torchvision.datasets.FashionMNIST(root=path, 
                                                download=True, 
                                                train=True,
                                                transform=transform_images)

    test_fmnist = torchvision.datasets.FashionMNIST(root=path, 
                                                    download=True, 
                                                    train=False,
                                                    transform=transform_images)
    if return_targets:
        train_fmnist_targets = torch.tensor(np.asarray(train_fmnist.targets).astype(int))
        test_fmnist_targets = torch.tensor(np.asarray(test_fmnist.targets).astype(int))

        return train_fmnist, test_fmnist, train_fmnist_targets, test_fmnist_targets

    return train_fmnist, test_fmnist



def get_cifar10(path, return_targets=False):
    '''
    Function to get the CIFAR 10 data from pytorch with
    some transformations first.

    The returned CIFAR 10 data will be flattened.

    Arguments
    ---------

    - ```path```: ```str```:
        The path that the data is located or will be saved.
        This should be a directory containing ````cifar-10-batches-py```.
    
    - ```return_targets```: ```bool```, (optional):
        Whether to return the targets along with the 
        datasets.
        Defaults to ```False```.

    Returns
    ---------

        - ```train_cifar```: ```torch.utils.data.Dataset```

        - ```test_cifar```: ```torch.utils.data.Dataset```

        - If ```return_targets=True```:
            - ```train_cifar_targets```: ```torch.tensor```
            - ```test_cifar_targets```: ```torch.tensor```

    '''
    transform_images = transforms.Compose([
                            transforms.PILToTensor(),
                            transforms.ConvertImageDtype(torch.float),
                            #transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.243, 0.261)),
                            ])

    train_cifar = torchvision.datasets.CIFAR10(root=path, 
                                                download=True, 
                                                train=True,
                                                transform=transform_images)

    test_cifar = torchvision.datasets.CIFAR10(root=path, 
                                                    download=True, 
                                                    train=False,
                                                    transform=transform_images)
    if return_targets:
        train_cifar_targets = torch.tensor(np.asarray(train_cifar.targets).astype(int))
        test_cifar_targets = torch.tensor(np.asarray(test_cifar.targets).astype(int))

        return train_cifar, test_cifar, train_cifar_targets, test_cifar_targets

    return train_cifar, test_cifar


def get_cifar100(path, return_targets=False):
    '''
    Function to get the CIFAR 100 data from pytorch with
    some transformations first.

    The returned CIFAR 100 data will be flattened.

    Arguments
    ---------

    - ```path```: ```str```:
        The path that the data is located or will be saved.
        This should be a directory containing ````cifar-100-python```.
    
    - ```return_targets```: ```bool```, (optional):
        Whether to return the targets along with the 
        datasets.
        Defaults to ```False```.

    Returns
    ---------

        - ```train_cifar```: ```torch.utils.data.Dataset```

        - ```test_cifar```: ```torch.utils.data.Dataset```

        - If ```return_targets=True```:
            - ```train_cifar_targets```: ```torch.tensor```
            - ```test_cifar_targets```: ```torch.tensor```

    '''
    transform_images = transforms.Compose([
                            transforms.PILToTensor(),
                            transforms.ConvertImageDtype(torch.float),
                            #transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.243, 0.261)),
                            ])

    train_cifar = torchvision.datasets.CIFAR100(root=path, 
                                                download=True, 
                                                train=True,
                                                transform=transform_images)

    test_cifar = torchvision.datasets.CIFAR100(root=path, 
                                                    download=True, 
                                                    train=False,
                                                    transform=transform_images)

    if return_targets:
        train_cifar_targets = torch.tensor(np.asarray(train_cifar.targets).astype(int))
        test_cifar_targets = torch.tensor(np.asarray(test_cifar.targets).astype(int))

        return train_cifar, test_cifar, train_cifar_targets, test_cifar_targets

    return train_cifar, test_cifar



