import torch
import torch.nn as nn
import torch.optim as optim





class signSGD(optim.Optimizer):
    """
    Implementation of Sign-SGD
    Bernstein, Jeremy, et al. "signSGD with majority vote is communication efficient and fault tolerant." 
    """
    def __init__(self, params, lr=0.01, rand_zero=True):
        defaults = dict(lr=lr)
        self.rand_zero = rand_zero
        super(signSGD, self).__init__(params, defaults)

    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                # take sign of gradient
                grad = torch.sign(p.grad)

                # randomise zero gradients to ±1
                if self.rand_zero:
                    grad[grad==0] = torch.randint_like(grad[grad==0], low=0, high=2)*2 - 1
                    assert not (grad==0).any()
                
                # make update
                p.data -= group['lr'] * grad

        return loss





class BernoulliSGD(optim.Optimizer):
    """
    Implementation of Noisy Sign-SGD with Bernoulli Noise
    """

    def __init__(self, params, lr=0.01, alpha=0.1):
        defaults = dict(lr=lr)
        self.alpha = alpha
        super(BernoulliSGD, self).__init__(params, defaults)

    def step(self, closure=None):
        """Performs a single optimization step.
           when alpha->0 BernoulliSGD converges to SignSGD
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                # take sign of gradient
                grad = p.grad

                b = 1/(1+torch.exp(-grad/self.alpha))

                # generate a Bernoulli random noisy for each coordinate
                noise = 2*torch.bernoulli(b)-1 #map it from [0,1] to [-1,1] p determins how often to have 1
                            
                # make update
                p.data -= group['lr'] * noise

        return loss























