# Implementation of HMC following Neal 2012
import torch as t

device = t.device('cpu')


class HMC:
    def __init__(self, num_steps, sample_dim, r_std, q0_mean=None, log_q0_std=None):
        self.num_steps = num_steps
        self.sample_dim = sample_dim
        self.r_std = r_std * t.ones(1, sample_dim, device=device)

        if q0_mean is None:
            self.q0_mean = t.zeros(1, sample_dim, device=device)
        else:
            self.q0_mean = q0_mean

        self.q0_mean.to(device)

        if log_q0_std is None:
            self.log_q0_std = t.zeros(1, sample_dim, device=device)
        else:
            self.log_q0_std = log_q0_std

        self.log_q0_std.to(device)

        # Initial step size for burn in phase
        self.lf_step_size = 0.0007*t.ones(1, device=device)

    def leapfrog(self, potential, eps, x, N, mask=1):
        r_std = self.r_std
        r_init = t.randn(N, self.sample_dim, device=device) * r_std
        xgrad = t.autograd.grad(outputs=potential(x).sum(), inputs=x)[0]
        r = r_init - 0.5*eps*mask*xgrad
        for i in range(self.num_steps - 1):
            x_int = x.clone() + eps*mask*r.clone()/(r_std**2)
            xgrad = t.autograd.grad(outputs=potential(x_int).sum(), inputs=x_int)[0]
            r = r.clone() - eps*mask*xgrad
            x = x_int.clone()
        x_final = x_int.clone() + eps*mask*r.clone()/(r_std**2)
        xgrad = t.autograd.grad(outputs=potential(x_final).sum(), inputs=x_final)[0]
        r_final = r.clone() - 0.5*eps*mask*xgrad
        return x_final.detach().clone(), -r_final.detach().clone(), r_init.detach().clone()

    def sample(self, num_samples, potential, N, burnin, mask=1, post_burnin_step_size=None):
        acceptance_rate = t.zeros(N, device=device)
        acc_rate_output = t.zeros(N, device=device)
        samples = t.zeros(num_samples, N, self.sample_dim, device=device)
        hamiltonian = t.zeros(num_samples, N, device=device)
        x_init = t.randn(N, self.sample_dim, requires_grad=True, device=device)*t.exp(self.log_q0_std) + \
            self.q0_mean
        for i in range(num_samples):
            if i > burnin:
                self.lf_step_size = post_burnin_step_size*t.ones(1, device=device)
            if i % 100 == 0:
                print('Iteration: ', i)
                print('Acceptance rate: ', acceptance_rate.detach().cpu().numpy()/100.)
                acceptance_rate = t.zeros(N, device=device)
            x_final, r_final, r_init = self.leapfrog(potential, self.lf_step_size, x_init.clone().requires_grad_(), N, mask=mask)
            ham_init = potential(x_init) + (r_init**2).sum(1)/2
            ham_final = potential(x_final) + (r_final**2).sum(1)/2
            runif = t.rand(N, device=device)
            for j in range(N):
                if runif[j] > t.exp(ham_init - ham_final)[j]:
                    x_final[j, :] = x_init[j, :]
                else:
                    acceptance_rate[j] += 1
                    if i >= burnin:
                        acc_rate_output[j] += 1
            x_init = x_final
            samples[i, :, :] = x_final

        acc_rate_output = acc_rate_output/(num_samples - burnin)
        return samples.detach(), hamiltonian.detach(), acc_rate_output.detach()
