import torch
from torchvision import transforms

# Preprocessing Cifar-10 according to ProxQuant paper; _cifar10_stats differ though

_cifar10_stats = {'mean': [0.4914, 0.4822, 0.4465], 'std': [0.247, 0.243, 0.261]}


def pad_random_crop_cifar10(input_size=32, scale_size=40):
    padding = int((scale_size - input_size) / 2)
    return transforms.Compose([transforms.RandomCrop(input_size, padding=padding),
                               transforms.RandomHorizontalFlip(),
                               transforms.ToTensor(),
                               transforms.Normalize(**_cifar10_stats)])


def scale_crop_cifar10(input_size=32, scale_size=32):
    t_list = [transforms.CenterCrop(input_size),
              transforms.ToTensor(),
              transforms.Normalize(**_cifar10_stats)]
    
    if scale_size != input_size:
        t_list = [transforms.Scale(scale_size)] + t_list

    return transforms.Compose(t_list)

train_transform_imagenet = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

val_transform_imagenet = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
