#! -*- coding: utf-8
import typing
from logging import getLogger

import numpy as np
import torch
import torch.distributed as dist

from .utils import get_group_value

__all__ = ["FedDynServer", "FedDynClient"]


class AFedDynOptimizer(object):
    @property
    def loss_correct(self):
        return get_group_value(self.param_groups,
                               "loss_correct")

    @property
    def grad_norm(self): return get_group_value(self.param_groups, "grad_norm")


class FedDynServer(torch.optim.Optimizer, AFedDynOptimizer):
    def __init__(self, params, rank: int, local_step: int, lr: float = 0.001,
                 alpha: float = 0.0, max_grad_norm: float = None,
                 client_node_ranks: typing.List[int] = [],
                 tag_offset: int = 0):
        defaults = dict(lr=lr, alpha=alpha, max_grad_norm=max_grad_norm)
        super().__init__(params, defaults)
        self.logger = getLogger(__name__)

        self.client_node_ranks = client_node_ranks
        self.is_standalone = self.client_node_ranks is None \
            or len(self.client_node_ranks) == 0

        self.rank = rank
        self.local_step = local_step
        self.step_counter = 0
        self._max_step = self.local_step*int(1e10)
        self.comm_cnt = 0  # 累積コミュニケーション回数
        self.tag_offset = tag_offset  # 送受信タグを他のラッパクラスと分離できるように実装を残す
        self.backend = dist.get_backend()

        assert not self.is_standalone, f"FedAvg Server Optimizer can't work on stand alone."

    @torch.no_grad()
    def step(self, closure: typing.Callable = None, **kwargs):
        loss = None
        # if closure is not None:
        #     with torch.enable_grad():
        #         loss = closure()

        # no parameter update.
        if self.step_counter > 0 and self.step_counter % self.local_step == 0:
            self.recv_params()  # update global model.
            self.send_params()

        self.step_counter = (self.step_counter + 1) % self._max_step

        return loss

    @torch.no_grad()
    def recv_params(self):
        # recv nabla
        # recved client model parameters.
        tasks, recved = [], {}
        for node_id in self.client_node_ranks:
            i = self.tag_offset
            for group in self.param_groups:
                for p in group["params"]:
                    if not p.requires_grad:
                        continue
                    if not p in recved:
                        recved[p] = {}
                    buff = torch.zeros_like(p, dtype=p.dtype,
                                            device="cpu" if self.backend == "gloo" else p.device)
                    tasks.append(dist.P2POp(dist.irecv, buff, node_id, tag=i))
                    recved[p][node_id] = buff
                    i += 1

        for task in dist.batch_isend_irecv(tasks):
            task.wait()

        for group in self.param_groups:
            for p in group["params"]:
                if not p.requires_grad:
                    continue
                p.data.add_(torch.stack(list(recved[p].values()),
                                        dim=0).mean(dim=0).to(device=p.device, dtype=p.dtype))

        # for group in self.param_groups:
        #     for p in group["params"]:
        #         if not p.requires_grad:
        #             continue
        #         state = self.state[p]
        #         if not "nabla" in state:
        #             state["nabla"] = {node_id: torch.zeros_like(p)
        #                               for node_id in self.client_node_ranks}

        #         nabla = state["nabla"]
        #         for node_id in self.client_node_ranks:
        #             nabla[node_id].add_((p - recved[p][node_id].to(device=p.device, dtype=p.dtype)),
        #                                 alpha=-1.0)

        #         p.data.add_(torch.stack(list(nabla.values()),
        #                                 dim=0).mean(dim=0))

    @torch.no_grad()
    def send_params(self):
        # send global model parameters.
        tasks = []
        for node_id in self.client_node_ranks:
            i = self.tag_offset
            for group in self.param_groups:
                for p in group["params"]:
                    if not p.requires_grad:
                        continue
                    tasks.append(dist.P2POp(dist.isend, p.cpu() if self.backend == "gloo" else p,
                                            node_id, tag=i))
                    i += 1

        for task in dist.batch_isend_irecv(tasks):
            task.wait()


class FedDynClient(torch.optim.Optimizer, AFedDynOptimizer):
    def _trace_log_items(self, group):
        return (group.get("loss_correct", np.nan),
                group.get("grad_norm", np.nan),)

    def __init__(self, params, rank: int, local_step: int, lr: float = 0.001,
                 alpha: float = 0.0, max_grad_norm: float = None,
                 server_node_rank: typing.Optional[int] = None,
                 tag_offset: int = 0):
        defaults = dict(lr=lr, alpha=alpha, max_grad_norm=max_grad_norm)
        super().__init__(params, defaults)
        self.logger = getLogger(__name__)

        self.server_node_rank = server_node_rank
        self.is_standalone = not (isinstance(server_node_rank, int)
                                  and server_node_rank >= 0)

        self.rank = rank
        self.local_step = local_step
        self.step_counter = 0
        self._max_step = self.local_step*int(1e10)
        self.comm_cnt = 0  # 累積コミュニケーション回数
        self.tag_offset = tag_offset  # 送受信タグを他のラッパクラスと分離できるように実装を残す
        self.backend = dist.get_backend()

    @torch.no_grad()
    def step(self, closure: typing.Callable = None, **kwargs):
        assert closure is not None
        loss = None
        # closure必須とする。closure内部でlossの計算とbackwardの呼び出しを行わなければならない。
        with torch.enable_grad():
            loss = closure(recalc=False)  # call backward in closure method.

        for group in self.param_groups:
            lr = group["lr"]
            alpha = group["alpha"]
            max_grad_norm = group["max_grad_norm"]

            trainables = [p for p in group["params"]
                          if p.requires_grad or not p.grad is None]

            # モデルパラメータを使用してgradを補正
            # この部分の処理だけgradを有効にする
            with torch.enable_grad():
                flatten_params = torch.cat([p.flatten() for p in trainables])
                delta = torch.cat([(self.state[p]["delta"] if "delta" in self.state[p] else torch.zeros_like(p)).flatten()
                                   for p in trainables])
                loss_correct = torch.sum(flatten_params * delta)
                for p, g in zip(trainables, torch.autograd.grad(alpha * loss_correct,
                                                                trainables)):
                    p.grad.add_(g)

            grad_norm = torch.stack([p.grad.norm() for p in group["params"]
                                     if p.requires_grad and not p.grad is None]).norm().detach().cpu().item()
            if isinstance(max_grad_norm, float) and max_grad_norm > 0.0 and max_grad_norm < grad_norm:
                coef = max_grad_norm / grad_norm
                for p in group["params"]:
                    if not p.requires_grad or p.grad is None:
                        continue
                    p.grad.mul_(coef)

            for p in group["params"]:
                if not p.requires_grad or p.grad is None:
                    continue
                state = self.state[p]
                if not "global" in state:
                    state["global"] = p.data.clone()
                if not "delta" in state:
                    state["delta"] = torch.zeros_like(p)

                p.add_(p.grad, alpha=-lr)

            # ログ用に記録
            group["loss_correct"] = loss_correct.detach().cpu().item()
            group["grad_norm"] = grad_norm
            self.logger.log(5, "step=%d, loss_correct=%f, grad_norm=%s",
                            self.step_counter, *self._trace_log_items(group))

        if self.step_counter > 0 and self.step_counter % self.local_step == 0:
            self.exchange()

        self.step_counter = (self.step_counter + 1) % self._max_step
        return loss

    @torch.no_grad()
    def exchange(self):
        if self.is_standalone:  # standalone mode
            return
        tasks, recved = [], {}

        # update nabla
        for group in self.param_groups:
            for p in group["params"]:
                if not p.requires_grad or p.grad is None:
                    continue
                state = self.state[p]
                global_params = state["global"]

                if not "nabla" in state:
                    state["nabla"] = torch.zeros_like(p)
                nabla = state["nabla"]
                nabla.add_((global_params - p.data), alpha=-1.0)

        # # send local model to server.
        # i = self.tag_offset
        # for group in self.param_groups:
        #     for p in group["params"]:
        #         if not p.requires_grad or p.grad is None:
        #             continue
        #         tasks.append(dist.P2POp(dist.isend, p.cpu() if self.backend == "gloo" else p,
        #                                 self.server_node_rank, tag=i))
        #         i += 1

        # send nabla to server.
        i = self.tag_offset
        for group in self.param_groups:
            for p in group["params"]:
                if not p.requires_grad or p.grad is None:
                    continue
                nabla = self.state[p]["nabla"]
                tasks.append(dist.P2POp(dist.isend, nabla.cpu() if self.backend == "gloo" else nabla,
                                        self.server_node_rank, tag=i))
                i += 1

        # recv global model from server
        i = self.tag_offset
        for group in self.param_groups:
            for p in group["params"]:
                if not p.requires_grad or p.grad is None:
                    continue
                buff = torch.zeros_like(p,
                                        device="cpu" if self.backend == "gloo" else p.device)
                tasks.append(dist.P2POp(dist.irecv, buff,
                             self.server_node_rank, tag=i))
                recved[p] = buff
                i += 1

        for task in dist.batch_isend_irecv(tasks):
            task.wait()

        # override local model to server
        for group in self.param_groups:
            for p in group["params"]:
                if not p.requires_grad or p.grad is None:
                    continue
                global_params = recved[p].to(dtype=p.dtype, device=p.device)

                state = self.state[p]

                state["global"] = global_params
                p.data.copy_(global_params)

                # update delta
                state["delta"] = state["nabla"] - global_params
