from torch.nn.parallel import DistributedDataParallel
from torch import distributed
from typing import List
import torch

def hierarchical_allreduce_hook(
    process_group_state: tuple,
    bucket: distributed.GradBucket
) -> torch.futures.Future[torch.Tensor]:
    # ReduceScatter - Parallelized MPI Allreduce - NCCL Allgather
    assert len(process_group_state) == 4, ""
    global_step = process_group_state[2]
    gradient_acc = process_group_state[3]

    if global_step is None or global_step.global_step % gradient_acc == 0:
        local_world_size = 8
        reduce_world_size = 8
        rank = distributed.get_rank()
        device = torch.cuda.current_device()
        list_local_group: List[distributed.ProcessGroup] = process_group_state[0]
        list_reduce_group: List[distributed.ProcessGroup] = process_group_state[1]
        group_local = list_local_group[rank // local_world_size]
        group_reduce = list_reduce_group[rank % local_world_size]

        rest = 0
        if bucket.buffer().size(0) % 8 == 0:
            tensor = bucket.buffer()
        else:
            rest = 8 - bucket.buffer().size(0) % 8
            tensor = torch.zeros(bucket.buffer().size(0) + rest, device=device)
            tensor[: -rest] = bucket.buffer()

        assert tensor.size(0) % 8 == 0
        
        tensor.div_(local_world_size)
        tensor_each = torch.zeros(tensor.size(0) // local_world_size, device=tensor.device)
        fut: torch.futures.Future = distributed.reduce_scatter(
            output=tensor_each,
            input_list=list(tensor.chunk(local_world_size)),
            group=group_local, 
            async_op=True).get_future()
        def _fut_allreduce(fut):
            tensor_reduce_scatter = fut.wait()[0]

            compressed_tensor = tensor_reduce_scatter.to(torch.float16).div_(reduce_world_size)
            fut = distributed.all_reduce(
                tensor=compressed_tensor, 
                op=distributed.ReduceOp.SUM,
                group=group_reduce,
                async_op=True).get_future()
            return fut.wait()
        def _fut_allgather(fut):
            tensor_allreduce: torch.Tensor = fut.wait()[0].float()

            final_tensor = torch.zeros_like(tensor)
            fut = distributed.all_gather(
                list(final_tensor.chunk(local_world_size)), tensor_allreduce, 
                group=group_local, async_op=True).get_future()
            return fut.wait()
        def _output(fut):
            gather_tensor = fut.wait()[0]
            gather_tensor = torch.reshape(gather_tensor, tensor.size())
            if rest != 0:
                gather_tensor = gather_tensor[: -rest]

            buffer = bucket.buffer()
            buffer.copy_(gather_tensor)
            return buffer
        return fut.then(_fut_allreduce).then(_fut_allgather).then(_output)
    else:
        fut: torch.futures.Future[torch.Tensor] = torch.futures.Future()
        fut.set_result(bucket.buffer())
        return fut


def register_2d_allreduce(
    model: DistributedDataParallel,
    global_step,
    gradient_acc
    ):
    assert distributed.is_initialized()
    world_size = distributed.get_world_size()
    list_group_local = []
    list_group_reduce = []
    for idx_node in range(world_size // 8):
        list_group_local.append(distributed.new_group(
            [idx_local_rank + idx_node * 8 for idx_local_rank in range(8)], backend="nccl"))
    for idx_local_rank in range(8):
        list_group_reduce.append(distributed.new_group(
            [idx_local_rank + idx_node * 8 for idx_node in range(world_size // 8)], backend="nccl"))
    if distributed.get_rank() == 0:
        import logging
        logging.info(list_group_local)    
        logging.info(list_group_reduce)
      
    model.register_comm_hook((list_group_local, list_group_reduce, global_step, gradient_acc), hierarchical_allreduce_hook)


if __name__ == "__main__":
    import inspect
    sig = inspect.signature(hierarchical_allreduce_hook)
    print(sig.return_annotation)
    from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import fp16_compress_hook   
    sig_fp16 = inspect.signature(fp16_compress_hook)
    print(sig_fp16.parameters["bucket"].annotation)
    print(sig_fp16.return_annotation)
