import torch
from optim import DecentralizedOptimizer


class DSGD(DecentralizedOptimizer):
    def __init__(self, model, **kwargs):
        super().__init__(model, **kwargs)

        self.device = next(model.module.parameters()).device
        self.buf = torch.zeros_like(model.flat_parameters, device=self.device)

    @torch.no_grad()
    def step(self, step, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        self.buf.zero_()
        grad = self.flatten_grads(self.model.module).to(self.device)
        self.mix(self.model.flat_parameters, self.buf)
        self.model.flat_parameters[:] = self.buf - self.lr * grad

        return loss, grad.detach().cpu()
