import itertools
from abc import ABC, abstractmethod
from typing import Iterator, List

import torch
import torch.distributed as dist
from torch import distributed
from torch.nn.parallel import DistributedDataParallel
from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import _allreduce_fut

def average_parameters(
    params: Iterator[torch.nn.Parameter], process_group: dist.ProcessGroup
):
    """
    Averages all the given parameters.
    For allreduce efficiency, all the parameters are flattened into a contiguous buffer.
    Thus, it requires extra memory of the same size as the given parameters.
    """
    group_to_use = process_group if process_group is not None else dist.group.WORLD
    # Do not update any parameter if not in the process group.
    if dist._rank_not_in_group(group_to_use):
        return

    params_it1, params_it2 = itertools.tee(params)
    # If the input parameters have different data types,
    # packing these parameters will trigger an implicit type up-casting.
    # The original parameter data types will be restored during the subsequent unpacking.
    flat_params = torch.cat([p.data.view(-1) for p in params_it1])
    flat_params /= dist.get_world_size(group_to_use)
    # Make sure the allreduce will not conflict with any other ongoing process group.
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    dist.all_reduce(flat_params, group=group_to_use)

    offset = 0
    for p in params_it2:
        p.data = flat_params[offset : offset + p.numel()].view_as(p).type_as(p)
        offset += p.numel()


class ModelAverager(ABC):
    r"""Base class for all model averagers.

    Args:
        process_group: The process group to be used for all-reduce.
                       If ``None``, the default process group, which
                       is created by :func:`torch.distributed.init_process_group`,
                       will be used. (default: ``None``)
    """

    def __init__(self, process_group=None):
        self.process_group = (
            process_group if process_group is not None else dist.group.WORLD
        )
        self.step = 0

    @abstractmethod
    def average_parameters(self, params):
        raise NotImplementedError


class PeriodicModelAverager(ModelAverager):
    r"""
    Averages parameters periodically after the warm-up stage.

    This can be used for running `post-local SGD <https://arxiv.org/abs/1808.07217>`_,
    by running :class:`~torch.nn.DistributedDataParallel` (DDP)
    using the subgroups created by :meth:`~torch.distributed.new_subgroups`.

    Args:
        period (int): The number of steps per model averaging.
                      Usually the period should be greater than ``1`` to reduce the communication cost.
                      Otherwise, only DDP needs to be used.
        warmup_steps (int): The number of warm-up steps. During this stage,
                            model averaging is skipped.
        process_group: The process group to be used for all-reduce.
                       If ``None``, the default process group, which
                       is created by :func:`torch.distributed.init_process_group`,
                       will be used. (default: ``None``)

    Example::

        >>>  import torch
        >>>  import torch.distributed as dist
        >>>  import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
        >>>  import torch.distributed.algorithms.model_averaging.averagers as averagers
        >>>  import torch.nn as nn
        >>>
        >>>  dist.init_process_group("nccl", rank=rank, world_size=16)
        >>>  torch.cuda.set_device(rank)
        >>>  module = nn.Linear(1, 1, bias=False).to(rank)
        >>>  model = nn.parallel.DistributedDataParallel(
        >>>     module, device_ids=[rank], output_device=rank
        >>>  )
        >>>  # Register a post-localSGD communication hook.
        >>>  subgroup, subgroups = dist.new_subgroups()
        >>>  state = PostLocalSGDState(subgroup=subgroup, start_localSGD_iter=100)
        >>>  model.register_comm_hook(state, post_localSGD_hook)
        >>>
        >>>  # In the first 100 steps, run global gradient averaging like normal DDP at every step.
        >>>  # After 100 steps, run model averaging every 4 steps.
        >>>  # Note that ``warmup_steps`` must be the same as ``start_localSGD_iter`` used in ``PostLocalSGDState``.
        >>>  averager = averagers.PeriodicModelAverager(period=4, warmup_steps=100)
        >>>  for step in range(0, 200):
        >>>     optimizer.zero_grad()
        >>>     loss = loss_fn(output, labels)
        >>>     loss.backward()
        >>>     optimizer.step()
        >>>     # Average parameters globally after ``optimizer.step()``.
        >>>     # Thus, the inter-node communication only occurs periodically after ``warmup_steps``.
        >>>     averager.average_parameters(model.parameters())

    .. warning ::
        `PeriodicModelAverager` is experimental and subject to change.
    """

    def __init__(
        self,
        period,
        warmup_steps=0,
        process_group=None,
    ):
        super().__init__(process_group)
        if warmup_steps < 0:
            raise ValueError("Arg ``warmup_steps`` must be a non-negative number.")
        self.warmup_steps = warmup_steps
        if period < 1:
            raise ValueError("Arg ``period`` must be a positive value.")
        elif period == 1:
            warnings.warn(
                "When period is 1, no need to use model averaging because the communication cost "
                "of all-reducing parameters will be no less than the cost of all-reducing gradients "
                "by DistributedDataParall in the backward pass. Therefore, only "
                "DistributedDataParallel should be used for this case."
            )
        self.period = period

    def average_parameters(self, params):
        r"""
        Averages parameters if ``step`` is no less than ``warmup_steps``
        and it can be divided by ``period``, where ``step`` is increased by 1
        at each iteration in the training loop.
        """
        if (
            self.step >= self.warmup_steps
            and (self.step - self.warmup_steps) % self.period == 0
        ):
            average_parameters(iter(params), self.process_group)
        self.step += 1


def hierarchical_allreduce_hook(
    process_group_state: tuple,
    bucket: distributed.GradBucket
) -> torch.futures.Future[torch.Tensor]:
    # ReduceScatter - Parallelized MPI Allreduce - NCCL Allgather
    local_world_size = 8
    reduce_world_size = 8
    rank = distributed.get_rank()
    device = torch.cuda.current_device()
    list_local_group: List[distributed.ProcessGroup] = process_group_state[0]
    list_reduce_group: List[distributed.ProcessGroup] = process_group_state[1]
    group_local = list_local_group[rank // local_world_size]
    group_reduce = list_reduce_group[rank % local_world_size]

    assert len(process_group_state) == 4, ""
    global_step = process_group_state[2]
    gradient_acc = process_group_state[3]
    if global_step is None or global_step.step % gradient_acc == 0:
        rest = 0
        if bucket.buffer().size(0) % 8 == 0:
            tensor = bucket.buffer()
        else:
            rest = 8 - bucket.buffer().size(0) % 8
            tensor = torch.zeros(bucket.buffer().size(0) + rest, device=device)
            tensor[: -rest] = bucket.buffer()

        assert tensor.size(0) % 8 == 0

        tensor.div_(local_world_size)
        tensor_each = torch.zeros(tensor.size(0) // local_world_size, device=tensor.device)
        fut: torch.futures.Future = distributed.reduce_scatter(
            output=tensor_each,
            input_list=list(tensor.chunk(local_world_size)),
            group=group_local,
            async_op=True).get_future()

        def _fut_allreduce(fut):
            tensor_reduce_scatter = fut.wait()[0]

            compressed_tensor = tensor_reduce_scatter.to(torch.float16).div_(reduce_world_size)
            fut = distributed.all_reduce(
                tensor=compressed_tensor,
                op=distributed.ReduceOp.SUM,
                group=group_reduce,
                async_op=True).get_future()
            return fut.wait()

        def _fut_allgather(fut):
            tensor_allreduce: torch.Tensor = fut.wait()[0].float()

            final_tensor = torch.zeros_like(tensor)
            fut = distributed.all_gather(
                list(final_tensor.chunk(local_world_size)), tensor_allreduce,
                group=group_local, async_op=True).get_future()
            return fut.wait()

        def _output(fut):
            gather_tensor = fut.wait()[0]
            gather_tensor = torch.reshape(gather_tensor, tensor.size())
            if rest != 0:
                gather_tensor = gather_tensor[: -rest]

            buffer = bucket.buffer()
            buffer.copy_(gather_tensor)
            return buffer
        return fut.then(_fut_allreduce).then(_fut_allgather).then(_output)
    else:
        fut: torch.futures.Future[torch.Tensor] = torch.futures.Future()
        fut.set_result(bucket.buffer())
        return fut


def posted_hierarchical_allreduce_hook(
    process_group_state: tuple,
    bucket: distributed.GradBucket
) -> torch.futures.Future[torch.Tensor]:
    # ReduceScatter - Parallelized MPI Allreduce - NCCL Allgather
    local_world_size = 8
    reduce_world_size = 8
    rank = distributed.get_rank()
    device = torch.cuda.current_device()
    list_local_group: List[distributed.ProcessGroup] = process_group_state[0]
    list_reduce_group: List[distributed.ProcessGroup] = process_group_state[1]
    group_local = list_local_group[rank // local_world_size]
    group_reduce = list_reduce_group[rank % local_world_size]

    assert len(process_group_state) == 5, ""
    global_step = process_group_state[2]
    gradient_acc = process_group_state[3]
    start_local_sgd_iter = process_group_state[4]

    assert global_step is not None
    if global_step.step < start_local_sgd_iter:
        if global_step is None or global_step.step % gradient_acc == 0:
            rest = 0
            if bucket.buffer().size(0) % 8 == 0:
                tensor = bucket.buffer()
            else:
                rest = 8 - bucket.buffer().size(0) % 8
                tensor = torch.zeros(bucket.buffer().size(0) + rest, device=device)
                tensor[: -rest] = bucket.buffer()

            assert tensor.size(0) % 8 == 0

            tensor.div_(local_world_size)
            tensor_each = torch.zeros(tensor.size(0) // local_world_size, device=tensor.device)
            fut: torch.futures.Future = distributed.reduce_scatter(
                output=tensor_each,
                input_list=list(tensor.chunk(local_world_size)),
                group=group_local,
                async_op=True).get_future()

            def _fut_allreduce(fut):
                tensor_reduce_scatter = fut.wait()[0]

                compressed_tensor = tensor_reduce_scatter.to(torch.float16).div_(reduce_world_size)
                fut = distributed.all_reduce(
                    tensor=compressed_tensor,
                    op=distributed.ReduceOp.SUM,
                    group=group_reduce,
                    async_op=True).get_future()
                return fut.wait()

            def _fut_allgather(fut):
                tensor_allreduce: torch.Tensor = fut.wait()[0].float()

                final_tensor = torch.zeros_like(tensor)
                fut = distributed.all_gather(
                    list(final_tensor.chunk(local_world_size)), tensor_allreduce,
                    group=group_local, async_op=True).get_future()
                return fut.wait()

            def _output(fut):
                gather_tensor = fut.wait()[0]
                gather_tensor = torch.reshape(gather_tensor, tensor.size())
                if rest != 0:
                    gather_tensor = gather_tensor[: -rest]

                buffer = bucket.buffer()
                buffer.copy_(gather_tensor)
                return buffer
            return fut.then(_fut_allreduce).then(_fut_allgather).then(_output)
        else:
            fut: torch.futures.Future[torch.Tensor] = torch.futures.Future()
            fut.set_result(bucket.buffer())
            return fut
    else:
        fut: torch.futures.Future[torch.Tensor] = torch.futures.Future()
        fut.set_result(bucket.buffer())
        return fut


def gradient_acc_hook(
    process_group_state: tuple,
    bucket: distributed.GradBucket
) -> torch.futures.Future[torch.Tensor]:
    global_step = process_group_state[0]
    gradient_acc = process_group_state[1]
    if global_step.step % gradient_acc == 0:
        return _allreduce_fut(None, bucket.buffer())
    else:
        fut: torch.futures.Future[torch.Tensor] = torch.futures.Future()
        fut.set_result(bucket.buffer())
        return fut

def register_acc(
    model: DistributedDataParallel,
    global_step,
    gradient_acc):
    model.register_comm_hook((global_step, gradient_acc), gradient_acc_hook)


def register_2d_allreduce(
    model: DistributedDataParallel,
    global_step,
    gradient_acc):
    assert distributed.is_initialized()
    world_size = distributed.get_world_size()
    list_group_local = []
    list_group_reduce = []
    for idx_node in range(world_size // 8):
        list_group_local.append(distributed.new_group(
            [idx_local_rank + idx_node * 8 for idx_local_rank in range(8)], backend="nccl"))
    for idx_local_rank in range(8):
        list_group_reduce.append(distributed.new_group(
            [idx_local_rank + idx_node * 8 for idx_node in range(world_size // 8)], backend="nccl"))
    if distributed.get_rank() == 0:
        import logging
        logging.info(list_group_local)
        logging.info(list_group_reduce)

    model.register_comm_hook((list_group_local, list_group_reduce, global_step, gradient_acc), hierarchical_allreduce_hook)


def register_posted_2d_allreduce(
    model: DistributedDataParallel,
    global_step,
    gradient_acc,
    start_local_sgd_iter):

    assert distributed.is_initialized()
    world_size = distributed.get_world_size()
    list_group_local = []
    list_group_reduce = []
    for idx_node in range(world_size // 8):
        list_group_local.append(distributed.new_group(
            [idx_local_rank + idx_node * 8 for idx_local_rank in range(8)], backend="nccl"))
    for idx_local_rank in range(8):
        list_group_reduce.append(distributed.new_group(
            [idx_local_rank + idx_node * 8 for idx_node in range(world_size // 8)], backend="nccl"))
    if distributed.get_rank() == 0:
        import logging
        logging.info(list_group_local)
        logging.info(list_group_reduce)

    model.register_comm_hook(
        (list_group_local, list_group_reduce, global_step, gradient_acc, start_local_sgd_iter),
        posted_hierarchical_allreduce_hook)


if __name__ == "__main__":
    import inspect
    sig = inspect.signature(hierarchical_allreduce_hook)
    print(sig.return_annotation)
    from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import \
        fp16_compress_hook
    sig_fp16 = inspect.signature(fp16_compress_hook)
    print(sig_fp16.parameters["bucket"].annotation)
    print(sig_fp16.return_annotation)
