#! -*- coding: utf-8
import typing
from datetime import datetime
from logging import getLogger

import torch
import torch.distributed as dist

from asdfghjkl.precondition.shampoo import (GRAFT, MOMENTUM, PRECONDITIONER,
                                            STEP, AdagradGraft, Graft,
                                            LayerwiseGrafting, Preconditioner,
                                            SGDGraft, ShampooHyperParams)
from dynamic_graph import DynamicGraph

__all__ = ["Shampoo", "DShampoo"]


class Shampoo(torch.optim.Optimizer):
    """The Shampoo optimizer."""

    def __init__(self,
                 params,
                 lr=1.0,
                 momentum=0.9,
                 hyperparams=ShampooHyperParams()):
        defaults = dict(lr=lr, momentum=momentum, shampoo=True)
        self.hps = hyperparams
        self.logger = getLogger(__name__)
        super(Shampoo, self).__init__(params, defaults)

    def init_var_state(self, var, state):
        """Initialize the PyTorch state of for a single variable."""
        state[STEP] = 0
        state[MOMENTUM] = torch.zeros_like(var.data, device=var.get_device())
        state[PRECONDITIONER] = Preconditioner(var, self.hps)
        if self.hps.graft_type == LayerwiseGrafting.ADAGRAD:
            state[GRAFT] = AdagradGraft(self.hps, var)
        elif self.hps.graft_type == LayerwiseGrafting.SGD:
            state[GRAFT] = SGDGraft(self.hps, var)
        else:
            state[GRAFT] = Graft(self.hps, var)
        # self.logger.critical(f"init shampoo parameters: {var.shape}")

    def step(self, closure=None):
        hps = self.hps
        for group in self.param_groups:
            lr = group['lr']
            is_shampoo = group["shampoo"] if "shampoo" in group else self.defaults["shampoo"]
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if is_shampoo:
                    if grad.is_sparse:
                        raise RuntimeError(
                            'Shampoo does not support sparse yet')
                    state = self.state[p]
                    if not state:
                        self.init_var_state(p, state)
                    state[STEP] += 1

                    preconditioner = state[PRECONDITIONER]
                    graft = state[GRAFT]

                    # Gather statistics, compute preconditioners
                    graft.add_statistics(grad)
                    if state[STEP] % hps.statistics_compute_steps == 0:
                        preconditioner.add_statistics(grad)
                    if state[STEP] % hps.preconditioning_compute_steps == 0:
                        preconditioner.compute_preconditioners()

                    # Precondition gradients
                    graft_grad = graft.precondition_gradient(grad)
                    shampoo_grad = grad
                    if state[STEP] >= self.hps.start_preconditioning_step:
                        shampoo_grad = preconditioner.preconditioned_grad(grad)

                    # Grafting
                    if hps.grafting:
                        graft_norm = torch.norm(graft_grad)
                        shampoo_norm = torch.norm(shampoo_grad)
                        shampoo_grad.mul_(graft_norm / (shampoo_norm + 1e-16))

                    # Weight decay
                    if self.hps.weight_decay != 0.0:
                        shampoo_grad.add_(p.data, alpha=self.hps.weight_decay)
                        graft_grad.add_(p.data, alpha=self.hps.weight_decay)

                    # grad smoothing
                    if hps.beta1 > 0.0 and "shampoo_grad" in state:
                        state["shampoo_grad"].mul_(hps.beta1).add_(shampoo_grad,
                                                                   alpha=(1-hps.beta1))
                        shampoo_grad = state["shampoo_grad"]
                    else:
                        state["shampoo_grad"] = shampoo_grad.detach().clone()

                    # Momentum and Nesterov momentum, if needed
                    state[MOMENTUM].mul_(group['momentum']).add_(shampoo_grad)
                    graft_momentum = graft.update_momentum(grad, 
                                                           group['momentum'])

                    if state[STEP] >= self.hps.start_preconditioning_step:
                        momentum_update = state[MOMENTUM]
                        wd_update = shampoo_grad
                    else:
                        momentum_update = graft_momentum
                        wd_update = graft_grad

                    if hps.nesterov:
                        momentum_update.mul_(group['momentum']).add_(wd_update)

                    # Final update
                    p.data.add_(momentum_update, alpha=-lr)
                else:
                    momentum_update = state[MOMENTUM]
                    if hps.nesterov:
                        momentum_update.mul_(group['momentum']).add_(grad)
                    p.data.add_(momentum_update, alpha=-lr)


class DShampoo(Shampoo):
    def __init__(self, params, node_id: int, graph: DynamicGraph,
                 lr: float = 1.0, momentum: float = 0.9,
                 hyperparams: ShampooHyperParams = ShampooHyperParams(),
                 local_step_exchange: int = 50,
                 tag_offset: int = 0):
        super().__init__(params, lr=lr, momentum=momentum, hyperparams=hyperparams)
        self.logger = getLogger(__name__)

        self.local_step_exchange = local_step_exchange
        self.node_id = node_id
        self.graph = graph
        self.tag_offset = tag_offset

        self.graph_idx = 0
        self.backend = dist.get_backend()

        self.comm_proc = 0.0
        self.step_iter = 0

    @property
    def communication_proc(self):
        return self.comm_proc

    @torch.no_grad()
    def step(self, closure=None):
        if self.step_iter == 0:  # init state
            for group in self.param_groups:
                is_shampoo = group["shampoo"] if "shampoo" in group else self.defaults["shampoo"]
                for p in group['params']:
                    if p.grad is None:
                        continue
                    grad = p.grad.data
                    if is_shampoo:
                        if grad.is_sparse:
                            raise RuntimeError(
                                'Shampoo does not support sparse yet')
                        state = self.state[p]
                        if not state:
                            self.init_var_state(p, state)

        if self.step_iter % self.local_step_exchange < self.graph.length:
            self.exchange_params(self.node_id)
        self.step_iter += 1

        loss = super().step(closure=closure)

        return loss


    @torch.no_grad()
    def exchange_params(self, node_id):
        graph_idx = self.graph_idx
        in_neighbors, out_neighbors = self.graph.get_neighbors(self.node_id,
                                                               idx=graph_idx)
        self.graph_idx = (self.graph_idx + 1) % self.graph.length

        tasklist, received_params = [], {}

        for node_id in [node_id for node_id in out_neighbors.keys()
                        if not node_id == self.node_id]:
            tasklist += self.send_params(node_id)

        for node_id in [node_id for node_id in in_neighbors.keys()
                        if not node_id == self.node_id]:
            tasks, params = self.recv_params(node_id)
            if len(tasks) > 0:
                tasklist += tasks
                received_params[node_id] = params

        try:
            if len(tasklist) > 0:
                s = datetime.now()
                tasklist = dist.batch_isend_irecv(tasklist)
                for task in tasklist:
                    task.wait()
                self.comm_proc = (datetime.now() - s).total_seconds()
            elif len(in_neighbors) > 0 or len(out_neighbors) > 0:
                self.logger.warning("No exchange parameters!!")
        except:
            self.logger.critical(f"current task list: {tasklist}, neighbors out: {list(out_neighbors.keys())}, in: {in_neighbors}",
                                 exc_info=True)
            raise

        if len(received_params) > 0:
            self.update_params(received_params, in_neighbors)

        self.logger.debug(
            " ,".join([f"[node{self.node_id}] exchange parameter: step={self.step_iter}",
                      f"neighbors-{graph_idx+1}/{self.graph.length}: in=[" +
                       ", ".join(map(str, list(in_neighbors.keys()))) + "]",
                       f"out=[" +
                       ", ".join(map(str, list(out_neighbors.keys()))) + "]",
                       f"comm proc={self.comm_proc} sec"]))

    @torch.no_grad()
    def send_params(self, node_id) -> typing.List[dist.P2POp]:
        tasklist = []
        i = self.tag_offset

        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue

                state = self.state[p]
                if not "preconditioner" in state:
                    continue

                preconditioner = state["preconditioner"]
                for statistic in preconditioner.statistics:
                    statistic = statistic.cpu() if self.backend == "gloo" else statistic
                    tasklist.append(dist.P2POp(dist.isend, statistic, node_id,
                                               tag=i))
                    i += 1
        return tasklist

    @torch.no_grad()
    def recv_params(self, node_id: int) -> typing.Tuple[typing.List[dist.P2POp],
                                                        typing.List[typing.Dict[str, torch.Tensor]]]:
        tasklist, recved = [], []
        i = self.tag_offset

        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue

                state = self.state[p]
                if not "preconditioner" in state:
                    continue

                statistics = []

                preconditioner = state["preconditioner"]
                for statistic in preconditioner.statistics:
                    recv_buffer = torch.zeros_like(statistic,
                                                   device="cpu" if self.backend == "gloo" else statistic.device)
                    tasklist.append(dist.P2POp(dist.irecv, recv_buffer, node_id,
                                               tag=i))
                    statistics.append(recv_buffer)
                    i += 1

                recved.append(statistics)

        return tasklist, recved

    @torch.no_grad()
    def update_params(self, received_params: typing.Dict[int, typing.List[typing.Dict[str, torch.Tensor]]],
                      neigbors: typing.Dict[int, float]):
        i, w = 0, neigbors[self.node_id]
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue

                state = self.state[p]
                if not "preconditioner" in state:
                    continue

                preconditioner = state["preconditioner"]
                for idx, statistic in enumerate(preconditioner.statistics):
                    statistic.mul_(w)
                    for node_id, weight in neigbors.items():
                        if self.node_id == node_id:
                            continue
                        recved = received_params[node_id][i][idx]
                        statistic.add_(weight * recved.to(statistic.device))

                preconditioner.compute_preconditioners()
                i += 1
