#! -*- coding: utf-8
import torch
from torch.optim.optimizer import Optimizer

__all__ = ["GossipOptimizer"]


class GossipOptimizer(Optimizer):

    def __init__(self, params, lr=1e-5, beta=0.9):
        self.l2_penalty = 0.001
        self.lr = lr

        defaults = dict(lr=lr, beta=beta)
        super(GossipOptimizer, self).__init__(params, defaults)


    @torch.no_grad()
    def step(self, closure=None):
        loss = None

        for group in self.param_groups:
            lr = group['lr']
            beta = group["beta"]

            for p in group['params']:
                if p.grad is None:
                    continue
                
                state = self.state[p]
                if "momentum" not in state:
                    state["momentum"] = torch.zeros_like(p)
                momentum = state["momentum"]

                momentum.data = p.grad + beta * momentum
                p.data = p.data - lr * momentum

        if closure is not None:
            loss = closure()

        return loss
