import torchvision.transforms as transforms

# MEAN_STD_PER_DATASET = {
#     "cifar10": {'mean': [0.4914, 0.4822, 0.4465], 'std': [0.2470, 0.2435, 0.2616]},
#     "cifar100": {'mean': [], 'std': []},
#     "stl10": {'mean':[], 'std': []},
# }

CROP_SIZE_PER_DATASET = {
    'cifar10': 32,
    'cifar100': 32,
    'stl10': 96,
}


class TwoCropsTransform:
    """Take two random crops of one image as the query and key."""

    def __init__(self, base_transform):
        self.base_transform = base_transform

    def __call__(self, x):
        q = self.base_transform(x)
        k = self.base_transform(x)
        return [q, k]



def get_transforms(args):
    
    assert args.dataset in ['cifar10', 'cifar100', 'stl10'], "Only support CIFAR10, CIFAR100, and STL10"
    
    crop_size = CROP_SIZE_PER_DATASET[args.dataset]
            
    if args.use_weak_transforms:
        augmentation = [transforms.RandomCrop(crop_size, padding=4),
                        transforms.RandomHorizontalFlip(),
                        transforms.ToTensor()
                       ]
    else:
        augmentation = [transforms.RandomResizedCrop(crop_size),
                        transforms.RandomApply([transforms.ColorJitter(0.8 * args.strength, 
                                                0.8 * args.strength, 
                                                0.8 * args.strength ,
                                                0.2 * args.strength)], p=0.8),
                        transforms.RandomGrayscale(p=0.2),
                        transforms.RandomHorizontalFlip(),
                        transforms.ToTensor(),
                       ]
        
#     if args.use_normalize:
#         augmentation.append(transforms.Normalize(mean=MEAN_STD_PER_DATASET[args.dataset]['mean'], 
#                                                  std=MEAN_STD_PER_DATASET[args.dataset]['std']))
    
    print(f'Data Augmentation: {augmentation}')
    return transforms.Compose(augmentation)