# Implementation of stochastic gradient Hamiltonian Monte Carlo (SGHMC).
# 
# Reference: Stochastic Gradient Hamiltonian Monte Carlo (https://arxiv.org/pdf/1402.4102.pdf)

from numpy import sqrt
import torch
from torch.optim import Optimizer


class SGHMC(Optimizer):

    def __init__(self, params, h, gamma, sigma):
        if h < 0.0:
            raise ValueError("Invalid step size: {}".format(h))
        defaults = dict(h=h, gamma=gamma, sigma=sigma)
        super(SGHMC, self).__init__(params, defaults)
        for group in self.param_groups:
            group['momentums'] = [param.new(torch.zeros_like(param)) for param in group['params']]


    def step(self, closure=None):
        if closure is not None:
            closure()

        for group in self.param_groups:
            h, gamma, sigma = group['h'], group['gamma'], group['sigma']

            for q, p in zip(group['params'], group['momentums']):
                q.data.add_(h, p)
                p.data.add_(h, - q.grad.data - gamma * p)
                noise = p.new(torch.randn_like(p))
                p.data.add_(sqrt(h) * sigma, noise)


    def refresh_momentum(self):
        for group in self.param_groups:
            group['momentums'] = [param.new(torch.randn_like(param)) for param in group['params']]