# Inheritaed from open-mmlab/mmclassification
# Modified by Hangyu LIN

import copy
import platform
import random
from functools import partial

import numpy as np
import torch

from utils.parallel import get_dist_info
from torch.utils.data import DataLoader

from datasets.transforms import build_transforms
from datasets.samplers import build_sampler
from .cifar import CIFAR10, CIFAR100, NoisyCIFAR10, NoisyCIFAR100
from .imagenet import ImageNet
from .tuberlin import TUBerlin
from .sketchy import SketchyPair, SketchyPhoto, SketchySketch
from .vggsound import VGGSoundPair, VGGSoundVideo, VGGSoundAudio
from .nyu_depth_v2 import NYUDepthV2Pair, NYUDepthV2Depth, NYUDepthV2Photo
from .contrastive_dataset import get_contrastive_dataset


def worker_init_fn(worker_id, num_workers, rank, seed):
    # The seed of each worker equals to
    # num_worker * rank + worker_id + user_seed
    worker_seed = num_workers * rank + worker_id + seed
    np.random.seed(worker_seed)
    random.seed(worker_seed)
    torch.manual_seed(worker_seed)

def build_datasets(cfg, default_args=None):
    """Bulid a dataset from a config for dataset
    Args:
        cfg (dict): A dictionary contains the information to contruct dataset object. like type:dataset types, root:dataset root.
        Typically, there should be a `type` and other `kwargs`.
    """
    assert len(cfg.type) == len(cfg.root) == len(cfg.transforms) == len(cfg.resized_size), 'The number of dataset types should be equal to roots'
    datasets = {}
    # print(cfg.type, cfg.root)
    if default_args is None:
        default_args = {}
    test_mode = cfg.get('test_mode', False)
    used_ratios = cfg.get('used_ratios', [1 for _i in range(len(cfg.type))])
    for data_type, root, trans_list, resized_size, used_ratio in zip(cfg.type, cfg.root, cfg.transforms, cfg.resized_size, used_ratios):
        trans_kwargs = {
            'RandomResizedCrop':{'size':resized_size},
            'RandomCrop':{'size':resized_size, 'padding':4, 'padding_mode':'reflect'},
            'Resize':{'size':resized_size},
            'Normalize':{'mean':(0.4914, 0.4822, 0.4465), 'std':(0.247, 0.243, 0.261)},
            'GaryNormalize':{'mean':(0.1307,), 'std':(0.3081,)},
            'SegPresetTrain':{'base_size':int(resized_size[0] * 1.2), 'crop_size':resized_size[0],
                              'hflip_prob':0.5,'mean':(0.485, 0.456, 0.406),'std':(0.229, 0.224, 0.225),},
            'SegPresetEval':{'base_sizes':resized_size},
        }
        if data_type == 'cifar10':
            dataset = CIFAR10(
                root=root,
                transforms=build_transforms(trans_list, trans_kwargs),
                test_mode=test_mode,
                **default_args)
        elif data_type == 'cifar100':
            dataset = CIFAR100(
                root=root,
                transforms=build_transforms(trans_list, trans_kwargs),
                test_mode=test_mode,
                **default_args)
        elif data_type in ['noisy_cifar10', 'noisy_cifar100']:
            NOISY_DATASET_DICT = {
                'noisy_cifar10': NoisyCIFAR10,
                'noisy_cifar100': NoisyCIFAR100,
            }
            inner_dataset = NOISY_DATASET_DICT[data_type](
                root=root,
                transforms=build_transforms(trans_list, trans_kwargs),
                trans_rate=cfg.trans_rate,
                trans_type=cfg.trans_type,
                test_mode=test_mode,
                **default_args)
            outer_dataset = NOISY_DATASET_DICT[data_type](
                root=root,
                transforms=build_transforms(trans_list, trans_kwargs),
                trans_rate=cfg.trans_rate,
                trans_type=cfg.trans_type,
                test_mode=test_mode,
                **default_args)
            data_infos = np.array(inner_dataset.data_infos)
            num_outer = int(cfg.outer_ratio * len(data_infos))
            num_inner = int(cfg.inner_ratio * len(data_infos))
            # shuffled indexs
            shuffled_inds = np.random.choice(len(data_infos), len(data_infos), replace=False)
            outer_inds = shuffled_inds[:num_outer]
            shuffled_inds = np.random.choice(len(data_infos), len(data_infos), replace=False)
            inner_inds = shuffled_inds[:num_inner]
            inner_dataset.data_infos = data_infos[inner_inds]
            outer_dataset.data_infos = data_infos[outer_inds]
            dataset = [inner_dataset, outer_dataset]
        elif data_type == 'mnist':
            # Not implemented
            dataset = CIFAR100(
                root=root,
                transforms=build_transforms(trans_list, trans_kwargs),
                test_mode=test_mode,
                **default_args)
        elif data_type == 'imagenet':
            dataset = ImageNet(
                root=root,
                transforms=build_transforms(trans_list, trans_kwargs),
                test_mode=test_mode,
                **default_args
            )
        elif data_type == 'tuberlin':
            split_num = cfg.get('split_num', 60)
            dataset = TUBerlin(
                root=root,
                transforms=build_transforms(trans_list, trans_kwargs),
                test_mode=test_mode,
                split_num=split_num,
                **default_args)
        elif data_type == 'sketchy_pair':
            dataset = SketchyPair(
                root=root,
                photo_transforms=build_transforms(trans_list, trans_kwargs),
                sketch_transforms=build_transforms(trans_list, trans_kwargs),
                used_ratio=used_ratio,
                photo_augs=cfg.photo_augs,
                sketch_augs=cfg.sketch_augs,
                test_mode=test_mode,
                **default_args
            )
        elif data_type == 'sketchy_photo':
            dataset = SketchyPhoto(
                root=root,
                photo_transforms=build_transforms(trans_list, trans_kwargs),
                used_ratio=used_ratio,
                photo_augs=cfg.photo_augs,
                test_mode=test_mode,
                **default_args
            )
        elif data_type == 'sketchy_sketch':
            dataset = SketchySketch(
                root=root,
                sketch_transforms=build_transforms(trans_list, trans_kwargs),
                used_ratio=used_ratio,
                sketch_augs=cfg.sketch_augs,
                test_mode=test_mode,
                **default_args
            )
        elif data_type == 'vggsound_pair':
            dataset = VGGSoundPair(
                root=root,
                video_transforms=build_transforms(trans_list, trans_kwargs),
                audio_transforms=build_transforms(trans_list, trans_kwargs),
                test_mode=test_mode,
                **default_args
            )
        elif data_type == 'vggsound_video':
            dataset = VGGSoundVideo(
                root=root,
                video_transforms=build_transforms(trans_list, trans_kwargs),
                test_mode=test_mode,
                **default_args
            )
        elif data_type == 'vggsound_audio':
            dataset = VGGSoundAudio(
                root=root,
                audio_transforms=build_transforms(trans_list, trans_kwargs),
                test_mode=test_mode,
                **default_args
            )
        elif data_type == 'nyu_depth_pair':
            dataset = NYUDepthV2Pair(
                root=root,
                photo_transforms=build_transforms(trans_list, trans_kwargs),
                depth_transforms=build_transforms(trans_list, trans_kwargs),
                test_mode=test_mode,
                **default_args
            )
        elif data_type == 'nyu_depth_photo]':
            dataset = NYUDepthV2Photo(
                root=root,
                photo_transforms=build_transforms(trans_list, trans_kwargs),
                resized_size=resized_size,
                test_mode=test_mode,
                **default_args
            )
        elif data_type == 'nyu_depth_depth':
            dataset = NYUDepthV2Depth(
                root=root,
                depth_transforms=build_transforms(trans_list, trans_kwargs),
                resized_size=resized_size,
                test_mode=test_mode,
                **default_args
            )
        elif data_type in ['con_cifar10', 'con_cifar100', 'con_imagenet', 'con_tuberlin']:
            ori_type = data_type[4:]
            dataset = get_contrastive_dataset(ori_type, **{**{
                                                            'root':root,
                                                            'transforms':build_transforms(trans_list, trans_kwargs),
                                                            'test_mode':test_mode,
                                                            'size':resized_size,
                                                            }, **default_args})
        elif data_type in ['con_sketchy_pair']:
            ori_type = data_type[4:]
            dataset = get_contrastive_dataset(ori_type, **{**{
                                                            'root':root,
                                                            'photo_transforms':build_transforms(trans_list, trans_kwargs),
                                                            'sketch_transforms':build_transforms(trans_list, trans_kwargs),
                                                            'used_ratio':used_ratio,
                                                            'photo_augs':cfg.photo_augs,
                                                            'sketch_augs':cfg.sketch_augs,
                                                            'test_mode':test_mode,
                                                            'size':resized_size,
                                                            }, **default_args})
        elif data_type in ['con_sketchy_photo']:
            ori_type = data_type[4:]
            dataset = get_contrastive_dataset(ori_type, **{**{
                                                            'root':root,
                                                            'photo_transforms':build_transforms(trans_list, trans_kwargs),
                                                            'used_ratio':used_ratio,
                                                            'photo_augs':cfg.photo_augs,
                                                            'test_mode':test_mode,
                                                            'size':resized_size,
                                                            }, **default_args})
        elif data_type in ['con_sketchy_sketch']:
            ori_type = data_type[4:]
            dataset = get_contrastive_dataset(ori_type, **{**{
                                                            'root':root,
                                                            'sketch_transforms':build_transforms(trans_list, trans_kwargs),
                                                            'used_ratio':used_ratio,
                                                            'sketch_augs':cfg.sketch_augs,
                                                            'test_mode':test_mode,
                                                            'size':resized_size,
                                                            }, **default_args})
        elif data_type in ['con_vggsound_pair']:
            ori_type = data_type[4:]
            dataset = get_contrastive_dataset(ori_type, **{**{
                                                            'root':root,
                                                            'video_transforms':build_transforms(trans_list, trans_kwargs),
                                                            'audio_transforms':build_transforms(trans_list, trans_kwargs),
                                                            'test_mode':test_mode,
                                                            'size':resized_size,
                                                            }, **default_args})
        elif data_type in ['con_vggsound_video']:
            ori_type = data_type[4:]
            dataset = get_contrastive_dataset(ori_type, **{**{
                                                            'root':root,
                                                            'video_transforms':build_transforms(trans_list, trans_kwargs),
                                                            'test_mode':test_mode,
                                                            'size':resized_size,
                                                            }, **default_args})
        elif data_type in ['con_vggsound_audio']:
            ori_type = data_type[4:]
            dataset = get_contrastive_dataset(ori_type, **{**{
                                                            'root':root,
                                                            'audio_transforms':build_transforms(trans_list, trans_kwargs),
                                                            'test_mode':test_mode,
                                                            'size':resized_size,
                                                            }, **default_args})
        elif data_type in ['con_nyu_depth_pair']:
            ori_type = data_type[4:]
            dataset = get_contrastive_dataset(ori_type, **{**{
                                                            'root':root,
                                                            'photo_transforms':build_transforms(trans_list, trans_kwargs),
                                                            'depth_transforms':build_transforms(trans_list, trans_kwargs),
                                                            'test_mode':test_mode,
                                                            'size':resized_size,
                                                            }, **default_args})
        elif data_type in ['con_nyu_depth_photo']:
            ori_type = data_type[4:]
            dataset = get_contrastive_dataset(ori_type, **{**{
                                                            'root':root,
                                                            'photo_transforms':build_transforms(trans_list, trans_kwargs),
                                                            'resized_size':resized_size,
                                                            'test_mode':test_mode,
                                                            'size':resized_size,
                                                            }, **default_args})
        elif data_type in ['con_nyu_depth_depth']:
            ori_type = data_type[4:]
            dataset = get_contrastive_dataset(ori_type, **{**{
                                                            'root':root,
                                                            'depth_transforms':build_transforms(trans_list, trans_kwargs),
                                                            'resized_size':resized_size,
                                                            'test_mode':test_mode,
                                                            'size':resized_size,
                                                            }, **default_args})
        else:
            ValueError(f'Unsupported type {data_type} of dataset.')
        if not isinstance(dataset, list):
            dataset = [dataset]
        datasets[data_type] = dataset
    return datasets

def build_dataloaders(datasets,
                     samples_per_gpu,
                     workers_per_gpu,
                     num_gpus=1,
                     dist=True,
                     shuffle=True,
                     drop_last=False,
                     seed=None,
                     pin_memory=True,
                     persistent_workers=True,
                     sampler_cfg=None,
                     **kwargs):
    """A Version that support to multiple datasets with the same parameters"""
    if not isinstance(datasets, list):
        datasets = [datasets]
    if not isinstance(samples_per_gpu, list):
        samples_per_gpu, workers_per_gpu = [samples_per_gpu], [workers_per_gpu]
    dataloaders = []
    assert len(samples_per_gpu) == len(workers_per_gpu) == len(datasets)
    
    for dataset, samples_per_gpu_, workers_per_gpu_ in zip(datasets, samples_per_gpu, workers_per_gpu):
        dataloaders.append(build_dataloader(dataset,
                    samples_per_gpu_,
                    workers_per_gpu_,
                    num_gpus,
                    dist,
                    shuffle,
                    drop_last,
                    seed,
                    pin_memory,
                    persistent_workers,
                    sampler_cfg,
                    **kwargs))
    return dataloaders

        
def build_dataloader(dataset,
                     samples_per_gpu,
                     workers_per_gpu,
                     num_gpus=1,
                     dist=True,
                     shuffle=True,
                     drop_last=False,
                     seed=None,
                     pin_memory=True,
                     persistent_workers=True,
                     sampler_cfg=None,
                     **kwargs):
    """Build PyTorch DataLoader.
    In distributed training, each GPU/process has a dataloader.
    In non-distributed training, there is only one dataloader for all GPUs.
    Args:
        dataset (Dataset): A PyTorch dataset.
        samples_per_gpu (int): Number of training samples on each GPU, i.e.,
            batch size of each GPU.
        workers_per_gpu (int): How many subprocesses to use for data loading
            for each GPU.
        num_gpus (int): Number of GPUs. Only used in non-distributed training.
        dist (bool): Distributed training/test or not. Default: True.
        shuffle (bool): Whether to shuffle the data at every epoch.
            Default: True.
        drop_last (bool): Whether to drop extra samples to make it evenly divisible. 
            Default: False.
        pin_memory (bool): Whether to use pin_memory in DataLoader.
            Default: True
        persistent_workers (bool): If True, the data loader will not shutdown
            the worker processes after a dataset has been consumed once.
            This allows to maintain the workers Dataset instances alive.
            The argument also has effect in PyTorch>=1.7.0.
            Default: True
        sampler_cfg (dict): sampler configuration to override the default
            sampler
        kwargs: any keyword argument to be used to initialize DataLoader
    Returns:
        DataLoader: A PyTorch dataloader.
    """
    rank, world_size = get_dist_info()

    # Custom sampler logic
    
    if sampler_cfg:
        # shuffle=False when val and test
        sampler_cfg.update(shuffle=shuffle)
        # TODO: Not implemented now.
        sampler = build_sampler(
            sampler_cfg,
            default_args=dict(
                dataset=dataset, num_replicas=world_size, rank=rank,
                seed=seed))
    # Default sampler logic
    elif dist:
        sampler = build_sampler(
            dict(
                type='DistributedSampler',),
            dict(dataset=dataset,
                num_replicas=world_size,
                rank=rank,
                shuffle=shuffle,
                drop_last=drop_last,
                seed=seed))
    else:
        sampler = None
    # print('sampler', sampler_cfg, sampler)
    # If sampler exists, turn off dataloader shuffle
    if sampler is not None:
        shuffle = False
    

    if dist:
        batch_size = samples_per_gpu
        num_workers = workers_per_gpu
    else:
        batch_size = num_gpus * samples_per_gpu
        num_workers = num_gpus * workers_per_gpu

    init_fn = partial(
        worker_init_fn, num_workers=num_workers, rank=rank,
        seed=seed) if seed is not None else None

    # print(batch_size, samples_per_gpu, num_workers)
    
    data_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=sampler,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=shuffle,
        worker_init_fn=init_fn,
        **kwargs)
    
    return data_loader

