#! -*- coding: utf-8
import typing
from datetime import datetime
from logging import getLogger

import torch
import torch.distributed as dist

from ..graphs.dynamic_graph import DynamicGraph

__all__ = ["Scaffold"]


class Scaffold(torch.optim.Optimizer):
    def __init__(self, params, node_id: int, graph: DynamicGraph, local_step: int,
                 lr: float = 1e-5, with_avg: bool = False, tag_offset: int = 0):
        self.logger = getLogger(__name__)
        self.lr = lr
        self.with_avg = with_avg

        defaults = dict(lr=lr, weight=1.0)
        super(Scaffold, self).__init__(params, defaults)

        self.node_id = node_id
        self.graph = graph
        self.local_step = local_step
        self.step_counter = 0
        self.comm_cnt = 0
        self._max_step = self.local_step*self.graph.length
        self.graph_idx = 0
        self.tag_offset = tag_offset
        self.backend = dist.get_backend()

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            lr = group["lr"]
            w = group["weight"]
            for p in group["params"]:
                if p.grad is None:
                    continue

                state = self.state[p]
                # init state
                if state is None or len(state) == 0:
                    state["c_prm"] = torch.zeros_like(p)
                    state["c_bar"] = torch.zeros_like(p)
                    state["i_prm"] = p.data.clone()
                    state["p_sum"] = torch.zeros_like(p)

                c_prm, c_bar = state["c_prm"], state["c_bar"]
                p_sum = state["p_sum"]

                # update local parameter
                p.data = p.data - lr * (p.grad.data + c_bar - c_prm)
                if self.with_avg:
                    p.data = p_sum + (w * p.data)

        # send/recv parameter
        if self.step_counter % self.local_step == 0:
            self.update()

        self.step_counter = (self.step_counter + 1) % self._max_step

        return loss

    @torch.no_grad()
    def send_param(self, node_id):
        tasks = []
        i = self.tag_offset
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None or not p.requires_grad:
                    continue
                tasks.append(dist.P2POp(dist.isend,
                                        p.cpu() if self.backend == "gloo" else p,
                                        node_id, tag=i))
                i += 1
        return tasks

    @torch.no_grad()
    def recv_param(self, node_id):
        tasks = []
        recieved_params = []

        i = self.tag_offset
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None or not p.requires_grad:
                    continue
                tmp = torch.zeros_like(p,
                                       device="cpu" if self.backend == "gloo" else p.device)
                tasks.append(dist.P2POp(dist.irecv, tmp, node_id, tag=i))
                recieved_params.append(tmp)
                i += 1
        return tasks, recieved_params

    @torch.no_grad()
    def update(self):
        graph_idx = self.graph_idx
        in_neighbors, out_neighbors = self.graph.get_neighbors(self.node_id,
                                                               idx=graph_idx)
        self.graph_idx = (self.graph_idx + 1) % self.graph.length

        task_list = []
        recieved_params = {}
        for node_id in out_neighbors.keys():
            if node_id != self.node_id:
                task_list += self.send_param(node_id)

        for node_id in in_neighbors:
            if node_id != self.node_id:
                tasks, params = self.recv_param(node_id)
                task_list += tasks
                recieved_params[node_id] = params

        try:
            if len(task_list) > 0:
                s = datetime.now()
                task_list = dist.batch_isend_irecv(task_list)
                for task in task_list:
                    task.wait()
                self.comm_proc = (datetime.now() - s).total_seconds()
        except:
            self.logger.critical(f"current task list: {task_list}, neighbors in: {in_neighbors}, out: {list(out_neighbors.keys())}",
                                 exc_info=True)
            raise
        self.comm_cnt += 1

        self.update_state(recieved_params, in_neighbors)

    @torch.no_grad()
    def update_state(self, recieved_params, neighbors):
        w = neighbors[self.node_id]

        i = 0
        for group in self.param_groups:
            lr = group["lr"]
            for p in group["params"]:
                if p.grad is None:
                    continue

                state = self.state[p]
                i_prm, c_prm = state["i_prm"], state["c_prm"]

                # update i_prm, c_prm
                c_prm.copy_(i_prm - p.data)
                i_prm.copy_(p.data)

                p_sum = torch.zeros_like(c_prm)
                z_sum = w * c_prm
                for node_id, weight in neighbors.items():
                    if self.node_id == node_id:
                        continue
                    recv = recieved_params[node_id][i].to(i_prm.device)
                    p_sum += recv
                    z_sum = z_sum + weight * ((i_prm - recv)
                                              / (lr * self.local_step))
                i += 1
                state["p_sum"] = p_sum
                state["c_bar"] = z_sum

            group["weight"] = w  # update weight
