from torchvision import datasets, transforms
import numpy as np
from torch.utils.data import Dataset, DataLoader, Subset
import torchvision.transforms as transforms

def load_dataset(dataset_name, size=(32, 32)): # 'mnist' and 'cifar10' are processed to support two workers
    if   dataset_name == 'mnist':
        transform = transforms.Compose(
            [transforms.Resize(size),
             transforms.ToTensor(),
             transforms.Normalize((0.1307,), (0.3081,))])
        train_data = datasets.MNIST(root='data', train=True, download=True,
                                    transform=transform)
        idx_train = (train_data.targets==0) | (train_data.targets==1)
        train_data.data = train_data.data[idx_train]
        train_data.targets = train_data.targets[idx_train]

        test_data = datasets.MNIST(root='data', train=False, download=True,
                                   transform=transform)
        idx_test = (test_data.targets==0) | (test_data.targets==1)
        test_data.data = test_data.data[idx_test]
        test_data.targets = test_data.targets[idx_test]

        num_classes = 2
        num_input_channels = 1

    elif dataset_name == 'cifar10':
        transform = transforms.Compose(
            [transforms.Resize(size),
             transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        train_data = datasets.CIFAR10(root='data', train=True,
                                      download=True, transform=transform)
        test_data = datasets.CIFAR10(root='data', train=False,
                                     download=True, transform=transform)
        
        #classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
        idx_train = np.where((np.array(train_data.targets) == 0) | (np.array(train_data.targets) == 1))[0]
        idx_test = np.where((np.array(test_data.targets) == 0) | (np.array(test_data.targets) == 1))[0]

        train_data = Subset(train_data, idx_train)
        test_data = Subset(test_data, idx_test)

        num_classes = 2
        num_input_channels = 3

    elif dataset_name == 'cifar100':
        transform = transforms.Compose(
            [transforms.Resize(size),
             transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        train_data = datasets.CIFAR100(root='data', train=True,
                                       download=True, transform=transform)
        test_data = datasets.CIFAR100(root='data', train=False,
                                      download=True, transform=transform)
        num_classes = 100
        num_input_channels = 3
    elif dataset_name == 'svhn':
        transform = transforms.Compose(
            [transforms.Resize(size),
             transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

        target_transform = lambda target: int(target[0]) - 1
        train_data = datasets.SVHN(root='data', split='train', download=True,
                                   transform=transform, target_transform=target_transform)
        test_data = datasets.SVHN(root='data', split='test', download=True,
                                  transform=transform, target_transform=target_transform)
        num_classes = 10
        num_input_channels = 3
    else:
        assert False
    return train_data, test_data, num_classes, num_input_channels
