# Implementation of preconditioned stochastic gradient Langevin dynamics (pSGLD).
#
# Reference: Preconditioned Stochastic Gradient Langevin Dynamics for Deep Neural Networks (https://arxiv.org/pdf/1512.07666.pdf)

from numpy import sqrt
import torch
from torch.optim import Optimizer


class pSGLD(Optimizer):

    def __init__(self, params, h, lam=1e-5, alpha=0.99):
        if h < 0.0:
            raise ValueError("Invalid step size: {}".format(h))
        defaults = dict(h=h, lam=lam, alpha=alpha)
        super(pSGLD, self).__init__(params, defaults)
        for group in self.param_groups:
            group['Vs'] = [param.new(torch.zeros_like(param)) for param in group['params']]
            group['Gs'] = [param.new(torch.zeros_like(param)) for param in group['params']] # G is a diagonal matrix, 
                                                                                            # we only maintain its diagonal in implementation

    def step(self, closure=None):
        if closure is not None:
            closure()

        for group in self.param_groups:
            h = group['h']
            for param, G in zip(group['params'], group['Gs']):
                if param.grad is None:
                    continue
                grad = param.grad.data
                noise = param.new(torch.randn_like(param))
                param.data.add_(-h * G * grad + sqrt(2 * h) * torch.sqrt(G) * noise)

    def update_preconditioner(self):
        for group in self.param_groups:
            lam, alpha = group['lam'], group['alpha']
            for i, (param, V, G) in enumerate(zip(group['params'], group['Vs'], group['Gs'])):
                V = alpha * V + (1.0 - alpha) * param.grad.data**2
                G = 1.0 / (lam + torch.sqrt(V))

                # upate V and G
                group['Vs'][i] = V
                group['Gs'][i] = G