import torch
from optim import DecentralizedOptimizer


class DSGT(DecentralizedOptimizer):
    def __init__(self, model, **kwargs):
        super().__init__(model, **kwargs)

        self.device = next(model.module.parameters()).device
        self.buf = torch.zeros_like(model.flat_parameters, device=self.device)
        self.y = torch.zeros_like(model.flat_parameters, device=self.device)
        self.prev_grad = torch.zeros_like(model.flat_parameters, device=self.device)

    def init(self):
        self.y[:] = self.flatten_grads(self.model.module).to(self.device)
        self.prev_grad[:] = self.y.clone()

    @torch.no_grad()
    def step(self, step, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        grad = self.flatten_grads(self.model.module).to(self.device)
        self.buf.zero_()
        self.mix(self.y, self.buf)
        self.y = self.buf + grad - self.prev_grad

        X = self.model.flat_parameters
        self.buf.zero_()
        self.mix(X - self.lr * self.y, self.buf)
        self.model.flat_parameters[:] = self.buf[:]

        self.prev_grad[:] = grad[:]

        return loss, grad.detach().cpu()
