#! -*- coding: utf-8
import typing
from logging import getLogger

import torch
import torch.distributed as dist

from .utils import calc_grad_norm, calc_param_norm, calc_state_norm, get_group_value


class AScaffoldOptimizer:
    @property
    def param_norm(self) -> float: return get_group_value(self.param_groups,
                                                          "param_norm")

    @property
    def grad_norm(self) -> float: return get_group_value(self.param_groups,
                                                         "grad_norm")

    @property
    def delta_x_norm(self) -> float: return get_group_value(self.param_groups,
                                                            "delta_x_norm")

    @property
    def local_c_norm(self) -> float: return get_group_value(self.param_groups,
                                                            "local_c_norm")

    @property
    def global_c_norm(self) -> float: return get_group_value(self.param_groups,
                                                             "global_c_norm")


class ScaffoldServer(torch.optim.Optimizer, AScaffoldOptimizer):
    def __init__(self, params, rank: int, local_step: int, lr: float = 0.001,
                 server_lr: float = None, client_lr: float = None,
                 weight_decay: float = 0.0, 
                 client_node_ranks: typing.List[int] = [],
                 tag_offset: int = 0):
        lr = lr if not isinstance(server_lr, float) else server_lr
        defaults = dict(lr=lr, weight_decay=weight_decay)
        super().__init__(params, defaults)
        self.logger = getLogger(__name__)

        self.client_node_ranks = client_node_ranks
        self.is_standalone = self.client_node_ranks is None \
            or len(self.client_node_ranks) == 0

        self.rank = rank
        self.local_step = local_step
        self.step_counter = 0
        self._max_step = self.local_step*int(1e10)
        self.comm_cnt = 0  # 累積コミュニケーション回数
        self.tag_offset = tag_offset  # 送受信タグを他のラッパクラスと分離できるように実装を残す
        self.backend = dist.get_backend()

        assert not self.is_standalone, f"Scaffold Server Optimizer can't work on stand alone."

    @torch.no_grad()
    def step(self, closure: typing.Callable = None, **kwargs):
        loss = None
        # if closure is not None:
        #     with torch.enable_grad():
        #         loss = closure()

        # no parameter update.
        if self.step_counter % self.local_step == 0:
            self.recv_params()  # recv params
            for group in self.param_groups:
                lr = group["lr"]

                group["param_norm"] = calc_param_norm(group)
                for p in group["params"]:
                    if not p.requires_grad:
                        continue
                    state = self.state[p]
                    if len(state) == 0:
                        state["delta_x"] = torch.zeros_like(
                            p)  # globalとlocalの差分
                        # state["delta_x"] = p.clone()
                        state["global_control"] = torch.zeros_like(p)

                    p.data.add_(state["delta_x"], alpha=lr)
                    # p.data.copy_(state["delta_x"])

                group["delta_x_norm"] = calc_state_norm(group, self.state,
                                                        "delta_x")
                group["global_c_norm"] = calc_state_norm(group, self.state,
                                                         "global_control")

            self.send_params()  # send params

        self.step_counter = (self.step_counter + 1) % self._max_step

        return loss

    @torch.no_grad()
    def recv_params(self):
        tasks, delta_x, client_c = [], {}, {}

        for node_id in self.client_node_ranks:
            i = self.tag_offset
            for group in self.param_groups:
                for p in group["params"]:
                    if not p.requires_grad:
                        continue

                    if not p in delta_x:
                        delta_x[p] = []
                    buff = torch.zeros_like(p, dtype=p.dtype,
                                            device="cpu" if self.backend == "gloo" else p.device)
                    tasks.append(dist.P2POp(dist.irecv, buff, node_id, tag=i))
                    delta_x[p].append(buff)
                    i += 1

                    if not p in client_c:
                        client_c[p] = []
                    buff = torch.zeros_like(p, dtype=p.dtype,
                                            device="cpu" if self.backend == "gloo" else p.device)
                    tasks.append(dist.P2POp(dist.irecv, buff, node_id, tag=i))
                    client_c[p].append(buff)
                    i += 1

        for task in dist.batch_isend_irecv(tasks):
            task.wait()

        for group in self.param_groups:
            for p in group["params"]:
                state = self.state[p]
                if p in delta_x:
                    state["delta_x"] = torch.stack(delta_x[p], dim=0).mean(dim=0).to(dtype=p.dtype,
                                                                                     device=p.device)
                if p in client_c:
                    if not "global_control" in state:
                        state["global_control"] = torch.zeros_like(p)
                    global_c = state["global_control"]
                    global_c.add_(torch.stack(client_c[p], dim=0).mean(dim=0).to(dtype=global_c.dtype,
                                                                                 device=global_c.device))

    @torch.no_grad()
    def send_params(self):
        tasks = []
        for node_id in self.client_node_ranks:
            i = self.tag_offset
            for group in self.param_groups:
                for p in group["params"]:
                    if not p.requires_grad:
                        continue
                    tasks.append(dist.P2POp(dist.isend, p.cpu() if self.backend == "gloo" else p,
                                            node_id, tag=i))
                    i += 1

                    c = self.state[p]["global_control"]
                    tasks.append(dist.P2POp(dist.isend, c.cpu() if self.backend == "gloo" else c,
                                            node_id, tag=i))
                    i += 1

        for task in dist.batch_isend_irecv(tasks):
            task.wait()

        self.comm_cnt += 1


class ScaffoldClient(torch.optim.Optimizer, AScaffoldOptimizer):
    def __init__(self, params, rank: int, local_step: int, lr: float = 0.001,
                 server_lr: float = None, client_lr: float = None,
                 weight_decay: float = 0.0, 
                 server_node_rank: typing.Optional[int] = None,
                 tag_offset: int = 0):
        lr = lr if not isinstance(client_lr, float) else client_lr
        defaults = dict(lr=lr, weight_decay=weight_decay)
        super().__init__(params, defaults)
        self.logger = getLogger(__name__)

        self.server_node_rank = server_node_rank
        self.is_standalone = not (isinstance(server_node_rank, int)
                                  and server_node_rank >= 0)

        self.rank = rank
        self.local_step = local_step
        self.step_counter = 0
        self._max_step = self.local_step*int(1e10)
        self.comm_cnt = 0  # 累積コミュニケーション回数
        self.tag_offset = tag_offset  # 送受信タグを他のラッパクラスと分離できるように実装を残す
        self.backend = dist.get_backend()

    @torch.no_grad()
    def step(self, closure: typing.Callable = None, **kwargs):
        loss = None
        # if closure is not None:
        #     with torch.enable_grad():
        #         loss = closure()

        for group in self.param_groups:
            K, eta = self.local_step, group["lr"]
            weight_decay = group["weight_decay"]

            group["param_norm"] = calc_param_norm(group)
            group["grad_norm"] = calc_grad_norm(group).to(dtype=torch.float32,
                                                          device="cpu")
            for p in group["params"]:
                if not p.requires_grad or p.grad is None:
                    continue
                if weight_decay > 0.0:
                    p.grad.add_(p.data, alpha=weight_decay)

                state = self.state[p]
                if len(state) == 0:
                    state["global"] = p.clone()
                    state["global_control"] = torch.zeros_like(p)
                    state["local_control"] = torch.zeros_like(p)
                    # state["delta_y"] = []

                # p.add_(p.grad, alpha=-eta)
                p.add_(p.grad - state["local_control"] + state["global_control"],
                       alpha=-eta)
                # state["delta_y"].append(
                #     ((p.data - state["global"]) / eta).detach())

            group["local_c_norm"] = calc_state_norm(group, self.state,
                                                    "local_control")
            group["global_c_norm"] = calc_state_norm(group, self.state,
                                                     "global_control")

        if self.step_counter % self.local_step == 0:
            # update control.
            for group in self.param_groups:
                K, eta = self.local_step, group["lr"]
                for p in group["params"]:
                    if not p.requires_grad or p.grad is None:
                        continue
                    state = self.state[p]
                    global_control = state["global_control"]
                    delta_x = (state["global"] - p.data).detach()  # send param
                    delta_c = - global_control + delta_x / (K*eta)
                    state["local_control"].add_(delta_c)

            self.send_params()  # send params: delta_x, local_control
            self.recv_params()  # recv params: global, global_control
            self.comm_cnt += 1

        self.step_counter = (self.step_counter + 1) % self._max_step
        return loss

    @torch.no_grad()
    def recv_params(self):
        tasks, params, global_c = [], {}, {}
        node_id = self.server_node_rank

        i = self.tag_offset
        for group in self.param_groups:
            for p in group["params"]:
                if not p.requires_grad or p.grad is None:
                    continue

                buff = torch.zeros_like(p, dtype=p.dtype,
                                        device="cpu" if self.backend == "gloo" else p.device)
                tasks.append(dist.P2POp(dist.irecv, buff, node_id, tag=i))
                params[p] = buff
                i += 1

                buff = torch.zeros_like(p, dtype=p.dtype,
                                        device="cpu" if self.backend == "gloo" else p.device)
                tasks.append(dist.P2POp(dist.irecv, buff, node_id, tag=i))
                global_c[p] = buff
                i += 1

        for task in dist.batch_isend_irecv(tasks):
            task.wait()

        for group in self.param_groups:
            for p in group["params"]:
                state = self.state[p]
                if p in params:
                    p.data.copy_(params[p].to(dtype=p.dtype, device=p.device))
                    state["global"] = p.clone()  # update global model.
                if p in global_c:
                    state["global_control"].copy_(global_c[p].to(dtype=p.dtype,
                                                                 device=p.device))

    @torch.no_grad()
    def send_params(self):
        node_id = self.server_node_rank
        tasks = []
        i = self.tag_offset
        for group in self.param_groups:
            for p in group["params"]:
                if not p.requires_grad or p.grad is None:
                    continue
                state = self.state[p]

                delta_x = (p.data - state["global"]).detach()  # send param
                # delta_x = state["delta_x"]
                tasks.append(dist.P2POp(dist.isend, delta_x.cpu() if self.backend == "gloo" else delta_x,
                                        node_id, tag=i))
                i += 1

                delta_c = self.state[p]["local_control"] - \
                    self.state[p]["global_control"]
                # delta_c = state["delta_c"]
                tasks.append(dist.P2POp(dist.isend, delta_c.cpu() if self.backend == "gloo" else delta_c,
                                        node_id, tag=i))
                # control = self.state[p]["local_control"]
                # tasks.append(dist.P2POp(dist.isend, control.cpu() if self.backend == "gloo" else control,
                #                         node_id, tag=i))
                i += 1

        for task in dist.batch_isend_irecv(tasks):
            task.wait()
