import torchvision
import torchvision.transforms as transforms


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])


def load_data(dataset):
    # Load train and test set
    if dataset == "cifar10":
        trainset_full = torchvision.datasets.CIFAR10(root='data', train=True,
                                                download=True, transform=transform)
        testset_full = torchvision.datasets.CIFAR10(root='data', train=False,
                                            download=True, transform=transform)
    elif dataset == "SVHN":
        trainset_full = torchvision.datasets.SVHN(root='data', split='train',
                                                download=True, transform=transform)
        trainset_full.classes = list(range(10))
        testset_full = torchvision.datasets.SVHN(root='data', split='test',
                                            download=True, transform=transform)
        testset_full.classes = list(range(10))
    elif dataset == "cifar100":
        trainset_full = torchvision.datasets.CIFAR100(root='data', train=True,
                                                download=True, transform=transform)
        testset_full = torchvision.datasets.CIFAR100(root='data', train=False,
                                            download=True, transform=transform)
    else:
        raise ValueError(f"Unknown dataset: {dataset}")

    return trainset_full, testset_full
