import jax.numpy as jnp
from jax.scipy.stats import norm as normal_dist
from tensorflow_probability.substrates.jax.math import lambertw
from tensorflow_probability.substrates.jax.distributions.noncentral_chi2 import NoncentralChi2


def gauss_kl(q_loc, q_scale, p_loc, p_scale):
  """Computes KL[Q || P]."""
  log_term = 2. * (jnp.log(p_scale) - jnp.log(q_scale))
  quad_term = (p_loc - q_loc)**2 / p_scale**2
  var_ratio_term = (q_scale / p_scale)**2

  return 0.5 * (var_ratio_term + quad_term + log_term - 1.)


def gauss_ratio_stats(q_loc, q_var, p_loc, p_var):
  """The ratio of two Gaussian densities is a scaled Gaussian density.
  Thus, this function computes the mean, variance and log-scaling factor of the density ratio.
  """
  var_delta = p_var - q_var

  ratio_var = p_var * q_var / var_delta

  ratio_loc = p_var * q_loc - q_var * p_loc
  ratio_loc = ratio_loc / var_delta

  logZ = jnp.log(p_var) - jnp.log(var_delta) - normal_dist.logpdf(q_loc, p_loc, jnp.sqrt(var_delta))

  return ratio_loc, ratio_var, logZ


def gauss_inf_div(q_loc, q_scale, p_loc, p_scale):
  """Computes D_infty[Q || P], i.e. the supremum of the log density ratio."""
  _, ratio_var, logZ = gauss_ratio_stats(q_loc, q_scale**2, p_loc, p_scale**2)

  return -0.5 * (jnp.log(2. * jnp.pi) + jnp.log(ratio_var)) + logZ


def gauss_variance_from_kl_and_mean(kl, mean):
  exponent = -2. * kl + mean**2. - 1.

  if exponent > -1.:
      raise ValueError("mean**2 must be less than twice the KL!")

  return -lambertw(-jnp.exp(exponent))


def gauss_mean_and_variance_from_kl_and_inf_div(kl, inf_div):

  if jnp.any(kl >= inf_div):
      raise ValueError("KL must be less than the infinity divergence!")

  b = 2. * inf_div - 1.
  a = b - 2. * kl

  var = jnp.exp(lambertw(a * jnp.exp(b)) - b)
  mean = jnp.sqrt(2. * kl - var + jnp.log(var) + 1.)

  return mean, var


def gauss_log_mass_of_sphere(sphere_loc, sphere_radius_squared, p_loc, p_scale):
  """Let P be a Gaussian measure with location `p_loc` and scale `p_scale`.
  Then, this function computes the P-measure of a sphere with the given location and squared radius."""

  loc = (sphere_loc - p_loc) / p_scale

  noncentrality = (loc**2).sum()
  df = loc.size

  return jnp.where(sphere_radius_squared == jnp.inf,
                   0.,
                   NoncentralChi2(df, noncentrality).log_cdf(sphere_radius_squared / p_scale**2))


def gauss_ratio_super_level_sols(q_loc, q_scale, p_loc, p_scale, log_level):
  r_loc, r_var, logZ = gauss_ratio_stats(q_loc, q_scale**2, p_loc, p_scale**2)

  sol_term = 2. * (logZ - log_level) - (jnp.log(2. * jnp.pi) + jnp.log(r_var))
  sol_term = jnp.sqrt(r_var * sol_term)

  sol_up = r_loc + sol_term
  sol_down = r_loc - sol_term

  return sol_down, sol_up