import torch
from optim import DecentralizedOptimizer


class DNSGD(DecentralizedOptimizer):
    def __init__(self, model, K=10, K_hat=10, **kwargs):
        super().__init__(model, **kwargs)
        self.K = K
        self.K_hat = K_hat

        self.device = next(model.module.parameters()).device
        self.buf = torch.zeros_like(model.flat_parameters, device=self.device)
        self.V = torch.zeros_like(model.flat_parameters, device=self.device)
        self.prev_grad = torch.zeros_like(model.flat_parameters, device=self.device)

        self.glambda2 = self.G.compute_lambda2(self.G.graph_type)
        self.eta_y = (1 - (1 - self.glambda2 ** 2) ** 0.5) / (1 + (1 - self.glambda2 ** 2) ** 0.5)

    def init(self):
        self.buf.zero_()
        grad = self.flatten_grads(self.model.module).to(self.device)
        self.V = self.accGossip(grad, self.eta_y, self.K_hat)
        U = self.V / (self.V.norm() + 1e-12)

        X = self.model.flat_parameters[:]
        self.prev_grad[:] = grad[:]

        self.buf.zero_()
        self.model.flat_parameters[:] = self.accGossip(X - self.lr * U, self.eta_y, self.K)

    def accGossip(self, tensor, eta=None, K=10):
        curr = tensor.clone()
        prev = tensor.clone()
        out = tensor.clone()

        for _ in range(K):
            self.buf.zero_()
            self.mix(curr, self.buf)
            out = (1 + eta) * self.buf - eta * prev
            prev, curr = curr, out
        return out

    @torch.no_grad()
    def step(self, step, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        self.buf.zero_()
        grad = self.flatten_grads(self.model.module).to(self.device)
        self.V[:] = self.accGossip(self.V + grad - self.prev_grad, self.eta_y, self.K_hat)
        self.prev_grad[:] = grad[:]

        U = self.V / (self.V.norm() + 1e-12)
        X = self.model.flat_parameters - self.lr * U
        self.buf.zero_()
        self.model.flat_parameters[:] = self.accGossip(X, self.eta_y, self.K)

        return loss, grad.detach().cpu()
