#! -*- coding: utf-8
import typing
from datetime import datetime
from logging import getLogger

import cvxpy as cp
import numpy as np
import torch
import torch.distributed as dist

__all__ = ["FedAvgMVDRServer", "FedAvgMVDRClient"]


class FedAvgMVDRServer(torch.optim.Optimizer):
    def __init__(self, params, rank: int, local_step: int, lr: float = 0.001,
                 client_node_ranks: typing.List[int] = [],
                 seed: int = 0, tag_offset: int = 0):
        defaults = dict(lr=lr)
        super().__init__(params, defaults)
        self.logger = getLogger(__name__)

        self.rs = np.random.RandomState(seed)

        self.client_node_ranks = client_node_ranks
        self.is_standalone = self.client_node_ranks is None \
            or len(self.client_node_ranks) == 0

        self.rank = rank
        self.local_step = local_step
        self.step_counter = 0
        self._max_step = self.local_step*int(1e10)
        self.comm_cnt = 0  # 累積コミュニケーション回数
        self.tag_offset = tag_offset  # 送受信タグを他のラッパクラスと分離できるように実装を残す
        self.backend = dist.get_backend()

        assert not self.is_standalone, f"FedAvg Server Optimizer can't work on stand alone."

    @torch.no_grad()
    def step(self, closure: typing.Callable = None, **kwargs):
        loss = None
        # if closure is not None:
        #     with torch.enable_grad():
        #         loss = closure()

        # no parameter update.
        self.step_counter = (self.step_counter + 1) % self._max_step
        if self.step_counter % self.local_step == 0:
            self.exchange()

        return loss

    # var 0.3
    @torch.no_grad()
    def exchange(self):
        # recved client model parameters.
        tasks, recved = [], {}
        for node_id in self.client_node_ranks:
            i = self.tag_offset
            for group in self.param_groups:
                for p in group["params"]:
                    if not p.requires_grad:
                        continue
                    if not p in recved:
                        recved[p] = []
                    buff = torch.zeros_like(p, dtype=p.dtype,
                                            device="cpu" if self.backend == "gloo" else p.device)
                    tasks.append(dist.P2POp(dist.irecv, buff, node_id, tag=i))
                    recved[p].append(buff)
                    i += 1

        for task in dist.batch_isend_irecv(tasks):
            task.wait()

        p0 = datetime.now()
        R = []
        eps = 1e-4
        for group in self.param_groups:
            for p in group["params"]:
                if not p.requires_grad:
                    continue
                # N, flatten param.
                params = torch.stack([d.to(device=p.device, dtype=p.dtype).flatten()
                                      for d in recved[p]], dim=0)
                # params = params - params.mean(dim=0, keepdim=True)
                R.append(params)
        R = torch.cat(R, dim=-1)  # N, full-flatten param.
        N, d = R.shape
        R = eps * np.eye(N) \
            + torch.matmul(R, R.transpose(0, 1)).detach().cpu().numpy()/d

        R0 = getattr(self, "R", np.zeros((N, N)))
        # beta = 0.9
        beta = 0.75
        R = beta*R0 + (1-beta)*R
        self.R = R

        a = np.ones(N)
        w = cp.Variable(N)
        objective = cp.Minimize(cp.quad_form(w, R))
        constraints = [w @ a == 1, w >= 0]
        prob = cp.Problem(objective, constraints)
        prob.solve()
        w = np.clip(w.value, 0.0, None)
        w /= w.sum()

        self.logger.debug(
            f"update global model: communication={self.comm_cnt}, {len(recved)} params")
        self.logger.debug("mixing weights: [%s]", ", ".join(list(map(str, w))))
        # mean client model parameters and set to global model.
        for group in self.param_groups:
            for p in group["params"]:
                if not p.requires_grad:
                    continue
                p.data.copy_(torch.stack([(float(weight) * d).to(device=p.device, dtype=p.dtype) for weight, d in zip(w, recved[p])],
                                         dim=0).sum(dim=0))
        self.logger.debug("MVDRmixing process %f sec",
                          (datetime.now()-p0).total_seconds())

        # send global model parameters.
        tasks = []
        for node_id in self.client_node_ranks:
            i = self.tag_offset
            for group in self.param_groups:
                for p in group["params"]:
                    if not p.requires_grad:
                        continue
                    tasks.append(dist.P2POp(dist.isend, p.cpu() if self.backend == "gloo" else p,
                                            node_id, tag=i))
                    i += 1

        for task in dist.batch_isend_irecv(tasks):
            task.wait()
        self.comm_cnt += 1

    # # var 0.2
    # @torch.no_grad()
    # def exchange(self):
    #     # recved client model parameters.
    #     tasks, recved = [], {}
    #     for node_id in self.client_node_ranks:
    #         i = self.tag_offset
    #         for group in self.param_groups:
    #             for p in group["params"]:
    #                 if not p.requires_grad:
    #                     continue
    #                 if not p in recved:
    #                     recved[p] = []
    #                 buff = torch.zeros_like(p, dtype=p.dtype,
    #                                         device="cpu" if self.backend == "gloo" else p.device)
    #                 tasks.append(dist.P2POp(dist.irecv, buff, node_id, tag=i))
    #                 recved[p].append(buff)
    #                 i += 1

    #     for task in dist.batch_isend_irecv(tasks):
    #         task.wait()

    #     p0 = datetime.now()
    #     R = []
    #     eps = 1e-4
    #     for group in self.param_groups:
    #         for p in group["params"]:
    #             if not p.requires_grad:
    #                 continue
    #             # N, flatten param.
    #             params = torch.stack([d.to(device=p.device, dtype=p.dtype).flatten()
    #                                   for d in recved[p]], dim=0)
    #             params = params - params.mean(dim=0, keepdim=True)
    #             R.append(params)
    #     R = torch.cat(R, dim=-1)  # N, full-flatten param.
    #     N, d = R.shape
    #     R = eps * np.eye(N) \
    #         + torch.matmul(R, R.transpose(0, 1)).detach().cpu().numpy()/d

    #     a = np.ones(N)
    #     w = cp.Variable(N)
    #     objective = cp.Minimize(cp.quad_form(w, R))
    #     constraints = [w @ a == 1, w >= 0]
    #     prob = cp.Problem(objective, constraints)
    #     prob.solve()
    #     w = np.clip(w.value, 0.0, None)
    #     w /= w.sum()

    #     self.logger.debug(
    #         f"update global model: communication={self.comm_cnt}, {len(recved)} params")
    #     self.logger.debug("mixing weights: [%s]", ", ".join(list(map(str, w))))
    #     # mean client model parameters and set to global model.
    #     for group in self.param_groups:
    #         for p in group["params"]:
    #             if not p.requires_grad:
    #                 continue
    #             p.data.copy_(torch.stack([(float(weight) * d).to(device=p.device, dtype=p.dtype) for weight, d in zip(w, recved[p])],
    #                                      dim=0).sum(dim=0))
    #     self.logger.debug("MVDRmixing process %f sec",
    #                       (datetime.now()-p0).total_seconds())

    #     # send global model parameters.
    #     tasks = []
    #     for node_id in self.client_node_ranks:
    #         i = self.tag_offset
    #         for group in self.param_groups:
    #             for p in group["params"]:
    #                 if not p.requires_grad:
    #                     continue
    #                 tasks.append(dist.P2POp(dist.isend, p.cpu() if self.backend == "gloo" else p,
    #                                         node_id, tag=i))
    #                 i += 1

    #     for task in dist.batch_isend_irecv(tasks):
    #         task.wait()
    #     self.comm_cnt += 1

    # # ver0.1
    # @torch.no_grad()
    # def exchange(self):
    #     # recved client model parameters.
    #     tasks, recved = [], {}
    #     for node_id in self.client_node_ranks:
    #         i = self.tag_offset
    #         for group in self.param_groups:
    #             for p in group["params"]:
    #                 if not p.requires_grad:
    #                     continue
    #                 if not p in recved:
    #                     recved[p] = []
    #                 buff = torch.zeros_like(p, dtype=p.dtype,
    #                                         device="cpu" if self.backend == "gloo" else p.device)
    #                 tasks.append(dist.P2POp(dist.irecv, buff, node_id, tag=i))
    #                 recved[p].append(buff)
    #                 i += 1

    #     for task in dist.batch_isend_irecv(tasks):
    #         task.wait()

    #     p0 = datetime.now()
    #     R = []
    #     for group in self.param_groups:
    #         for p in group["params"]:
    #             if not p.requires_grad:
    #                 continue
    #             # N, flatten param.
    #             R.append(torch.stack([d.to(device=p.device, dtype=p.dtype).flatten()
    #                                   for d in recved[p]], dim=0))
    #     R = torch.cat(R, dim=-1)  # N, full-flatten param.
    #     N, d = R.shape
    #     # self.logger.critical("R.shape %s", R.shape)
    #     # R = R + hj ・ hj^T  / d
    #     # R = R@R.T + 0.1 * np.eye(N)
    #     ###
    #     R = torch.matmul(R, R.transpose(0, 1)).detach().cpu().numpy()/d
    #     ###
    #     # R = torch.matmul(R, R.transpose(0, 1)).detach().cpu().numpy()
    #     # R /= R.sum(axis=0, keepdims=True)
    #     ###

    #     a = np.ones(N)
    #     w = cp.Variable(N)
    #     objective = cp.Minimize(cp.quad_form(w, R))
    #     constraints = [w @ a == 1, w >= 0]
    #     prob = cp.Problem(objective, constraints)
    #     prob.solve()
    #     w = np.clip(w.value, 0.0, None)
    #     w /= w.sum()

    #     self.logger.debug(
    #         f"update global model: communication={self.comm_cnt}, {len(recved)} params")
    #     self.logger.debug("mixing weights: [%s]", ", ".join(list(map(str, w))))
    #     # mean client model parameters and set to global model.
    #     for group in self.param_groups:
    #         for p in group["params"]:
    #             if not p.requires_grad:
    #                 continue
    #             p.data.copy_(torch.stack([(float(weight) * d).to(device=p.device, dtype=p.dtype) for weight, d in zip(w, recved[p])],
    #                                      dim=0).sum(dim=0))
    #     self.logger.debug("MVDRmixing process %f sec",
    #                       (datetime.now()-p0).total_seconds())

    #     # send global model parameters.
    #     tasks = []
    #     for node_id in self.client_node_ranks:
    #         i = self.tag_offset
    #         for group in self.param_groups:
    #             for p in group["params"]:
    #                 if not p.requires_grad:
    #                     continue
    #                 tasks.append(dist.P2POp(dist.isend, p.cpu() if self.backend == "gloo" else p,
    #                                         node_id, tag=i))
    #                 i += 1

    #     for task in dist.batch_isend_irecv(tasks):
    #         task.wait()
    #     self.comm_cnt += 1


class FedAvgMVDRClient(torch.optim.Optimizer):
    def __init__(self, params, rank: int, local_step: int, lr: float = 0.001,
                 server_node_rank: typing.Optional[int] = None,
                 seed: int = 0, tag_offset: int = 0):
        defaults = dict(lr=lr)
        super().__init__(params, defaults)
        self.logger = getLogger(__name__)

        self.server_node_rank = server_node_rank
        self.is_standalone = not (isinstance(server_node_rank, int)
                                  and server_node_rank >= 0)

        self.rank = rank
        self.local_step = local_step
        self.step_counter = 0
        self._max_step = self.local_step*int(1e10)
        self.comm_cnt = 0  # 累積コミュニケーション回数
        self.tag_offset = tag_offset  # 送受信タグを他のラッパクラスと分離できるように実装を残す
        self.backend = dist.get_backend()

    def __str__(self) -> str: return super().__str__() + \
        f"\nrank={self.rank}, server rank={self.server_node_rank}, local step={self.local_step}"

    @torch.no_grad()
    def step(self, closure: typing.Callable = None, **kwargs):
        loss = None
        # if closure is not None:
        #     with torch.enable_grad():
        #         loss = closure()

        for group in self.param_groups:
            lr = group["lr"]
            for p in group["params"]:
                if not p.requires_grad or p.grad is None:
                    continue
                p.add_(p.grad, alpha=-lr)

        self.step_counter = (self.step_counter + 1) % self._max_step
        if self.step_counter % self.local_step == 0:
            self.exchange()

        return loss

    @torch.no_grad()
    def exchange(self):
        if self.is_standalone:  # standalone mode
            return
        tasks, recved = [], {}

        # send local model to server.
        i = self.tag_offset
        for group in self.param_groups:
            for p in group["params"]:
                if not p.requires_grad or p.grad is None:
                    continue
                tasks.append(dist.P2POp(dist.isend, p.cpu() if self.backend == "gloo" else p,
                                        self.server_node_rank, tag=i))
                i += 1

        # recv global model from server
        i = self.tag_offset
        for group in self.param_groups:
            for p in group["params"]:
                if not p.requires_grad or p.grad is None:
                    continue
                buff = torch.zeros_like(p,
                                        device="cpu" if self.backend == "gloo" else p.device)
                tasks.append(dist.P2POp(dist.irecv, buff,
                             self.server_node_rank, tag=i))
                recved[p] = buff
                i += 1

        for task in dist.batch_isend_irecv(tasks):
            task.wait()

        # override local model to server
        for group in self.param_groups:
            for p in group["params"]:
                if not p.requires_grad or p.grad is None:
                    continue
                p.data.copy_(recved[p].to(dtype=p.dtype, device=p.device))
