import torch
import torchvision.transforms as transforms

__all__ = [
    'TwoCropTransform',
]


# Supervised Contrastive Learning: 
# (paper) https://arxiv.org/pdf/2004.11362.pdf
# (official code) https://github.com/HobbitLong/SupContrast/
class TwoCropTransform:
    """ Create two crops of the same image """
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, x):
        return [self.transform(x), self.transform(x)]

