import torch
import torchvision
from torchvision import datasets, transforms


def get_cifar10(batch_size, datapath):
    transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
    ])

    trainset = torchvision.datasets.CIFAR10(root=datapath, train=True,
                                            download=True, transform=transform_train)
    testset = torchvision.datasets.CIFAR10(root=datapath, train=False,
                                        download=True, transform=transform_test)

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,
                                            shuffle=True, num_workers=2)

    testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size,
                                            shuffle=False, num_workers=2)
    return trainloader, testloader



def get_cifar100(batch_size, datapath):
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.50707516, 0.48654887, 0.44091784), (0.26733429, 0.25643846, 0.27615047))])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.50707516, 0.48654887, 0.44091784), (0.26733429, 0.25643846, 0.27615047)),
    ])

    trainset = torchvision.datasets.CIFAR100(root=datapath, train=True,
                                            download=True, transform=transform_train)
    testset = torchvision.datasets.CIFAR100(root=datapath, train=False,
                                            download=True, transform=transform_test)

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,
                                            shuffle=True, num_workers=2)

    testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size,
                                            shuffle=False, num_workers=2)
    return trainloader, testloader



def get_svhn(batch_size, datapath):
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize((0.4378, 0.4439, 0.4729), (0.1980, 0.2011, 0.1971))])

    transform_test = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.4378, 0.4439, 0.4729), (0.1980, 0.2011, 0.1971))])

    trainset = torchvision.datasets.SVHN(datapath + '/svhn', split='train', download=True,
                                        transform=transform_train)
    testset = torchvision.datasets.SVHN(datapath + '/svhn', split='test', download=True,
                                                transform=transform_test)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,
                                            shuffle=True, num_workers=2)

    testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size,
                                            shuffle=False, num_workers=2)

    return trainloader, testloader



def get_stl10(batch_size, datapath):
    transform_train = transforms.Compose([
                        transforms.RandomCrop(96, padding=4),
                        transforms.RandomHorizontalFlip(),
                        transforms.ToTensor(),
                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])

    transform_test = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    trainset = torchvision.datasets.STL10(root=datapath, split='train', download=True,
                                        transform=transform_train)
    testset = torchvision.datasets.STL10(root=datapath, split='test', download=True,
                                                transform=transform_test)

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,
                                            shuffle=True, num_workers=2)

    testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size,
                                            shuffle=False, num_workers=2)
    return trainloader, testloader