#A lot of this code is reused from https://github.com/yongchao97/FRePo
from absl import logging

import os
import numpy as np

from tiny_imagenet_dataset import TinyImageNet
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch

from imagenetLoad import ImageNetDownSample
from cifar10_permutations import cifar10_shuffle
import wilds

# Precomputed mean and std
data_stats = {
    'mnist': ([0.1307], [0.3081]),
    'fashion_mnist': ([0.2861], [0.3530]),
    'cifar10': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    'cifar100': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    # 'tiny_imagenet': ([0.4759, 0.4481, 0.3926], [0.2763, 0.2687, 0.2813]),
    'tiny_imagenet': ([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]),
    # 'tiny_imagenet': ([0., 0., 0.], [1., 1., 1.]),
    'imagenet2012': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    'imagenet21k': ([0., 0., 0.], [1., 1., 1.]),
    'imagenette': ([0.4626, 0.4588, 0.4251], [0.2790, 0.2745, 0.2973]),
    'imagewoof': ([0.4917, 0.4613, 0.3931], [0.2513, 0.2442, 0.2530]),
    'imagenet_resized/32x32': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    'imagenet_resized/64x64': ([0.4815, 0.4578, 0.4082], [0.2686, 0.2613, 0.2758]),
    'caltech_birds2011': ([0.4810, 0.4964, 0.4245], [0.2129, 0.2084, 0.2468]),
    'imagenet32': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    'imagenet64': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    'cam17': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
}


def center_crop(x, resolution):
    shape = tf.shape(x)
    h, w = shape[0], shape[1]
    size = tf.minimum(h, w)
    begin = tf.cast([h - size, w - size], tf.float32) / 2.0
    begin = tf.cast(begin, tf.int32)
    begin = tf.concat([begin, [0]], axis=0)  # Add channel dimension.
    x = tf.slice(x, begin, [size, size, 3])
    x = tf.image.resize_with_pad(x, resolution, resolution, method='area', antialias=True)
    return x


def get_ds_builder(dataset_name, data_dir):
    if dataset_name == 'imagewoof':
        ds_builder = ImagewoofV2(data_dir=data_dir)
    elif dataset_name == 'imagenette':
        ds_builder = ImagenetteV2(data_dir=data_dir)
    elif 'tiny_imagenet' in dataset_name:
        ds_builder = TinyImagenetV2(data_dir=data_dir)
    else:
        ds_builder = tfds.builder(dataset_name, data_dir=data_dir)
    ds_builder.download_and_prepare()
    return ds_builder


def configure_dataloader(ds, batch_size, x_transform=None, y_transform=None, shuffle=False, seed=0, cache = True):
    if y_transform is None:
        y_transform = lambda x: x
    else:
        y_transform = y_transform

    if cache:
        ds = ds.cache()

    if shuffle:
        ds = ds.shuffle(16 * batch_size, seed=seed)

    if x_transform:
        ds = ds.map(lambda x, y: (x_transform(x), y_transform(y)), tf.data.AUTOTUNE)
    else:
        ds = ds.map(lambda x, y: (x, y_transform(y)), tf.data.AUTOTUNE)
    ds = ds.batch(batch_size=batch_size)
    ds = ds.prefetch(buffer_size=tf.data.AUTOTUNE)
    return ds

def get_dataset(config, return_raw=False, apply_aug = True, data_folder = 'sdfsdf', batch_size = 256, get_normalization = False, target_class = None, subsample = 1.0, target_resolution = 224):
    dataset_name = config.name
    data_path = config.data_path

    if dataset_name in ['imagenet_resized/64x64', 'imagenette', 'imagewoof', 'imagenet_resized/32x32', 'imagenet2012']:
        split = ['train', 'validation']
    else:
        split = ['train', 'test']

    preprocess_type = 'standard'

    if dataset_name in ['imagenette', 'imagewoof']:
        use_checkboard = True
        use_mean_block = True
        block_size = 64
        resolution = 128
    elif dataset_name in ['tiny_imagenet'] or 'tiny_imagenet' in dataset_name:
        use_checkboard = False
        use_mean_block = False
        block_size = None
        resolution = 64
        num_classes = 200
        normalize = transforms.Normalize(mean=[0.4802, 0.4481, 0.3975], std=[0.2302, 0.2265, 0.2262])
    elif dataset_name in ['imagenet2012']:
        use_checkboard = False
        use_mean_block = False
        block_size = None
        resolution = 224
        num_classes = 1000
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
    elif dataset_name in ['cifar10']:
        use_checkboard = False
        use_mean_block = False
        block_size = None
        resolution = 32
        num_classes = 10
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
    elif dataset_name in ['imagenet21k']:
        use_checkboard = False
        use_mean_block = False
        block_size = None
        resolution = 224
        num_classes = 10450
        normalize = transforms.Normalize(mean=[0., 0., 0.],
                                    std=[1.0, 1.0, 1.0])
    elif dataset_name in ['cifar10'] or 'cifar10_split' in dataset_name:
        use_checkboard = False
        use_mean_block = False
        block_size = None
        resolution = 32
        num_classes = 10
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])

    elif dataset_name in ['cifar100']:
        use_checkboard = False
        use_mean_block = False
        block_size = None
        resolution = 32
        num_classes = 100
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
    elif 'cifar100_split' in dataset_name:
        use_checkboard = False
        use_mean_block = False
        block_size = None
        resolution = 32
        
        include_segments = [int(s) for s in dataset_name[len('cifar100_split'):].split('_')]
            
        include_list = []
        
        for split in include_segments:
            include_list.extend([i + 10 * split for i in range(10)])
            
        num_classes = len(include_list)
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
    elif dataset_name in ['imagenet32']:
        use_checkboard = False
        use_mean_block = False
        block_size = None
        resolution = 32
        num_classes = 1000
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
    elif dataset_name in ['imagenet64']:
        use_checkboard = False
        use_mean_block = False
        block_size = None
        resolution = 64
        num_classes = 1000
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
    elif dataset_name in ['cam17']:
        use_checkboard = False
        use_mean_block = False
        block_size = None
        resolution = 64
        num_classes = 2
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
    elif dataset_name in ['cam17_32']:
        use_checkboard = False
        use_mean_block = False
        block_size = None
        resolution = 32
        num_classes = 2
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
    else:
        use_checkboard = False
        use_mean_block = False
        block_size = None
        resolution = 32


    if dataset_name == 'imagenet2012':
        traindir = os.path.join(data_folder, 'train')
        valdir = os.path.join(data_folder, 'val')
    elif dataset_name == 'imagenet21k':
        traindir = os.path.join(data_folder, 'imagenet21k_train')
        valdir = os.path.join(data_folder, 'imagenet21k_val')
    

    if apply_aug:
        if dataset_name == 'imagenet2012':
            train_dataset = datasets.ImageFolder(
                traindir,
                transforms.Compose([
                    transforms.RandomResizedCrop(224),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    normalize,
            ]), target_transform = transforms.Compose([
                lambda y:torch.LongTensor([y]),
                transforms.Lambda(lambda y: torch.nn.functional.one_hot(y, 1000).type(torch.float)[0])
            ]))
        elif dataset_name == 'imagenet21k':
            train_dataset = datasets.ImageFolder(
                traindir,
                transforms.Compose([
                    transforms.Resize((224, 224)),
                    transforms.ToTensor(),
                    normalize,
            ]), target_transform = transforms.Compose([
                lambda y:torch.LongTensor([y]),
                transforms.Lambda(lambda y: torch.nn.functional.one_hot(y, 10450).type(torch.float)[0])
            ]))
        elif dataset_name == 'tiny_imagenet':
            train_transform = transforms.Compose([
                transforms.RandomResizedCrop(64),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
            train_dataset = TinyImageNet(os.path.join(data_folder, 'tiny_imagenet'), split='train', download=True, transform=train_transform, target_transform = transforms.Compose([
                lambda y:torch.LongTensor([y]),
                transforms.Lambda(lambda y: torch.nn.functional.one_hot(y, 200).type(torch.float)[0])
            ]))
            
        elif dataset_name == 'cifar10' or 'cifar10_split' in dataset_name:
            train_transform = transforms.Compose([
                transforms.RandomResizedCrop(32, scale = (0.3, 1.0)),
                transforms.ToTensor(),
                normalize,
            ])
            
            if 'split' in dataset_name:
                train_dataset = SubLoaderCifar10(os.path.join(data_folder, 'cifar_10'), train = True, download=True, transform=train_transform, target_transform = transforms.Compose([
                    lambda y:torch.LongTensor([y]),
                    transforms.Lambda(lambda y: torch.nn.functional.one_hot(y, 10).type(torch.float)[0])
                ]), include_list = [] if target_class is None else [target_class], sub_split = int(dataset_name[-1]))
            else:
                train_dataset = SubLoaderCifar10(os.path.join(data_folder, 'cifar_10'), train = True, download=True, transform=train_transform, target_transform = transforms.Compose([
                    lambda y:torch.LongTensor([y]),
                    transforms.Lambda(lambda y: torch.nn.functional.one_hot(y, 10).type(torch.float)[0])
                ]), include_list = [] if target_class is None else [target_class], subsample = subsample)
        elif dataset_name == 'cifar100':
            train_transform = transforms.Compose([
                transforms.RandomResizedCrop(32, scale = (0.3, 1.0)),
                transforms.ToTensor(),
                normalize,
            ])
            train_dataset = SubLoaderCifar100(os.path.join(data_folder, 'cifar_100'), train = True, download=True, transform=train_transform, target_transform = transforms.Compose([
                lambda y:torch.LongTensor([y]),
                transforms.Lambda(lambda y: torch.nn.functional.one_hot(y, 100).type(torch.float)[0])
            ]), include_list = [] if target_class is None else [target_class], subsample = subsample)
        elif dataset_name == 'imagenet32':
            train_transform = transforms.Compose([
                transforms.RandomResizedCrop((32, 32), scale = (0.3, 1.0)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
            train_dataset = ImageNetDownSample(data_folder, train = True, transform=train_transform, target_transform = transforms.Compose([
                lambda y:torch.LongTensor([y]),
                transforms.Lambda(lambda y: torch.nn.functional.one_hot(y, 1000).type(torch.float)[0])
            ]))
        elif dataset_name == 'imagenet64':
            train_transform = transforms.Compose([
                transforms.RandomResizedCrop((64, 64), scale = (0.3, 1.0)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
            train_dataset = ImageNetDownSample(data_folder, train = True, transform=train_transform, target_transform = transforms.Compose([
                lambda y:torch.LongTensor([y]),
                transforms.Lambda(lambda y: torch.nn.functional.one_hot(y, 1000).type(torch.float)[0])
            ]))
        elif dataset_name == 'cam17':
            dataset = wilds.get_dataset(dataset='camelyon17', download=True, root_dir='./cam_data')
            
            train_transform = transforms.Compose([
                transforms.RandomResizedCrop((64, 64), scale = (0.3, 1.0)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
            train_dataset = dataset.get_subset("train", transform=train_transform)
        elif dataset_name == 'cam17_32':
            dataset = wilds.get_dataset(dataset='camelyon17', download=True, root_dir='./cam_data')
            
            train_transform = transforms.Compose([
                transforms.RandomResizedCrop((32, 32), scale = (0.3, 1.0)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
            train_dataset = dataset.get_subset("train", transform=train_transform)
    else:
        if dataset_name == 'imagenet2012':
            train_dataset = datasets.ImageFolder(
                traindir,
                transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    normalize,
            ]), target_transform = transforms.Compose([
                lambda y:torch.LongTensor([y]),
                transforms.Lambda(lambda y: torch.nn.functional.one_hot(y, 1000).type(torch.float)[0])
            ]))
        elif dataset_name == 'imagenet21k':
            train_dataset = datasets.ImageFolder(
                traindir,
                transforms.Compose([
                    transforms.Resize((224, 224)),
                    transforms.ToTensor(),
                    normalize,
            ]), target_transform = transforms.Compose([
                lambda y:torch.LongTensor([y]),
                transforms.Lambda(lambda y: torch.nn.functional.one_hot(y, 10450).type(torch.float)[0])
            ]))
        elif dataset_name == 'tiny_imagenet':
            train_transform = transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ])
            train_dataset = TinyImageNet(os.path.join(data_folder, 'tiny_imagenet'), split='train', download=True, transform=train_transform, target_transform = transforms.Compose([
                lambda y:torch.LongTensor([y]),
                transforms.Lambda(lambda y: torch.nn.functional.one_hot(y, 200).type(torch.float)[0])
            ]))
            
        elif dataset_name == 'cifar10' or 'cifar10_split' in dataset_name:
            train_transform = transforms.Compose([
                transforms.Resize((target_resolution, target_resolution)),
                transforms.ToTensor(),
                normalize,
            ])
            if 'split' in dataset_name:
                include_segments = [int(s) for s in dataset_name[len('cifar10_split'):].split('_')]
                train_dataset = SubLoaderCifar10(os.path.join(data_folder, 'cifar_10'), train = True, download=True, transform=train_transform, target_transform = transforms.Compose([
                    lambda y:torch.LongTensor([y]),
                    transforms.Lambda(lambda y: torch.nn.functional.one_hot(y, 10).type(torch.float)[0])
                ]), include_list = [] if target_class is None else [target_class], sub_split = include_segments)
            else:
                
                train_dataset = SubLoaderCifar10(os.path.join(data_folder, 'cifar_10'), train = True, download=True, transform=train_transform, target_transform = transforms.Compose([
                    lambda y:torch.LongTensor([y]),
                    transforms.Lambda(lambda y: torch.nn.functional.one_hot(y, 10).type(torch.float)[0])
                ]), include_list = [] if target_class is None else [target_class], subsample = subsample)
                
            dataset_name = 'cifar10'

        elif dataset_name == 'cifar100' or 'cifar100_split' in dataset_name:
            train_transform = transforms.Compose([
                transforms.Resize((target_resolution, target_resolution)),
                transforms.ToTensor(),
                normalize,
            ])
            if 'split' in dataset_name:
                # split = int(dataset_name[-1])
                include_segments = [int(s) for s in dataset_name[len('cifar100_split'):].split('_')]
                
                include_list = []
                
                for split in include_segments:
                    include_list.extend([i + 10 * split for i in range(10)])
                
                
                train_dataset = SubLoaderCifar100(os.path.join(data_folder, 'cifar_100'), train = True, download=True, transform=train_transform, target_transform = transforms.Compose([
                    lambda y:torch.LongTensor([y]),
                    transforms.Lambda(lambda y: torch.nn.functional.one_hot(y, len(include_list)).type(torch.float)[0])
                ]), include_list = include_list, subsample = subsample)
            else:
                train_dataset = SubLoaderCifar100(os.path.join(data_folder, 'cifar_100'), train = True, download=True, transform=train_transform, target_transform = transforms.Compose([
                    lambda y:torch.LongTensor([y]),
                    transforms.Lambda(lambda y: torch.nn.functional.one_hot(y, 100).type(torch.float)[0])
                ]), include_list = [], subsample = subsample)

        elif dataset_name == 'imagenet32':
            train_transform = transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ])
            train_dataset = ImageNetDownSample(data_folder, train = True, transform=train_transform, target_transform = transforms.Compose([
                lambda y:torch.LongTensor([y]),
                transforms.Lambda(lambda y: torch.nn.functional.one_hot(y, 1000).type(torch.float)[0])
            ]))
        elif dataset_name == 'imagenet64':
            train_transform = transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ])
            train_dataset = ImageNetDownSample(data_folder, train = True, transform=train_transform, target_transform = transforms.Compose([
                lambda y:torch.LongTensor([y]),
                transforms.Lambda(lambda y: torch.nn.functional.one_hot(y, 1000).type(torch.float)[0])
            ]))
            
        elif dataset_name == 'cam17':
            dataset = wilds.get_dataset(dataset='camelyon17', download=True, root_dir='./cam_data')
            
            train_transform = transforms.Compose([
                transforms.Resize((64, 64)),
                transforms.ToTensor(),
                normalize,
            ])
            train_dataset = dataset.get_subset("train", transform=train_transform)
        
        elif dataset_name == 'cam17_32':
            dataset = wilds.get_dataset(dataset='camelyon17', download=True, root_dir='./cam_data')
            
            train_transform = transforms.Compose([
                transforms.Resize((32, 32)),
                transforms.ToTensor(),
                normalize,
            ])
            train_dataset = dataset.get_subset("train", transform=train_transform)
            
            
    if dataset_name == 'imagenet2012':
        val_dataset = datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize
            ]), target_transform = transforms.Compose([
                lambda y:torch.LongTensor([y]),
                transforms.Lambda(lambda y: torch.nn.functional.one_hot(y, 1000).type(torch.float)[0])
            ]))
    elif dataset_name == 'imagenet21k':
        val_dataset = datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                normalize
            ]), target_transform = transforms.Compose([
                lambda y:torch.LongTensor([y]),
                transforms.Lambda(lambda y: torch.nn.functional.one_hot(y, 10450).type(torch.float)[0])
            ]))
    elif dataset_name == 'tiny_imagenet':
        val_transform = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
        val_dataset = TinyImageNet('./data', split='val', download=True, transform=val_transform, target_transform = transforms.Compose([
                lambda y:torch.LongTensor([y]),
                transforms.Lambda(lambda y: torch.nn.functional.one_hot(y, 200).type(torch.float)[0])
            ]))
            
    elif dataset_name == 'cifar10':
        val_transform = transforms.Compose([
            transforms.Resize((target_resolution, target_resolution)),
            transforms.ToTensor(),
            normalize,
        ])
        val_dataset = SubLoaderCifar10(os.path.join(data_folder, 'cifar_10'), train = False, download=True, transform=val_transform, target_transform = transforms.Compose([
            lambda y:torch.LongTensor([y]),
            transforms.Lambda(lambda y: torch.nn.functional.one_hot(y, 10).type(torch.float)[0])
        ]))
    elif dataset_name == 'cifar100' or 'cifar100_split' in dataset_name:
        val_transform = transforms.Compose([
            transforms.Resize((target_resolution, target_resolution)),
            transforms.ToTensor(),
            normalize,
        ])
        if 'split' in dataset_name:
            include_segments = [int(s) for s in dataset_name[len('cifar100_split'):].split('_')]
            
            include_list = []
            
            for split in include_segments:
                include_list.extend([i + 10 * split for i in range(10)])
            
            
            val_dataset = SubLoaderCifar100(os.path.join(data_folder, 'cifar_100'), train = False, download=True, transform=val_transform, target_transform = transforms.Compose([
                lambda y:torch.LongTensor([y]),
                transforms.Lambda(lambda y: torch.nn.functional.one_hot(y, len(include_list)).type(torch.float)[0])
            ]), include_list = include_list, subsample = subsample)
        else:
            val_dataset = SubLoaderCifar100(os.path.join(data_folder, 'cifar_100'), train = False, download=True, transform=val_transform, target_transform = transforms.Compose([
                lambda y:torch.LongTensor([y]),
                transforms.Lambda(lambda y: torch.nn.functional.one_hot(y, 100).type(torch.float)[0])
            ]))
    elif dataset_name == 'imagenet32':
        val_transform = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
        val_dataset = ImageNetDownSample(data_folder, train = False, transform=val_transform, target_transform = transforms.Compose([
            lambda y:torch.LongTensor([y]),
            transforms.Lambda(lambda y: torch.nn.functional.one_hot(y, 1000).type(torch.float)[0])
        ]))
    elif dataset_name == 'imagenet64':
        val_transform = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
        val_dataset = ImageNetDownSample(data_folder, train = False, transform=val_transform, target_transform = transforms.Compose([
            lambda y:torch.LongTensor([y]),
            transforms.Lambda(lambda y: torch.nn.functional.one_hot(y, 1000).type(torch.float)[0])
        ]))
    elif dataset_name == 'cam17':
        dataset = wilds.get_dataset(dataset='camelyon17', download=True, root_dir='./cam_data')
        
        val_transform = transforms.Compose([
            transforms.Resize((64, 64)),
            transforms.ToTensor(),
            normalize,
        ])
        val_dataset = dataset.get_subset("test", transform=val_transform)
    elif dataset_name == 'cam17_32':
        dataset = wilds.get_dataset(dataset='camelyon17', download=True, root_dir='./cam_data')
        
        val_transform = transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            normalize,
        ])
        val_dataset = dataset.get_subset("test", transform=val_transform)
            
            
    

    mean, std = data_stats.get(dataset_name, ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))
    mean, std = np.array(mean)[None, None, None, :].astype(np.float16), np.array(std)[None, None, None, :].astype(np.float16)
    

    preprocess_op, rev_preprocess_op = get_preprocess_op_np(mean=mean, std=std, zca_mean=None,
                                                            whitening_transform=None,
                                                            rev_whitening_transform=None,
                                                            block_size=block_size, use_mean_block=use_mean_block,
                                                            use_checkboard=use_checkboard)


    if dataset_name in ['cam17', 'cam17_32']:
        def convert_to_bin(batch):            
            ims = []
            labs = []
            
            for sam in batch:
                x, y, z = sam
                # res.append(tuple([x, torch.nn.functional.one_hot(y, 2).type(torch.float)]))
                ims.append(x)
                labs.append(torch.nn.functional.one_hot(y, 2).type(torch.float))
            
            return torch.stack(ims), torch.stack(labs)
        
        train_dataset = torch.utils.data.DataLoader(
            train_dataset, batch_size=batch_size, shuffle = True,
            num_workers=8, pin_memory=True, collate_fn = convert_to_bin)
        
        val_dataset = torch.utils.data.DataLoader(
            val_dataset, batch_size=batch_size, shuffle = False,
            num_workers=8, pin_memory=True, collate_fn = convert_to_bin)
        
    else:
        train_dataset = torch.utils.data.DataLoader(
            train_dataset, batch_size=batch_size, shuffle = True,
            num_workers=8, pin_memory=True)
        
        val_dataset = torch.utils.data.DataLoader(
            val_dataset, batch_size=batch_size, shuffle = False,
            num_workers=8, pin_memory=True)

    

    logging.info('Resolution: {}'.format(resolution))

    with config.unlocked():
        config.img_shape = (resolution, resolution, 3)
        config.num_classes = num_classes
        # config.class_names = class_names
        # config.train_size = num_train
        # config.test_size = num_test

    if get_normalization:
        return (train_dataset, val_dataset), preprocess_op, rev_preprocess_op, (mean, std), normalize
    else:
        return (train_dataset, val_dataset), preprocess_op, rev_preprocess_op, (mean, std)


def get_preprocess_op_np(mean=None, std=None, zca_mean=None, whitening_transform=None, rev_whitening_transform=None,
                         block_size=None, use_mean_block=False, use_checkboard=False):
    # This operation deals with a batch of data per time
    def preprocess_op(images):
        images = (images - mean) / std
        return images

    def preprocess_op_rev(images):
        images = images * std + mean
        return images

    return preprocess_op, preprocess_op_rev

class SubLoaderCifar10(datasets.CIFAR10):
    def __init__(self, *args, include_list=[], subsample = 1.0, sub_split = -1, **kwargs):
        super(SubLoaderCifar10, self).__init__(*args, **kwargs)

        if include_list == []:
            return

        if self.train:
            # if subsample != 1.0:
            #     sample_mask = np.random.binomial(1, subsample, size = [len(self.targets)])


            #     print('subsample')

            #     print(np.mean(sample_mask))
            #     print(sample_mask)

            #     self.targets = np.array(self.targets)[sample_mask > 0.5]
            #     self.data = np.array(self.data)[sample_mask > 0.5]


            if sub_split != -1:
                labels = np.array(self.targets)
                full_mask = []
                for c in range(10):
                    include = np.array([c]).reshape(1, -1)
                    mask = (labels.reshape(-1, 1) == include).any(axis=1)
                    
                    for ss in sub_split:
                        mask_again = np.array(list(range(50000)))[mask][cifar10_shuffle[c][500 * ss: 500 * (ss + 1)]]
                        full_mask.extend(mask_again)
                    

                self.data = self.data[full_mask]
                self.targets = labels[full_mask].tolist()
                                
            else:
                labels = np.array(self.targets)
                include = np.array(include_list).reshape(1, -1)
                mask = (labels.reshape(-1, 1) == include).any(axis=1)

                self.data = self.data[mask]
                self.targets = labels[mask].tolist()
        else:
            labels = np.array(self.targets)
            include = np.array(include_list).reshape(1, -1)
            mask = (labels.reshape(-1, 1) == include).any(axis=1)

            self.data = self.data[mask]
            self.targets = labels[mask].tolist()



class SubLoaderCifar100(datasets.CIFAR100):
    def __init__(self, *args, include_list=[], subsample = 1.0, **kwargs):
        super(SubLoaderCifar100, self).__init__(*args, **kwargs)

        if include_list == []:
            return

        if self.train:
            # if subsample != 1.0:
            #     sample_mask = np.random.binomial(1, subsample, size = [len(self.targets)])


            #     print('subsample')

            #     print(np.mean(sample_mask))
            #     print(sample_mask)

            #     self.targets = np.array(self.targets)[sample_mask > 0.5]
            #     self.data = np.array(self.data)[sample_mask > 0.5]



            labels = np.array(self.targets)
            include = np.array(include_list).reshape(1, -1)
            mask = (labels.reshape(-1, 1) == include).any(axis=1)

            self.data = self.data[mask]
            self.targets = labels[mask].tolist()
            new_targets = []
            print(include_list)
            for target in self.targets:
                new_targets.append(include_list.index(target))
            self.targets = new_targets
        else:
            labels = np.array(self.targets)
            include = np.array(include_list).reshape(1, -1)
            mask = (labels.reshape(-1, 1) == include).any(axis=1)

            self.data = self.data[mask]
            self.targets = labels[mask].tolist()
            new_targets = []
            print(include_list)
            for target in self.targets:
                new_targets.append(include_list.index(target))
            self.targets = new_targets