"""
Dataset setup and loaders
"""
from datasets import cityscapes
from datasets import mapillary
from datasets import synthia
from datasets import kitti
from datasets import camvid
from datasets import bdd100k
from datasets import gtav
from datasets import idd
from datasets import nullloader

from datasets import multi_loader
from datasets.sampler import DistributedSampler

import torchvision.transforms as standard_transforms

import transforms.joint_transforms as joint_transforms
import transforms.transforms as extended_transforms
from torch.utils.data import DataLoader, ConcatDataset
import torch


num_classes = 19
ignore_label = 255


def get_train_joint_transform(args, dataset):
    """
    transform which execute on both image and mask
    Get train joint transform
    Args:
        args: input config arguments
        dataset: dataset class object

    return: train_joint_transform_list, train_joint_transform
    """

    # Geometric image transformations
    train_joint_transform_list = []
    train_joint_transform_list += [
        joint_transforms.RandomSizeAndCrop(args.crop_size,
                                           crop_nopad=args.crop_nopad,
                                           pre_size=args.pre_size,
                                           scale_min=args.scale_min,
                                           scale_max=args.scale_max,
                                           ignore_index=dataset.ignore_label),
        joint_transforms.Resize(args.crop_size),
        joint_transforms.RandomHorizontallyFlip()]

    if args.rrotate > 0:
        train_joint_transform_list += [joint_transforms.RandomRotate(
            degree=args.rrotate,
            ignore_index=dataset.ignore_label)]

    train_joint_transform = joint_transforms.Compose(train_joint_transform_list)

    # return the raw list for class uniform sampling
    return train_joint_transform_list, train_joint_transform


def get_input_transforms(args, dataset):
    """
    get transform which execute on only image
    Get input transforms
    Args:
        args: input config arguments
        dataset: dataset class object

    return: train_input_transform, val_input_transform
    """

    # Image appearance transformations
    train_input_transform = []
    val_input_transform = []
    if args.color_aug > 0.0:
        train_input_transform += [standard_transforms.RandomApply([
            standard_transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.5)]

    if args.bblur:
        train_input_transform += [extended_transforms.RandomBilateralBlur()]
    elif args.gblur:
        train_input_transform += [extended_transforms.RandomGaussianBlur()]

    train_input_transform += [
                                standard_transforms.ToTensor()
    ]
    val_input_transform += [
                            standard_transforms.ToTensor()
    ]
    train_input_transform = standard_transforms.Compose(train_input_transform)
    val_input_transform = standard_transforms.Compose(val_input_transform)

    return train_input_transform, val_input_transform

def get_color_geometric_transforms():
    """
    hard augmentation on robustnet
    Get input transforms
    Args:
        args: input config arguments
        dataset: dataset class object

    return: train_input_transform, val_input_transform
    """

    # Image appearance transformations
    color_input_transform = []
    geometric_input_transform = []

    color_input_transform += [standard_transforms.ColorJitter(0.8, 0.8, 0.8, 0.3)]
    color_input_transform += [extended_transforms.RandomGaussianBlur()]

    geometric_input_transform += [standard_transforms.RandomHorizontalFlip(p=1.0)]

    color_input_transform += [
                              standard_transforms.ToTensor()
    ]
    geometric_input_transform += [
                            standard_transforms.ToTensor()
    ]
    color_input_transform = standard_transforms.Compose(color_input_transform)
    geometric_input_transform = standard_transforms.Compose(geometric_input_transform)

    return color_input_transform, geometric_input_transform

def get_meta_transforms():
    """
    get meta learning hard transformation
    """

    # Image appearance transformations
    color_input_transform = []

    color_input_transform += [standard_transforms.ColorJitter(0.8, 0.8, 0.8, 0.3)]
    color_input_transform += [extended_transforms.RandomGaussianBlur()]

    color_input_transform += [
                              standard_transforms.ToTensor()
    ]
    meta_transform = standard_transforms.Compose(color_input_transform)

    return meta_transform

def get_target_transforms(args, dataset):
    """
    get only mask transform
    Get target transforms
    Args:
        args: input config arguments
        dataset: dataset class object

    return: target_transform, target_train_transform, target_aux_train_transform
    """

    target_transform = extended_transforms.MaskToTensor()
    if args.jointwtborder:
        target_train_transform = extended_transforms.RelaxedBoundaryLossToTensor(
                dataset.ignore_label, dataset.num_classes)
    else:
        target_train_transform = extended_transforms.MaskToTensor()

    target_aux_train_transform = extended_transforms.MaskToTensor()

    return target_transform, target_train_transform, target_aux_train_transform


def create_extra_val_loader(args, dataset, val_input_transform, target_transform, val_sampler):
    """
    Create extra validation loader
    Args:
        args: input config arguments
        dataset: dataset class object
        val_input_transform: validation input transforms
        target_transform: target transforms
        val_sampler: validation sampler

    return: validation loaders
    """
    if dataset == 'cityscapes':
        val_set = cityscapes.CityScapes('fine', 'val', 0,
                                        transform=val_input_transform,
                                        target_transform=target_transform,
                                        cv_split=args.cv,
                                        image_in=args.image_in)
    elif dataset == 'idd':
        val_set = idd.Idd('val', 0,
                                        transform=val_input_transform,
                                        target_transform=target_transform,
                                        cv_split=args.cv,
                                        image_in=args.image_in)
    elif dataset == 'bdd100k':
        val_set = bdd100k.BDD100K('val', 0,
                                  transform=val_input_transform,
                                  target_transform=target_transform,
                                  cv_split=args.cv,
                                  image_in=args.image_in)
    elif dataset == 'gtav':
        val_set = gtav.GTAV('val', 0,
                            transform=val_input_transform,
                            target_transform=target_transform,
                            cv_split=args.cv,
                            image_in=args.image_in)
    elif dataset == 'synthia':
        val_set = synthia.Synthia('val', 0,
                                  transform=val_input_transform,
                                  target_transform=target_transform,
                                  cv_split=args.cv,
                                  image_in=args.image_in)
    elif dataset == 'mapillary':
        eval_size = 1536
        val_joint_transform_list = [
            joint_transforms.ResizeHeight(eval_size),
            joint_transforms.CenterCropPad(eval_size)]
        val_set = mapillary.Mapillary('semantic', 'val',
                                      joint_transform_list=val_joint_transform_list,
                                      transform=val_input_transform,
                                      target_transform=target_transform,
                                      test=False)
    elif dataset == 'null_loader':
        val_set = nullloader.nullloader(args.crop_size)
    else:
        raise Exception('Dataset {} is not supported'.format(dataset))

    if args.syncbn:
        from datasets.sampler import DistributedSampler
        val_sampler = DistributedSampler(val_set, pad=False, permutation=False, consecutive_sample=False)

    else:
        val_sampler = None

    val_loader = DataLoader(val_set, batch_size=args.val_batch_size,
                            num_workers=args.num_workers // 2 , shuffle=False, drop_last=False,
                            sampler = val_sampler)
    return val_loader

def create_covstat_val_loader(args, dataset, val_input_transform, target_transform, val_sampler):
    """
    Create covariance statistics validation loader
    Args:
        args: input config arguments
        dataset: dataset class object
        val_input_transform: validation input transforms
        target_transform: target transforms
        val_sampler: validation sampler

    return: validation loaders
    """

    color_transform, geometric_transform = get_color_geometric_transforms()
    if dataset == 'cityscapes':
        val_set = cityscapes.CityScapesAug('fine', 'train', 0,
                                        transform=val_input_transform,
                                        color_transform=color_transform,
                                        geometric_transform=geometric_transform,
                                        target_transform=target_transform,
                                        cv_split=args.cv,
                                        image_in=args.image_in)
    elif dataset == 'idd':
        val_set = idd.IddAug('train', 0,
                                        transform=val_input_transform,
                                        color_transform=color_transform,
                                        geometric_transform=geometric_transform,
                                        target_transform=target_transform,
                                        cv_split=args.cv,
                                        image_in=args.image_in)
    elif dataset == 'bdd100k':
        val_set = bdd100k.BDD100KAug('train', 0,
                                  transform=val_input_transform,
                                  color_transform=color_transform,
                                  geometric_transform=geometric_transform,
                                  target_transform=target_transform,
                                  cv_split=args.cv,
                                  image_in=args.image_in)
    elif dataset == 'gtav':
        val_set = gtav.GTAVAug('train', 0,
                            transform=val_input_transform,
                            color_transform=color_transform,
                            geometric_transform=geometric_transform,
                            target_transform=target_transform,
                            cv_split=args.cv,
                            image_in=args.image_in)
    elif dataset == 'synthia':
        val_set = synthia.SynthiaAug('train', 0,
                                  transform=val_input_transform,
                                  color_transform=color_transform,
                                  geometric_transform=geometric_transform,
                                  target_transform=target_transform,
                                  cv_split=args.cv,
                                  image_in=args.image_in)
    elif dataset == 'mapillary':
        print("Not supported")
        exit()
    elif dataset == 'null_loader':
        val_set = nullloader.nullloader(args.crop_size)
    else:
        raise Exception('Dataset {} is not supported'.format(dataset))

    # if args.syncbn:
    #     from datasets.sampler import DistributedSampler
    #     val_sampler = DistributedSampler(val_set, pad=False, permutation=False, consecutive_sample=False)
    # else:
    val_sampler = None
    val_loader = DataLoader(val_set, batch_size=1,
                            num_workers=args.num_workers // 2 , shuffle=True, drop_last=False,
                            sampler = val_sampler)
    return val_loader

def setup_loaders(args):
    """
    Setup Data Loaders[Currently supports Cityscapes, Mapillary and ADE20kin]
    input: argument passed by the user
    return:  training data loader, validation data loader loader,  train_set
    """

    args.train_batch_size = args.bs_mult * args.ngpu
    if args.bs_mult_val > 0:
        args.val_batch_size = args.bs_mult_val * args.ngpu
    else:
        args.val_batch_size = args.bs_mult * args.ngpu

    # Readjust batch size to mini-batch size for syncbn
    if args.syncbn:
        args.train_batch_size = args.bs_mult
        args.val_batch_size = args.bs_mult_val

    args.num_workers = 8 #1 * args.ngpu
    if args.test_mode:
        args.num_workers = 1

    train_sets = []
    val_sets = []
    val_dataset_names = []

    for dataset_name in args.dataset:
        if 'cityscapes' == dataset_name:
            dataset = cityscapes
            city_mode = args.city_mode #'train' ## Can be trainval
            city_quality = 'fine'
            train_joint_transform_list, train_joint_transform = get_train_joint_transform(args, dataset)
            train_input_transform, val_input_transform = get_input_transforms(args, dataset)
            target_transform, target_train_transform, target_aux_train_transform = get_target_transforms(args, dataset)
            meta_transform = get_meta_transforms()

            if args.class_uniform_pct:
                if args.coarse_boost_classes:
                    coarse_boost_classes = \
                        [int(c) for c in args.coarse_boost_classes.split(',')]
                else:
                    coarse_boost_classes = None

                train_set = dataset.CityScapesUniform(
                    city_quality, city_mode, args.maxSkip,
                    joint_transform_list=train_joint_transform_list,
                    transform=train_input_transform,
                    target_transform=target_train_transform,
                    target_aux_transform=target_aux_train_transform,
                    dump_images=args.dump_augmentation_images,
                    cv_split=args.cv,
                    class_uniform_pct=args.class_uniform_pct,
                    class_uniform_tile=1024,
                    test=args.test_mode,
                    coarse_boost_classes=coarse_boost_classes,
                    image_in=args.image_in)
            else:
                train_set = dataset.CityScapes(
                    city_quality, city_mode, 0,
                    joint_transform=train_joint_transform,
                    transform=train_input_transform,
                    target_transform=target_train_transform,
                    target_aux_transform=target_aux_train_transform,
                    dump_images=args.dump_augmentation_images,
                    image_in=args.image_in,
                    meta_transform = meta_transform)

            val_set = dataset.CityScapes('fine', 'val', 0,
                                         transform=val_input_transform,
                                         target_transform=target_transform,
                                         cv_split=args.cv,
                                         image_in=args.image_in)
            train_sets.append(train_set)
            val_sets.append(val_set)
            val_dataset_names.append('cityscapes')

        if 'idd' == dataset_name:
            dataset = idd
            city_mode = args.city_mode #'train' ## Can be trainval
            train_joint_transform_list, train_joint_transform = get_train_joint_transform(args, dataset)
            train_input_transform, val_input_transform = get_input_transforms(args, dataset)
            target_transform, target_train_transform, target_aux_train_transform = get_target_transforms(args, dataset)
            meta_transform = get_meta_transforms()

            if args.class_uniform_pct:
                train_set = dataset.IddUniform(
                    city_mode, args.maxSkip,
                    joint_transform_list=train_joint_transform_list,
                    transform=train_input_transform,
                    target_transform=target_train_transform,
                    target_aux_transform=target_aux_train_transform,
                    dump_images=args.dump_augmentation_images,
                    cv_split=args.cv,
                    class_uniform_pct=args.class_uniform_pct,
                    class_uniform_tile=1024,
                    test=args.test_mode,
                    image_in=args.image_in)
            else:
                train_set = dataset.Idd(
                    city_mode, 0,
                    joint_transform=train_joint_transform,
                    transform=train_input_transform,
                    target_transform=target_train_transform,
                    target_aux_transform=target_aux_train_transform,
                    dump_images=args.dump_augmentation_images,
                    image_in=args.image_in,
                    meta_transform = meta_transform)

            val_set = dataset.Idd('val', 0,
                                         transform=val_input_transform,
                                         target_transform=target_transform,
                                         cv_split=args.cv,
                                         image_in=args.image_in)
            train_sets.append(train_set)
            val_sets.append(val_set)
            val_dataset_names.append('idd')

        if 'bdd100k' == dataset_name:
            dataset = bdd100k
            bdd_mode = 'train' ## Can be trainval
            train_joint_transform_list, train_joint_transform = get_train_joint_transform(args, dataset)
            train_input_transform, val_input_transform = get_input_transforms(args, dataset)
            target_transform, target_train_transform, target_aux_train_transform = get_target_transforms(args, dataset)
            meta_transform = get_meta_transforms()

            if args.class_uniform_pct:
                if args.coarse_boost_classes:
                    coarse_boost_classes = \
                        [int(c) for c in args.coarse_boost_classes.split(',')]
                else:
                    coarse_boost_classes = None

                train_set = dataset.BDD100KUniform(
                    bdd_mode, args.maxSkip,
                    joint_transform_list=train_joint_transform_list,
                    transform=train_input_transform,
                    target_transform=target_train_transform,
                    target_aux_transform=target_aux_train_transform,
                    dump_images=args.dump_augmentation_images,
                    cv_split=args.cv,
                    class_uniform_pct=args.class_uniform_pct,
                    class_uniform_tile=1024,
                    test=args.test_mode,
                    coarse_boost_classes=coarse_boost_classes,
                    image_in=args.image_in)
            else:
                train_set = dataset.BDD100K(
                    bdd_mode, 0,
                    joint_transform=train_joint_transform,
                    transform=train_input_transform,
                    target_transform=target_train_transform,
                    target_aux_transform=target_aux_train_transform,
                    dump_images=args.dump_augmentation_images,
                    cv_split=args.cv,
                    image_in=args.image_in,
                    meta_transform = meta_transform)

            val_set = dataset.BDD100K('val', 0,
                                      transform=val_input_transform,
                                      target_transform=target_transform,
                                      cv_split=args.cv,
                                      image_in=args.image_in)
            train_sets.append(train_set)
            val_sets.append(val_set)
            val_dataset_names.append('bdd100k')

        if 'gtav' == dataset_name:
            dataset = gtav
            gtav_mode = 'train' ## Can be trainval
            train_joint_transform_list, train_joint_transform = get_train_joint_transform(args, dataset)
            train_input_transform, val_input_transform = get_input_transforms(args, dataset)
            target_transform, target_train_transform, target_aux_train_transform = get_target_transforms(args, dataset)
            meta_transform = get_meta_transforms()

            if args.class_uniform_pct:
                if args.coarse_boost_classes:
                    coarse_boost_classes = \
                        [int(c) for c in args.coarse_boost_classes.split(',')]
                else:
                    coarse_boost_classes = None

                train_set = dataset.GTAVUniform(
                    gtav_mode, args.maxSkip,
                    joint_transform_list=train_joint_transform_list,
                    transform=train_input_transform,
                    target_transform=target_train_transform,
                    target_aux_transform=target_aux_train_transform,
                    dump_images=args.dump_augmentation_images,
                    cv_split=args.cv,
                    class_uniform_pct=args.class_uniform_pct,
                    class_uniform_tile=1024,
                    test=args.test_mode,
                    coarse_boost_classes=coarse_boost_classes,
                    image_in=args.image_in)
            else:
                train_set = gtav.GTAV(
                    gtav_mode, 0,
                    joint_transform=train_joint_transform,
                    transform=train_input_transform,
                    target_transform=target_train_transform,
                    target_aux_transform=target_aux_train_transform,
                    dump_images=args.dump_augmentation_images,
                    cv_split=args.cv,
                    image_in=args.image_in,
                    meta_transform = meta_transform)

            val_set = gtav.GTAV('val', 0,
                                transform=val_input_transform,
                                target_transform=target_transform,
                                cv_split=args.cv,
                                image_in=args.image_in)
            train_sets.append(train_set)
            val_sets.append(val_set)
            val_dataset_names.append('gtav')

        if 'synthia' == dataset_name:
            dataset = synthia
            synthia_mode = 'train' ## Can be trainval
            train_joint_transform_list, train_joint_transform = get_train_joint_transform(args, dataset)
            train_input_transform, val_input_transform = get_input_transforms(args, dataset)
            target_transform, target_train_transform, target_aux_train_transform = get_target_transforms(args, dataset)
            meta_transform = get_meta_transforms()

            if args.class_uniform_pct:
                if args.coarse_boost_classes:
                    coarse_boost_classes = \
                        [int(c) for c in args.coarse_boost_classes.split(',')]
                else:
                    coarse_boost_classes = None

                train_set = dataset.SynthiaUniform(
                    synthia_mode, args.maxSkip,
                    joint_transform_list=train_joint_transform_list,
                    transform=train_input_transform,
                    target_transform=target_train_transform,
                    target_aux_transform=target_aux_train_transform,
                    dump_images=args.dump_augmentation_images,
                    cv_split=args.cv,
                    class_uniform_pct=args.class_uniform_pct,
                    class_uniform_tile=760,
                    test=args.test_mode,
                    coarse_boost_classes=coarse_boost_classes,
                    image_in=args.image_in)
            else:
                train_set = dataset.Synthia(
                    synthia_mode, 0,
                    joint_transform=train_joint_transform,
                    transform=train_input_transform,
                    target_transform=target_train_transform,
                    target_aux_transform=target_aux_train_transform,
                    dump_images=args.dump_augmentation_images,
                    cv_split=args.cv,
                    image_in=args.image_in,
                    meta_transform = meta_transform)

            val_set = dataset.Synthia('val', 0,
                                      transform=val_input_transform,
                                      target_transform=target_transform,
                                      cv_split=args.cv,
                                      image_in=args.image_in)
            train_sets.append(train_set)
            val_sets.append(val_set)
            val_dataset_names.append('synthia')

        if 'mapillary' == dataset_name:
            dataset = mapillary
            train_joint_transform_list, train_joint_transform = get_train_joint_transform(args, dataset)
            train_input_transform, val_input_transform = get_input_transforms(args, dataset)
            target_transform, target_train_transform, target_aux_train_transform = get_target_transforms(args, dataset)
            meta_transform = get_meta_transforms()

            eval_size = 1536
            val_joint_transform_list = [
                joint_transforms.ResizeHeight(eval_size),
                joint_transforms.CenterCropPad(eval_size)]

            train_set = dataset.Mapillary(
                'semantic', 'train',
                joint_transform_list=train_joint_transform_list,
                transform=train_input_transform,
                target_transform=target_train_transform,
                target_aux_transform=target_aux_train_transform,
                image_in=args.image_in,
                dump_images=args.dump_augmentation_images,
                class_uniform_pct=args.class_uniform_pct,
                class_uniform_tile=1024,
                test=args.test_mode,
                meta_transform = meta_transform)
            val_set = dataset.Mapillary(
                'semantic', 'val',
                joint_transform_list=val_joint_transform_list,
                transform=val_input_transform,
                target_transform=target_transform,
                image_in=args.image_in,
                test=False)
            train_sets.append(train_set)
            val_sets.append(val_set)
            val_dataset_names.append('mapillary')

        if 'null_loader' == dataset_name:
            train_set = nullloader.nullloader(args.crop_size)
            val_set = nullloader.nullloader(args.crop_size)

            train_sets.append(train_set)
            val_sets.append(val_set)
            val_dataset_names.append('null_loader')

    if len(train_sets) == 0:
        raise Exception('Dataset {} is not supported'.format(args.dataset))

    if len(train_sets) != len(args.dataset):
        raise Exception('Something went wrong. Please check your dataset names are valid')

    # Define new train data set that has all the train sets
    # Define new val data set that has all the val sets
    val_loaders = {}
    # if len(args.dataset) != 1:
    #     if args.image_uniform_sampling:
    #         train_set = ConcatDataset(train_sets)
    #     else:
    #         train_set = multi_loader.DomainUniformConcatDataset(args, train_sets)

    if args.image_uniform_sampling:
        train_set = ConcatDataset(train_sets)
    else:
        train_set = multi_loader.DomainUniformConcatDataset(args, train_sets)


    for i, val_set in enumerate(val_sets):
        if args.syncbn:
            val_sampler = DistributedSampler(val_set, pad=False, permutation=False, consecutive_sample=False)
        else:
            val_sampler = None
        val_loader = DataLoader(val_set, batch_size=args.val_batch_size,
                                num_workers=args.num_workers // 2 , shuffle=False, drop_last=False,
                                sampler = val_sampler)
        val_loaders[val_dataset_names[i]] = val_loader

    if args.syncbn:
        train_sampler = DistributedSampler(train_set, pad=True, permutation=True, consecutive_sample=False)
    else:
        train_sampler = None

    train_loader = DataLoader(train_set, batch_size=args.train_batch_size,
                              num_workers=args.num_workers, shuffle=(train_sampler is None), drop_last=True, sampler = train_sampler)

    extra_val_loader = {}
    for val_dataset in args.val_dataset:
        _, val_input_transform = get_input_transforms(args, val_dataset)
        target_transform, _, _ = get_target_transforms(args, val_dataset)
        extra_val_loader[val_dataset] = create_extra_val_loader(args, val_dataset, val_input_transform, target_transform, val_sampler)

    covstat_val_loader = {}
    for val_dataset in args.covstat_val_dataset:
        _, val_input_transform = get_input_transforms(args, val_dataset)
        target_transform, _, _ = get_target_transforms(args, val_dataset)
        covstat_val_loader[val_dataset] = create_covstat_val_loader(args, val_dataset, val_input_transform, target_transform, val_sampler)

    return train_loader, val_loaders, train_set, extra_val_loader, covstat_val_loader

