"""
    Dataset loader, to load the clean and poisoning datasets
"""

# torch ...
import torch
from torch.utils.data import Dataset
from torchvision import datasets, transforms


# ------------------------------------------------------------------------------
#   Dataset class (to load the denoised data)
# ------------------------------------------------------------------------------
class DenoisedDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.pdata     = data
        self.plabels   = labels
        self.transform = transform
        print ('DenoisedDataset: read the data - {}'.format(self.pdata.shape))


    def __len__(self):
        # return the number of instances in a dataset
        return len(self.pdata)


    def __getitem__(self, idx):
        # return (data, label) where label is index of the label class
        data, label = self.pdata[idx], self.plabels[idx]

        # transform into the Pytorch format
        if self.transform:
            data = self.transform(data)
        return (data, label)


# ------------------------------------------------------------------------------
#   Dataset loader
# ------------------------------------------------------------------------------
def load_dataset(dataset, nbatch, normalize, augment, kwargs):
    
    # CIFAR10
    if 'cifar10' == dataset:
        # : load the raw cifar10
        trainset, validset = _load_cifar10(normalize, augment)

        # : make loaders
        train_loader = torch.utils.data.DataLoader(trainset, \
                batch_size=nbatch, shuffle=True, **kwargs)

        valid_loader = torch.utils.data.DataLoader(validset, \
                batch_size=nbatch, shuffle=False, **kwargs)
    
    # CIFAR100
    elif 'cifar100' == dataset:
        # : load the raw cifar10
        trainset, validset = _load_cifar100(normalize, augment)

        # : make loaders
        train_loader = torch.utils.data.DataLoader(trainset, \
                batch_size=nbatch, shuffle=True, **kwargs)

        valid_loader = torch.utils.data.DataLoader(validset, \
                batch_size=nbatch, shuffle=False, **kwargs)

    # Undefined dataset
    else:
        assert False, ('Error: invalid dataset name [{}]'.format(dataset))

    return train_loader, valid_loader


def load_denoised_dataset(dataset, nbatch, normalize, augment, datpath, kwargs):

    # CIFAR10
    if 'cifar10' == dataset:
        # : load the denoised data
        trainset, validset = _load_denoised_cifar(normalize, augment, datpath)

        # : make loaders
        train_loader = torch.utils.data.DataLoader(trainset, \
                batch_size=nbatch, shuffle=True, **kwargs)

        valid_loader = torch.utils.data.DataLoader(validset, \
                batch_size=nbatch, shuffle=False, **kwargs)

    # Undefined dataset
    else:
        assert False, ('Error: invalid dataset name [{}]'.format(dataset))

    return train_loader, valid_loader


# ------------------------------------------------------------------------------
#   Raw dataset loaders
# ------------------------------------------------------------------------------
def _load_cifar10(normalize=True, augment=True):
    # compose the transformation
    transform_train = []
    transform_valid = []

    # augmentation
    if augment:
        transform_train += [transforms.RandomCrop(32, padding=4),
                            transforms.RandomHorizontalFlip()]

    transform_train += [transforms.ToTensor()]
    transform_valid += [transforms.ToTensor()]

    # normalization
    if normalize:
        transform_train += [transforms.Normalize((0.4914, 0.4822, 0.4465),
                                                 (0.2023, 0.1994, 0.2010))]
        transform_valid += [transforms.Normalize((0.4914, 0.4822, 0.4465),
                                                 (0.2023, 0.1994, 0.2010))]

    # compose the dataset
    trainset = datasets.CIFAR10(root='datasets/originals/cifar10',
                                train=True, download=True,
                                transform=transforms.Compose(transform_train))
    validset = datasets.CIFAR10(root='datasets/originals/cifar10',
                                train=False, download=True,
                                transform=transforms.Compose(transform_valid))
    return trainset, validset


def _load_cifar100(normalize=True, augment=True):
    # compose the transformation
    transform_train = []
    transform_valid = []

    # augmentation
    if augment:
        transform_train += [transforms.RandomCrop(32, padding=4),
                            transforms.RandomHorizontalFlip()]

    transform_train += [transforms.ToTensor()]
    transform_valid += [transforms.ToTensor()]

    # normalization
    if normalize:
        transform_train += [transforms.Normalize((0.5070, 0.4865, 0.4409),
                                                 (0.2673, 0.2564, 0.2762))]
        transform_valid += [transforms.Normalize((0.5070, 0.4865, 0.4409),
                                                 (0.2673, 0.2564, 0.2762))]

    # compose the dataset
    trainset = datasets.CIFAR100(root='datasets/originals/cifar100',
                                 train=True, download=True,
                                 transform=transforms.Compose(transform_train))
    validset = datasets.CIFAR100(root='datasets/originals/cifar100',
                                 train=False, download=True,
                                 transform=transforms.Compose(transform_valid))
    return trainset, validset


def _load_denoised_cifar(normalize=True, augment=True, datpath=None):
    # compose the transformation
    transform_train = []
    transform_valid = []

    # augmentation
    if augment:
        transform_train += [transforms.RandomCrop(32, padding=4),
                            transforms.RandomHorizontalFlip()]

    # normalization
    if normalize:
        transform_train += [transforms.Normalize((0.4914, 0.4822, 0.4465),
                                                 (0.2023, 0.1994, 0.2010))]
        transform_valid += [transforms.Normalize((0.4914, 0.4822, 0.4465),
                                                 (0.2023, 0.1994, 0.2010))]

    # extract the data
    custom_dataset = torch.load(datpath)

    # compose the dataset
    trainset = DenoisedDataset(custom_dataset['train_data'],
                               custom_dataset['train_labels'],
                               transform=transforms.Compose(transform_train))
    validset = DenoisedDataset(custom_dataset['test_data'],
                               custom_dataset['test_labels'],
                               transform=transforms.Compose(transform_valid))
    return trainset, validset