import torch
from torch.distributions import Normal
from torch.optim import Optimizer
import numpy as np

class SGLD(Optimizer):
    """
    Modified version of pytorch SGD to implement pSGLD
    The RMSprop preconditioning code is mostly from pytorch rmsprop implementation.
    """

    def __init__(self, params, lr=1e-3, noise=1e-6, alpha=0.99, eps=1e-8, centered=False, addnoise=True):
        defaults = dict(lr=lr, noise=noise, alpha=alpha, eps=eps, centered=centered, addnoise=addnoise)
        super(SGLD, self).__init__(params, defaults)
        
    def __setstate__(self, state):
        super(SGLD, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('centered', False)

    def step(self, lr=None, noise=None, add_noise = False):
        """
        Performs a single optimization step.
        """
        loss = None

        for group in self.param_groups:
            if lr:
                group['lr'] = lr
            if noise:
                group['noise'] = noise
            for p in group['params']:
                if p.grad is None:
                    continue
                d_p = p.grad.data
                
                state = self.state[p]
                
                if len(state) == 0:
                    state['step'] = 0
                    state['square_avg'] = torch.zeros_like(p.data)
                    if group['centered']:
                        state['grad_avg'] = torch.zeros_like(p.data)
                        
                square_avg = state['square_avg']
                alpha = group['alpha']
                state['step'] += 1

                lr_t = group['lr'] * np.power((1 - 1e-5), state['step'] - 1)
                noise_t = group['noise'] * np.power((1 - 5e-5), state['step'] - 1)
                # sqavg x alpha + (1-alph) sqavg *(elemwise) sqavg
                square_avg.mul_(alpha).addcmul_(d_p, d_p, value=1-alpha)
                
                if group['centered']:
                    grad_avg = state['grad_avg']
                    grad_avg.mul_(alpha).add_(1-alpha, d_p)
                    avg = square_avg.addcmul(-1, grad_avg, grad_avg).sqrt().add_(group['eps'])
                else:
                    avg = square_avg.sqrt().add_(group['eps'])
                    
                
                if group['addnoise']:
                    
                    size = d_p.size()
                    langevin_noise = Normal(
                        torch.zeros(size),
                        torch.ones(size).div_(lr_t).div_(avg).sqrt()
                    )
                    p.data.add_(
                                d_p.div_(avg) + np.sqrt(2) * noise_t * langevin_noise.sample(),alpha=-lr_t)
                else:
                    #p.data.add_(-group['lr'], d_p.div_(avg))
                    p.data.addcdiv_(-lr_t, d_p, avg)

        return loss
