# pylint: disable=g-bare-generic
import math
import torch
import torch.distributions as D

from tqdm import tqdm
from .functions import *

class Sampler:
  def __init__(self, initial_distribution, target_distribution, gradient_function, energy_function = None):
    torch.manual_seed(1)
    self._initial_distribution = initial_distribution
    self._target_distribution = target_distribution    
    if target_distribution is None:
      assert gradient_function is not None
      self._gradient_function = gradient_function
      self._energy_function = energy_function
    else:
      def fit_gradient_function(target_distribution, x, i):
        """ return the gradient w.r.t. input x """
        x.requires_grad_()
        neg_logp_unnorm = target_distribution(x, i).sum()
        grad = torch.autograd.grad([neg_logp_unnorm], [x], retain_graph = False, create_graph=False)[0]
        x.requires_grad = False
        return grad.detach()
      self._gradient_function = lambda x, i: fit_gradient_function(target_distribution, x, i)
      self._energy_function = target_distribution


class AnnealedULASampler(Sampler):
  """ Unadjusted Langevin Dynamics (ULA) """
  def __init__(self, num_steps: int, num_samples_per_step: int, step_sizes: torch.Tensor,
               initial_distribution, target_distribution, gradient_function):
    super().__init__(initial_distribution, target_distribution, gradient_function)
    assert len(step_sizes) == num_steps, "Must have as many stepsizes as intermediate distributions."
    self._step_sizes = step_sizes
    self._num_steps = num_steps
    self._num_samples_per_step = num_samples_per_step
    self._total_steps = self._num_samples_per_step * (self._num_steps - 1)

  def transition_distribution(self, i, x):
    step_size = self._step_sizes[i]
    mu = x + step_size * self._gradient_function(x, i)
    std = torch.ones(mu.shape) * math.sqrt(2 * step_size)
    dist = D.Independent(D.Normal(mu, std), reinterpreted_batch_ndims = 1)
    return dist

  def sample(self, n_samples: int):
    x = self._initial_distribution.sample(sample_shape=(n_samples,))
    logw = -self._initial_distribution.log_prob(x)
    inputs = (logw, x)

    def body_fn(i, inputs):
      logw, x = inputs
      dist_ind = (i // self._num_samples_per_step) + 1
      dist_forward = self.transition_distribution(dist_ind, x)
      x_hat = dist_forward.sample()
      dist_reverse = self.transition_distribution(dist_ind - 1, x_hat)
      logw += dist_reverse.log_prob(x) - dist_forward.log_prob(x_hat)
      return logw, x_hat

    for i in tqdm(range(self._total_steps), desc='ULA-%d'%len(self._step_sizes)):
      logw, x = body_fn(i, inputs)
      inputs  = (logw, x)
    if self._target_distribution is not None:
      logw += self._target_distribution(x, self._num_steps - 1)
    else:
      logw = None

    return x, logw, None


class AnnealedMALASampler(Sampler):
  """ Metropolis-Adjusted-Langevin-Algorithm (MALA) """
  def __init__(self, num_steps: int, num_samples_per_step: int, step_sizes: torch.Tensor,
               initial_distribution, target_distribution, gradient_function, energy_function):
    super().__init__(initial_distribution, target_distribution, gradient_function, energy_function)
    assert len(step_sizes) == num_steps, "Must have as many stepsizes as intermediate distributions."
    self._step_sizes = step_sizes
    self._num_steps = num_steps
    self._num_samples_per_step = num_samples_per_step
    self._total_steps = self._num_samples_per_step * (self._num_steps - 1)
    self._total_steps_reverse = self._num_samples_per_step * self._num_steps

  def transition_distribution(self, i, x):
    step_size = self._step_sizes[i]
    mu = x + step_size * self._gradient_function(x, i)
    std = torch.ones(mu.shape) * math.sqrt(2 * step_size)
    dist = D.Independent(D.Normal(mu, std), reinterpreted_batch_ndims = 1)
    return dist

  def sample(self, n_samples: int):
    x = self._initial_distribution.sample(sample_shape=(n_samples,))
    logw = -self._initial_distribution.log_prob(x)

    accept_rate = torch.zeros((self._num_steps,))
    inputs = (logw, x, accept_rate)

    def body_fn(i, inputs):
      # setup
      logw, x, accept_rate = inputs
      dist_ind = (i // self._num_samples_per_step)

      # compute forward distribution and sample
      dist_forward = self.transition_distribution(dist_ind, x)
      x_hat = dist_forward.sample()

      # compute reverse distribution
      dist_reverse = self.transition_distribution(dist_ind, x_hat)

      # compute previous and current logp(x)
      logp_x = self._energy_function(x, dist_ind)
      logp_x_hat = self._energy_function(x_hat, dist_ind)

      # compute proposal and reversal probs
      logp_reverse = dist_reverse.log_prob(x)
      logp_forward = dist_forward.log_prob(x_hat)

      # accept prob
      logp_accept = logp_x_hat - logp_x + logp_reverse - logp_forward
      u = torch.rand((x.shape[0],))
      accept = (u < torch.exp(logp_accept)).type(torch.get_default_dtype())

      # update samples and importance weights
      x = accept[:, None] * x_hat + (1 - accept[:, None]) * x
      logw += (logp_x - logp_x_hat) * accept

      # update accept rate
      accept_rate[dist_ind] = accept_rate[dist_ind] + accept.mean()
      return logw, x, accept_rate
    
    for i in tqdm(range(self._total_steps), desc='MALA-%d'%len(self._step_sizes)):
      logw, x, accept_rate = body_fn(i, inputs)
      inputs = (logw, x, accept_rate)
    accept_rate /= self._num_samples_per_step
    # logw += self._target_distribution(x, self._num_steps - 1)
    return x, logw, accept_rate


class AnnealedUHMCSampler(Sampler):
  """ Unadjusted-Hamitonian-Monte-Carlo (UHMC) """
  def __init__(self, num_steps: int, num_samples_per_step: int, step_sizes: torch.Tensor,
               damping_coeff: int, mass_diag_sqrt: int, num_leapfrog_steps: int,
               initial_distribution, target_distribution, gradient_function):
    super().__init__(initial_distribution, target_distribution, gradient_function)
    assert len(step_sizes) == num_steps, "Must have as many stepsizes as intermediate distributions."
    self._damping_coeff = damping_coeff
    self._mass_diag_sqrt = mass_diag_sqrt
    self._step_sizes = step_sizes
    self._num_steps = num_steps
    self._num_leapfrog_steps = num_leapfrog_steps
    self._num_samples_per_step = num_samples_per_step
    self._total_steps = self._num_samples_per_step * (self._num_steps - 1)

  def leapfrog_step(self, x, v, i):
      step_size = self._step_sizes[i]
      return leapfrog_step(x, v, lambda _x: self._gradient_function(_x, i), step_size, self._mass_diag_sqrt, self._num_leapfrog_steps)

  def sample(self, n_samples: int):
    x_k = self._initial_distribution.sample(sample_shape=(n_samples,))
    v_dist = D.Independent(
      D.Normal(torch.zeros(x_k.shape), torch.ones(x_k.shape) * self._mass_diag_sqrt), 
      reinterpreted_batch_ndims = 1,
    )
    v_k  = v_dist.sample()
    logw = -self._initial_distribution.log_prob(x_k)

    inputs = (logw, x_k, v_k)
    def body_fn(i, inputs):
      # unpack inputs
      logw, x_k, v_k = inputs
      dist_ind = (i // self._num_samples_per_step)
      eps = torch.randn(x_k.shape)
      # resample momentum
      v_k_prime = v_k * self._damping_coeff + math.sqrt(1. - self._damping_coeff**2) * eps * self._mass_diag_sqrt
      # advance samples
      x_k, v_k = self.leapfrog_step(x_k, v_k_prime, dist_ind)
      # compute change in density
      logp_v_p = v_dist.log_prob(v_k_prime)
      logp_v = v_dist.log_prob(v_k)
      # update importance weights
      logw += logp_v - logp_v_p
      return logw, x_k, v_k
    
    for i in tqdm(range(self._total_steps), desc='UHMC-%d'%len(self._step_sizes)):
      logw, x_k, v_k, = body_fn(i, inputs)
      inputs = (logw, x_k, v_k)
  
    if self._target_distribution is not None:
      logw += self._target_distribution(x_k, self._num_steps - 1)
    else:
      logw = None

    return x_k, logw, None


class AnnealedMAHMCSampler(Sampler):
  """ Metropolis-Adjusted-Hanmitonian-Monte-Carlo (MAHMC)  """
  def __init__(self, num_steps: int, num_samples_per_step: int, step_sizes: torch.Tensor,
               damping_coeff: float, mass_diag_sqrt: float, num_leapfrog_steps: int,
               initial_distribution, target_distribution, gradient_function, energy_function):
    super().__init__(initial_distribution, target_distribution, gradient_function, energy_function)
    assert len(step_sizes) == num_steps, "Must have as many stepsizes as intermediate distributions."
    self._damping_coeff = damping_coeff
    self._mass_diag_sqrt = mass_diag_sqrt
    self._step_sizes = step_sizes
    self._num_steps = num_steps
    self._num_leapfrog_steps = num_leapfrog_steps
    self._num_samples_per_step = num_samples_per_step
    self._total_steps = self._num_samples_per_step * (self._num_steps - 1)
    self._total_steps_reverse = self._num_samples_per_step * self._num_steps

  def leapfrog_step(self, x, v, i):
      step_size = self._step_sizes[i]
      return leapfrog_step(x, v, lambda _x: self._gradient_function(_x, i), step_size, self._mass_diag_sqrt, self._num_leapfrog_steps)

  def sample(self, n_samples: int):
    x_k = self._initial_distribution.sample(sample_shape=(n_samples,))
    v_dist = D.Independent(
      D.Normal(torch.zeros(x_k.shape), torch.ones(x_k.shape) * self._mass_diag_sqrt), 
      reinterpreted_batch_ndims = 1,
    )
    v_k  = v_dist.sample()
    logw = -self._initial_distribution.log_prob(x_k)
    accept_rate = torch.zeros((self._num_steps,))
    inputs = (logw, x_k, v_k, accept_rate)

    def body_fn(i, inputs):
      # unpack inputs
      logw, x_k, v_k, accept_rate = inputs
      dist_ind = (i // self._num_samples_per_step) + 1
      eps = torch.randn(x_k.shape)

      # resample momentum
      v_k_prime = v_k * self._damping_coeff + math.sqrt(1. - self._damping_coeff ** 2) * eps * self._mass_diag_sqrt

      # advance samples
      x_k_next, v_k_next = self.leapfrog_step(x_k, v_k_prime, dist_ind)

      # compute change in density
      logp_v_p = v_dist.log_prob(v_k_prime)
      logp_v = v_dist.log_prob(v_k_next)

      # compute target log-probs
      logp_x = self._energy_function(x_k, dist_ind)
      logp_x_hat = self._energy_function(x_k_next, dist_ind)

      # compute joint log-probs
      log_joint_prev = logp_x + logp_v_p
      log_joint_next = logp_x_hat + logp_v

      # acceptance prob
      logp_accept = log_joint_next - log_joint_prev
      u = torch.rand((x_k_next.shape[0],))
      accept = (u < torch.exp(logp_accept)).type(torch.get_default_dtype())

      # update importance weights
      logw += (logp_x - logp_x_hat) * accept

      # update samples
      accept = accept.unsqueeze(1)
      x_k = accept * x_k_next + (1 - accept) * x_k
      v_k = accept * v_k_next + (1 - accept) * v_k_prime
      accept_rate[dist_ind] = accept_rate[dist_ind] + accept.mean()
      return logw, x_k, v_k, accept_rate

    for i in tqdm(range(self._total_steps), desc='MAHMC-%d'%len(self._step_sizes)):
      logw, x_k, v_k, accept_rate = body_fn(i, inputs)
      inputs = (logw, x_k, v_k, accept_rate)
    # logw += self._target_distribution(x_k, self._num_steps - 1)
    accept_rate /= self._num_samples_per_step
    return x_k, logw, accept_rate


class AnnealedLHMCSampler(Sampler):
  """ Metropolis-Adjusted-Hanmitonian-Monte-Carlo (MAHMC)  """
  def __init__(self, num_steps: int, 
               num_samples_per_step: int, 
               step_sizes: torch.Tensor,
               lambda_L: float, lambda_G: float, lambda_V: float,
               elastic_variance: float, elastic_force: float, random_momentum: bool,
               damping_coeff: float, mass_diag_sqrt: float, num_leapfrog_steps: int,
               initial_distribution, target_distribution, gradient_function, energy_function):
    super().__init__(initial_distribution, target_distribution, gradient_function, energy_function)
    assert len(step_sizes) == num_steps, "Must have as many stepsizes as intermediate distributions."
    self._lambda_L, self._lambda_G, self._lambda_V = lambda_L, lambda_G, lambda_V
    self._damping_coeff = damping_coeff
    self._mass_diag_sqrt = mass_diag_sqrt
    self._step_sizes = step_sizes
    self._num_steps = num_steps
    self._num_leapfrog_steps = num_leapfrog_steps
    self._num_samples_per_step = num_samples_per_step
    self._total_steps = self._num_samples_per_step * (self._num_steps - 1)
    self._total_steps_reverse = self._num_samples_per_step * self._num_steps

  def leapfrog_step(self, x, v, i, num_particles):
      step_size = self._step_sizes[i]
      return leader_leapfrog_step(x, v, self._lambda_G, lambda _x: self._energy_function(x, i), lambda _x: self._gradient_function(_x, i), step_size, self._mass_diag_sqrt, self._num_leapfrog_steps, num_particles)

  def sample(self, n_samples: int):
    x_k = self._initial_distribution.sample(sample_shape=(n_samples,))
    v_dist = D.Independent(
      D.Normal(torch.zeros(x_k.shape), torch.ones(x_k.shape) * self._mass_diag_sqrt), 
      reinterpreted_batch_ndims = 1,
    )
    v_k  = v_dist.sample()
    logw = -self._initial_distribution.log_prob(x_k)
    accept_rate = torch.zeros((self._num_steps,))
    inputs = (logw, x_k, v_k, accept_rate)
    num_particles = 4

    def body_fn(i, inputs):
      # unpack inputs
      logw, x_k, v_k, accept_rate = inputs
      dist_ind = (i // self._num_samples_per_step) + 1
      eps = torch.randn(x_k.shape)

      # resample momentum
      v_k_prime = v_k * self._damping_coeff + math.sqrt(1. - self._damping_coeff ** 2) * eps * self._mass_diag_sqrt

      # advance samples
      x_k_next, v_k_next = self.leapfrog_step(x_k, v_k_prime, dist_ind, num_particles)

      # compute change in density
      logp_v_p = v_dist.log_prob(v_k_prime)
      logp_v = v_dist.log_prob(v_k_next)

      # compute target log-probs
      logp_x = self._energy_function(x_k, dist_ind)
      logp_x_hat = self._energy_function(x_k_next, dist_ind)

      # compute joint log-probs
      log_joint_prev = logp_x + logp_v_p
      log_joint_next = logp_x_hat + logp_v

      # acceptance prob
      # print((logp_x_hat - logp_x)[:4] + (logp_v - logp_v_p)[:4])

      logp_accept = log_joint_next - log_joint_prev
      logp_accept = logp_accept.view(num_particles, -1).sum(dim=0)
      # print(logp_accept[:4])

      accept = torch.rand_like(logp_accept) < torch.exp(logp_accept)
      accept = accept.type(torch.get_default_dtype()).repeat_interleave(num_particles, dim=0)
      # print(accept.sum() / len(accept))

      # update importance weights
      logw += (logp_x - logp_x_hat) * accept

      # update samples
      accept = accept.unsqueeze(1)
      x_k = accept * x_k_next + (1 - accept) * x_k
      v_k = accept * v_k_next + (1 - accept) * v_k_prime
      accept_rate[dist_ind] = accept_rate[dist_ind] + accept.mean()
      # print(accept_rate.sum())
      # energy_function = lambda x: self._energy_function(x, dist_ind)
      # gradient_function = lambda x: self._gradient_function(x, dist_ind)
      # step_size = min(self._step_sizes[dist_ind] * self._num_leapfrog_steps, 0.5)
      # _, x_k = lula(energy_function, gradient_function, x_k, step_size, self._lambda_L, self._lambda_V)
      return logw, x_k, v_k, accept_rate

    for i in tqdm(range(self._total_steps), desc='LHMC-%d'%len(self._step_sizes)):
      logw, x_k, v_k, accept_rate = body_fn(i, inputs)
      inputs = (logw, x_k, v_k, accept_rate)
    # logw += self._target_distribution(x_k, self._num_steps - 1)
    accept_rate /= self._num_samples_per_step
    return x_k, logw, accept_rate