# The code mainly borrow from  https://github.com/sthalles/SimCLR
import numpy as np
import torch
from torchvision import transforms, datasets
np.random.seed(0)

class ContrastiveLearningViewGenerator(object):
    """Take two random crops of one image as the query and key."""

    def __init__(self, base_transform, n_views=2):
        self.base_transform = base_transform
        self.n_views = n_views

    def __call__(self, x):
        return [self.base_transform(x) for i in range(self.n_views)]

class ContrastiveDataset:
    def __init__(self, root, transforms, target_transforms=None, classes=None, ann_file=None, test_mode=False):
        self.root = root

    @staticmethod
    def get_simclr_pipeline_transform(size, s=1):
        """Return a set of data augmentation transformations as described in the SimCLR paper."""
        
        if isinstance(size, list):
            kernel_size = [int(0.1 * s_) // 2 * 2 + 1 for s_ in size]
        else:
            kernel_size = int(0.1 * size) // 2 * 2 + 1
        color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
        data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=size),
                                              transforms.RandomHorizontalFlip(),
                                              transforms.RandomVerticalFlip(),
                                              transforms.RandomApply([color_jitter], p=0.8),
                                              transforms.RandomGrayscale(p=0.2),
                                              transforms.GaussianBlur(kernel_size=kernel_size),
                                              transforms.ToTensor()])
        return data_transforms
    @staticmethod
    def get_simclr_audio_pipeline_transform(size, s=1):
        """Return a set of data augmentation transformations as described in the SimCLR paper."""
        
        data_transforms = transforms.Compose([transforms.ToTensor(),
                                              transforms.RandomCrop(size=size),
                                              transforms.RandomHorizontalFlip(),
                                              transforms.RandomVerticalFlip(),
                                              ])
        return data_transforms
    

def get_contrastive_dataset(name, *args, **kwargs):
    assert name in ['cifar10', 'cifar100', 'imagenet', 'tuberlin',
                    'sketchy_pair', 'sketchy_photo', 'sketchy_sketch',
                    'nyu_depth_pair', 'nyu_depth_photo', 'nyu_depth_depth',
                    'vggsound_pair', 'vggsound_video', 'vggsound_audio',
                    ]
    # print(name, args, kwargs)
    size = kwargs.pop('size')
    if name == 'cifar10':
        from .cifar import CIFAR10
        dataset = CIFAR10(*args, **kwargs)
        dataset.transforms = ContrastiveLearningViewGenerator(ContrastiveDataset.get_simclr_pipeline_transform(size))
    elif name == 'cifar100':
        from .cifar import CIFAR100
        dataset = CIFAR100(*args, **kwargs)
        dataset.transforms = ContrastiveLearningViewGenerator(ContrastiveDataset.get_simclr_pipeline_transform(size))
    elif name == 'tuberlin':
        from .tuberlin import TUBerlin
        dataset = TUBerlin(*args, **kwargs)
        dataset.transforms = ContrastiveLearningViewGenerator(ContrastiveDataset.get_simclr_pipeline_transform(size))
    elif name == 'imagenet':
        from .imagenet import ImageNet
        dataset = ImageNet(*args, **kwargs)
        dataset.transforms = ContrastiveLearningViewGenerator(ContrastiveDataset.get_simclr_pipeline_transform(size))
    elif name == 'sketchy_pair':
        from .sketchy import SketchyPair
        dataset = SketchyPair(*args, **kwargs)
        dataset.photo_transforms = ContrastiveLearningViewGenerator(ContrastiveDataset.get_simclr_pipeline_transform(size))
        dataset.sketch_transforms = ContrastiveLearningViewGenerator(ContrastiveDataset.get_simclr_pipeline_transform(size))
    elif name == 'sketchy_photo':
        from .sketchy import SketchyPhoto
        dataset = SketchyPhoto(*args, **kwargs)
        dataset.photo_transforms = ContrastiveLearningViewGenerator(ContrastiveDataset.get_simclr_pipeline_transform(size))
    elif name == 'sketchy_sketch':
        from .sketchy import SketchySketch
        dataset = SketchySketch(*args, **kwargs)
        dataset.sketch_transforms = ContrastiveLearningViewGenerator(ContrastiveDataset.get_simclr_pipeline_transform(size))
    elif name == 'vggsound_pair':
        from .vggsound import VGGSoundPair
        dataset = VGGSoundPair(*args, **kwargs)
        dataset.video_transforms = ContrastiveLearningViewGenerator(ContrastiveDataset.get_simclr_pipeline_transform(size))
        dataset.audio_transforms = ContrastiveLearningViewGenerator(ContrastiveDataset.get_simclr_audio_pipeline_transform(size=(128,512)))
    elif name == 'vggsound_video':
        from .vggsound import VGGSoundVideo
        dataset = VGGSoundVideo(*args, **kwargs)
        dataset.video_transforms = ContrastiveLearningViewGenerator(ContrastiveDataset.get_simclr_pipeline_transform(size))
    elif name == 'vggsound_audio':
        from .vggsound import VGGSoundAudio
        dataset = VGGSoundAudio(*args, **kwargs)
        dataset.audio_transforms = ContrastiveLearningViewGenerator(ContrastiveDataset.get_simclr_audio_pipeline_transform(size=(128,512)))
    elif name == 'nyu_depth_pair':
        from .nyu_depth_v2 import NYUDepthV2Pair
        dataset = NYUDepthV2Pair(*args, **kwargs)
        dataset.photo_transforms = ContrastiveLearningViewGenerator(ContrastiveDataset.get_simclr_pipeline_transform(size))
        dataset.depth_transforms = ContrastiveLearningViewGenerator(ContrastiveDataset.get_simclr_pipeline_transform(size))
    elif name == 'nyu_depth_photo':
        from .nyu_depth_v2 import NYUDepthV2Photo
        dataset = NYUDepthV2Photo(*args, **kwargs)
        dataset.photo_transforms = ContrastiveLearningViewGenerator(ContrastiveDataset.get_simclr_pipeline_transform(size))
    elif name == 'nyu_depth_depth':
        from .nyu_depth_v2 import NYUDepthV2Depth
        dataset = NYUDepthV2Depth(*args, **kwargs)
        dataset.depth_transforms = ContrastiveLearningViewGenerator(ContrastiveDataset.get_simclr_pipeline_transform(size))
    dataset.contrastive = True
    return dataset
