import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader


def mnist_get_datasets(data_dir):
    train_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    train_dataset = datasets.MNIST(root=data_dir, train=True,
                                   download=True, transform=train_transform)

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    test_dataset = datasets.MNIST(root=data_dir, train=False,
                                  transform=test_transform)

    return train_dataset, test_dataset


def cifar10_get_datasets(data_dir, use_data_aug=False, cuda=False):
    if use_data_aug:
        train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),
                                              transforms.RandomHorizontalFlip(),
                                              transforms.ToTensor(),
                                              transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
    else:
        train_transform = transforms.Compose([transforms.ToTensor(),
                                              transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

    train_dataset = datasets.CIFAR10(root=data_dir, train=True,
                                     download=True, transform=train_transform)

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    test_dataset = datasets.CIFAR10(root=data_dir, train=False,
                                    download=True, transform=test_transform)

    return train_dataset, test_dataset


def get_unaugmented_cifar(data_dir):
    return datasets.CIFAR10(root=data_dir, train=True)


class AbstractDataset(torch.utils.data.Dataset):
    def __init__(self, input_data):
        super(AbstractDataset, self).__init__()
        self.data = input_data

    def __getitem__(self, item):
        return self.data[0][item], self.data[1][item]

    def __len__(self):
        return len(self.data[0])


def fetch_ordered_dataset(dset, class_indices):
    length = len(dset)
    dummy_loader = DataLoader(dset, batch_size=length, shuffle=False)
    for dummy_batch, dummy_labels in dummy_loader:
        inputs = dummy_batch
        labels = dummy_labels
    output = torch.Tensor([])
    label_output = torch.Tensor([])
    for idx in class_indices:
        output = torch.cat((output, inputs[labels == class_indices[idx]]), dim=0)
        label_output = torch.cat((label_output, labels[labels == class_indices[idx]]), dim=0)
    return output, label_output


class OrderedCifar10(AbstractDataset):
    def __init__(self, class_indices):
        dset, _ = cifar10_get_datasets(data_dir='../data')
        input_data = fetch_ordered_dataset(dset, class_indices)
        super(OrderedCifar10, self).__init__(input_data)


class OrderedCifar10Test(AbstractDataset):
    def __init__(self, class_indices):
        _, dset = cifar10_get_datasets(data_dir='../data')
        input_data = fetch_ordered_dataset(dset, class_indices)
        super(OrderedCifar10Test, self).__init__(input_data)
