#! -*- coding: utf-8
import typing
from logging import getLogger

import numpy as np
import torch
import torch.distributed as dist

from .dog_utils import calc_mu_dog, calc_vu_dog
from .utils import calc_grad_norm, calc_param_norm, get_group_value

__all__ = ["FedProxDoLEarlyClient", "FedProxDoLEarlyServer", ]


class AFedProxDoLEarlyOptimizer(object):
    @property
    def mu(self) -> float: return get_group_value(self.param_groups,
                                                  "mu")

    @property
    def eta(self) -> float: return get_group_value(self.param_groups,
                                                   "eta")

    @property
    def r(self) -> float: return get_group_value(self.param_groups,
                                                 "r")

    @property
    def vu(self) -> float: return get_group_value(self.param_groups,
                                                  "vu")

    @property
    def G(self) -> float: return get_group_value(self.param_groups,
                                                 "G")

    @property
    def rbar(self) -> float: return get_group_value(self.param_groups,
                                                    "rbar")

    @property
    def param_norm(self) -> float:
        # return torch.stack([calc_param_norm(group) for group in self.param_groups], dim=0).norm().detach().cpu().item()
        return get_group_value(self.param_groups, "param_norm")

    @property
    def grad_norm(self) -> float: return get_group_value(self.param_groups,
                                                         "grad_norm")

    @property
    def rdist(self) -> float: return get_group_value(self.param_groups,
                                                     "rdist")

    @property
    def client_loss(self): return get_group_value(self.param_groups,
                                                  "client_loss")

    @property
    def server_loss(self): return get_group_value(self.param_groups,
                                                  "server_loss")

    @property
    def delta_loss(self): return get_group_value(self.param_groups,
                                                 "delta")


class FedProxDoLEarlyClient(torch.optim.Optimizer, AFedProxDoLEarlyOptimizer):
    def _trace_log_items(self, group):
        return (group.get("rbar", np.nan),
                group.get("G", np.nan),
                group.get("eta", np.nan),
                group.get("mu", np.nan),
                group.get("rdist", np.nan),
                group.get("client_loss", np.nan),
                group.get("grad_norm", np.nan),
                group.get("param_norm", np.nan), )

    def __init__(self, params, rank: int, local_step: int,
                 lr: float = 1e-3, vu: float = 1e-4,
                 reps_rel: float = 1e-6, weight_decay: float = 0.0, eps: float = 1e-8,
                 init_eta: typing.Optional[float] = None, init_mu: typing.Optional[float] = None,
                 clip_param_norm: float = None, clip_dist_norm: float = None,
                 fix_eta: bool = False, fix_mu: bool = False,
                 without_model_merge: bool= False,
                 nsample: int = None, use_global_model_on_calc_local_loss: bool = False,
                 server_node_rank: typing.Optional[int] = None,
                 warmup_t: int = 0, warmup_eta: float = None, warmup_mu: float = None, tag_offset: int = 0):
        init_mu = (self.calc_mu(lr, vu) if not isinstance(
            init_mu, float) else init_mu)
        init_eta = (self.calc_mu(lr, vu) if not isinstance(
            init_eta, float) else init_eta)
        defaults = dict(lr=lr, vu=vu,
                        mu=init_mu if warmup_t <= 0 or warmup_mu is None else warmup_mu,
                        init_mu=init_mu, warmup_mu=warmup_mu,
                        eta=init_eta if warmup_t <= 0 or warmup_eta is None else warmup_eta,
                        init_eta=init_eta, warmup_eta=warmup_eta,
                        reps_rel=reps_rel, weight_decay=weight_decay, eps=eps,
                        clip_param_norm=clip_param_norm, clip_dist_norm=clip_dist_norm)
        super().__init__(params, defaults)
        self.logger = getLogger(__name__)

        self.server_node_rank = server_node_rank
        self.is_standalone = not (isinstance(server_node_rank, int)
                                  and server_node_rank >= 0)

        self.rank = rank
        self.local_step = local_step
        self.step_counter = 0
        self._max_step = self.local_step*int(1e10)  # local_stepの整数倍
        self.comm_cnt = 0  # 累積コミュニケーション回数
        self.tag_offset = tag_offset  # 送受信タグを他のラッパクラスと分離できるように実装を残す
        self.is_send_grad_norm = True
        self.is_warmup = True
        self.warmup_t = warmup_t
        self.fix_eta = fix_eta
        self.fix_mu = fix_mu

        self.without_model_merge = without_model_merge

        self.nsample = nsample
        self.use_global_model_on_calc_local_loss = use_global_model_on_calc_local_loss

        self.backend = dist.get_backend()

    @torch.no_grad()
    def step(self, closure: typing.Callable = None, **kwargs):
        assert closure is not None
        loss = None
        # closure必須とする。closure内部でlossの計算とbackwardの呼び出しを行わなければならない。
        with torch.enable_grad():
            loss = closure()  # call backward in closure method.

        for group in self.param_groups:
            # 記録用
            group["param_norm"] = calc_param_norm(group)
            group["grad_norm"] = calc_grad_norm(group).to(dtype=torch.float32,
                                                          device="cpu")
            if not "grad_sq_norm_sum" in group:
                group["grad_sq_norm_sum"] = group["grad_norm"]**2
            if not "G_buff" in group:
                group["G_buff"] = []
            group["G_buff"].append(group["grad_norm"]**2)

        if self.is_send_grad_norm:  # send gradient norm
            if not self.is_standalone:
                _, tasks = self.send_grad_norm(self.tag_offset)
                for task in dist.batch_isend_irecv(tasks):
                    task.wait()
            self.is_send_grad_norm = False

        # update parameter
        for group in self.param_groups:
            # 送信用データ
            group["client_loss"] = loss.detach().cpu().float() if isinstance(loss, torch.Tensor) \
                else loss

            # #     self.update_group_state()
            # self.logger.log(5, "step=%d, rbar=%f, G=%f, eta=%f, mu=%f, rdist=%f, grad_norm=%f, param_norm=%f, warmup=%s",
            #                 self.step_counter, *self._trace_log_items(group), self.is_warmup)

            # r, u, v = group["lr"], group["u"], group["v"]
            mu, eta = group["mu"], group["eta"]
            weight_decay = group["weight_decay"]

            for p in group["params"]:
                if p.grad is None or not p.requires_grad:
                    continue
                if weight_decay > 0.0:
                    p.grad.add_(p.data, alpha=weight_decay)

                state = self.state[p]
                if not "xbar" in state:
                    state["xbar"] = p.clone().detach()
                # if state is None or len(state) == 0:
                #     state["xbar"] = p.clone().detach()
                #     state["init_buffer"] = p.clone().detach()
                #     group["G_buff"] = []
                #     group["rbar_buff"] = []

                # x_{i}^(t+1) = x_{i}^(t) - eta^{t} ( ∇f(x_{i}^{t}) + 1/mu^{t}*( x_{i}^{t} - bar{x}^{t} ) )
                # p.data.add_((p.grad + (p.data-state["xbar"])/mu),
                #             alpha=-eta)

                if self.fix_mu and mu == 0.0:
                    grad = p.grad
                else:
                    grad = p.grad + (p.data-state["xbar"])/mu
                # grad = p.grad + (state["xbar"]-p.data)/mu
                p.add_(grad, alpha=-eta)

        # send/recv parameter
        if self.step_counter > 0 and self.step_counter % self.local_step == 0:

            # update global model.
            if not self.is_standalone:
                tag_offset = self.tag_offset
                # send model params
                tag_offset, send_tasks = self.send_param(tag_offset)
                # recv model params
                tag_offset, recv_tasks, recv_params = self.recv_param(
                    tag_offset)

                for task in dist.batch_isend_irecv(send_tasks+recv_tasks):
                    task.wait()

                if not self.use_global_model_on_calc_local_loss:
                    with torch.enable_grad():
                        # call backward in closure method.
                        loss = closure(recalc=isinstance(self.nsample, int),
                                       nsample=self.nsample)
                    if isinstance(self.nsample, int):
                        self.logger.log(5, "client_loss %f -> %f",
                                        group["client_loss"], loss.detach().cpu().item())
                    # 送信用データ
                    group["client_loss"] = loss.detach().cpu().float() if isinstance(loss, torch.Tensor) \
                        else loss

                # update global model.
                for group in self.param_groups:
                    for p in group["params"]:
                        if not p.requires_grad or p.grad is None:
                            continue
                        if not p in recv_params:
                            continue
                        state = self.state[p]
                        xbar = recv_params[p].to(device=p.device,
                                                 dtype=p.dtype)
                        state["xbar"] = xbar
                        p.data.copy_(xbar)

                if self.use_global_model_on_calc_local_loss:
                    with torch.enable_grad():
                        # call backward in closure method.
                        loss = closure(recalc=True, nsample=self.nsample)
                    self.logger.log(5, "client_loss %f -> %f",
                                    group["client_loss"], loss.detach().cpu().item())
                    # 送信用データ
                    group["client_loss"] = loss.detach().cpu().float() if isinstance(loss, torch.Tensor) \
                        else loss

                # send loss
                tag_offset, send_loss_tasks = self.send_loss(tag_offset)
                # recv mu
                tag_offset, recv_mu_tasks, recv_mus = self.recv_mu(tag_offset)
                # recv eta
                tag_offset, recv_eta_tasks, recv_etas = self.recv_eta(
                    tag_offset)

                for task in dist.batch_isend_irecv(send_loss_tasks + recv_mu_tasks + recv_eta_tasks):
                    task.wait()

                # apply next mu, eta.
                for group, mu, eta in zip(self.param_groups, recv_mus.detach().cpu().numpy(),
                                          recv_etas.detach().cpu().numpy()):
                    group["mu"], group["eta"] = float(mu), float(eta)
            self.is_send_grad_norm = True

        # logging.
        for group in self.param_groups:
            self.logger.log(5,
                            ", ".join(["step=%d", "rbar=%f", "G=%f", "eta=%f", "mu=%f", "rdist=%f",
                                       "client_loss=%f", "grad_norm=%f", "param_norm=%f", "warmup=%s"]),
                            self.step_counter, *self._trace_log_items(group), self.is_warmup)

        self.step_counter = (self.step_counter + 1) % self._max_step
        if self.is_warmup and self.step_counter >= self.warmup_t:
            self.is_warmup, self.is_send_grad_norm = False, True
            for group in self.param_groups:
                g_buff = group["G_buff"]
                group["grad_sq_norm_sum"] = torch.sum(torch.stack(g_buff, dim=0),
                                                      dim=0) if len(g_buff) > 0 else group["grad_norm"]**2
                group["G_buff"] = []

    @torch.no_grad()
    def send_param(self, tag_offset: int) -> typing.Tuple[int, typing.List[dist.P2POp]]:
        tasks = []
        i = tag_offset
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None or not p.requires_grad:
                    continue
                tasks.append(dist.P2POp(dist.isend,
                                        p.cpu() if self.backend == "gloo" else p,
                                        self.server_node_rank, tag=i))
                i += 1
        return i, tasks

    @torch.no_grad()
    def recv_param(self, tag_offset: int) -> typing.Tuple[int, typing.List[dist.P2POp], typing.Dict[torch.Tensor, torch.Tensor]]:
        tasks = []
        recieved_params = {}

        i = tag_offset
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None or not p.requires_grad:
                    continue
                buf = torch.zeros_like(p,
                                       device="cpu" if self.backend == "gloo" else p.device)
                tasks.append(dist.P2POp(dist.irecv, buf,
                             self.server_node_rank, tag=i))
                recieved_params[p] = buf
                i += 1
        return i, tasks, recieved_params

    @torch.no_grad()
    def recv_mu(self, tag_offset: int) -> typing.Tuple[int, typing.List[dist.P2POp], torch.Tensor]:
        # must be float and cpu
        buff = torch.zeros(len(self.param_groups))
        tasks = [dist.P2POp(dist.irecv, buff, self.server_node_rank,
                            tag=tag_offset)]
        return tag_offset+len(tasks), tasks, buff

    @torch.no_grad()
    def send_loss(self, tag_offset: int) -> typing.Tuple[int, typing.List[dist.P2POp]]:
        # force change dtype to float32 and cpu
        losses = torch.tensor([group["client_loss"]
                              for group in self.param_groups])
        assert not torch.any(torch.isnan(losses)), \
            f"loss is NaN: norms={losses.detach().cpu().numpy()} step={self.step_counter}"

        tasks = [dist.P2POp(dist.isend,
                            losses,  # norms allocated on main memory.,
                            self.server_node_rank, tag=tag_offset)]
        return tag_offset+len(tasks), tasks

    @torch.no_grad()
    def recv_eta(self, tag_offset: int) -> typing.Tuple[int, typing.List[dist.P2POp], torch.Tensor]:
        # must be float and cpu
        buff = torch.zeros(len(self.param_groups))
        tasks = [dist.P2POp(dist.irecv, buff, self.server_node_rank,
                            tag=tag_offset)]
        return tag_offset+len(tasks), tasks, buff

    @torch.no_grad()
    def send_grad_norm(self, tag_offset: int) -> typing.Tuple[int, typing.List[dist.P2POp]]:
        # force change dtype to float32 and cpu
        for group in self.param_groups:
            g_buff = group["G_buff"]
            group["grad_sq_norm_sum"] = torch.sum(torch.stack(g_buff, dim=0),
                                                  dim=0) if len(g_buff) > 0 else group["grad_norm"]**2
            # group["grad_sq_norm_sum"] = torch.mean(torch.stack(g_buff, dim=0),
            #                                       dim=0) if len(g_buff) > 0 else group["grad_norm"]**2
            group["G_buff"] = []

        norms = torch.tensor([group["grad_sq_norm_sum"]
                             for group in self.param_groups])
        assert not torch.any(torch.isnan(norms)), \
            f"grad norm is NaN: norms={norms.detach().cpu().numpy()} step={self.step_counter}"

        tasks = [dist.P2POp(dist.isend,
                            norms,  # norms allocated on main memory.,
                            self.server_node_rank, tag=tag_offset)]
        return tag_offset+len(tasks), tasks


class FedProxDoLEarlyServer(torch.optim.Optimizer, AFedProxDoLEarlyOptimizer):
    def _trace_log_items(self, group):
        return (group.get("mu", np.nan),
                group.get("eta", np.nan),
                group.get("r", np.nan),
                group.get("rdist", np.nan),
                group.get("vu", np.nan),
                group.get("vu_eta", np.nan),
                group.get("grad_sq_norm", np.nan),
                group.get("grad_norm", np.nan),
                group.get("param_norm", np.nan),
                group.get("model_sq_norm", np.nan),
                group.get("normalized_model_sq_norm", np.nan),
                group.get("delta", np.nan),
                group.get("client_loss", np.nan),
                group.get("server_loss", np.nan),)

    def __init__(self, params, rank: int, local_step: int,
                 lr: float = 1e-3, vu: float = 1e-4,
                 reps_rel: float = 1e-6, weight_decay: float = 0.0, eps: float = 1e-8,
                 init_eta: typing.Optional[float] = None, init_mu: typing.Optional[float] = None,
                 clip_param_norm: float = None, clip_dist_norm: float = None,
                 fix_eta: bool = False, fix_mu: bool = False,
                 without_model_merge: bool= False,
                 nsample: int = None, use_global_model_on_calc_local_loss: bool = False,
                 client_node_ranks: typing.List[int] = [],
                 warmup_t: int = 0, warmup_eta: float = None, warmup_mu: float = None,
                 tag_offset: int = 0):
        init_mu = (self.calc_mu(lr, vu) if not isinstance(init_mu, float)
                   else init_mu)
        defaults = dict(lr=lr, r=lr, vu=vu,
                        mu=init_mu if warmup_mu is None else warmup_mu,
                        init_mu=init_mu, warmup_eta=warmup_eta,
                        reps_rel=reps_rel, weight_decay=weight_decay, eps=eps, init_eta=init_eta,
                        clip_param_norm=clip_param_norm, clip_dist_norm=clip_dist_norm)
        super().__init__(params, defaults)
        self.logger = getLogger(__name__)

        self.client_node_ranks = client_node_ranks
        self.is_standalone = self.client_node_ranks is None \
            or len(self.client_node_ranks) == 0

        self.rank = rank
        self.local_step = local_step
        self.step_counter = 0
        self._max_step = self.local_step*int(1e10)
        self.comm_cnt = 0  # 累積コミュニケーション回数
        self.tag_offset = tag_offset  # 送受信タグを他のラッパクラスと分離できるように実装を残す
        self.is_recv_grad_norm = True
        self.is_warmup = True
        self.warmup_t = warmup_t
        self.fix_eta = fix_eta
        self.fix_mu = fix_mu

        self.without_model_merge = without_model_merge

        self.backend = dist.get_backend()

    @torch.no_grad()
    def step(self, closure: typing.Callable = None, **kwargs):
        assert closure is not None
        server_loss = None

        if self.is_recv_grad_norm:  # send gradient norm
            _ = self.recv_grad_norm(self.tag_offset)
            self.is_recv_grad_norm = False

        # send/recv parameter
        # self.step_count == 0 is first_step...
        if self.step_counter > 0 and self.step_counter % self.local_step == 0:
            tag_offset = self.tag_offset
            tag_offset = self.recv_param(tag_offset)  # Recv client model
            # Send global model to clients
            tag_offset = self.send_param(tag_offset)
            # recv and mean to group["client_loss"]
            tag_offset = self.recv_loss(tag_offset)

            server_loss = None
            if closure is not None:
                with torch.enable_grad():
                    server_loss = closure()

            for group in self.param_groups:  # Update global model
                group["param_norm"] = calc_param_norm(group)
                # group["grad_norm"] = calc_grad_norm(group).to(dtype=torch.float32,
                #                                               device="cpu")
                if not "vu_eta" in group:
                    group["vu_eta"] = float(group["vu"])
                if not "eta" in group:
                    group["eta"] = float(group["init_eta"])
                server_loss = server_loss.detach().cpu().float().item()
                group["server_loss"] = server_loss  # 記録用に保持

                # if self.is_warmup:  # warmup中はmu, r, v/uなどハイパーパラメータを更新しない
                #     self.logger.log(5, "step=%d, mu=%f, eta=%f, r=%f, vu=%f, vu_eta=%f, grad_norm=%f, param_norm=%f, delta=%f, client_loss=%f, server_loss=%f, warmup=%s",
                #                     self.step_counter, *self._trace_log_items(group), self.is_warmup)
                #     continue

                r, eps = group["r"], group["eps"]
                grad_sq_norm = group["grad_sq_norm"]
                server_loss = group["server_loss"]

                # update r
                rdist = []
                for p in group["params"]:
                    if not p.requires_grad:
                        continue
                    state = self.state[p]
                    if not "init_buffer" in state:
                        state["init_buffer"] = p.clone().detach()
                    if not "xbar0" in state:
                        state["xbar0"] = p.clone().detach()

                    # xbar = state["xbar"] # updated on recv_param
                    # 異なるdtypeが挟まると更新できない
                    rdist.append((p-state["init_buffer"]).norm())

                rdist = torch.stack(rdist).norm().detach().cpu().item()
                clip_dist_norm = group["clip_dist_norm"]
                if isinstance(clip_dist_norm, float):
                    rdist = np.clip(rdist, None, clip_dist_norm)
                r = float(np.max([r, rdist]))

                # update eta
                # 直前時間のグローバルモデルと今の時間のグローバルモデルの差分のnorm
                # grad_norm = torch.stack([torch.norm((p.detach() - self.state[p]["xbar0"])/((group["eta"])*self.local_step))
                #                          for p in group["params"]
                #                          if p.requires_grad]).norm().detach().item()
                # group["grad_norm"] = grad_norm
                # vu_eta = group["vu_eta"]
                # vu_eta = vu_eta + grad_norm ** 2 * self.local_step
                # # vu_eta = vu_eta + grad_sq_norm
                # eta = float(self.calc_mu(r, vu_eta, eps=eps))
                grad_norm = torch.stack([torch.norm((p.detach() - self.state[p]["xbar0"])/((group["eta"])))
                                         for p in group["params"]
                                         if p.requires_grad]).norm().detach().item()
                group["grad_norm"] = grad_norm
                vu_eta = group["vu_eta"]
                # vu_eta = vu_eta + grad_norm ** 2
                # vu_eta = vu_eta + grad_sq_norm
                vu_eta = float(self.calc_vu(r, vu_eta, grad_sq_norm))
                eta = float(self.calc_mu(r, vu_eta, eps=eps))

                # update mu
                vu, mu = group["vu"], group["mu"]
                normalized_model_sq_norm = group["model_sq_norm"] / (2*(mu)) if mu != 0.0 \
                    else group["model_sq_norm"]
                # normalized_model_sq_norm = group["model_sq_norm"] / (2*(mu**2))
                group["normalized_model_sq_norm"] = normalized_model_sq_norm
                # delta = np.max([server_loss - group["client_loss"], 0.0])
                delta = np.max([server_loss - group["client_loss"] - normalized_model_sq_norm,
                                0.0])

                # DoL, DoWL差分をメソッド抽出
                # vu = vu + delta / (mu+eps)
                # mu = r / np.sqrt(vu)
                vu = float(self.calc_vu(r, vu, delta * mu))
                mu = float(self.calc_mu(r, vu, eps=eps))
                # if not self.fix_mu:
                #     mu = r / np.sqrt(vu)

                if not self.is_warmup:
                    if not self.fix_mu:
                        group["mu"] = mu
                    if not self.fix_eta:
                        group["eta"] = eta
                    group["r"], group["vu"] = r, vu
                    group["rdist"] = rdist
                    group["vu_eta"] = vu_eta
                # group["r"], group["vu"] = r, vu
                # group["rdist"] = rdist
                # group["vu_eta"] = vu_eta
                group["delta"] = delta

            self.update_global_model(closure=closure, **kwargs)

            tag_offset = self.send_mu(tag_offset)  # Send global mu to clients.
            # Send global mu to clients.
            tag_offset = self.send_eta(tag_offset)

            self.is_recv_grad_norm = True

        self.step_counter = (self.step_counter + 1) % self._max_step
        if self.is_warmup and self.step_counter >= self.warmup_t:
            self.is_warmup, self.is_recv_grad_norm = False, True
            for group in self.param_groups:
                for p in group["params"]:
                    if not p.requires_grad:
                        continue
                    state = self.state[p]
                    # update init_buffer on current parameters.
                    state["init_buffer"] = p.clone().detach()
                group["mu"] = group["init_mu"]  # reset mu
                group["eta"] = group["init_eta"]  # reset mu
            # for group in self.param_groups:
            #     self.logger.log(5, "end of warmup step=%d, rbar=%f, mu=%f vu=%f, delta=%f",
            #                     self.step_counter, group["rbar"], group["mu"], group["vu"], group["delta"])

        return server_loss

    def calc_vu(self, r: float, vu: float, delta: float) -> float:
        return calc_vu_dog(r, vu, delta)

    def calc_mu(self, r: float, vu: float, eps: float = 0.0) -> float:
        return calc_mu_dog(r, vu, eps=eps)

    @torch.no_grad()
    def update_global_model(self, closure: typing.Callable = None, **kwargs):
        for group in self.param_groups:  # Update global model
            self.logger.log(5,
                            ", ".join(["step=%d", "mu=%f", "eta=%f", "r=%f", "rdist=%f", "vu=%f", "vu_eta=%f",
                                       "grad_sq_norm=%f", "grad_norm=%f", "param_norm=%f",
                                       "model_sq_norm=%f", "normalized_model_sq_norm=%f",
                                       "delta=%f", "client_loss=%f", "server_loss=%f", "warmup=%s"]),
                            self.step_counter, *self._trace_log_items(group), self.is_warmup)
            for p in group["params"]:
                if not p.requires_grad:
                    continue
                state = self.state[p]
                state["xbar0"] = p.clone().detach()

    @torch.no_grad()
    def recv_loss(self, tag_offset) -> int:
        # サーバサイドでは学習を行わないのでgradは常にNoneになる。
        tasks, recvd = [], []
        for node_id in self.client_node_ranks:
            # client側で必ずfloat32にキャスト、CPUに転送する。
            buff = torch.zeros(len(self.param_groups),
                               device="cpu", dtype=torch.float32)
            tasks.append(dist.P2POp(dist.irecv, buff, node_id, tag=tag_offset))
            recvd.append(buff)

        for task in dist.batch_isend_irecv(tasks):
            task.wait()

        losses = torch.mean(torch.stack(recvd, dim=0),
                            dim=0).detach().cpu().numpy()

        for group, loss in zip(self.param_groups, losses):
            group["client_loss"] = float(loss)

        return tag_offset+1

    @torch.no_grad()
    def send_eta(self, tag_offset) -> int:
        etas, tasks = [], []

        for group in self.param_groups:
            etas.append(group["eta"] if "eta" in group else 0.0)

        etas = torch.tensor(np.array(etas), dtype=torch.float32, device="cpu")
        for node_id in self.client_node_ranks:
            # mus allocated main memory, because mu is float value.
            tasks.append(dist.P2POp(dist.isend,
                                    etas, node_id, tag=tag_offset))

        for task in dist.batch_isend_irecv(tasks):
            task.wait()

        return tag_offset+1

    @torch.no_grad()
    def recv_grad_norm(self, tag_offset) -> int:
        # サーバサイドでは学習を行わないのでgradは常にNoneになる。
        tasks, recvd = [], []
        for node_id in self.client_node_ranks:
            # client側で必ずfloat32にキャスト、CPUに転送する。
            buff = torch.zeros(len(self.param_groups),
                               device="cpu", dtype=torch.float32)
            tasks.append(dist.P2POp(dist.irecv, buff, node_id, tag=tag_offset))
            recvd.append(buff)

        for task in dist.batch_isend_irecv(tasks):
            task.wait()

        grad_sq_norms = torch.mean(torch.stack(recvd, dim=0),
                                   dim=0).detach().cpu().numpy()

        for group, grad_sq_norm in zip(self.param_groups, grad_sq_norms):
            group["grad_sq_norm"] = float(grad_sq_norm)

        return tag_offset+1

    @torch.no_grad()
    def send_mu(self, tag_offset) -> int:
        mus, tasks = [], []

        for group in self.param_groups:
            mus.append(group["mu"] if "mu" in group else 0.0)

        mus = torch.tensor(np.array(mus), dtype=torch.float32, device="cpu")
        for node_id in self.client_node_ranks:
            # mus allocated main memory, because mu is float value.
            tasks.append(dist.P2POp(dist.isend,
                                    mus, node_id, tag=tag_offset))

        for task in dist.batch_isend_irecv(tasks):
            task.wait()

        return tag_offset+1

    @torch.no_grad()
    def send_param(self, tag_offset) -> int:
        tasks = []
        for node_id in self.client_node_ranks:
            i = tag_offset
            for group in self.param_groups:
                for p in group["params"]:
                    if not p.requires_grad:
                        continue
                    tasks.append(dist.P2POp(dist.isend,
                                            p.cpu() if self.backend == "gloo" else p,
                                            node_id, tag=i))
                    i += 1

        for task in dist.batch_isend_irecv(tasks):
            task.wait()

        return i

    @torch.no_grad()
    def recv_param(self, tag_offset) -> int:
        tasks = []
        recieved_params = {}

        last_tag_offset = tag_offset
        for node_id in self.client_node_ranks:
            i = tag_offset
            for group in self.param_groups:
                for p in group["params"]:
                    if not p.requires_grad:
                        continue
                    if p not in recieved_params:
                        recieved_params[p] = {}
                    buff = torch.zeros_like(p,
                                            device="cpu" if self.backend == "gloo" else p.device)
                    tasks.append(dist.P2POp(dist.irecv, buff, node_id, tag=i))
                    recieved_params[p][node_id] = buff
                    i += 1
            last_tag_offset = max(i, last_tag_offset)

        for task in dist.batch_isend_irecv(tasks):
            task.wait()

        # local modelとlocal modelの出発点となったglobal modelの差を記録
        model_norms = {node_id: [] for node_id in self.client_node_ranks}
        for group in self.param_groups:
            for p in group["params"]:
                if p not in recieved_params:
                    continue

                recieved = recieved_params[p]
                for node_id in self.client_node_ranks:
                    model_norms[node_id].append(
                        (recieved[node_id].detach().cpu()-p.detach().cpu()).norm().detach())

                p.data.copy_(torch.mean(torch.stack(list(recieved.values()),
                                                    dim=0), dim=0).to(device=p.device,
                                                                      dtype=p.dtype))
        model_sq_norm = torch.stack(
            [torch.stack(norms).norm()**2 for norms in model_norms.values()]
        ).mean().detach().cpu().item()
        for group in self.param_groups:
            group["model_sq_norm"] = model_sq_norm

        return last_tag_offset
