#! -*- coding: utf-8
import typing

import numpy as np
import torch

from .fedprox_dol_early import FedProxDoLEarlyClient, FedProxDoLEarlyServer
from .utils import get_group_value

__all__ = ["FedProxDoLClient", "FedProxDoLServer", ]


class AFedProxDoLOptimizer(object):
    @property
    def server_loss_best(self): return get_group_value(self.param_groups,
                                                       "server_loss_best")

    @property
    def server_loss_out(self): return get_group_value(self.param_groups,
                                                      "server_loss_out")

    @property
    def w(self): return get_group_value(self.param_groups, "w")
    @property
    def y(self): return get_group_value(self.param_groups, "y")


class FedProxDoLClient(FedProxDoLEarlyClient, AFedProxDoLOptimizer):
    pass  # 完全にEarly版と同じ


class FedProxDoLServer(FedProxDoLEarlyServer, AFedProxDoLOptimizer):
    def _trace_log_items(self, group):
        return (group.get("mu", np.nan),
                group.get("eta", np.nan),
                group.get("r", np.nan),
                group.get("rdist", np.nan),
                group.get("vu", np.nan),
                group.get("vu_eta", np.nan),
                group.get("grad_sq_norm", np.nan),
                group.get("grad_norm", np.nan),
                group.get("param_norm", np.nan),
                group.get("model_sq_norm", np.nan),
                group.get("normalized_model_sq_norm", np.nan),
                group.get("delta", np.nan),
                group.get("client_loss", np.nan),
                group.get("server_loss", np.nan),
                group.get("server_loss_out", np.nan),
                group.get("server_loss_best", np.nan),
                group.get("w1", np.nan),
                group.get("w2", np.nan),)

    @torch.no_grad()
    def step(self, closure: typing.Callable = None, **kwargs):
        for group in self.param_groups:
            if not "mu0" in group:
                group["mu0"] = group["mu"]
        return super().step(closure=closure, **kwargs)

    @torch.no_grad()
    def update_global_model(self, closure: typing.Callable = None, **kwargs):
        # update w, y, X_out, X_best on at FedProxDoL(full ver)
        # current implements: return mean of client parameters.
        for group in self.param_groups:  # Update global model
            r, mu = group["r"], group["mu"]
            mu0 = group["mu0"] if "mu0" in group else mu

            w1 = (min(mu/mu0, 1) * r) if mu0 != 0.0 else r
            w2 = group["w2"] if "w2" in group else 0.0
            # w2 = w2 + w1

            # make xbar_out
            for p in group["params"]:
                if not p.requires_grad:
                    continue
                state = self.state[p]
                if "xbar_out" in state:
                    x_out = state["xbar_out"]
                    # ここでp=group["params"]はxbarからxbar_outになる。
                    p.data.copy_((w2*x_out + w1*p.data) / (w2+w1))
                state["xbar_out"] = p.clone().detach()  # override xbar_out

            group["mu0"] = mu
            group["w1"], group["w2"] = w1, (w2+w1)

        # xbar_out（=今のパラメータ）でserver loss再計算
        if closure is not None:
            with torch.enable_grad():
                server_loss_out = closure()
            for group in self.param_groups:  # Update global model
                if self.is_warmup:  # warmup中はmu, r, v/uなどハイパーパラメータを更新しない
                    self.logger.log(5,
                                    ", ".join(["step=%d", "mu=%f", "eta=%f", "r=%f", "rdist=%f",
                                              "vu=%f", "vu_eta=%f",
                                               "grad_sq_norm=%f", "grad_norm=%f", "param_norm=%f",
                                               "model_sq_norm=%f", "normalized_model_sq_norm=%f",
                                               "delta=%f", "client_loss=%f", "server_loss=%f",
                                               "server_loss_out=%f", "server_loss_best=%f",
                                               "w1=%f", "w2=%f", "warmup=%s"]),
                                    self.step_counter, *self._trace_log_items(group), self.is_warmup)
                    continue

                server_loss_out = server_loss_out.detach().cpu().float().item()
                group["server_loss_out"] = server_loss_out  # 記録用に保持
                if not "server_loss_best" in group:
                    group["server_loss_best"] = server_loss_out
                server_loss_best = group["server_loss_best"]
                if server_loss_out <= server_loss_best or self.without_model_merge:
                    # update best model: xbar_bestの更新
                    for p in group["params"]:
                        if not p.requires_grad:
                            continue
                        state = self.state[p]
                        state["xbar_best"] = p.clone()
                    # update best server loss
                    group["server_loss_best"] = server_loss_out
                else:  # pをxbar_bestで更新（クライアントに送信するモデルを前回と同じにする
                    for p in group["params"]:
                        if not p.requires_grad:
                            continue
                        state = self.state[p]
                        p.data.copy_(state["xbar_best"])

                self.logger.log(5,
                                ", ".join(["step=%d", "mu=%f", "eta=%f", "r=%f", "rdist=%f",
                                          "vu=%f", "vu_eta=%f",
                                           "grad_sq_norm=%f", "grad_norm=%f", "param_norm=%f",
                                           "model_sq_norm=%f", "normalized_model_sq_norm=%f",
                                           "delta=%f", "client_loss=%f", "server_loss=%f",
                                           "server_loss_out=%f", "server_loss_best=%f",
                                           "w1=%f", "w2=%f", "warmup=%s"]),
                                self.step_counter, *self._trace_log_items(group), self.is_warmup)
