#! -*- coding: utf-8
import typing
from logging import getLogger

import numpy as np
import torch
import torch.distributed as dist

from .utils import get_group_value

__all__ = ["FedProxServer", "FedProxClient"]


class AFedProxOptimizer(object):
    @property
    def server_loss_best(self): return get_group_value(self.param_groups,
                                                       "server_loss_best")

    @property
    def server_loss_out(self): return get_group_value(self.param_groups,
                                                      "server_loss_out")


class FedProxServer(torch.optim.Optimizer, AFedProxOptimizer):
    def _trace_log_items(self, group):
        return (group.get("server_loss_out", np.nan),
                group.get("server_loss_best", np.nan),)

    def __init__(self, params, rank: int, local_step: int,
                 lr: float = 0.001, mu: float = 0.1,
                 weight_decay: float = 0.0, 
                 use_model_marge: bool = False,
                 client_node_ranks: typing.List[int] = [],
                 tag_offset: int = 0):
        defaults = dict(lr=lr, mu=mu, weight_decay=weight_decay)
        super().__init__(params, defaults)
        self.logger = getLogger(__name__)

        self.client_node_ranks = client_node_ranks
        self.is_standalone = self.client_node_ranks is None \
            or len(self.client_node_ranks) == 0

        self.rank = rank
        self.local_step = local_step
        self.step_counter = 0
        self._max_step = self.local_step*int(1e10)
        self.comm_cnt = 0  # 累積コミュニケーション回数
        self.tag_offset = tag_offset  # 送受信タグを他のラッパクラスと分離できるように実装を残す
        self.backend = dist.get_backend()

        self.use_model_marge = use_model_marge

        assert not self.is_standalone, f"DoG Server Optimizer can't work on stand alone."

    #     return loss
    @torch.no_grad()
    def step(self, closure: typing.Callable = None, **kwargs):
        loss = None
        # if closure is not None:
        #     with torch.enable_grad():
        #         loss = closure()

        # no parameter update.
        if self.step_counter > 0 and self.step_counter % self.local_step == 0:
            tag_offset = self.tag_offset
            tag_offset = self.recv_param(tag_offset)  # Recv client model

            # xbar_out（=今のパラメータ）でserver loss再計算
            if self.use_model_marge and closure is not None:
                with torch.enable_grad():
                    server_loss_out = closure()
                for group in self.param_groups:  # Update global model
                    server_loss_out = server_loss_out.detach().cpu().float().item()
                    group["server_loss_out"] = server_loss_out  # 記録用に保持
                    if not "server_loss_best" in group:
                        group["server_loss_best"] = server_loss_out
                    server_loss_best = group["server_loss_best"]

                    if server_loss_out <= server_loss_best:
                        # update best model: xbar_bestの更新
                        for p in group["params"]:
                            if not p.requires_grad:
                                continue
                            state = self.state[p]
                            state["xbar_best"] = p.clone()
                        # update best server loss
                        group["server_loss_best"] = server_loss_out
                    else:  # pをxbar_bestで更新（クライアントに送信するモデルを前回と同じにする
                        for p in group["params"]:
                            if not p.requires_grad:
                                continue
                            state = self.state[p]
                            if not "xbar_best" in state:
                                state["xbar_best"] = p.clone()
                            p.data.copy_(state["xbar_best"])

                self.logger.log(5,
                                ", ".join(["step=%d",
                                           "server_loss_out=%f", "server_loss_best=%f"]),
                                self.step_counter, *self._trace_log_items(group))

            # Send global model to clients
            tag_offset = self.send_param(tag_offset)

        self.step_counter = (self.step_counter + 1) % self._max_step

        return loss

    @torch.no_grad()
    def recv_param(self, tag_offset) -> int:
        tasks = []
        recieved_params = {}

        last_tag_offset = tag_offset
        for node_id in self.client_node_ranks:
            i = tag_offset
            for group in self.param_groups:
                for p in group["params"]:
                    if not p.requires_grad:
                        continue
                    if p not in recieved_params:
                        recieved_params[p] = {}
                    buff = torch.zeros_like(p,
                                            device="cpu" if self.backend == "gloo" else p.device)
                    tasks.append(dist.P2POp(dist.irecv, buff, node_id, tag=i))
                    recieved_params[p][node_id] = buff
                    i += 1
            last_tag_offset = max(i, last_tag_offset)

        for task in dist.batch_isend_irecv(tasks):
            task.wait()

        for group in self.param_groups:
            for p in group["params"]:
                if p not in recieved_params:
                    continue

                recieved = recieved_params[p]
                p.data.copy_(torch.mean(torch.stack(list(recieved.values()),
                                                    dim=0), dim=0).to(device=p.device,
                                                                      dtype=p.dtype))
        return last_tag_offset

    @torch.no_grad()
    def send_param(self, tag_offset) -> int:
        tasks = []
        for node_id in self.client_node_ranks:
            i = tag_offset
            for group in self.param_groups:
                for p in group["params"]:
                    if not p.requires_grad:
                        continue
                    tasks.append(dist.P2POp(dist.isend,
                                            p.cpu() if self.backend == "gloo" else p,
                                            node_id, tag=i))
                    i += 1

        for task in dist.batch_isend_irecv(tasks):
            task.wait()

        return i


class FedProxClient(torch.optim.Optimizer, AFedProxOptimizer):
    def __init__(self, params, rank: int, local_step: int,
                 lr: float = 0.001, mu: float = 0.1,
                 weight_decay: float = 0.0, 
                 use_model_marge: bool = False,
                 server_node_rank: typing.Optional[int] = None,
                 tag_offset: int = 0):
        defaults = dict(lr=lr, mu=mu, weight_decay=weight_decay)
        super().__init__(params, defaults)

        self.server_node_rank = server_node_rank
        self.is_standalone = not (isinstance(server_node_rank, int)
                                  and server_node_rank >= 0)

        self.rank = rank
        self.local_step = local_step
        self.step_counter = 0
        self._max_step = self.local_step*int(1e10)
        self.comm_cnt = 0  # 累積コミュニケーション回数
        self.tag_offset = tag_offset  # 送受信タグを他のラッパクラスと分離できるように実装を残す
        self.backend = dist.get_backend()
        self.logger = getLogger(__name__)

        self.use_model_marge = use_model_marge

    @torch.no_grad()
    def step(self, closure: typing.Callable = None, **kwargs):
        loss = None
        # if closure is not None:
        #     with torch.enable_grad():
        #         loss = closure()

        for group in self.param_groups:
            lr, mu = group["lr"], group["mu"]
            weight_decay = group["weight_decay"]
            for p in group["params"]:
                if not p.requires_grad or p.grad is None:
                    continue
                if weight_decay > 0.0:
                    p.grad.add_(p.data, alpha=weight_decay)

                state = self.state[p]
                if state is None or len(state) == 0:
                    state["xbar"] = p.clone()

                grad = p.grad + mu*(p.data - state["xbar"])
                p.add_(grad, alpha=-lr)

        # if self.step_counter % self.local_step == 0:
        #     self.exchange()

        # send/recv parameter
        if self.step_counter > 0 and self.step_counter % self.local_step == 0:

            # update global model.
            if not self.is_standalone:
                tag_offset = self.tag_offset
                # send model params
                tag_offset, send_tasks = self.send_param(tag_offset)
                # recv model params
                tag_offset, recv_tasks, recv_params = self.recv_param(
                    tag_offset)

                for task in dist.batch_isend_irecv(send_tasks+recv_tasks):
                    task.wait()

                # update global model.
                for group in self.param_groups:
                    for p in group["params"]:
                        if not p.requires_grad or p.grad is None:
                            continue
                        if not p in recv_params:
                            continue
                        state = self.state[p]
                        xbar = recv_params[p].to(device=p.device,
                                                 dtype=p.dtype)
                        state["xbar"] = xbar
                        p.data.copy_(xbar)

        self.step_counter = (self.step_counter + 1) % self._max_step
        return loss

    @torch.no_grad()
    def send_param(self, tag_offset: int) -> typing.Tuple[int, typing.List[dist.P2POp]]:
        tasks = []
        i = tag_offset
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None or not p.requires_grad:
                    continue
                tasks.append(dist.P2POp(dist.isend,
                                        p.cpu() if self.backend == "gloo" else p,
                                        self.server_node_rank, tag=i))
                i += 1
        return i, tasks

    @torch.no_grad()
    def recv_param(self, tag_offset: int) -> typing.Tuple[int, typing.List[dist.P2POp], typing.Dict[torch.Tensor, torch.Tensor]]:
        tasks = []
        recieved_params = {}

        i = tag_offset
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None or not p.requires_grad:
                    continue
                buf = torch.zeros_like(p,
                                       device="cpu" if self.backend == "gloo" else p.device)
                tasks.append(dist.P2POp(dist.irecv, buf,
                             self.server_node_rank, tag=i))
                recieved_params[p] = buf
                i += 1
        return i, tasks, recieved_params
