import torch
from torch import Tensor
from torch.autograd import Variable
from torch.optim import Optimizer
import numpy as np


class pSGD(torch.optim.SGD):
    def __init__(self, 
                 params, 
                 lr, 
                 prior_sd=1,
                 sparse_sd=0.1,
                 sparse_ratio=1,
                 momentum=0, 
                 dampening=0, 
                 weight_decay=0, 
                 nesterov=False):
        super(pSGD, self).__init__(params, lr, momentum, dampening, weight_decay, nesterov)
        self.prior_sd = prior_sd
        self.sparse_sd = sparse_sd
        self.sparse_ratio = sparse_ratio
        
    def step(self, closure=None):
        # Scale the gradients
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                prior_grad = self._prior_gradient( p, self.prior_sd, self.sparse_sd, self.sparse_ratio)
                grad = p.grad.data
                grad.add_(prior_grad)

        # Call the original step method
        super(pSGD, self).step(closure)
        
        
    def _prior_gradient(self, param, prior_sd, sparse_sd, sparse_ratio):
        if sparse_ratio == 1:
            return param/prior_sd**2
        elif sparse_ratio < 1:
            A = sparse_ratio/(1-sparse_ratio) * sparse_sd/prior_sd * torch.exp(-(1/prior_sd**2 - 1/sparse_sd**2)* torch.square(param)/2) + 1
            coef = 1/prior_sd**2 + torch.div((1/sparse_sd**2 - 1/prior_sd**2),A)
            return coef * param

if __name__ == "__main__":

    # Example usage
    model = torch.nn.Linear(2, 1)
    optimizer = pSGD(model.parameters(), lr=0.01, scale_factor=0.5)