#! -*- coding: utf-8
import typing

import torch
import torch.distributed

__all__ = ["send_model_parameter", "recv_model_parameter"]


@torch.no_grad()
def send_model_parameter(dest_ranks: typing.List[int], model: torch.nn.Module):
    if len(dest_ranks) < 1:
        return
    tasklist: typing.List[torch.distributed.P2POp] = []

    for rank in dest_ranks:
        tasklist += send_param(rank, model)

    for task in torch.distributed.batch_isend_irecv(tasklist):
        task.wait()


@torch.no_grad()
def send_param(rank: int, model: torch.nn.Module, tagoffset: int = 0) -> typing.List[torch.distributed.P2POp]:
    backend = torch.distributed.get_backend()

    return [torch.distributed.P2POp(torch.distributed.isend,
                                    p.cpu() if backend == "gloo" else p,
                                    rank, tag=tag+tagoffset)
            for tag, p in enumerate([p for p in model.parameters() if p.requires_grad])]


@torch.no_grad()
def recv_model_parameter(src_ranks: typing.List[int], model: torch.nn.Module):
    if len(src_ranks) < 1:
        return
    tasklist, params = [], [[] for p in model.parameters() if p.requires_grad]

    for rank in src_ranks:

        tasks, buffs = recv_param(rank, model)
        tasklist += tasks
        for param, buf in zip(params, buffs):
            param.append(buf)
        # params.append(buffs)

    for task in torch.distributed.batch_isend_irecv(tasklist):
        task.wait()

    mean_param(model, params)


@torch.no_grad()
def recv_param(rank: int, model: torch.nn.Module, tagoffset: int = 0) -> typing.Tuple[typing.List[torch.distributed.P2POp], torch.Tensor]:
    backend = torch.distributed.get_backend()
    tasklist, params = [], []

    for tag, p in enumerate([p for p in model.parameters() if p.requires_grad]):
        buff = torch.zeros_like(
            p, device="cpu" if backend == "gloo" else p.device)
        tasklist.append(torch.distributed.P2POp(torch.distributed.irecv,
                                                buff,
                                                rank, tag=tag+tagoffset))
        params.append(buff)

    return tasklist, params


@torch.no_grad
def mean_param(model: torch.nn.Module, params: typing.List[typing.List[torch.Tensor]]):
    for dest, recved in zip([p for p in model.parameters() if p.requires_grad], params):
        dest.data = torch.mean(torch.stack(recved, dim=0),
                               dim=0).to(dest.device)


@torch.no_grad
def merge_model_parameter(self_rank: int, target_ranks: typing.List[int], model: torch.nn.Module):
    if len(target_ranks) < 1:
        return
    backend = torch.distributed.get_backend()
    tasklist = []
    params = [[p.cpu() if backend == "gloo" else p] for p in model.parameters()
              if p.requires_grad]

    for rank in target_ranks:
        if rank == self_rank:
            continue
        tasklist += send_param(rank, model)

    for rank in target_ranks:
        if rank == self_rank:
            continue

        tasks, buffs = recv_param(rank, model)
        tasklist += tasks
        for param, buf in zip(params, buffs):
            param.append(buf)

    for task in torch.distributed.batch_isend_irecv(tasklist):
        task.wait()

    mean_param(model, params)
