"""
[Lyndon Duong](https://www.lyndonduong.com/sgmcmc/)'s SGHMC implementation
"""
import torch


class SGLD:

    def __init__(self, params, eta, log_density):
        """
        Stochastic gradient monte carlo sampler via Langevin Dynamics

        Parameters
        ----------
        eta: float
            learning rate param
        log_density: function computing log_density (loss) for given sample and batch of data.
        """
        self.eta = eta
        self.log_density = log_density
        self.optimizer = torch.optim.SGD(params, lr=1, momentum=0.)  # momentum is set to zero

    def _noise(self, params):
        """We are adding param+noise to each param."""
        std = torch.sqrt(2 * self.eta)
        loss = 0.
        for param in params:
            noise = torch.randn_like(param) * std
            loss += (noise * param).sum()
        return loss

    def sample(self, params):
        self.optimizer.zero_grad()
        loss = self.log_density(params) * self.eta
        loss += self._noise(params)  # add noise*param before calling backward!
        loss.backward()  # let autograd do its thing
        self.optimizer.step()
        return params


# class SGHMC:

#     def __init__(self, params, eta, log_density, alpha=0.0):
#         """
#         Stochastic Gradient Monte Carlo sampler WITH momentum
#         This is Hamiltonian Monte Carlo.

#         Parameters
#         ---------
#         eta: learning rate
#         log_density: loss function for given sample/batch of data
#         alpha: momentum param
#         """
#         self.alpha = alpha
#         self.eta = eta
#         self.log_density = log_density
#         self.optimizer = torch.optim.SGD(params, lr=1, momentum=(1 - self.alpha))

#     def _noise(self, params):
#         std = torch.sqrt(2 * self.alpha * self.eta)
#         loss = 0.
#         for param in params:
#             noise = torch.randn_like(param) * std
#             loss += (noise * param).sum()
#         return loss

#     def sample(self, params):
#         self.optimizer.zero_grad()
#         loss = -self.log_density(params) * self.eta
#         loss += self._noise(params)
#         loss.backward()
#         self.optimizer.step()
#         return params
