import os
import argparse
import logging
import time
from pathlib import Path
import PIL.Image as Image
import functools

import numpy as np
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import json
from collections import namedtuple
import zipfile
from typing import Any, Callable, Dict, List, Optional, Union, Tuple

from torchvision.datasets.utils import extract_archive, verify_str_arg, iterable_to_str
import training_config as config
from cityscapes_helper import id2trainid
        
import torch
from torch.nn.parallel.data_parallel import DataParallel
from torch.nn.parallel.parallel_apply import parallel_apply
from torch.nn.parallel._functions import Scatter


def do_args(jupyter=False):
    parser = argparse.ArgumentParser(
        description="Wrapper utility for training and testing land cover models.",
    )
    parser.add_argument(
        "-v",
        "--verbose",
        type=int,
        help="Verbosity of keras.fit",
        default=config.VERBOSE,
    )
    # parser.add_argument("--name", type=str, help="Experiment name", required=True)
    parser.add_argument("--name-prefix", type=str,
                        help="Experiment name prefix", default=config.NAME_PREFIX)
    parser.add_argument("--name-suffix", type=str,
                        help="Experiment name suffix", default=config.NAME_SUFFIX)
    parser.add_argument(
        "--data-dir",
        type=str,
        help="Path to data directory containing the splits CSV files",
        default=config.DATA_DIR,
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        help="Output base directory",
        default=config.OUTPUT_DIR,
    )
    parser.add_argument(
        "--log-dir",
        type=str,
        help="Log directory for Tensorboard and console records",
        default=config.LOG_DIR,
    )
    parser.add_argument(
        "--training-states",
        nargs="+",
        type=str,
        help="States to use as training",
        default=config.TRAINING_STATES,
    )
    parser.add_argument(
        "--validation-states",
        nargs="+",
        type=str,
        help="States to use as validation",
        default=config.VALIDATION_STATES,
    )
    parser.add_argument(
        "--superres-states",
        nargs="+",
        type=str,
        help="States to use only superres loss with",
        default=config.SUPERRES_STATES,
    )
    parser.add_argument(
        "--do-color",
        action="store_true",
        help="Enable color augmentation",
        default=config.DO_COLOR,
    )
    parser.add_argument(
        "--color-augmentation-intensity", type=float, help="Color augmentation intensity",
        default=config.COLOR_AUGMENTATION_INTENSITY
    )
    parser.add_argument(
        "--do-label-overloading",
        action="store_true",
        help="Enable label overloading",
        default=config.DO_LABEL_OVERLOADING,
    )
    parser.add_argument(
        "--net-model",
        type=str,
        default=config.NET_MODEL,
        help="Model architecture to use"
    )
    parser.add_argument(
        "--att-hidden", type=int, help="Dimension of attention hidden", default=config.ATT_HIDDEN
    )
    parser.add_argument(
        "--epochs", type=int, help="Number of epochs", default=config.EPOCHS
    )
    parser.add_argument(
        "--loss",
        type=str,
        help="Loss function",
        default=config.LOSS,
        choices=["crossentropy", "superres"],
    )
    parser.add_argument(
        "--learning-rate",
        type=float,
        help="Learning rate",
        default=config.LEARNING_RATE,
    )
    parser.add_argument(
        "--batch-size", type=int, help="Batch size", default=config.BATCH_SIZE
    )
    parser.add_argument(
        "--step-size", type=int, help="Learning rate scheduler step size", default=config.STEP_SIZE
    )
    parser.add_argument(
        "--shuffle", action="store_true", help="Dataloader shuffle", default=config.SHUFFLE
    )
    parser.add_argument(
        "--crossentropy-loss-weight", type=int, help="Weight of crossentropy loss if using superres",
        default=config.CROSSENTROPY_LOSS_WEIGHT
    )
    parser.add_argument(
        "--superres-loss-weight", type=int, help="Weight of superres loss if using superres",
        default=config.SUPERRES_LOSS_WEIGHT
    )
    parser.add_argument(
        "--mode", type=str, help="Used for remote console debug", default='Nothing'
    )
    parser.add_argument(
        "--input-size", type=int, help="Input image size in pixels", default=config.INPUT_SIZE
    )
    parser.add_argument(
        "--input-nchannels", type=int, help="Number of input image channels", default=config.INPUT_NCHANNELS
    )
    parser.add_argument(
        "--data-type", type=str, help="Data type of imagery", default=config.DATA_TYPE
    )
    parser.add_argument(
        "--hr-label-key", type=str, help="Keys for transformation of hr labels", default=config.HR_LABEL_KEY
    )
    parser.add_argument(
        "--hr-label-index", type=int, help="Positional index of hr labels in patches", default=config.HR_LABEL_INDEX
    )
    parser.add_argument(
        "--hr-nclasses", type=int, help="Number of target hr classes", default=config.HR_NCLASSES
    )
    parser.add_argument(
        "--lr-label-key", type=str, help="Keys for transformation of lr labels", default=config.LR_LABEL_KEY
    )
    parser.add_argument(
        "--lr-label-index", type=int, help="Positional index of lr labels in patches", default=config.LR_LABEL_INDEX
    )
    parser.add_argument(
        "--lr-nclasses", type=int, help="Number of target lr classes", default=config.LR_NCLASSES
    )
    parser.add_argument(
        "--training-steps-per-epoch", type=int, help="Number of training steps or batches per epoch",
        default=config.TRAINING_STEPS_PER_EPOCH
    )
    parser.add_argument(
        "--validation-steps-per-epoch", type=int, help="Number of validation steps or batches per epoch",
        default=config.VALIDATION_STEPS_PER_EPOCH
    )
    parser.add_argument(
        "--model-save-checkpoint", type=int, help="Save model at each checkpoint",
        default=config.MODEL_SAVE_CHECKPOINT
    )
    parser.add_argument(
        "--seed", type=int, help="Numpy random seed. Used to ensure we get the same validation batches",
        default=config.SEED
    )
    parser.add_argument(
        "--resume",
        action="store_true",
        help="Resume a previous training",
        default=config.RESUME,
    )
    parser.add_argument(
        "--resume-name", type=str, help="Experiment name to resume", default=config.RESUME_NAME
    )
    parser.add_argument(
        "--resume-model", type=str, help="Load a model to resume training", default=config.RESUME_MODEL
    )
    # parser.add_argument(
    #     "--freeze-att-until", type=int, help="Freeze the attention weights until reaching the indicated epoch",
    #     default=config.FREEZE_ATT_UNTIL
    # )
    parser.add_argument(
        "--optimizer", type=str, help="Optimizer", default=config.OPTIMIZER, choices=["sgd"]
    )
    parser.add_argument(
        "--hdf5",
        action="store_true",
        help="Use HDF5 datasets",
        default=config.HDF5
    )
    parser.add_argument(
        "--workers", type=int, help="Dataloader workers.",
        default=config.WORKERS
    )
    parser.add_argument(
        "--multiout-loss-type",
        type=str,
        help="Multiout loss type",
        default=config.MULTIOUT_LOSS_TYPE,
        choices=["decay", "same"],
    )
    parser.add_argument(
        "--att-mh", type=int, help="Number of attention multi-head.",
        default=config.ATT_MH
    )
    parser.add_argument(
        "--att-sm", type=float, help="Attention multiplier before softmax.",
        default=config.ATT_SM
    )
    parser.add_argument(
        "--att-ks", type=int, help="Attention kernel size.",
        default=config.ATT_KS
    )
    parser.add_argument(
        "--att-two-w",
        action="store_true",
        help="Use two different mlp for att q and k",
        default=config.ATT_TWO_W
    )
    parser.add_argument(
        "--output-stride", type=int, help="DeepLab output stride.",
        default=config.OUTPUT_STRIDE
    )
    parser.add_argument(
        "--cityscapes",
        action="store_true",
        help="Using Cityscapes dataset",
        default=config.CITYSCAPES
    )
    parser.add_argument(
        "--stem-os", type=int, help="Dynamic Routing stem output stride.",
        default=config.STEM_OS,
        choices=[1, 2, 4]
    )
    parser.add_argument(
        "--freeze-until", type=int, help="Freeze part of parameters until given epoch.",
        default=config.FREEZE_UNTIL,
    )
    parser.add_argument(
        "--nested-budget", type=float,
    )
    parser.add_argument(
        "--backbone-budget", type=float,
    )
    parser.add_argument(
        "--budget-loss-ratio", type=float,
    )
    parser.add_argument(
        "--city",
        action="store_true",
        help="Using city only dataset",
        default=config.CITY
    )
    parser.add_argument(
        "--isaid",
        action="store_true",
        help="Using iSAID dataset",
        default=config.ISAID
    )
    parser.add_argument(
        "--isaid-dir",
        type=str,
        help="iSAID dataset directory",
        default=config.ISAID_DIR
    )
    parser.add_argument("--potsdam",action="store_true",help="Using potsdam dataset",default=False)
    parser.add_argument("--vaihingen",action="store_true",help="Using vaihingen dataset",default=False)


    parser.add_argument('--local_rank', default=-1, type=int,
                    help='node rank for distributed training')

    parser.add_argument('--weight_decay', default=0.0005, type=float,
                    help='weight decay ')
                    
    parser.add_argument('--slowbackbone', default=0, type=int,
                    help='slow backbine or not ')
    
    parser.add_argument(
        "--stagerate",default=1,type=float,
    )

    parser.add_argument(
        "--quantile",default=0.3,type=float,
    )

    parser.add_argument(
        "--gamma",default=0.9,type=float,
    )

    parser.add_argument('--valid_all', default=0, type=int)
    args = parser.parse_args([]) if jupyter else parser.parse_args()
    # args.name = f"{args.name_prefix}" \
    #             f"{args.loss}_" \
    #             f'''{"CO_" if args.do_color else ""}''' \
    #             f'''{"LO_" if args.do_label_overloading else ""}''' \
    #             f"e{args.epochs}s{args.training_steps_per_epoch}b{args.batch_size}"\
    #             f"{args.name_suffix}"
    args.name = f'{get_datetime_str_simplified()}_' \
                f"{args.name_prefix}" \
                f"e{args.epochs}s{args.training_steps_per_epoch}b{args.batch_size}"\
                f"{args.name_suffix}"

    if args.resume:
        args.name = args.resume_name

    if args.cityscapes:
        args.hr_nclasses = 19
        args.input_nchannels = 3
    elif args.isaid:
        args.hr_nclasses = 16
        args.input_nchannels = 3
    elif args.potsdam or args.vaihingen:
        args.hr_nclasses = 6
        args.input_nchannels = 3

    return args


def run_initialize(args, name):
    # mkdir to save the log file and model checkpoint records
    output_path = Path(args.output_dir) / Path(args.name)
    output_path.mkdir(parents=True, exist_ok=True)
    args.output_path = output_path
    logger = get_logger(name, args)
    logger.info(f"Experiment {args.name} {'resumes' if args.resume else 'starts'} at {get_datetime_str()}")
    logger.info(
        f"Experiment results and logs are saved at {str(args.output_path)}")
    logger.info("Configuration")
    log_str = ""
    max_log_str_len = 127
    for key in vars(args):
        item_str = f"{key.upper()}: {vars(args)[key]}"
        if len(f"{log_str}  {item_str}") <= max_log_str_len:
            log_str = f"{log_str}    {item_str}" if len(
                log_str) > 0 else item_str
        else:
            logger.info(log_str)
            log_str = item_str
    logger.info(log_str)
    check_directories(args)

    # save config file
    with open(str(output_path / Path("config.txt")), 'a') as f:
        for key in vars(args):
            item_str = f"{key.upper()}: {vars(args)[key]}\n"
            f.writelines(item_str)


def get_logger(name, args, save_dir=None, level=logging.DEBUG):
    LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
    logger = logging.getLogger(name)
    logger.setLevel(level)
    formatter = logging.Formatter(LOG_FORMAT)

    # create console handler and set level to debug
    ch = logging.StreamHandler()
    ch.setLevel(level)
    ch.setFormatter(formatter)
    logger.addHandler(ch)

    # create file handler and set level to debug
    if save_dir is None:
        logfile = str(args.output_path / Path("log.txt"))
    else:
        logfile = str(Path(save_dir) / Path("log.txt"))
    fh = logging.FileHandler(logfile, mode='a')
    fh.setLevel(level)
    fh.setFormatter(formatter)

    logger.addHandler(fh)

    fh.close()
    ch.close()
    
    return logger

class Logger_Lin():
    def __init__(self,name, args, save_dir=None):
        self.name = name
        
        if save_dir is None:
            logfile = str(args.output_path / Path("log.txt"))
        else:
            logfile = str(Path(save_dir) / Path("log.txt"))
        print(logfile)
        self.f = open(logfile,'a')
        
    def info(self,string):
        cur_t = time.strftime("%y-%m-%d-%a_%H:%M:%S", time.localtime())
        s = f"{cur_t}    {self.name}    {string}"
        print(s)
        self.f.write(s+'\n')
        

def save_epoch_summary(args, summary_str):
    logfile = str(args.output_path / Path("epoch_summary.txt"))
    with open(logfile, 'a') as f:
        f.writelines(summary_str)


def to_float(arr, data_type=config.DATA_TYPE):
    if data_type == "int8":
        res = np.clip(arr / 255.0, 0.0, 1.0)
    elif data_type == "int16":
        res = np.clip(arr / 4096.0, 0.0, 1.0)
    else:
        raise ValueError("Select an appropriate data type.")
    return res


def handle_labels(arr, key_txt):
    key_array = np.loadtxt(key_txt)
    trans_arr = arr

    for translation in key_array:
        # translation is (src label, dst label)
        scr_l, dst_l = translation
        if scr_l != dst_l:
            trans_arr[trans_arr == scr_l] = dst_l

    # translated array
    return trans_arr


def get_datetime_list():
    os.environ['TZ'] = 'Asia/Shanghai'
    time.tzset()
    lt = time.localtime()
    return [lt.tm_year, lt.tm_mon, lt.tm_mday, lt.tm_hour, lt.tm_min, lt.tm_sec]


#fix evaluate bug: decrease the bug probability when in DDP mode
def get_datetime_str():
    os.environ['TZ'] = 'Asia/Shanghai'
    time.tzset()
    return time.strftime("%Y-%m-%d %A %H:%M", time.localtime())

#fix evaluate bug: decrease the bug probability when in DDP mode
def get_datetime_str_simplified():
    os.environ['TZ'] = 'Asia/Shanghai'
    time.tzset()
    return time.strftime("%y-%m-%d-%a_%H-%M", time.localtime())


def check_directories(args):
    logger = get_logger(__name__, args)

    # Ensure folders are there and no overwrite
    logger.info("Ensuring all folders are there...")
    assert Path(args.data_dir).is_dir(), (
        "DATA_DIR (%s) does not exist. Make sure path is correct." % args.data_dir
    )

    Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    assert Path(args.output_dir).is_dir(), (
        "OUTPUT_DIR (%s) does not exist. Make sure path is correct." % args.output_dir
    )

    Path(args.log_dir).mkdir(parents=True, exist_ok=True)
    assert Path(args.log_dir).is_dir(), (
        "LOG_DIR (%s) does not exist. Make sure path is correct." % args.log_dir
    )
    # output_path = (Path(args.output_dir) / Path(f'{get_datetime_str_simplified()}_{args.name}'))
    # output_path.mkdir(parents=True, exist_ok=False)
    # args.output_path = output_path


def load_nlcd_stats(
    stats_mu=config.LR_STATS_MU,
    stats_sigma=config.LR_STATS_SIGMA,
    class_weights=config.LR_CLASS_WEIGHTS,
    lr_classes=config.LR_NCLASSES,
    hr_classes=config.HR_NCLASSES,
):
    stats_mu = np.loadtxt(stats_mu)
    assert lr_classes == stats_mu.shape[0]
    assert hr_classes == (stats_mu.shape[1] + 1)
    nlcd_means = np.concatenate([np.zeros((lr_classes, 1)), stats_mu], axis=1)
    nlcd_means[nlcd_means == 0] = 0.000001
    nlcd_means[:, 0] = 0
    nlcd_means = do_nlcd_means_tuning(nlcd_means)

    stats_sigma = np.loadtxt(stats_sigma)
    assert lr_classes == stats_sigma.shape[0]
    assert hr_classes == (stats_sigma.shape[1] + 1)
    nlcd_vars = np.concatenate(
        [np.zeros((lr_classes, 1)), stats_sigma], axis=1)
    nlcd_vars[nlcd_vars < 0.0001] = 0.0001

    if not class_weights:
        nlcd_class_weights = np.ones((lr_classes,))
    else:
        nlcd_class_weights = np.loadtxt(class_weights)
        assert lr_classes == nlcd_class_weights.shape[0]

    return nlcd_class_weights, nlcd_means, nlcd_vars


def do_nlcd_means_tuning(nlcd_means):
    nlcd_means[2:, 1] -= 0
    nlcd_means[3:7, 4] += 0.25
    nlcd_means = nlcd_means / \
        np.maximum(0, nlcd_means).sum(axis=1, keepdims=True)
    nlcd_means[0, :] = 0
    nlcd_means[-1, :] = 0
    return nlcd_means


def load_model(model, path, multi_gpus=False):
    checkpoint = torch.load(path)["model_state_dict"]
    if multi_gpus:
        model.load_state_dict(checkpoint)
    else:
        model.load_state_dict({k.replace('module.', ''): v for k, v in checkpoint.items()})
    return model


#please use ext_transform to build your transform.
class MyTransform:
    def __init__(self,):
        pass






def scatter(inputs, target_gpus, chunk_sizes, dim=0):
    r"""
    Slices tensors into approximately equal chunks and
    distributes them across given GPUs. Duplicates
    references to objects that are not tensors.
    """

    def scatter_map(obj):
        if isinstance(obj, torch.Tensor):
            try:
                return Scatter.apply(target_gpus, chunk_sizes, dim, obj)
            except Exception:
                print('obj', obj.size())
                print('dim', dim)
                print('chunk_sizes', chunk_sizes)
                quit()
        if isinstance(obj, tuple) and len(obj) > 0:
            return list(zip(*map(scatter_map, obj)))
        if isinstance(obj, list) and len(obj) > 0:
            return list(map(list, zip(*map(scatter_map, obj))))
        if isinstance(obj, dict) and len(obj) > 0:
            return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
        return [obj for targets in target_gpus]

    # After scatter_map is called, a scatter_map cell will exist. This cell
    # has a reference to the actual function scatter_map, which has references
    # to a closure that has a reference to the scatter_map cell (because the
    # fn is recursive). To avoid this reference cycle, we set the function to
    # None, clearing the cell
    try:
        return scatter_map(inputs)
    finally:
        scatter_map = None


def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0):
    """Scatter with support for kwargs dictionary"""
    inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else []
    kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else []
    if len(inputs) < len(kwargs):
        inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
    elif len(kwargs) < len(inputs):
        kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
    inputs = tuple(inputs)
    kwargs = tuple(kwargs)
    return inputs, kwargs

class CallbackContext(object):
    pass

def execute_replication_callbacks(modules):
    """
    Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
    The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
    Note that, as all modules are isomorphism, we assign each sub-module with a context
    (shared among multiple copies of this module on different devices).
    Through this context, different copies can share some information.
    We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
    of any slave copies.
    """
    master_copy = modules[0]
    nr_modules = len(list(master_copy.modules()))
    ctxs = [CallbackContext() for _ in range(nr_modules)]

    for i, module in enumerate(modules):
        for j, m in enumerate(module.modules()):
            if hasattr(m, '__data_parallel_replicate__'):
                m.__data_parallel_replicate__(ctxs[j], i)

class Sync_DataParallel(DataParallel):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def replicate(self, module, device_ids):
        modules = super(Sync_DataParallel, self).replicate(module, device_ids)
        # !!!Important!!!  If you want to employ SYNC BN, this code is very very important!
        execute_replication_callbacks(modules)
        return modules

class BalancedDataParallel(DataParallel):

    def __init__(self, gpu0_bsz, *args, **kwargs):
        self.gpu0_bsz = gpu0_bsz
        super().__init__(*args, **kwargs)

    def forward(self, *inputs, **kwargs):
        if not self.device_ids:
            return self.module(*inputs, **kwargs)
        if self.gpu0_bsz == 0:
            device_ids = self.device_ids[1:]
        else:
            device_ids = self.device_ids
        inputs, kwargs = self.scatter(inputs, kwargs, device_ids)
        if len(self.device_ids) == 1:
            return self.module(*inputs[0], **kwargs[0])
        replicas = self.replicate(self.module, self.device_ids)
        if self.gpu0_bsz == 0:
            replicas = replicas[1:]
        ###LIN  if batch size < len(GPU node)
        if len(replicas)>len(inputs):
            replicas = replicas[:len(inputs)]
            device_ids = device_ids[:len(inputs)]
        outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs)
        return self.gather(outputs, self.output_device)

    def parallel_apply(self, replicas, device_ids, inputs, kwargs):
        return parallel_apply(replicas, inputs, kwargs, device_ids)

    def scatter(self, inputs, kwargs, device_ids):
        bsz = inputs[0].size(self.dim)
        num_dev = len(self.device_ids)
        gpu0_bsz = self.gpu0_bsz
        bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1)
        #### data design
        if gpu0_bsz < bsz_unit:
            chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1)
            delta = bsz - sum(chunk_sizes)
            for i in range(delta):
                chunk_sizes[i + 1] += 1
            if gpu0_bsz == 0:
                chunk_sizes = chunk_sizes[1:]
        else:
            return super().scatter(inputs, kwargs, device_ids)
        return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim)

    # def replicate(self, module, device_ids):
    #     modules = super(BalancedDataParallel, self).replicate(module, device_ids)
    #     # !!!Important!!!  If you want to employ SYNC BN, this code is very very important!
    #     execute_replication_callbacks(modules)
    #     return modules

class BalancedDataParallel_gpu0gpu1(DataParallel):

    def __init__(self, gpu01_bsz, *args, **kwargs):
        self.gpu01_bsz = gpu01_bsz
        
        super().__init__(*args, **kwargs)

    def forward(self, *inputs, **kwargs):
        if not self.device_ids:
            return self.module(*inputs, **kwargs)
        if self.gpu01_bsz == 0:
            device_ids = self.device_ids[2:]
        else:
            device_ids = self.device_ids
        inputs, kwargs = self.scatter(inputs, kwargs, device_ids)
        if len(self.device_ids) == 1:
            return self.module(*inputs[0], **kwargs[0])

        replicas = self.replicate(self.module, self.device_ids) #
   
        if self.gpu01_bsz == 0:
            replicas = replicas[2:]
#         ###LIN  if batch size < len(GPU node)
        if len(replicas)>len(inputs):
            replicas = replicas[:len(inputs)]
            device_ids = device_ids[:len(inputs)]

        outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs)
#         for i in inputs:
#             print(i[0].shape)
#         for i in outputs:
#             print(i.shape)
#         print(len(replicas),device_ids,len(inputs),len(outputs),self.output_device)
        return self.gather(outputs, self.output_device)

    def parallel_apply(self, replicas, device_ids, inputs, kwargs):
        return parallel_apply(replicas, inputs, kwargs, device_ids)

    def scatter(self, inputs, kwargs, device_ids):
        bsz = inputs[0].size(self.dim)
        num_dev = len(self.device_ids)
        gpu01_bsz = self.gpu01_bsz
        bsz_unit = (bsz - gpu01_bsz) // (num_dev - 1)
        #### data design
        if gpu01_bsz < bsz_unit:
            chunk_sizes = [gpu01_bsz]*2 + [bsz_unit] * (num_dev - 2)
            delta = bsz - sum(chunk_sizes)
            for i in range(delta):
                chunk_sizes[i + 2] += 1
            if gpu01_bsz == 0:
                chunk_sizes = chunk_sizes[2:]
        else:
            return super().scatter(inputs, kwargs, device_ids)
        return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim)


    def replicate(self, module, device_ids):
        modules = super(BalancedDataParallel_gpu0gpu1, self).replicate(module, device_ids)
        execute_replication_callbacks(modules)
        return modules