#! -*- coding: utf-8
import typing
from abc import abstractmethod
from logging import getLogger

import torch
import torch.distributed as dist

__all__ = ["AFedOptServer", "AFedOptClient"]


class AFedOptClient(torch.optim.Optimizer):
    def __init__(self, params, rank: int, local_step: int, lr: float = 0.001,
                 beta1: float = 0.9, beta2: float = 0.999, tau: float = 1e-3,
                 weight_decay: float = 0.0, 
                 server_node_rank: typing.Optional[int] = None,
                 tag_offset: int = 0):
        defaults = dict(lr=lr, beta1=beta1, beta2=beta2, tau=tau,
                        weight_decay=weight_decay)
        super().__init__(params, defaults)
        self.logger = getLogger(__name__)

        self.server_node_rank = server_node_rank
        self.is_standalone = not (isinstance(server_node_rank, int)
                                  and server_node_rank >= 0)

        self.rank = rank
        self.local_step = local_step
        self.step_counter = 0
        self._max_step = self.local_step*int(1e10)
        self.comm_cnt = 0  # 累積コミュニケーション回数
        self.tag_offset = tag_offset  # 送受信タグを他のラッパクラスと分離できるように実装を残す
        self.backend = dist.get_backend()

    @torch.no_grad()
    def step(self, closure: typing.Callable = None, **kwargs):
        loss = None
        # if closure is not None:
        #     with torch.enable_grad():
        #         loss = closure()

        for group in self.param_groups:
            lr = group["lr"]
            weight_decay = group["weight_decay"]
            for p in group["params"]:
                if not p.requires_grad or p.grad is None:
                    continue
                if weight_decay > 0.0:
                    p.grad.add_(p.data, alpha=weight_decay)

                p.add_(p.grad, alpha=-lr)

        if self.step_counter % self.local_step == 0:
            self.exchange()

        self.step_counter = (self.step_counter + 1) % self._max_step
        return loss

    @torch.no_grad()
    def exchange(self):
        if self.is_standalone:  # standalone mode
            return
        tasks, recved = [], {}

        # send local model to server.
        i = self.tag_offset
        for group in self.param_groups:
            for p in group["params"]:
                if not p.requires_grad or p.grad is None:
                    continue
                tasks.append(dist.P2POp(dist.isend, p.cpu() if self.backend == "gloo" else p,
                                        self.server_node_rank, tag=i))
                i += 1

        # recv global model from server
        i = self.tag_offset
        for group in self.param_groups:
            for p in group["params"]:
                if not p.requires_grad or p.grad is None:
                    continue
                buff = torch.zeros_like(p,
                                        device="cpu" if self.backend == "gloo" else p.device)
                tasks.append(dist.P2POp(dist.irecv, buff,
                             self.server_node_rank, tag=i))
                recved[p] = buff
                i += 1

        for task in dist.batch_isend_irecv(tasks):
            task.wait()

        # override local model to server
        for group in self.param_groups:
            for p in group["params"]:
                if not p.requires_grad or p.grad is None:
                    continue
                p.data.copy_(recved[p].to(dtype=p.dtype, device=p.device))


class AFedOptServer(torch.optim.Optimizer):
    def __init__(self, params, rank: int, local_step: int, lr: float = 0.001,
                 beta1: float = 0.9, beta2: float = 0.999, tau: float = 1e-3,
                 weight_decay: float = 0.0, 
                 client_node_ranks: typing.List[int] = [],
                 tag_offset: int = 0):
        defaults = dict(lr=lr, beta1=beta1, beta2=beta2, tau=tau,
                        weight_decay=weight_decay)
        super().__init__(params, defaults)
        self.logger = getLogger(__name__)

        self.client_node_ranks = client_node_ranks
        self.is_standalone = self.client_node_ranks is None \
            or len(self.client_node_ranks) == 0

        self.rank = rank
        self.local_step = local_step
        self.step_counter = 0
        self._max_step = self.local_step*int(1e10)
        self.comm_cnt = 0  # 累積コミュニケーション回数
        self.tag_offset = tag_offset  # 送受信タグを他のラッパクラスと分離できるように実装を残す
        self.backend = dist.get_backend()

        assert not self.is_standalone, f"FedAvg Server Optimizer can't work on stand alone."

    @torch.no_grad()
    def step(self, closure: typing.Callable = None, **kwargs):
        loss = None
        # if closure is not None:
        #     with torch.enable_grad():
        #         loss = closure()

        # no parameter update.
        if self.step_counter % self.local_step == 0:
            deltas = self.recv_param()
            # update momentum from delta.
            self.update_momentum(deltas)
            # update velocity from delta
            self.update_velocity(deltas)
            # TODO update model.

            for group in self.param_groups:
                # tau = eps
                lr, tau = group["lr"], group["tau"]
                for p in group["params"]:
                    if not p.requires_grad:
                        continue
                    state = self.state[p]
                    # must init momentum and velocity in update_momentum and update_velocity methods.
                    m, v = state["momentum"], state["velocity"]
                    p.data.add_((m/(v.sqrt() + tau)), alpha=lr)

            self.send_param()
            self.comm_cnt += 1

        self.step_counter = (self.step_counter + 1) % self._max_step

        return loss

    @torch.no_grad()
    def recv_param(self):
        tasks, recved = [], {}
        for node_id in self.client_node_ranks:
            i = self.tag_offset
            for group in self.param_groups:
                for p in group["params"]:
                    if not p.requires_grad:
                        continue
                    if not p in recved:
                        recved[p] = {}
                    buff = torch.zeros_like(p, dtype=p.dtype,
                                            device="cpu" if self.backend == "gloo" else p.device)
                    tasks.append(dist.P2POp(dist.irecv, buff, node_id, tag=i))
                    recved[p][node_id] = buff
                    i += 1

        for task in dist.batch_isend_irecv(tasks):
            task.wait()

        deltas = {}  # 更新差分を収集
        for group in self.param_groups:
            for p in group["params"]:
                if not p in recved:
                    continue
                # delta = new param - old param
                deltas[p] = torch.stack([recved[p][node_id].to(device=p.device, dtype=p.dtype) - p.data
                                         for node_id in self.client_node_ranks], dim=0).mean(dim=0)

        return deltas

    @torch.no_grad()
    def send_param(self):
        # send global model parameters.
        tasks = []
        for node_id in self.client_node_ranks:
            i = self.tag_offset
            for group in self.param_groups:
                for p in group["params"]:
                    if not p.requires_grad:
                        continue
                    tasks.append(dist.P2POp(dist.isend, p.cpu() if self.backend == "gloo" else p,
                                            node_id, tag=i))
                    i += 1

        for task in dist.batch_isend_irecv(tasks):
            task.wait()

    @torch.no_grad()
    def update_momentum(self, deltas: typing.Dict[torch.Tensor, torch.Tensor]):
        # update momentum from delta.
        for group in self.param_groups:
            beta1 = group["beta1"]
            for p in group["params"]:
                if not p.requires_grad or not p in deltas:
                    continue

                state, delta = self.state[p], deltas[p]
                if not "momentum" in state:
                    state["momentum"] = torch.zeros_like(p)
                momentum = state["momentum"]
                momentum.mul_(beta1).add_(delta, alpha=(1.0-beta1))

    @abstractmethod
    def update_velocity(self, deltas: typing.Dict[torch.Tensor, torch.Tensor]):
        raise NotImplementedError("Abstract method!")
