from torchvision import transforms
import torch


def aug01():
    transform_train = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])

    transform_test = None
    return transform_train, transform_test


def aug02():
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])
    return transform, transform


def aug03():
    normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    #normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    transform_train = transforms.Compose([
        #transforms.ToPILImage(),
        #transforms.RandomCrop(32, padding=4),
        #transforms.RandomHorizontalFlip(),
        #transforms.ToTensor(),
        normalize,
    ])

    transform_test = transforms.Compose([
        normalize,
    ])

    return transform_train, transform_test

def aug04():
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    trn_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomCrop(112),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
    tst_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.CenterCrop(112),
        transforms.ToTensor(),
        normalize,
    ])
    return trn_transform, tst_transform

def aug05():
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    trn_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
    tst_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])
    return trn_transform, tst_transform

def aug10():
    transform = transforms.Compose([
        transforms.Resize(128),
        transforms.CenterCrop(128),
        transforms.ToTensor(),
    ])
    return transform, transform

def aug11():
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(256),
        transforms.ToTensor(),
    ])
    return transform, transform

def noise00():
    transform = transforms.Compose([
        transforms.Lambda(lambda x: x + 0.05 * torch.randn(x.size())),
    ])
    tst_transform = transforms.Lambda(lambda x: x)
    return transform, tst_transform

def noise01():
    transform = transforms.Compose([
        transforms.Lambda(lambda x: x + 0.1 * torch.randn(x.size())),
    ])
    tst_transform = transforms.Lambda(lambda x: x)
    return transform, tst_transform

def noise02():
    transform = transforms.Compose([
        transforms.Lambda(lambda x: x + 0.2 * torch.randn(x.size())),
    ])
    tst_transform = transforms.Lambda(lambda x: x)
    return transform, tst_transform

def noise03():
    transform = transforms.Compose([
        transforms.Lambda(lambda x: x + 0.5 * torch.randn(x.size())),
    ])
    tst_transform = transforms.Lambda(lambda x: x)
    return transform, tst_transform

def noise04():
    transform = transforms.Compose([
        transforms.Lambda(lambda x: x + 0.75 * torch.randn(x.size())),
    ])
    tst_transform = transforms.Lambda(lambda x: x)
    return transform, tst_transform

def noise05():
    transform = transforms.Compose([
        transforms.Lambda(lambda x: x + 1.0 * torch.randn(x.size())),
    ])
    tst_transform = transforms.Lambda(lambda x: x)
    return transform, tst_transform

def noise06():
    transform = transforms.Compose([
        transforms.Lambda(lambda x: x + 5.0 * torch.randn(x.size())),
    ])
    tst_transform = transforms.Lambda(lambda x: x)
    return transform, tst_transform

def noise07():
    transform = transforms.Compose([
        transforms.Lambda(lambda x: x + 10.0 * torch.randn(x.size())),
    ])
    tst_transform = transforms.Lambda(lambda x: x)
    return transform, tst_transform


def aug01noise00():
    trn_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x + 0.05 * torch.randn(x.size())),
    ])
    tst_transform = transforms.Lambda(lambda x: x)
    return trn_transform, tst_transform

def aug01noise01():
    trn_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x + 0.1 * torch.randn(x.size())),
    ])
    tst_transform = transforms.Lambda(lambda x: x)
    return trn_transform, tst_transform

def aug01noise02():
    trn_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x + 0.2 * torch.randn(x.size())),
    ])
    tst_transform = transforms.Lambda(lambda x: x)
    return trn_transform, tst_transform

def aug01noise03():
    trn_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x + 0.5 * torch.randn(x.size())),
    ])
    tst_transform = transforms.Lambda(lambda x: x)
    return trn_transform, tst_transform

def aug01noise04():
    trn_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x + 1.0 * torch.randn(x.size())),
    ])
    tst_transform = transforms.Lambda(lambda x: x)
    return trn_transform, tst_transform

def aug01noise05():
    trn_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x + 2.0 * torch.randn(x.size())),
    ])
    tst_transform = transforms.Lambda(lambda x: x)
    return trn_transform, tst_transform