from torchvision.transforms import transforms

# Standard CIFAR-10 mean and std for normalization
cifar10_mean = [0.4914, 0.4822, 0.4465]
cifar10_std  = [0.2023, 0.1994, 0.2010]


# Standard CIFAR-10 transforms on 32x32 images
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(cifar10_mean, cifar10_std),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(cifar10_mean, cifar10_std),
])

# Standard CIFAR-10 transforms on 32x32 images
transform_train_224x224 = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomCrop(224, padding=28),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(cifar10_mean, cifar10_std),
])

transform_test_224x224 = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(cifar10_mean, cifar10_std),
])