import os
import argparse
import logging
import time
from pathlib import Path
import PIL.Image as Image

import numpy as np
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import json
from collections import namedtuple
import zipfile
from typing import Any, Callable, Dict, List, Optional, Union, Tuple

from torchvision.datasets.utils import extract_archive, verify_str_arg, iterable_to_str
from torchvision.datasets.vision import VisionDataset
from torchvision.datasets.cityscapes import Cityscapes
import training_config as config
from cityscapes_helper import id2trainid
        
import torch
from torch.nn.parallel.data_parallel import DataParallel
from torch.nn.parallel.parallel_apply import parallel_apply
from torch.nn.parallel._functions import Scatter


def do_args(jupyter=False):
    parser = argparse.ArgumentParser(
        description="Wrapper utility for training and testing land cover models.",
    )
    parser.add_argument(
        "-v",
        "--verbose",
        type=int,
        help="Verbosity of keras.fit",
        default=config.VERBOSE,
    )
    # parser.add_argument("--name", type=str, help="Experiment name", required=True)
    parser.add_argument("--name-prefix", type=str,
                        help="Experiment name prefix", default=config.NAME_PREFIX)
    parser.add_argument("--name-suffix", type=str,
                        help="Experiment name suffix", default=config.NAME_SUFFIX)
    parser.add_argument(
        "--data-dir",
        type=str,
        help="Path to data directory containing the splits CSV files",
        default=config.DATA_DIR,
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        help="Output base directory",
        default=config.OUTPUT_DIR,
    )
    parser.add_argument(
        "--log-dir",
        type=str,
        help="Log directory for Tensorboard and console records",
        default=config.LOG_DIR,
    )
    parser.add_argument(
        "--training-states",
        nargs="+",
        type=str,
        help="States to use as training",
        default=config.TRAINING_STATES,
    )
    parser.add_argument(
        "--validation-states",
        nargs="+",
        type=str,
        help="States to use as validation",
        default=config.VALIDATION_STATES,
    )
    parser.add_argument(
        "--superres-states",
        nargs="+",
        type=str,
        help="States to use only superres loss with",
        default=config.SUPERRES_STATES,
    )
    parser.add_argument(
        "--do-color",
        action="store_true",
        help="Enable color augmentation",
        default=config.DO_COLOR,
    )
    parser.add_argument(
        "--color-augmentation-intensity", type=float, help="Color augmentation intensity",
        default=config.COLOR_AUGMENTATION_INTENSITY
    )
    parser.add_argument(
        "--do-label-overloading",
        action="store_true",
        help="Enable label overloading",
        default=config.DO_LABEL_OVERLOADING,
    )
    parser.add_argument(
        "--net-model",
        type=str,
        default=config.NET_MODEL,
        help="Model architecture to use"
    )
    parser.add_argument(
        "--att-hidden", type=int, help="Dimension of attention hidden", default=config.ATT_HIDDEN
    )
    parser.add_argument(
        "--epochs", type=int, help="Number of epochs", default=config.EPOCHS
    )
    parser.add_argument(
        "--loss",
        type=str,
        help="Loss function",
        default=config.LOSS,
        choices=["crossentropy", "superres"],
    )
    parser.add_argument(
        "--learning-rate",
        type=float,
        help="Learning rate",
        default=config.LEARNING_RATE,
    )
    parser.add_argument(
        "--batch-size", type=int, help="Batch size", default=config.BATCH_SIZE
    )
    parser.add_argument(
        "--step-size", type=int, help="Learning rate scheduler step size", default=config.STEP_SIZE
    )
    parser.add_argument(
        "--shuffle", action="store_true", help="Dataloader shuffle", default=config.SHUFFLE
    )
    parser.add_argument(
        "--crossentropy-loss-weight", type=int, help="Weight of crossentropy loss if using superres",
        default=config.CROSSENTROPY_LOSS_WEIGHT
    )
    parser.add_argument(
        "--superres-loss-weight", type=int, help="Weight of superres loss if using superres",
        default=config.SUPERRES_LOSS_WEIGHT
    )
    parser.add_argument(
        "--mode", type=str, help="Used for remote console debug", default='Nothing'
    )
    parser.add_argument(
        "--input-size", type=int, help="Input image size in pixels", default=config.INPUT_SIZE
    )
    parser.add_argument(
        "--input-nchannels", type=int, help="Number of input image channels", default=config.INPUT_NCHANNELS
    )
    parser.add_argument(
        "--data-type", type=str, help="Data type of imagery", default=config.DATA_TYPE
    )
    parser.add_argument(
        "--hr-label-key", type=str, help="Keys for transformation of hr labels", default=config.HR_LABEL_KEY
    )
    parser.add_argument(
        "--hr-label-index", type=int, help="Positional index of hr labels in patches", default=config.HR_LABEL_INDEX
    )
    parser.add_argument(
        "--hr-nclasses", type=int, help="Number of target hr classes", default=config.HR_NCLASSES
    )
    parser.add_argument(
        "--lr-label-key", type=str, help="Keys for transformation of lr labels", default=config.LR_LABEL_KEY
    )
    parser.add_argument(
        "--lr-label-index", type=int, help="Positional index of lr labels in patches", default=config.LR_LABEL_INDEX
    )
    parser.add_argument(
        "--lr-nclasses", type=int, help="Number of target lr classes", default=config.LR_NCLASSES
    )
    parser.add_argument(
        "--training-steps-per-epoch", type=int, help="Number of training steps or batches per epoch",
        default=config.TRAINING_STEPS_PER_EPOCH
    )
    parser.add_argument(
        "--validation-steps-per-epoch", type=int, help="Number of validation steps or batches per epoch",
        default=config.VALIDATION_STEPS_PER_EPOCH
    )
    parser.add_argument(
        "--model-save-checkpoint", type=int, help="Save model at each checkpoint",
        default=config.MODEL_SAVE_CHECKPOINT
    )
    parser.add_argument(
        "--seed", type=int, help="Numpy random seed. Used to ensure we get the same validation batches",
        default=config.SEED
    )
    parser.add_argument(
        "--resume",
        action="store_true",
        help="Resume a previous training",
        default=config.RESUME,
    )
    parser.add_argument(
        "--resume-name", type=str, help="Experiment name to resume", default=config.RESUME_NAME
    )
    parser.add_argument(
        "--resume-model", type=str, help="Load a model to resume training", default=config.RESUME_MODEL
    )
    # parser.add_argument(
    #     "--freeze-att-until", type=int, help="Freeze the attention weights until reaching the indicated epoch",
    #     default=config.FREEZE_ATT_UNTIL
    # )
    parser.add_argument(
        "--optimizer", type=str, help="Optimizer", default=config.OPTIMIZER, choices=["sgd"]
    )
    parser.add_argument(
        "--hdf5",
        action="store_true",
        help="Use HDF5 datasets",
        default=config.HDF5
    )
    parser.add_argument(
        "--workers", type=int, help="Dataloader workers.",
        default=config.WORKERS
    )
    parser.add_argument(
        "--multiout-loss-type",
        type=str,
        help="Multiout loss type",
        default=config.MULTIOUT_LOSS_TYPE,
        choices=["decay", "same"],
    )
    parser.add_argument(
        "--att-mh", type=int, help="Number of attention multi-head.",
        default=config.ATT_MH
    )
    parser.add_argument(
        "--att-sm", type=float, help="Attention multiplier before softmax.",
        default=config.ATT_SM
    )
    parser.add_argument(
        "--att-ks", type=int, help="Attention kernel size.",
        default=config.ATT_KS
    )
    parser.add_argument(
        "--att-two-w",
        action="store_true",
        help="Use two different mlp for att q and k",
        default=config.ATT_TWO_W
    )
    parser.add_argument(
        "--output-stride", type=int, help="DeepLab output stride.",
        default=config.OUTPUT_STRIDE
    )
    parser.add_argument(
        "--cityscapes",
        action="store_true",
        help="Using Cityscapes dataset",
        default=config.CITYSCAPES
    )
    parser.add_argument(
        "--stem-os", type=int, help="Dynamic Routing stem output stride.",
        default=config.STEM_OS,
        choices=[1, 2, 4]
    )
    parser.add_argument(
        "--freeze-until", type=int, help="Freeze part of parameters until given epoch.",
        default=config.FREEZE_UNTIL,
    )
    parser.add_argument(
        "--nested-budget", type=float,
    )
    parser.add_argument(
        "--backbone-budget", type=float,
    )
    parser.add_argument(
        "--budget-loss-ratio", type=float,
    )
    parser.add_argument(
        "--city",
        action="store_true",
        help="Using city only dataset",
        default=config.CITY
    )

    args = parser.parse_args([]) if jupyter else parser.parse_args()
    # args.name = f"{args.name_prefix}" \
    #             f"{args.loss}_" \
    #             f'''{"CO_" if args.do_color else ""}''' \
    #             f'''{"LO_" if args.do_label_overloading else ""}''' \
    #             f"e{args.epochs}s{args.training_steps_per_epoch}b{args.batch_size}"\
    #             f"{args.name_suffix}"
    args.name = f'{get_datetime_str_simplified()}_' \
                f"{args.name_prefix}" \
                f"e{args.epochs}s{args.training_steps_per_epoch}b{args.batch_size}"\
                f"{args.name_suffix}"

    if args.resume:
        args.name = args.resume_name

    if args.cityscapes:
        args.hr_nclasses = 20
        args.input_nchannels = 3

    return args


def run_initialize(args, name):
    # mkdir to save the log file and model checkpoint records
    output_path = Path(args.output_dir) / Path(args.name)
    output_path.mkdir(parents=True, exist_ok=True)
    args.output_path = output_path
    logger = get_logger(name, args)
    logger.info(f"Experiment {args.name} {'resumes' if args.resume else 'starts'} at {get_datetime_str()}")
    logger.info(
        f"Experiment results and logs are saved at {str(args.output_path)}")
    logger.info("Configuration")
    log_str = ""
    max_log_str_len = 127
    for key in vars(args):
        item_str = f"{key.upper()}: {vars(args)[key]}"
        if len(f"{log_str}  {item_str}") <= max_log_str_len:
            log_str = f"{log_str}    {item_str}" if len(
                log_str) > 0 else item_str
        else:
            logger.info(log_str)
            log_str = item_str
    logger.info(log_str)
    check_directories(args)

    # save config file
    with open(str(output_path / Path("config.txt")), 'a') as f:
        for key in vars(args):
            item_str = f"{key.upper()}: {vars(args)[key]}\n"
            f.writelines(item_str)


def get_logger(name, args, save_dir=None, level=logging.DEBUG):
    LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
    logger = logging.getLogger(name)
    logger.setLevel(level)
    formatter = logging.Formatter(LOG_FORMAT)

    # create console handler and set level to debug
    ch = logging.StreamHandler()
    ch.setLevel(level)
    ch.setFormatter(formatter)
    logger.addHandler(ch)

    # create file handler and set level to debug
    if save_dir is None:
        logfile = str(args.output_path / Path("log.txt"))
    else:
        logfile = str(Path(save_dir) / Path("log.txt"))
    fh = logging.FileHandler(logfile, mode='a')
    fh.setLevel(level)
    fh.setFormatter(formatter)
    logger.addHandler(fh)

    return logger


def save_epoch_summary(args, summary_str):
    logfile = str(args.output_path / Path("epoch_summary.txt"))
    with open(logfile, 'a') as f:
        f.writelines(summary_str)


def to_float(arr, data_type=config.DATA_TYPE):
    if data_type == "int8":
        res = np.clip(arr / 255.0, 0.0, 1.0)
    elif data_type == "int16":
        res = np.clip(arr / 4096.0, 0.0, 1.0)
    else:
        raise ValueError("Select an appropriate data type.")
    return res


def handle_labels(arr, key_txt):
    key_array = np.loadtxt(key_txt)
    trans_arr = arr

    for translation in key_array:
        # translation is (src label, dst label)
        scr_l, dst_l = translation
        if scr_l != dst_l:
            trans_arr[trans_arr == scr_l] = dst_l

    # translated array
    return trans_arr


def get_datetime_list():
    os.environ['TZ'] = 'Asia/Shanghai'
    time.tzset()
    lt = time.localtime()
    return [lt.tm_year, lt.tm_mon, lt.tm_mday, lt.tm_hour, lt.tm_min, lt.tm_sec]


def get_datetime_str():
    os.environ['TZ'] = 'Asia/Shanghai'
    time.tzset()
    return time.strftime("%Y-%m-%d %A %H:%M:%S", time.localtime())


def get_datetime_str_simplified():
    os.environ['TZ'] = 'Asia/Shanghai'
    time.tzset()
    return time.strftime("%y-%m-%d-%a_%H-%M-%S", time.localtime())


def check_directories(args):
    logger = get_logger(__name__, args)

    # Ensure folders are there and no overwrite
    logger.info("Ensuring all folders are there...")
    assert Path(args.data_dir).is_dir(), (
        "DATA_DIR (%s) does not exist. Make sure path is correct." % args.data_dir
    )

    Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    assert Path(args.output_dir).is_dir(), (
        "OUTPUT_DIR (%s) does not exist. Make sure path is correct." % args.output_dir
    )

    Path(args.log_dir).mkdir(parents=True, exist_ok=True)
    assert Path(args.log_dir).is_dir(), (
        "LOG_DIR (%s) does not exist. Make sure path is correct." % args.log_dir
    )
    # output_path = (Path(args.output_dir) / Path(f'{get_datetime_str_simplified()}_{args.name}'))
    # output_path.mkdir(parents=True, exist_ok=False)
    # args.output_path = output_path


def load_nlcd_stats(
    stats_mu=config.LR_STATS_MU,
    stats_sigma=config.LR_STATS_SIGMA,
    class_weights=config.LR_CLASS_WEIGHTS,
    lr_classes=config.LR_NCLASSES,
    hr_classes=config.HR_NCLASSES,
):
    stats_mu = np.loadtxt(stats_mu)
    assert lr_classes == stats_mu.shape[0]
    assert hr_classes == (stats_mu.shape[1] + 1)
    nlcd_means = np.concatenate([np.zeros((lr_classes, 1)), stats_mu], axis=1)
    nlcd_means[nlcd_means == 0] = 0.000001
    nlcd_means[:, 0] = 0
    nlcd_means = do_nlcd_means_tuning(nlcd_means)

    stats_sigma = np.loadtxt(stats_sigma)
    assert lr_classes == stats_sigma.shape[0]
    assert hr_classes == (stats_sigma.shape[1] + 1)
    nlcd_vars = np.concatenate(
        [np.zeros((lr_classes, 1)), stats_sigma], axis=1)
    nlcd_vars[nlcd_vars < 0.0001] = 0.0001

    if not class_weights:
        nlcd_class_weights = np.ones((lr_classes,))
    else:
        nlcd_class_weights = np.loadtxt(class_weights)
        assert lr_classes == nlcd_class_weights.shape[0]

    return nlcd_class_weights, nlcd_means, nlcd_vars


def do_nlcd_means_tuning(nlcd_means):
    nlcd_means[2:, 1] -= 0
    nlcd_means[3:7, 4] += 0.25
    nlcd_means = nlcd_means / \
        np.maximum(0, nlcd_means).sum(axis=1, keepdims=True)
    nlcd_means[0, :] = 0
    nlcd_means[-1, :] = 0
    return nlcd_means


def load_model(model, path, multi_gpus=False):
    checkpoint = torch.load(path)["model_state_dict"]
    if multi_gpus:
        model.load_state_dict(checkpoint)
    else:
        model.load_state_dict({k.replace('module.', ''): v for k, v in checkpoint.items()})
    return model


class MyTransform:
    def __init__(self, size, random_crop):
        self.sample_size = size
        self.random_crop = random_crop
    def __call__(self, img, target):
        # C, H, W, tensor
        s = self.sample_size
        img = TF.to_tensor(img)
        target = torch.tensor(target)

        if self.random_crop:
            C, H, W = img.shape
            ph = s - H
            pw = s - W
            if ph > 0 or pw > 0:
                p = pw if pw > ph else ph
                img = TF.pad(img, padding=p)
                target = TF.pad(target, padding=p)

            C, H, W = img.shape
            rh = np.random.randint(0, H - s + 1) 
            rw = np.random.randint(0, W - s + 1) 
            img = img[:, rh:rh+s, rw:rw+s]
            target = target[rh:rh+s, rw:rw+s].unsqueeze(dim=0)
        
        # mean=[123.675, 116.28, 103.53]
        # std=[58.395, 57.12, 57.375]
        # img_n = TF.normalize(img, mean, std, inplace=False)

        target = target.map_(target, lambda x, y: id2trainid[x])
        # o_s = self.size
        # img = TF.resize(img, size=(o_s, o_s), interpolation=Image.NEAREST)
        # target = TF.resize(target, size=(o_s, o_s), interpolation=Image.NEAREST).squeeze()
        return img, target




class Cityscapes(VisionDataset):
    """`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.

    Args:
        root (string): Root directory of dataset where directory ``leftImg8bit``
            and ``gtFine`` or ``gtCoarse`` are located.
        split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="fine"
            otherwise ``train``, ``train_extra`` or ``val``
        mode (string, optional): The quality mode to use, ``fine`` or ``coarse``
        target_type (string or list, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
            or ``color``. Can also be a list to output a tuple with all specified target types.
        transform (callable, optional): A function/transform that takes in a PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        transforms (callable, optional): A function/transform that takes input sample and its target as entry
            and returns a transformed version.

    Examples:

        Get semantic segmentation target

        .. code-block:: python

            dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
                                 target_type='semantic')

            img, smnt = dataset[0]

        Get multiple targets

        .. code-block:: python

            dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
                                 target_type=['instance', 'color', 'polygon'])

            img, (inst, col, poly) = dataset[0]

        Validate on the "coarse" set

        .. code-block:: python

            dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse',
                                 target_type='semantic')

            img, smnt = dataset[0]
    """

    # Based on https://github.com/mcordts/cityscapesScripts
    CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id',
                                                     'has_instances', 'ignore_in_eval', 'color'])

    classes = [
        CityscapesClass('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)),
        CityscapesClass('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)),
        CityscapesClass('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)),
        CityscapesClass('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)),
        CityscapesClass('static', 4, 255, 'void', 0, False, True, (0, 0, 0)),
        CityscapesClass('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)),
        CityscapesClass('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)),
        CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)),
        CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)),
        CityscapesClass('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)),
        CityscapesClass('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)),
        CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)),
        CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)),
        CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)),
        CityscapesClass('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)),
        CityscapesClass('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)),
        CityscapesClass('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)),
        CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)),
        CityscapesClass('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)),
        CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)),
        CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)),
        CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)),
        CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)),
        CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)),
        CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)),
        CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)),
        CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)),
        CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)),
        CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)),
        CityscapesClass('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)),
        CityscapesClass('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)),
        CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)),
        CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)),
        CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)),
        CityscapesClass('license plate', -1, -1, 'vehicle', 7, False, True, (0, 0, 142)),
    ]

    def __init__(
            self,
            root: str,
            split: str = "train",
            mode: str = "fine",
            target_type: Union[List[str], str] = "instance",
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            transforms: Optional[Callable] = None,
    ) -> None:
        super(Cityscapes, self).__init__(root, transforms, transform, target_transform)
        self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse'
        self.images_dir = os.path.join(self.root, 'leftImg8bit', split)
        self.targets_dir = os.path.join(self.root, self.mode, split)
        self.target_type = target_type
        self.split = split
        self.images = []
        self.targets = []

        verify_str_arg(mode, "mode", ("fine", "coarse"))
        if mode == "fine":
            valid_modes = ("train", "test", "val")
        else:
            valid_modes = ("train", "train_extra", "val")
        msg = ("Unknown value '{}' for argument split if mode is '{}'. "
               "Valid values are {{{}}}.")
        msg = msg.format(split, mode, iterable_to_str(valid_modes))
        verify_str_arg(split, "split", valid_modes, msg)

        if not isinstance(target_type, list):
            self.target_type = [target_type]
        [verify_str_arg(value, "target_type",
                        ("instance", "semantic", "polygon", "color"))
         for value in self.target_type]

        if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):

            if split == 'train_extra':
                image_dir_zip = os.path.join(self.root, 'leftImg8bit{}'.format('_trainextra.zip'))
            else:
                image_dir_zip = os.path.join(self.root, 'leftImg8bit{}'.format('_trainvaltest.zip'))

            if self.mode == 'gtFine':
                target_dir_zip = os.path.join(self.root, '{}{}'.format(self.mode, '_trainvaltest.zip'))
            elif self.mode == 'gtCoarse':
                target_dir_zip = os.path.join(self.root, '{}{}'.format(self.mode, '.zip'))

            if os.path.isfile(image_dir_zip) and os.path.isfile(target_dir_zip):
                extract_archive(from_path=image_dir_zip, to_path=self.root)
                extract_archive(from_path=target_dir_zip, to_path=self.root)
            else:
                raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the'
                                   ' specified "split" and "mode" are inside the "root" directory')

        for city in os.listdir(self.images_dir):
            img_dir = os.path.join(self.images_dir, city)
            target_dir = os.path.join(self.targets_dir, city)
            for file_name in os.listdir(img_dir):
                target_types = []
                for t in self.target_type:
                    target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0],
                                                 self._get_target_suffix(self.mode, t))
                    target_types.append(os.path.join(target_dir, target_name))

                self.images.append(os.path.join(img_dir, file_name))
                self.targets.append(target_types)
    #     self._initdata()
    
    
    # def _initdata(self):
    #     self.actural_images = []
    #     self.actural_targets = []
    #     for index in range(len(self.images)):
    #         image = Image.open(self.images[index]).convert('RGB')
    #         targets: Any = []
    #         for i, t in enumerate(self.target_type):
    #             if t == 'polygon':
    #                 target = self._load_json(self.targets[index][i])
    #             else:
    #                 target = Image.open(self.targets[index][i])
    #             targets.append(target)
    #         target = tuple(targets) if len(targets) > 1 else targets[0]
    #         self.actural_images.append(image)
    #         self.actural_targets.append(np.array(target))

    # def __getitem__(self, index: int) -> Tuple[Any, Any]:
    #     """
    #     Args:
    #         index (int): Index
    #     Returns:
    #         tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
    #         than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
    #     """
    #     image,target = self.actural_images[index],self.actural_targets[index]
        
    #     if self.transforms is not None:
    #         image, target = self.transforms(image, target)
    #     return image, target
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
            than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
        """

        image = Image.open(self.images[index]).convert('RGB')

        targets: Any = []
        for i, t in enumerate(self.target_type):
            if t == 'polygon':
                target = self._load_json(self.targets[index][i])
            else:
                target = Image.open(self.targets[index][i])

            targets.append(target)

        target = tuple(targets) if len(targets) > 1 else targets[0]
        target = np.array(target)
        
        if self.transforms is not None:
            image, target = self.transforms(image, target)

        return image, target

    def __len__(self) -> int:
        return len(self.images)

    def extra_repr(self) -> str:
        lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"]
        return '\n'.join(lines).format(**self.__dict__)

    def _load_json(self, path: str) -> Dict[str, Any]:
        with open(path, 'r') as file:
            data = json.load(file)
        return data

    def _get_target_suffix(self, mode: str, target_type: str) -> str:
        if target_type == 'instance':
            return '{}_instanceIds.png'.format(mode)
        elif target_type == 'semantic':
            return '{}_labelIds.png'.format(mode)
        elif target_type == 'color':
            return '{}_color.png'.format(mode)
        else:
            return '{}_polygons.json'.format(mode)
        
        


def scatter(inputs, target_gpus, chunk_sizes, dim=0):
    r"""
    Slices tensors into approximately equal chunks and
    distributes them across given GPUs. Duplicates
    references to objects that are not tensors.
    """

    def scatter_map(obj):
        if isinstance(obj, torch.Tensor):
            try:
                return Scatter.apply(target_gpus, chunk_sizes, dim, obj)
            except Exception:
                print('obj', obj.size())
                print('dim', dim)
                print('chunk_sizes', chunk_sizes)
                quit()
        if isinstance(obj, tuple) and len(obj) > 0:
            return list(zip(*map(scatter_map, obj)))
        if isinstance(obj, list) and len(obj) > 0:
            return list(map(list, zip(*map(scatter_map, obj))))
        if isinstance(obj, dict) and len(obj) > 0:
            return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
        return [obj for targets in target_gpus]

    # After scatter_map is called, a scatter_map cell will exist. This cell
    # has a reference to the actual function scatter_map, which has references
    # to a closure that has a reference to the scatter_map cell (because the
    # fn is recursive). To avoid this reference cycle, we set the function to
    # None, clearing the cell
    try:
        return scatter_map(inputs)
    finally:
        scatter_map = None


def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0):
    """Scatter with support for kwargs dictionary"""
    inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else []
    kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else []
    if len(inputs) < len(kwargs):
        inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
    elif len(kwargs) < len(inputs):
        kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
    inputs = tuple(inputs)
    kwargs = tuple(kwargs)
    return inputs, kwargs


class BalancedDataParallel(DataParallel):

    def __init__(self, gpu0_bsz, *args, **kwargs):
        self.gpu0_bsz = gpu0_bsz
        super().__init__(*args, **kwargs)

    def forward(self, *inputs, **kwargs):
        if not self.device_ids:
            return self.module(*inputs, **kwargs)
        if self.gpu0_bsz == 0:
            device_ids = self.device_ids[1:]
        else:
            device_ids = self.device_ids
        inputs, kwargs = self.scatter(inputs, kwargs, device_ids)
        if len(self.device_ids) == 1:
            return self.module(*inputs[0], **kwargs[0])
        replicas = self.replicate(self.module, self.device_ids)
        if self.gpu0_bsz == 0:
            replicas = replicas[1:]
        outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs)
        return self.gather(outputs, self.output_device)

    def parallel_apply(self, replicas, device_ids, inputs, kwargs):
        return parallel_apply(replicas, inputs, kwargs, device_ids)

    def scatter(self, inputs, kwargs, device_ids):
        bsz = inputs[0].size(self.dim)
        num_dev = len(self.device_ids)
        gpu0_bsz = self.gpu0_bsz
        bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1)
        if gpu0_bsz < bsz_unit:
            chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1)
            delta = bsz - sum(chunk_sizes)
            for i in range(delta):
                chunk_sizes[i + 1] += 1
            if gpu0_bsz == 0:
                chunk_sizes = chunk_sizes[1:]
        else:
            return super().scatter(inputs, kwargs, device_ids)
        return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim)

