import torch
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import os

root = '/home/sharefolder/data/'


std_transform = transforms.Compose([transforms.RandomResizedCrop(224),
                                    transforms.RandomHorizontalFlip(),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                         std=[0.229, 0.224, 0.225])])


class ZipDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_0, dataset_1):
        self.dataset_0 = dataset_0
        self.dataset_1 = dataset_1
    
    def __len__(self):
        return min(len(self.dataset_0), len(self.dataset_1))

    def __getitem__(self, index):
        x0, y0 = self.dataset_0[index]
        x1, y1 = self.dataset_1[index]
        return x0, x1, y0


def create_dataset(name, transform = None):
    if '+' in name:
       name = name.split('+')
       dataset_0 = create_dataset(name[0])
       dataset_1 = create_dataset(name[1])
       return ZipDataset(dataset_0, dataset_1)


    name = name.split('_')
    assert(transform is None)
    
    if name[0] == 'CIFAR10':
        assert(name[1] == 'train' or name[1] == 'val')
        print ("Note that currently CIFAR-10 uses ImageNet transforms")
        return datasets.CIFAR10(root, train = name[1] == 'train', download=True, transform = torchvision.models.ResNet18_Weights.IMAGENET1K_V1.transforms())
    elif name[0] == 'CIFAR10-std':
        assert(name[1] == 'train' or name[1] == 'val')
        return datasets.CIFAR10(root, train = name[1] == 'train', transform = std_transform)
    elif name[0] == 'ImageNet':
        assert(name[1] == 'train' or name[1] == 'val')
        return datasets.ImageNet(root, split = name[1], transform = torchvision.models.ResNet18_Weights.IMAGENET1K_V1.transforms())
    elif name[0] == 'ImageNet-std':
        assert(name[1] == 'train' or name[1] == 'val')
        return datasets.ImageNet(root, split = name[1], transform = std_transform)
    
    elif name[0] == 'CIFAR100':
        assert(name[1] == 'train' or name[1] == 'val')
        print ("Note that currently CIFAR-100 uses ImageNet transforms")
        return datasets.CIFAR100(root, train = name[1] == 'train', download=True, transform = torchvision.models.ResNet18_Weights.IMAGENET1K_V1.transforms())
    elif name[0] == 'CIFAR100-std':
        assert(name[1] == 'train' or name[1] == 'val')
        print ("Note that currently CIFAR-100 uses ImageNet transforms")
        return datasets.CIFAR100(root, train = name[1] == 'train', download=True, transform = std_transform)
    elif name[0] == 'STL10':
        assert(name[1] == 'train' or name[1] == 'test')
        print ("Note that currently STL-10 uses ImageNet transforms")
        return datasets.STL10(root, split = name[1], download=True, transform = torchvision.models.ResNet18_Weights.IMAGENET1K_V1.transforms())
    elif name[0] == 'STL10-std':
        assert(name[1] == 'train' or name[1] == 'test')
        print ("Note that currently STL-10 uses ImageNet transforms")
        return datasets.STL10(root, split = name[1], download=True, transform = std_transform)
    
    else:
        raise Exception('dataset not implemented')
