import os
import math
import inspect
import torch
import importlib
import amp_C
from apex.multi_tensor_apply import multi_tensor_applier

import torch.distributed.distributed_c10d as c10d

# Fallback to private fields if using older PyTorch version
try:
    import torch.distributed.distributed_c10d.get_process_group_ranks
except ImportError:
    def get_process_group_ranks(group):
        return list(c10d._pg_group_ranks[group].keys())

_make_nccl_premul_sum = getattr(torch.distributed, "_make_nccl_premul_sum", None)
# Ref: https://github.com/pytorch/pytorch/pull/81272
if _make_nccl_premul_sum is None:
    if hasattr(torch.distributed, "make_nccl_premul_sum"):
        _make_nccl_premul_sum = torch.distributed.make_nccl_premul_sum

class DistributedFusedLAMB(torch.optim.Optimizer):

    """Implements LAMB algorithm.

    Currently GPU-only.  Requires Apex to be installed via
    ``pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./``.

    This version of fused LAMB implements 2 fusions.

      * Fusion of the LAMB update's elementwise operations
      * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.

    :class:`apex.optimizers.FusedLAMB`'s usage is identical to any ordinary Pytorch optimizer::

        opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)
        ...
        opt.step()

    :class:`apex.optimizers.FusedLAMB` may be used with or without Amp.  If you wish to use :class:`FusedLAMB` with Amp,
    you may choose any ``opt_level``::

        opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)
        model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2")
        ...
        opt.step()

    In general, ``opt_level="O1"`` is recommended.

    LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.

    Arguments:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups.
        lr (float, optional): learning rate. (default: 1e-3)
        betas (Tuple[float, float], optional): coefficients used for computing
            running averages of gradient and its norm. (default: (0.9, 0.999))
        eps (float, optional): term added to the denominator to improve
            numerical stability. (default: 1e-8)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
        amsgrad (boolean, optional): whether to use the AMSGrad variant of this
            algorithm from the paper `On the Convergence of Adam and Beyond`_
            NOT SUPPORTED now! (default: False)
        adam_w_mode (boolean, optional): Apply L2 regularization or weight decay
            True for decoupled weight decay(also known as AdamW) (default: True)
        grad_averaging (bool, optional): whether apply (1-beta2) to grad when
            calculating running averages of gradient. (default: True)
        set_grad_none (bool, optional): whether set grad to None when zero_grad()
            method is called. (default: True)
        max_grad_norm (float, optional): value used to clip global grad norm
            (default: 1.0)
        use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0
            weight decay parameter (default: False)
        step_supports_amp_scaling(boolean, optional): whether to use customized
            gradient unscaling logic (default: True)

    .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
        https://arxiv.org/abs/1904.00962
    .. _On the Convergence of Adam and Beyond:
        https://openreview.net/forum?id=ryQu7f-RZ
    """

    class AtomicCounter(object):
        def __init__(self):
            self.value = 0
            self.order = []
            import threading
            self._lock = threading.Lock()

        def add(self, idx):
            with self._lock:
                self.value += 1
                self.order.append(idx)

    def __init__(self, params,
                 lr=1e-3, bias_correction = True, grad_averaging=True,
                 betas=(0.9, 0.999), eps=1e-8,
                 weight_decay=0., max_grad_norm=0.,
                 adam_w_mode=True, use_nvlamb=False,
                 step_supports_amp_scaling=True, overlap_reductions=True,
                 dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
                 dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0, fused_norm=False,
                 e5m2_allgather=False, verbose=False, clip_after_ar=True,
                 full_ar=False, set_param_views_to_flat_buffer=False, skip_allgather=False,
                 fuse_scale=False, param_order=None, nccl_allgather_channels=0):
        defaults = dict(lr=lr, bias_correction=bias_correction,
                        betas=betas, eps=eps, weight_decay=weight_decay,
                        grad_averaging=grad_averaging,
                        max_grad_norm=max_grad_norm)

        super(DistributedFusedLAMB, self).__init__(params, defaults)

        global fused_adam_cuda, distributed_lamb_cuda
        fused_adam_cuda = importlib.import_module("fused_adam_cuda")
        distributed_lamb_cuda = importlib.import_module("distributed_lamb_cuda")

        self._overflow_buf = torch.cuda.IntTensor([0])
        self._has_overflow = False
        self.multi_tensor_lamb_compute_update_term = distributed_lamb_cuda.multi_tensor_lamb_compute_update_term
        self.multi_tensor_lamb_update_weights = distributed_lamb_cuda.multi_tensor_lamb_update_weights
        import amp_C
        self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm

        self._grad_averaging = grad_averaging
        self._adam_w_mode = 1 if adam_w_mode else 0
        self._use_nvlamb = use_nvlamb
        self._step_supports_amp_scaling = step_supports_amp_scaling
        self._is_accumulation_step = False
        self._last_step = False
        self._overlap_reductions = overlap_reductions
        self._global_scale = None
        self._num_blocks = dwu_num_blocks
        self._num_chunks = dwu_num_chunks
        self._e5m2_allgather = e5m2_allgather
        self._verbose = verbose
        self._clip_after_ar = clip_after_ar
        self._full_ar = full_ar
        self._fuse_scale = fuse_scale
        self._L2_grad_norm = None
        self._set_flat_param_view = set_param_views_to_flat_buffer
        self._skip_ag = skip_allgather
        self._fused_norm = fused_norm if not clip_after_ar else False
        self._current_process_group = c10d._get_default_group()
        self._available_ranks = get_process_group_ranks(self._current_process_group)
        self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size
        self._world_size = torch.distributed.get_world_size()
        self._num_groups = self._world_size // self._group_size
        self._rank_in_group = torch.distributed.get_rank() % self._group_size

        self._lr = torch.tensor(0.0, dtype=torch.float32, device='cuda')

        self._resume_from_checkpoint = False
        self._step = torch.cuda.IntTensor([0])

        # Master weight, moment, gradient buffers
        self._fp32_p, self._fp32_m, self._fp32_v, self._fp16_p, self._fp16_g = None, None, None, None, None

        # Check if collectives have no_copy option
        self._reduce_scatter_no_copy = (
            'no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args
        )
        self._all_gather_no_copy = (
            'no_copy' in inspect.getfullargspec(torch.distributed.all_gather).args
        )

        if "reduce_scatter_tensor" not in dir(torch.distributed):
            torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
        if "all_gather_into_tensor" not in dir(torch.distributed):
            torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base

        self._num_rs_pg = dwu_num_rs_pg
        self._num_ar_pg = dwu_num_ar_pg
        self._num_ag_pg = dwu_num_ag_pg

        if self._full_ar: # full all reduce, only need AR and AG groups
            # l2_grad_norm may be reduced within a node to limit from memory reads
            for group_i in range(self._num_groups):
                ranks = [group_i*self._group_size+j for j in range(self._group_size)]
                l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
                if torch.distributed.get_rank() in ranks:
                    self._l2_grad_norm_pg = l2_grad_norm_pg

            self._ar_pg = []
            # consider all the ranks
            ranks = list(range(0, self._world_size))
            for i in range(self._num_ar_pg):
                if self._verbose:
                    print(f"creating new AR group {i}: {ranks}")
                grp = torch.distributed.new_group(ranks=ranks)
                if grp != torch.distributed.GroupMember.NON_GROUP_MEMBER:
                    if self._verbose:
                        print(f"group {i}: init barrier (device: {torch.cuda.current_device()})")
                    torch.distributed.barrier(group=grp, device_ids=[torch.cuda.current_device()])
                if self._verbose:
                    print(f"created new AR group {i}: {ranks}")

                if torch.distributed.get_rank() in ranks:
                    self._ar_pg.append(grp)
            self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]
            if nccl_allgather_channels > 0:
                os.putenv('NCCL_MAX_NCHANNELS', str(nccl_allgather_channels))
            if self._num_ag_pg == 0:
                self._ag_pg = self._ar_pg
                self._ag_st = self._ar_st
                self._num_ag_pg = self._num_ar_pg
            else:
                self._ag_pg = []
                ranks = []
                stride = torch.cuda.device_count()
                for i in range(self._num_groups):
                    rs = list(range(i*stride, (i+1)*stride))
                    ranks.append(rs)
                for rs in ranks:
                    for i in range(self._num_ag_pg):
                        grp = torch.distributed.new_group(ranks=rs)
                        if torch.distributed.get_rank() in rs:
                            if self._verbose:
                                print(f"creating AG group {i}: {rs}")
                            self._ag_pg.append(grp)

                self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]
        else: # reduce-scatter + all-reduce, need RS, AR, AG groups
            if self._num_groups > 1:
                self._ar_pg = []
                for dev_i in range(self._group_size):
                    ranks = [dev_i+j*self._group_size for j in range(self._num_groups)]
                    for i in range(self._num_ar_pg):
                        if self._verbose:
                            print(f"creating new AR group {i}: {ranks}")
                        grp = torch.distributed.new_group(ranks=ranks)
                        if grp != torch.distributed.GroupMember.NON_GROUP_MEMBER:
                            if self._verbose:
                                print(f"group {i}: init barrier (device: {torch.cuda.current_device()})")
                            torch.distributed.barrier(group=grp, device_ids=[torch.cuda.current_device()])
                        if self._verbose:
                            print(f"created new AR group {i}: {ranks}")

                        if torch.distributed.get_rank() in ranks:
                            self._ar_pg.append(grp)
                self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]
            rs_ranks = []
            for group_i in range(self._num_groups):
                rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)])
            self._rs_pg = []
            for group_i in range(self._num_groups):
                ranks = rs_ranks[group_i]
                for i in range(self._num_rs_pg):
                    grp = torch.distributed.new_group(ranks=ranks)
                    if torch.distributed.get_rank() in ranks:
                        self._rs_pg.append(grp)
                        if self._verbose:
                            print(f"creating RS group : {ranks}")
                l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
                if torch.distributed.get_rank() in ranks:
                    self._l2_grad_norm_pg = l2_grad_norm_pg
            self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)]
            if self._num_ag_pg == 0:
                self._ag_pg = self._rs_pg
                self._ag_st = self._rs_st
                self._num_ag_pg = self._num_rs_pg
            else:
                self._ag_pg = []
                for group_i in range(self._num_groups):
                    ranks = rs_ranks[group_i]
                    for i in range(self._num_ag_pg):
                        grp = torch.distributed.new_group(ranks=ranks)
                        if torch.distributed.get_rank() in ranks:
                            self._ag_pg.append(grp)
                            if self._verbose:
                                print(f"creating AG group : {ranks}")
                self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]
        for ag_pg in self._ag_pg:
            torch.distributed.barrier(group=ag_pg)

        self._l2_grad_norm_st = torch.cuda.Stream()
        self._completion_st = torch.cuda.Stream()
        self._step.record_stream(self._completion_st)

        self._reductions_works = [None]*self._num_blocks
        self._allgather_works = [None]*self._num_blocks

        self._one = torch.cuda.IntTensor([1])

        self._first_step = True
        self._lazy_init_stage1_done, self._lazy_init_stage2_done = False, False
        self._param_order = self.AtomicCounter()

        p_offset = 0
        p_i = 0
        self._model_params = []
        self._grad_accs = []
        self._group_properties = []
        for group in self.param_groups:
            prev = None
            beta1, beta2 = group['betas']
            beta3 = 1.0 - beta1 if self._grad_averaging else 1.0
            bias_correction = 1 if group['bias_correction'] else 0
            eps = group['eps']
            weight_decay = group['weight_decay']
            for p in group['params']:
                if not p.requires_grad:
                    continue
                self._model_params.append(p)
                self._group_properties.append((
                    weight_decay,
                    bias_correction,
                    beta1,
                    beta2,
                    beta3,
                    eps
                    ))
                p_grads_size = p.numel()
                if self._set_flat_param_view:
                    if param_order:
                        # this is executed when param_order is specified by the user
                        self._param_order.add(param_order[p])
                    else:
                        self._param_order.add(p_i)
                p_offset += p_grads_size
                # Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
                # RNN is one example of consecutive parameters:
                # (weight_ih, weight_hh, bias_ih, bias_hh)
                if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):
                    p_offset = ((p_offset + 63) // 64) * 64
                prev = p
                p_i += 1
        if param_order:
            self._param_order.order = torch.argsort(torch.tensor(self._param_order.order)).tolist()
        self._grads_generated = [False]*len(self._model_params)
        self._grads_fp16, self._grads_fp32 = [], []
        if self._overlap_reductions:
            self._current_block = self._num_blocks

        self._net_total_param_size = p_offset
        self._total_param_size = p_offset
        dwu_min_page_size = 256 * self._num_blocks * self._num_chunks * self._group_size
        self._total_param_size = ((self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size) * dwu_min_page_size
        self._new_params = torch.zeros([self._total_param_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')



    def _lazy_init_stage1(self):
        if self._lazy_init_stage1_done: return

        p_i = 0
        #self._model_params = []
        #self._grad_accs = []
        #self._group_properties = []
        for group in self.param_groups:
            for p in group['params']:
                torch.distributed.broadcast(p, 0)
                if not p.requires_grad:
                    continue
                def wrapper(param, param_i):
                    param_tmp = param.expand_as(param)
                    grad_acc = param_tmp.grad_fn.next_functions[0][0]
                    def allreduce_hook(*unused):
                        if not self._set_flat_param_view:
                            if self._first_step:
                                # first time
                                self._param_order.add(param_i)
                            else:
                                idx = self._param_order.order.index(param_i)
                                self._do_overlapped_reduction(idx, param)
                        else:
                            if not self._first_step:
                                idx = self._param_order.order.index(param_i)
                                self._do_overlapped_reduction(idx, param)
                    grad_acc.register_hook(allreduce_hook)
                    self._grad_accs.append(grad_acc)
                wrapper(p, p_i)
                p_i += 1

        self._block_size = self._total_param_size // self._num_blocks
        self._chunk_size = self._block_size // self._num_chunks
        self._shard_size = self._chunk_size // self._group_size

        self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda')
        self._mega_shard_size = self._num_blocks * self._num_chunks * self._shard_size
        # initialize master weights, moments buffers if not loaded from checkpoint
        if self._fp32_p is None:
            self._fp32_p = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
            self._fp32_m = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
            self._fp32_v = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
            self._fp32_u = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
        # FIXME: Rethink fp16 label since it's either uint8 or fp16
        self._fp16_p = torch.zeros([self._mega_shard_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
        self._fp16_g = torch.zeros([self._mega_shard_size], dtype=torch.float16, device='cuda')

        def _flat_split(p):
            def __blockify(p):
                return [p[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)]
            def __chunkify(p):
                return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]
            def __shardify(p):
                return [p[shard_id*self._shard_size:(shard_id+1)*self._shard_size] for shard_id in range(self._group_size)]
            list_of_blocks = __blockify(p)
            list_of_list_of_chunks = [__chunkify(block) for block in list_of_blocks]
            list_of_list_of_list_of_shards = [[__shardify(chunk) for chunk in chunks] for chunks in list_of_list_of_chunks]
            return list_of_blocks, list_of_list_of_chunks, list_of_list_of_list_of_shards

        # note(crcrpar): the function below doesn't seem to be used at all.
        # def _flat_split_no_shards(p):
        #     def __blockify(p):
        #         return [p[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)]
        #     def __chunkify(p):
        #         return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]
        #     list_of_blocks = __blockify(self._flat_grads)
        #     list_of_list_of_chunks = [__chunkify(block) for block in list_of_blocks]
        #     return list_of_blocks, list_of_list_of_chunks

        def _full_packed_split(p):
            def __shardify(p):
                return [p[mega_shard*self._mega_shard_size:(mega_shard+1)*self._mega_shard_size] for mega_shard in range(self._group_size)]
            def __blockify(p):
                return [p[block_id*self._num_chunks*self._shard_size:(block_id+1)*self._num_chunks*self._shard_size] for block_id in range(self._num_blocks)]
            def __chunkify(p):
                return [p[chunk_id*self._shard_size:(chunk_id+1)*self._shard_size] for chunk_id in range(self._num_chunks)]
            list_of_mega_shards = __shardify(p)
            list_of_list_of_mega_blocks = [__blockify(mega_shard) for mega_shard in list_of_mega_shards]
            list_of_list_of_list_of_mega_chunks = [[__chunkify(mega_block) for mega_block in mega_blocks] for mega_blocks in list_of_list_of_mega_blocks]
            return list_of_mega_shards, list_of_list_of_mega_blocks, list_of_list_of_list_of_mega_chunks
        def _packed_split(p):
            def __packed_blockify(p):
                packed_block_size = self._num_chunks*self._shard_size
                return [p[block_id*packed_block_size:(block_id+1)*packed_block_size] for block_id in range(self._num_blocks)]
            def __packed_chunkify(p):
                # in the packed format, each chunk contains one shard, so packed_chunk_size == self._shard_size
                return [p[chunk_id*self._shard_size:(chunk_id+1)*self._shard_size] for chunk_id in range(self._num_chunks)]
            list_of_blocks = __packed_blockify(p)
            list_of_list_of_chunks = [__packed_chunkify(block) for block in list_of_blocks]
            return list_of_blocks, list_of_list_of_chunks
        def _split_assign(shards):
            packed_block_size = self._num_chunks*self._shard_size
            list_of_list_of_chunks=[]
            for block_id in range(self._num_blocks):
                list_of_chunks=[]
                for chunk_id in range(self._num_chunks):
                    #self._fp16_g[block_id*packed_block_size+chunk_id*self._shard_size:block_id*packed_block_size+(chunk_id+1)*self._shard_size] = shards[block_id][chunk_id][self._rank_in_group]
                    list_of_chunks.append( shards[block_id][chunk_id][self._rank_in_group])
                list_of_list_of_chunks.append(list_of_chunks)
            return list_of_list_of_chunks

        self._new_params_mega_shards, self._new_params_mega_blocks, self._new_params_mega_chunks = _full_packed_split(self._new_params)
        # this splitting scheme is needed when allgather needs to be split into multiple chunks in a contiguous way
        self._new_params2_blocks, self._new_params2_chunks, self._new_params2_shards = _flat_split(self._new_params)

        self._fp32_p_blocks, self._fp32_p_chunks = _packed_split(self._fp32_p)
        self._fp32_m_blocks, self._fp32_m_chunks = _packed_split(self._fp32_m)
        self._fp32_v_blocks, self._fp32_v_chunks = _packed_split(self._fp32_v)
        self._fp32_u_blocks, self._fp32_u_chunks = _packed_split(self._fp32_u)
        self._fp16_p_blocks, self._fp16_p_chunks = _packed_split(self._fp16_p)

        if self._full_ar:
            # for gradient all-reduce
            self._flat_grads_blocks, self._flat_grads_chunks, self._flat_grads_shards = _flat_split(self._flat_grads)
            # for weight update
            self._fp16_g_chunks = _split_assign(self._flat_grads_shards)
        else:
            self._flat_grads_blocks, self._flat_grads_chunks, self._flat_grads_shards = _flat_split(self._flat_grads)
            self._fp16_g_blocks, self._fp16_g_chunks = _packed_split(self._fp16_g)

        self._lazy_init_stage1_done = True

    def _lazy_init_stage2(self):
        if self._lazy_init_stage2_done: return
        if not self._set_flat_param_view:
            # reversing is needed for overlapping allreduce and backprop, but currently not supported for flat param view
            self._param_order.order.reverse()

            # re-order model_params, grad_accs, group_properties lists
        self._model_params = [self._model_params[i] for i in self._param_order.order]
        self._grad_accs = [self._grad_accs[i] for i in self._param_order.order]
        self._group_properties = [self._group_properties[i] for i in self._param_order.order]

        def _get_flat_view(param):
            if param.is_contiguous(memory_format=torch.channels_last):
                K, C, H, W = param.shape
                pv = param.as_strided(size=(K,H,W,C), stride=(H*W*C, W*C, C, 1))
            elif param.is_contiguous(memory_format=torch.channels_last_3d):
                K, C, D, H, W = param.shape
                pv = param.as_strided(size=(K,D,H,W,C), stride=(D*H*W*C, H*W*C, W*C, C, 1))
            else:
                pv = param
            return pv.view(-1)

        # re-collect grads info (size, offset) after ordering
        prev = None
        p_offset = 0
        self._grads_info = []
        self._individual_flat_grads = []
        for i, p in enumerate(self._model_params):
            p_grads_size = p.numel()
            self._grads_info.append({"param_grads_size":p_grads_size, "param_offset":p_offset})
            self._individual_flat_grads.append(self._flat_grads[p_offset:p_offset+p_grads_size].view_as(p))
            # for the first iteration
            self._do_overlapped_reduction(i, p)
            p_offset += p_grads_size
            # Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
            # RNN is one example of consecutive parameters:
            # (weight_ih, weight_hh, bias_ih, bias_hh)
            if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):
                p_offset = ((p_offset + 63) // 64) * 64
            prev = p

        self._low_param_i = [0]*self._num_blocks
        for block_id in range(self._num_blocks-1,-1,-1):
            p_i = len(self._grads_info)-1
            while p_i > 0 and self._grads_info[p_i]["param_offset"] > block_id*self._block_size:
                p_i -= 1
            self._low_param_i[block_id] = p_i
        #print("self._low_param_i", self._low_param_i)

        # This paragraph does two things:
        # 1) Copy model parameters into master buffer
        # 2) Create tensor lists for unpacking new parameter tensor after all-gather
        self._packed_flat_to_model_params_fp16 = []
        self._packed_flat_to_model_params_fp32 = []
        self._model_params_num = len(self._model_params)
        self._contrib_tensor_list = []
        self._contrib_min_param_i, self._contrib_max_param_i = -1, -1
        self._contrib_update_frag_for_norm = []
        self._contrib_model_param_for_norm_fp16 = []
        self._contrib_model_param_for_norm_fp32 = []
        self._contrib_model_param_for_norm_is_fp16 = []
        self._model_param_is_contrib = []
        self._contrib_group_properties = []
        for shard_id in range(self._group_size):
            for block_id in range(self._num_blocks):
                for chunk_id in range(self._num_chunks):
                    flat_shard_start = (((block_id * self._num_chunks + chunk_id) * self._group_size) + shard_id) * self._shard_size
                    flat_shard_end = flat_shard_start + self._shard_size
                    for param_i, (p, grads_info, group_props) in enumerate(zip(self._model_params, self._grads_info, self._group_properties)):
                        flat_grad_start = grads_info["param_offset"]
                        flat_grad_end = flat_grad_start + grads_info["param_grads_size"]
                        clipped_start = (lambda a,b: a if a > b else b)(flat_grad_start, flat_shard_start)
                        clipped_end = (lambda a,b: a if a < b else b)(flat_grad_end, flat_shard_end)
                        if clipped_start < clipped_end:
                            grad_offset = clipped_start - flat_grad_start
                            grad_length = clipped_end - clipped_start
                            shard_offset = clipped_start - flat_shard_start
                            pf = _get_flat_view(p)
                            model_param_fragment = pf[grad_offset:grad_offset+grad_length]
                            new_param_packed_fragment = self._new_params_mega_chunks[shard_id][block_id][chunk_id][shard_offset:shard_offset+grad_length]
                            if model_param_fragment.dtype == torch.float16:
                                self._packed_flat_to_model_params_fp16.append( (new_param_packed_fragment, model_param_fragment) )
                            else:
                                self._packed_flat_to_model_params_fp32.append( (new_param_packed_fragment, model_param_fragment) )
                            if shard_id == self._rank_in_group:
                                self._model_param_is_contrib.append(param_i)
                                # copy model parameters into master buffer
                                master_param_fragment = self._fp32_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
                                opti_state_m_fragment = self._fp32_m_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
                                opti_state_v_fragment = self._fp32_v_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
                                opti_state_u_fragment = self._fp32_u_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
                                opti_state_g_fragment = self._fp16_g_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
                                opti_state_p_fragment = self._fp16_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
                                #print("model_param_fragment.size()=%s, new_param_packed_fragment.size()=%s, master_param_fragment.size()=%s" % (str(model_param_fragment.size()), str(new_param_packed_fragment.size()), str(master_param_fragment.size())))
                                if not self._resume_from_checkpoint:
                                    master_param_fragment.copy_(model_param_fragment)
                                self._contrib_group_properties.append(group_props)
                                self._contrib_tensor_list.append((master_param_fragment, opti_state_m_fragment, opti_state_v_fragment, opti_state_u_fragment, opti_state_g_fragment, opti_state_p_fragment)) # p, m, v, u, g, p_copy
                                self._contrib_update_frag_for_norm.append(opti_state_u_fragment)
                                if p.dtype == torch.float16:
                                    self._contrib_model_param_for_norm_fp16.append(p)
                                else:
                                    self._contrib_model_param_for_norm_fp32.append(p)
                                self._contrib_model_param_for_norm_is_fp16.append(True if p.dtype == torch.float16 else False)
                                if self._contrib_min_param_i < 0: self._contrib_min_param_i = param_i
                                self._contrib_max_param_i = param_i
        self._contrib_model_param_for_norm_num = len(self._contrib_model_param_for_norm_is_fp16)
        if len(self._contrib_model_param_for_norm_fp16) == 0: self._contrib_model_param_for_norm_fp16 = None
        if len(self._contrib_model_param_for_norm_fp32) == 0: self._contrib_model_param_for_norm_fp32 = None
        self._contrib_model_param_for_norm_is_fp32 = torch.tensor([not is_fp16 for is_fp16 in self._contrib_model_param_for_norm_is_fp16], dtype=torch.bool, device='cuda')
        self._contrib_model_param_for_norm_is_fp16 = torch.tensor([is_fp16 for is_fp16 in self._contrib_model_param_for_norm_is_fp16], dtype=torch.bool, device='cuda')
        self._offsets = torch.tensor(self._model_param_is_contrib, dtype=torch.int64, device='cuda')

        p, m, v, u, g, p_copy = list(zip(*self._contrib_tensor_list))
        self._contrib_compute_update_term_tensor_list = [g, p, m, v, u]
        self._contrib_update_weights_tensor_list = [u, p, p_copy]

        math_type = self._fp32_u.dtype
        decay, bias_correction, beta1, beta2, beta3, epsilon = list(zip(*self._contrib_group_properties))
        self._contrib_beta1 = torch.tensor(beta1, dtype=math_type, device='cuda')
        self._contrib_beta2 = torch.tensor(beta2, dtype=math_type, device='cuda')
        self._contrib_beta3 = torch.tensor(beta3, dtype=math_type, device='cuda')
        self._contrib_bias_correction = torch.tensor(bias_correction, dtype=torch.int, device='cuda')
        self._contrib_epsilon = torch.tensor(epsilon, dtype=math_type, device='cuda')
        self._contrib_weight_decay = torch.tensor(decay, dtype=math_type, device='cuda')

        self._packed_flat_to_model_params_fp16 = list(zip(*self._packed_flat_to_model_params_fp16)) if len(self._packed_flat_to_model_params_fp16) > 0 else None
        self._packed_flat_to_model_params_fp32 = list(zip(*self._packed_flat_to_model_params_fp32)) if len(self._packed_flat_to_model_params_fp32) > 0 else None

        self._lazy_init_stage2_done = True

        self.complete_reductions()
        self._first_step = False

    def set_is_accumulation_step(self, is_accumulation_step):
        self._is_accumulation_step = is_accumulation_step

    def set_last_step(self, last_step):
        self._last_step = last_step

    def _get_flush_block(self):
        flush_block = []
        if self._current_block > 0 and self._grads_generated[self._low_param_i[self._current_block-1]]:
            num_grads = len(self._grads_generated)
            contiguous_idx = num_grads
            while contiguous_idx > 0 and self._grads_generated[contiguous_idx-1]:
                contiguous_idx -= 1

            if contiguous_idx < num_grads and self._grads_info[contiguous_idx]["param_offset"] <= (self._current_block-1)*self._block_size:
                self._current_block -= 1
                start = self._current_block * self._block_size
                end = (self._current_block+1) * self._block_size
                flush_block = [start, end]

        return flush_block

    def _full_all_reduce_scale(self, block_id, scale):
        works = [None]*self._num_chunks
        if  self._clip_after_ar:
            for chunk_id in range(self._num_chunks):
                glob_chunk_id = block_id * self._num_chunks + chunk_id
                ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
                ar_stream.wait_stream(torch.cuda.current_stream())
                with torch.cuda.stream(ar_stream):
                    works[chunk_id] = torch.distributed.all_reduce(self._flat_grads_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True,op=_make_nccl_premul_sum(scale))
        else:
            glob_chunk_id = block_id
            ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
            ar_stream.wait_stream(torch.cuda.current_stream())
            with torch.cuda.stream(ar_stream):
                    works0 = torch.distributed.all_reduce(self._flat_grads_blocks[block_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True,op=_make_nccl_premul_sum(scale))
            for i in range(self._num_chunks):
                works[i]=works0
        self._reductions_works[block_id] = works

    def _full_all_reduce(self, block_id):
        works = [None]*self._num_chunks

        for chunk_id in range(self._num_chunks):
            glob_chunk_id = block_id * self._num_chunks + chunk_id
            ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
            ar_stream.wait_stream(torch.cuda.current_stream())
            with torch.cuda.stream(ar_stream):
                works[chunk_id] = torch.distributed.all_reduce(self._flat_grads_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
        self._reductions_works[block_id] = works

    def _reduce_scatter_and_all_reduce_scale(self, block_id, scale):
        # Reduction within each node
        # Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
        # The output format is the same as the fp32 master parameters
        works = [None]*self._num_chunks
        for chunk_id in range(self._num_chunks):
            glob_chunk_id = block_id * self._num_chunks + chunk_id
            rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]
            rs_stream.wait_stream(torch.cuda.current_stream())
            rs_stream.wait_stream(self._l2_grad_norm_st)
            with torch.cuda.stream(rs_stream):
                if self._reduce_scatter_no_copy:
                    works[chunk_id] = torch.distributed.reduce_scatter(
                        output=self._fp16_g_chunks[block_id][chunk_id],
                        input_list=self._flat_grads_shards[block_id][chunk_id],
                        group=self._rs_pg[glob_chunk_id%self._num_rs_pg],
                        async_op=True,
                        no_copy=True,
                        op=_make_nccl_premul_sum(scale),
                    )
                else:
                    works[chunk_id] = torch.distributed.reduce_scatter_tensor(
                        output=self._fp16_g_chunks[block_id][chunk_id],
                        input=self._flat_grads_chunks[block_id][chunk_id],
                        group=self._rs_pg[glob_chunk_id%self._num_rs_pg],
                        async_op=True,
                        op=_make_nccl_premul_sum(scale),
                    )

        # Reduction across nodes for each rank
        if self._num_groups > 1:
            for chunk_id in range(self._num_chunks):
                glob_chunk_id = block_id * self._num_chunks + chunk_id
                ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
                with torch.cuda.stream(ar_stream):
                    works[chunk_id].wait()
                    works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
        self._reductions_works[block_id] = works

    def _reduce_scatter_and_all_reduce(self, block_id):
        # Reduction within each node
        # Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
        # The output format is the same as the fp32 master parameters
        works = [None]*self._num_chunks
        for chunk_id in range(self._num_chunks):
            glob_chunk_id = block_id * self._num_chunks + chunk_id
            rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]
            rs_stream.wait_stream(torch.cuda.current_stream())
            with torch.cuda.stream(rs_stream):
                if self._reduce_scatter_no_copy:
                    works[chunk_id] = torch.distributed.reduce_scatter(
                        output=self._fp16_g_chunks[block_id][chunk_id],
                        input_list=self._flat_grads_shards[block_id][chunk_id],
                        group=self._rs_pg[glob_chunk_id%self._num_rs_pg],
                        async_op=True,
                        no_copy=True,
                    )
                else:
                    works[chunk_id] = torch.distributed.reduce_scatter_tensor(
                        output = self._fp16_g_chunks[block_id][chunk_id],
                        input = self._flat_grads_chunks[block_id][chunk_id],
                        group = self._rs_pg[glob_chunk_id%self._num_rs_pg],
                        async_op = True,
                    )

        # Reduction across nodes for each rank
        if self._num_groups > 1:
            for chunk_id in range(self._num_chunks):
                glob_chunk_id = block_id * self._num_chunks + chunk_id
                ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
                with torch.cuda.stream(ar_stream):
                    works[chunk_id].wait()
                    works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
        self._reductions_works[block_id] = works

    def _pipeline_block_reductions(self, block_id):
        if self._clip_after_ar:
            self._flatten_grad_mt(1.0/self._world_size)

            if self._full_ar:
                self._full_all_reduce(block_id)
            else:
                self._reduce_scatter_and_all_reduce(block_id)

            # Compute L2 grad norm
            if block_id == 0:
                with torch.cuda.stream(self._l2_grad_norm_st):
                    for block_id in range(self._num_blocks):
                        for chunk_id in range(self._num_chunks):
                            self._reductions_works[block_id][chunk_id].wait()
                    # Since the packed format is contiguous after reductions, only one norm is needed
                    l2_grad_norm_sq = torch.empty([1], device='cuda')
                    if self._full_ar:
                        # this flattening of lists is to keep multi_tensor_apply function happy, it wants depth=1 for l2 norm computation
                        flat_list = [item for sublist in self._fp16_g_chunks for item in sublist]
                        l2_grad_norm_sq = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [flat_list], False)[0]**2
                    else:
                        l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2
                    torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg)
                    self._L2_grad_norm = l2_grad_norm_sq.sqrt()
        else:
            # Copy model grads to flat grads buffer
            self._flatten_grad_mt(1.0)

            # Compute L2 grad norm
            self._l2_grad_norm_st.wait_stream(torch.cuda.current_stream())
            with torch.cuda.stream(self._l2_grad_norm_st):
                if not self._fused_norm:
                    self._L2_grad_norm = self._flat_grads.norm(dtype=torch.float16, p=2).float()
            torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)

            # Apply clipping & pre-reduction scaling on grads
            loss_scale = self.global_scale
            max_grad_norm = loss_scale*self.defaults['max_grad_norm']
            coeff = max_grad_norm /(1e-6+self.L2_grad_norm)
            coeff = (coeff>1) * self._one + (coeff<=1) * coeff
            tmp = torch.cat(((self._one), (coeff)))
            index = (coeff+1>coeff).int()
            scale = tmp.index_select(0, index).half()/self._world_size
            if not self._fuse_scale:
                self._flat_grads.mul_(scale)

            if self._full_ar:
                if self._fuse_scale:
                    self._full_all_reduce_scale(block_id, scale)
                else:
                    self._full_all_reduce(block_id)
            else:
                if self._fuse_scale:
                    self._reduce_scatter_and_all_reduce_scale(block_id, scale)
                else:
                    self._reduce_scatter_and_all_reduce(block_id)

            if block_id == 0:
                for block_id in range(self._num_blocks):
                    for chunk_id in range(self._num_chunks):
                        self._reductions_works[block_id][chunk_id].wait()

    def __compute_contrib_param_norm(self):
        if self._contrib_model_param_for_norm_fp16 is not None and self._contrib_model_param_for_norm_fp32 is not None:
            gnorm_fp16 = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp16], True)[1]
            gnorm_fp32 = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp32], True)[1]
            gnorm = torch.empty(size=[self._contrib_model_param_for_norm_num], dtype=torch.bool, device='cuda')
            gnorm.masked_scatter_(self._contrib_model_param_for_norm_is_fp16, gnorm_fp16)
            gnorm.masked_scatter_(self._contrib_model_param_for_norm_is_fp32, gnorm_fp32)
        elif self._contrib_model_param_for_norm_fp16 is not None:
            gnorm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp16], True)[1]
        elif self._contrib_model_param_for_norm_fp32 is not None:
            gnorm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp32], True)[1]
        return gnorm

    def __compute_contrib_update_norm(self):
        l2_norm = torch.zeros(size=[self._model_params_num], dtype=torch.float32, device='cuda')
        local_contrib_l2_norm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_update_frag_for_norm], True)[1] ** 2
        l2_norm.scatter_(dim=0, index=self._offsets, src=local_contrib_l2_norm)
        torch.distributed.all_reduce(l2_norm, group=self._ag_pg[0])
        l2_norm = torch.sqrt(l2_norm)
        return l2_norm

    def _pipeline_step(self):
        global_scale = self.global_scale
        # if clip before ar, set max_grad_norm to 0
        max_grad_norm = self.defaults['max_grad_norm'] * self._clip_after_ar
        self._completion_st.wait_stream(self._l2_grad_norm_st)
        global_grad_norm = self.L2_grad_norm

        # check global_grad_norm and fill overflow_buf
        is_finite = (global_grad_norm + 1 > global_grad_norm).int()
        self._overflow_buf = self._one * (is_finite ^ self._one) # toggle between 0 and 1

        if not self._clip_after_ar:
            torch.distributed.all_reduce(is_finite,
                                         op=torch.distributed.ReduceOp.MIN,
                                         group=self._current_process_group)
            torch.distributed.all_reduce(self._overflow_buf,
                                         op=torch.distributed.ReduceOp.MAX,
                                         group=self._current_process_group)

        # increment step counter if no overflow
        self._step += is_finite
        self._completion_st.wait_stream(torch.cuda.current_stream())
        self._completion_st.wait_stream(self._l2_grad_norm_st)

        # Call step kernel once per step
        # Call all-gather once per step
        with torch.cuda.stream(self._completion_st):
            for block_id in range(self._num_blocks):
                for chunk_id in range(self._num_chunks):
                    self._reductions_works[block_id][chunk_id].wait()
            param_norm = self.__compute_contrib_param_norm()
            multi_tensor_applier(self.multi_tensor_lamb_compute_update_term,
                    self._overflow_buf,
                    self._contrib_compute_update_term_tensor_list, # g, p, m, v, u
                    self._contrib_beta1,
                    self._contrib_beta2,
                    self._contrib_beta3,
                    self._contrib_bias_correction,
                    self._step,
                    self._contrib_epsilon,
                    self._adam_w_mode,
                    self._contrib_weight_decay,
                    global_scale,
                    global_grad_norm,
                    max_grad_norm)
            upd_norm = self.__compute_contrib_update_norm()
            multi_tensor_applier(self.multi_tensor_lamb_update_weights,
                    self._overflow_buf,
                    self._contrib_update_weights_tensor_list, # u, p, p_copy
                    param_norm,
                    upd_norm,
                    self._offsets,
                    self._lr,
                    self._contrib_weight_decay,
                    global_grad_norm,
                    self._use_nvlamb)
            if not self._skip_ag:
                # allgather chunking is currently not supported for clip after allreduce
                if not self._clip_after_ar:
                    for block in range(self._num_blocks):
                        for chunk in range(self._num_chunks):
                            if self._all_gather_no_copy:
                                torch.distributed.all_gather(
                                    tensor_list = self._new_params2_shards[block][chunk],
                                    tensor = self._fp16_p_chunks[block][chunk],
                                    group = self._ag_pg[0],
                                    no_copy = True,
                                )
                            else:
                                torch.distributed.all_gather_into_tensor(
                                    output_tensor = self._new_params2_blocks[block],
                                    input_tensor = self._fp16_p_chunks[block][chunk],
                                    group = self._ag_pg[0],
                                )
                else:
                    if self._all_gather_no_copy:
                        torch.distributed.all_gather(
                            tensor_list = self._new_params_mega_shards,
                            tensor = self._fp16_p,
                            group = self._ag_pg[0],
                            no_copy = True,
                        )
                    else:
                        torch.distributed.all_gather_into_tensor(
                            output_tensor = self._new_params,
                            input_tensor = self._fp16_p,
                            group = self._ag_pg[0],
                        )

    def _flatten_grad_mt(self, scale):
        if len(self._grads_fp16) > 0:
            self._overflow_buf.zero_()
            if not self._fused_norm:
                multi_tensor_applier(
                        amp_C.multi_tensor_scale,
                        self._overflow_buf,
                        list(zip(*self._grads_fp16)),
                        scale)
            else:
                self._L2_grad_norm=multi_tensor_applier(
                        amp_C.multi_tensor_l2norm_scale,
                        self._overflow_buf,
                        list(zip(*self._grads_fp16)),
                        scale, False)[0].float()

            self._grads_fp16 = []
        if len(self._grads_fp32) > 0:
            self._overflow_buf.zero_()
            if not self._fused_norm:
                multi_tensor_applier(
                        amp_C.multi_tensor_scale,
                        self._overflow_buf,
                        list(zip(*self._grads_fp32)),
                        scale)
            else:
                self._L2_grad_norm=multi_tensor_applier(
                        amp_C.multi_tensor_l2norm_scale,
                        self._overflow_buf,
                        list(zip(*self._grads_fp32)),
                        scale, False)[0].float()
            self._grads_fp32 = []

    def _do_overlapped_reduction(self, param_i, param):
        if not self._is_accumulation_step:
            # handle overlapped reductions
            if param.dtype == torch.float16:
                self._grads_fp16.append( (param.grad, self._individual_flat_grads[param_i]) )
            else:
                self._grads_fp32.append( (param.grad, self._individual_flat_grads[param_i]) )
            self._grads_generated[param_i]=True
            if not self._first_step and not self._last_step:
                if self._overlap_reductions:
                    flush_block = self._get_flush_block()
                    while flush_block:
                        block_id = flush_block[0] // self._block_size
                        self._pipeline_block_reductions(block_id)
                        flush_block = self._get_flush_block()

    def set_global_scale(self, global_scale):
        """Set global scale.
        """
        self._global_scale = global_scale

    @property
    def global_scale(self):
        return self._global_scale

    @property
    def L2_grad_norm(self):
        torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
        return self._L2_grad_norm

    def complete_reductions(self):
        """Complete reductions if full pipeline is not selected or overlap is not allowed.
        """
        if self._last_step:
            # zero out gradients that have not been completed yet
            for param_i, grad_generated in enumerate(self._grads_generated):
                if not grad_generated:
                    grad_info = self._grads_info[param_i]
                    param_offset = grad_info["param_offset"]
                    param_size = grad_info["param_grads_size"]
                    self._flat_grads[param_offset:param_offset+param_size].zero_()
                    self._grads_generated[param_i] = True

        if self._first_step or self._last_step or not self._overlap_reductions:
            # nothing done so far, run full pipeline after reductions
            for block_id in range(self._num_blocks-1,-1,-1):
                self._pipeline_block_reductions(block_id)

        torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)

        self._current_block = self._num_blocks
        self._grads_generated = [False]*len(self._grads_info)

    def step(self, closure=None, grad_scaler=None):
        loss = None
        if closure is not None:
            loss = closure()

        self._pipeline_step()

        if grad_scaler is not None:
            found_inf = self._overflow_buf.float()
            optimizer_state = grad_scaler._per_optimizer_states[id(self)]
            current_device = torch.device('cuda', torch.cuda.current_device())
            optimizer_state["found_inf_per_device"][current_device] = found_inf

        self._completion_st.wait_stream(torch.cuda.current_stream())
        if not self._set_flat_param_view:
            with torch.cuda.stream(self._completion_st):
                # Copy self._new_params to model params
                with torch.no_grad():
                    if self._packed_flat_to_model_params_fp16 is not None:
                        multi_tensor_applier(
                                fused_adam_cuda.maybe_cast_mt,
                                self._overflow_buf,
                                self._packed_flat_to_model_params_fp16)
                    if self._packed_flat_to_model_params_fp32 is not None:
                        multi_tensor_applier(
                                fused_adam_cuda.maybe_cast_mt,
                                self._overflow_buf,
                                self._packed_flat_to_model_params_fp32)

        torch.cuda.current_stream().wait_stream(self._completion_st)

        self._reductions_works = [None]*self._num_blocks
        self._allgather_works = [None]*self._num_blocks

        return loss

    def state_dict(self):
        """
        Returns a dict containing the current state of this :class:`DistributedFusedAdam` instance.
        Example::
            checkpoint = {}
            checkpoint['model'] = model.state_dict()
            checkpoint['optimizer'] = optimizer.state_dict()
            torch.save(checkpoint, "saved.pth")
        """
        # save step, master weights and first/second moments
        state_dict = {}
        state_dict['step'] = self._step
        state_dict['fp32_p'] = self._fp32_p
        state_dict['fp32_m'] = self._fp32_m
        state_dict['fp32_v'] = self._fp32_v
        return state_dict

    def load_state_dict(self, state_dict):
        """
        Loads a state_dict created by an earlier call to state_dict().
        If an DistributedFusedAdam instance was constructed from some ``init_optimizer``,
        whose parameters in turn came from ``model``, it is expected that the user
        will call ``model.load_state_dict()`` before
        ``optimizer.load_state_dict()`` is called.
        Example::
            model = torch.nn.Linear(D_in, D_out).cuda().half()
            optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
            optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
            ...
            checkpoint = torch.load("saved.pth")
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
        """
        # restore step, master weights and first/second moments
        self._step = state_dict['step']
        self._fp32_p = state_dict['fp32_p'].to(device="cuda")
        self._fp32_m = state_dict['fp32_m'].to(device="cuda")
        self._fp32_v = state_dict['fp32_v'].to(device="cuda")
        self._resume_from_checkpoint = True
