#! -*- coding: utf-8
from datetime import datetime
from logging import getLogger

import torch
import torch.distributed as dist
from ..graphs.dynamic_graph import DynamicGraph

__all__ = ["MeanModelParameterExchanger"]


class MeanModelParameterExchanger(object):
    def __init__(self, model: torch.nn.Module, optimizer: torch.optim.Optimizer,
                 node_id: int, graph: DynamicGraph, local_step: int,
                 tag_offset: int = 0):
        self.logger = getLogger(__name__)

        self.model = model 
        self.optimizer = optimizer

        self.node_id = node_id
        self.graph = graph
        self.local_step = local_step
        self.step_counter = 0
        self.comm_cnt = 0 
        self._max_step = self.local_step*self.graph.length
        self.graph_idx = 0
        self.tag_offset = tag_offset 
        self.backend = dist.get_backend()

        self.comm_proc = 0.0

    def __str__(self) -> str:
        msg = super().__str__()
        msg += "\n"
        msg += str(self.optimizer)
        return msg

    @property
    def communication_proc(self):
        comm_proc = self.comm_proc
        if hasattr(self.optimizer, "communication_proc"):
            comm_proc += self.optimizer.communication_proc
        elif hasattr(self.optimizer, "comm_proc"):
            comm_proc += self.optimizer.comm_proc
        return comm_proc

    @torch.no_grad()
    def step(self, closure=None):
        loss = self.optimizer.step(closure)  

        self.step_counter = (self.step_counter + 1) % self._max_step
        if self.step_counter % self.local_step == 0:
            self.update()
        # self.step_counter += 1

        return loss


    def zero_grad(self, set_to_none: bool = False):
        self.optimizer.zero_grad(set_to_none=set_to_none)

    @torch.no_grad()
    def update(self):
        graph_idx = self.graph_idx
        in_neighbors, out_neighbors = self.graph.get_neighbors(self.node_id,
                                                               idx=graph_idx)
        self.graph_idx = (self.graph_idx + 1) % self.graph.length

        task_list = []
        recieved_params = {}
        for node_id in out_neighbors.keys():
            if node_id != self.node_id:
                task_list += self.send_param(node_id)

        for node_id in in_neighbors:
            if node_id != self.node_id:
                tasks, params = self.recv_param(node_id)
                task_list += tasks
                recieved_params[node_id] = params

        try:
            if len(task_list) > 0:
                s = datetime.now()
                task_list = dist.batch_isend_irecv(task_list)
                for task in task_list:
                    task.wait()
                self.comm_proc = (datetime.now() - s).total_seconds()
        except:
            self.logger.critical(f"current task list: {task_list}, neighbors in: {in_neighbors}, out: {list(out_neighbors.keys())}",
                                 exc_info=True)
            raise

        self.mean_param(recieved_params, in_neighbors, tau=graph_idx)
        self.logger.debug(
            ", ".join([f"[node{self.node_id}] exchange parameter: step={self.step_counter}",
                      f"neighbors-{graph_idx+1}/{self.graph.length}: in=[" +
                       ", ".join(map(str, list(in_neighbors.keys()))) + "]",
                       f"out=[" +
                       ", ".join(map(str, list(out_neighbors.keys()))) + "]",
                       f"comm proc={self.comm_proc} sec"]))
        self.comm_cnt += 1

    @torch.no_grad()
    def send_param(self, node_id):
        task_list = []

        i = self.tag_offset
        for param_name, p in self.model.named_parameters():
            if not p.requires_grad:
                continue
            if "lora_A" in param_name or "lora_B" in param_name:
                if not ("lora_A.0" in param_name or "lora_B.0" in param_name):
                    senddata = -p
                    self.logger.debug("send param from %d to %d: -1*%s",
                                      self.node_id, node_id, param_name)
                else:
                    senddata = p
                    self.logger.debug("send param from %d to %d: %s",
                                      self.node_id, node_id, param_name)
            else:
                senddata = p
                self.logger.debug("send param from %d to %d: %s",
                                  self.node_id, node_id, param_name)
            task_list.append(dist.P2POp(dist.isend,
                                        senddata.cpu() if self.backend == "gloo" else senddata,
                                        node_id, tag=i))
            i += 1
        return task_list

    @torch.no_grad()
    def recv_param(self, node_id):
        task_list = []
        recieved_params = []

        i = self.tag_offset
        for p in self.model.parameters():
            if not p.requires_grad:
                continue
            tmp = torch.zeros_like(p,
                                   device="cpu" if self.backend == "gloo" else p.device)
            task_list.append(dist.P2POp(dist.irecv, tmp, node_id, tag=i))
            # task_list.append(dist.irecv(tensor=tmp, src=node_id, tag=i))
            recieved_params.append(tmp)
            i += 1
        return task_list, recieved_params

    @torch.no_grad()
    def mean_param(self, recieved_params, neighbors, tau: int = 0):

        for i, p in enumerate([p for p in self.model.parameters() if p.requires_grad]):
            p.data *= neighbors[self.node_id]

            for node_id, weight in neighbors.items():
                if self.node_id == node_id:
                    continue

                p.data += weight * recieved_params[node_id][i].to(p.device)
