import torch
from optim import DecentralizedOptimizer


class DNASA(DecentralizedOptimizer):
    def __init__(self, model, alpha=0.01, **kwargs):
        super().__init__(model, **kwargs)
        self.alpha = alpha

        self.device = next(model.module.parameters()).device
        self.buf = torch.zeros_like(model.flat_parameters, device=self.device)
        self.V = torch.zeros_like(model.flat_parameters, device=self.device)
        self.U = torch.zeros_like(model.flat_parameters, device=self.device)
        self.Z = torch.zeros_like(model.flat_parameters, device=self.device)

    def init(self):
        self.V[:] = self.flatten_grads(self.model.module).to(self.device)
        self.U[:] = self.flatten_grads(self.model.module).to(self.device)
        self.Z[:] = self.flatten_grads(self.model.module).to(self.device)

    @torch.no_grad()
    def step(self, step, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        self.lr = 10 ** 0.25 / (step ** 0.75)
        self.alpha = (10 / step) ** 0.5

        X = self.model.flat_parameters
        X -= self.lr * self.Z / (self.Z.norm() + 1e-12)

        V_pred = self.V.clone()
        self.V[:] = self.flatten_grads(self.model.module).to(self.device)
        self.U += self.V - V_pred
        self.Z = (1 - self.alpha) * self.Z + self.alpha * self.U

        self.buf.zero_(); self.mix(X, self.buf); X[:] = self.buf
        self.buf.zero_(); self.mix(self.U, self.buf); self.U[:] = self.buf
        self.buf.zero_(); self.mix(self.Z, self.buf); self.Z[:] = self.buf

        grad = self.flatten_grads(self.model.module).to(self.device)

        return loss, grad.detach().cpu()
