import torch
from typing import Callable

_T = torch.Tensor


def hamiltonian(pos: _T, vel: _T, nll_fn: Callable) -> _T:
    """
    Returns the Hamiltonian, which is the sum of potential and kinetic energies, for the given positions and velocities
    :param pos: a tensor with shape [N, d] of the d-dimensional positions of N particles
    :param vel: a tensor with shape [N, d] of the d-dimensional velocities of N particles
    :param nll_fn: a pytorch-differentiable, callable function which returns the negative log-likelihood of a given set
                   of positions; should receive as input a tensor with shape [N, d] and return a tensor with shape [N,]
    :return: a tensor with shape [N,] whose i-th entry is Hamiltonian at position pos[i] and velocity vel[i]
    """
    return nll_fn(pos) + .5*torch.sum(vel*vel, dim=1)


def MH_accept(energy_prev: _T, energy_next: _T) -> _T:
    """
    Performs a Metropolis-Hastings accept-reject check
    :param energy_prev: the energy of N particles in the previous step, a tensor with shape [N,]
    :param energy_next: the energy of N particles for the next step, a tensor with shape [N,]
    :return: a list of length N with entries equal to True if the move was accepted and otherwise set to False
    """
    ediff = torch.exp(energy_prev - energy_next)
    return (ediff - torch.rand_like(ediff)) >= 0


@torch.enable_grad()
def leapfrog(pos: _T, vel: _T, step_sz: _T, n_steps: int, nll_fn: Callable) -> tuple[_T, _T]:
    """
    Simulate n steps of the dynamics according to initial positions and velocity vectors, according to the potential
    defined in nll_fn
    :param pos: a tensor with shape [N, d] of the d-dimensional positions of N particles
    :param vel: a tensor with shape [N, d] of the d-dimensional velocities of N particles
    :param step_sz: a tensor with shape [N,] representing the step size for each of the samples
    :param n_steps: an int representing the number of leaps to run the algorithm
    :param nll_fn: a pytorch-differentiable, callable function which returns the negative log-likelihood of a given set
                   of positions; should receive as input a tensor with shape [N, d] and return a tensor with shape [N,]
    :return: the update position and velocity tensors
    """
    vel = vel.requires_grad_(False)
    for i in range(n_steps):
        pos = pos.requires_grad_(True)

        # get gradients of the position
        loss = torch.sum(nll_fn(pos))
        loss.backward()

        # update velocity and position vectors
        vel = vel - step_sz[:, None] * pos.grad
        pos = pos.data + step_sz[:, None] * vel

    return pos, vel


class HMCSampler:

    def __init__(self, init_step_sz: float=.01, step_decrease: float=.98, step_increase: float=1.02,
                 min_step: float=1e-5, max_step: float=.5, target_acceptance: float=.65,
                 acceptance_smoothing: float=.9):
        self.avg_acc = None
        self.target_acc = target_acceptance
        self.smoothing = max(min(acceptance_smoothing, .95), .1)

        self.step_sz = None
        self.init_step = init_step_sz
        self.step_increase = step_increase
        self.step_decresae = step_decrease
        self.min_step = min_step
        self.max_step = max_step

    def __call__(self, pos: _T, nll_fn: Callable, n_steps: int, vel: _T=None):

        # on first time sampling, initialize step size and moving average of acceptance rates
        if self.step_sz is None: self.step_sz = torch.ones(pos.shape[0], device=pos.device)*self.init_step

        # sample a random velocity
        if vel is None: vel = torch.randn_like(pos)

        # leapfrog to find new positions and velocities
        new_pos, new_vel = leapfrog(pos.clone(), vel.clone(), self.step_sz, n_steps, nll_fn)

        # get accept-reject for the new positions
        accept = MH_accept(energy_prev=hamiltonian(pos, vel, nll_fn),
                           energy_next=hamiltonian(new_pos, new_vel, nll_fn))

        if self.avg_acc is None: self.avg_acc = accept
        # update positions to the newly accepted positions
        pos[accept] = new_pos[accept]
        vel[accept] = new_vel[accept]

        # ============ update hyperparameters according to ideal acceptance rate ====================
        new_step = torch.where(self.avg_acc > self.target_acc,
                               self.step_sz*self.step_increase, self.step_sz*self.step_decresae)
        self.step_sz = new_step.clamp(self.min_step, self.max_step)

        self.avg_acc = self.smoothing*self.avg_acc + (1-self.smoothing)*accept
        # ============ update hyperparameters according to ideal acceptance rate ====================
        return pos, vel

    def sample(self, pos: _T, nll_fn: Callable, n_steps: int, N: int=1, vel: _T=None,):
        samples = []
        no_vel = vel is None
        for i in range(N):
            pos, vel = self(pos, nll_fn, n_steps, vel)
            samples.append(pos.clone().data)
            if no_vel: vel = None
        return torch.stack(samples) if N > 1 else samples[0]
