import torch
from sympy.abc import kappa
from torch.optim import Optimizer
from decentralized_opt.tensor_tools import reduce_tensors, flatten_tensors
from decentralized_opt import log


class DecentralizedOptimizer(Optimizer):
    def __init__(self, model, lr=1, G=None, kappa=None, world_size=1, rank=0, **kwargs):
        super().__init__(model.module.parameters(), dict())
        log.info(f'world size = {world_size}, rank = {rank}, lr = {lr}')
        self.lr = lr
        self.model = model
        self.world_size = world_size
        self.rank = rank
        self.G = G
        self.kappa = kappa

    @torch.no_grad()
    def mix(self, flat_tensor, flat_buf):
        log.debug('Mixing')

        reqs = []

        for dst in range(self.world_size):
            log.debug('dst is rank %d', dst)
            group = self.G.process_group[dst]
            if dst == self.rank:
                log.debug('receiving')
                reqs += reduce_tensors([flat_tensor], dst, group, bufs=[flat_buf])
                log.debug('rank %d recv ', self.rank)
            else:
                # Send
                if self.rank in self.G.graph.neighbors(dst):
                    log.debug('sending to %d', dst)
                    reqs += reduce_tensors([flat_tensor], dst, group)
                    log.debug('rank %d send to %d', self.rank, dst)

        for req in reqs:
            req.wait()

        if kappa:
            n_neighbors = len(list(self.G.neighbors(self.rank)))
            neighbor_weight = (1.0 - self.kappa) / n_neighbors if n_neighbors > 0 else 0.0
            flat_buf.mul_(neighbor_weight)
            flat_buf.add_(flat_tensor * (self.kappa - neighbor_weight))
        else:
            n_neighbors = len(list(self.G.neighbors(self.rank)))
            flat_buf.div_(n_neighbors + 1)

        log.debug('Mixing done')

    @torch.no_grad()
    def zero_grad(self):
        self.model.zero_grad()

    @torch.no_grad()
    def flatten_grads(self, module):
        return flatten_tensors([t.grad for t in module.parameters()])
