import jax.random as random
import jax.numpy as jnp

from jax.scipy.stats import norm as normal_dist

from tensorflow_probability.substrates.jax.distributions.noncentral_chi2 import NoncentralChi2

from .gauss_helpers import gauss_kl
from .gauss_helpers import gauss_ratio_stats
from .gauss_helpers import gauss_inf_div
from .gauss_helpers import gauss_log_mass_of_sphere
from .gauss_helpers import gauss_ratio_super_level_sols

from .util import log1mexp
from .util import logsubexp


class IsotropicGaussianDensityRatio:
  """Models the case when both Q and P are isotropic Gaussians."""

  def __init__(self, q_loc, q_scale, p_loc, p_scale):

    if not jnp.isscalar(q_scale) or not jnp.isscalar(p_scale):
      raise ValueError("The scales have to be scalars!") 

    if not q_loc.shape == p_loc.shape:
      raise ValueError("The location of Q has to have the same shape as the location of P!")

    # if not q_scale < p_scale:
    #   raise ValueError("The Q scale has to be smaller than P scale!")

    # Just bookkeeping
    self.q_loc = q_loc
    self.q_scale = q_scale
    self.q_var = q_scale**2

    self.p_loc = p_loc
    self.p_scale = p_scale
    self.p_var = p_scale**2

    # Problem dimensionality
    self.dim = self.q_loc.size

    # Compute ratio statistics
    self.r_loc, self.r_var, self.logZ = gauss_ratio_stats(q_loc, q_scale**2, p_loc, p_scale**2)
    self.r_scale = jnp.sqrt(self.r_var)

    # Divergences in nats
    self.kl = gauss_kl(q_loc, q_scale, p_loc, p_scale)
    self.inf_div = gauss_inf_div(q_loc, q_scale, p_loc, p_scale)

    # Divergences in bits
    self.kl2 = self.kl / jnp.log(2.)
    self.inf_div2 = self.inf_div / jnp.log(2.)

    # Compute some constants for convenient computation later
    self.radius_sq_const = -self.dim * (jnp.log(2. * jnp.pi) + jnp.log(self.r_var)) + 2. * self.logZ
    self.radius_sq_const = self.r_var * self.radius_sq_const

    super().__init__()

  def sample_p(self, key, shape, lower=-jnp.inf, upper=jnp.inf):
    return self.p_loc + self.p_scale * random.truncated_normal(key, (lower - self.p_loc) / self.p_scale, (upper - self.p_loc) / self.p_scale, shape=shape)

  def sample_q(self, key, shape, lower=-jnp.inf, upper=jnp.inf):
    return self.q_loc + self.q_scale * random.truncated_normal(key, (lower - self.q_loc) / self.q_scale, (upper - self.q_loc) / self.q_scale, shape=shape)

  def log_ratio(self, x):
    return normal_dist.logpdf(x, self.q_loc, self.q_scale) - normal_dist.logpdf(x, self.p_loc, self.p_scale)

  def ratio(self, x):
    return jnp.exp(self.log_ratio(x))

  def log_width_p(self, h, log=False):
    """Computes the log-P-measure of the h-superlevel set of the density ratio."""

    log_h = h if log else jnp.log(h)

    radius = -2. * self.r_var * log_h + self.radius_sq_const

    return gauss_log_mass_of_sphere(self.r_loc, radius, self.p_loc, self.p_scale)

  def log_width_q(self, h, log=False):
    """Computes the log-Q-measure of the h-superlevel set of the density ratio."""

    log_h = h if log else jnp.log(h)

    radius = -2. * self.r_var * log_h + self.radius_sq_const

    return gauss_log_mass_of_sphere(self.r_loc, radius, self.q_loc, self.q_scale)

  def log_1m_lower_p_mass(self, h, log=False):
    """Computes the log of 1 - \int_0^h width_p(y) dy."""

    log_h = h if log else jnp.log(h)

    return logsubexp(self.log_width_q(h, log), log_h + self.log_width_p(h, log))

  def log_lower_p_mass(self, h, log=False):
    """Computes the log of \int_0^h width_p(y) dy."""
    return log1mexp(self.log_1m_lower_p_mass(h, log))

  def log_minus_width_p_prime(self, h):
    radius_sq = -2. * self.r_var * jnp.log(h) + self.radius_sq_const

    loc = (self.r_loc - self.p_loc) / self.p_scale
    noncentrality = (loc**2).sum()

    print(f"{radius_sq=}")

    return jnp.log(2. * self.r_var) - jnp.log(h) + NoncentralChi2(self.dim, noncentrality).log_prob(radius_sq)
    

  def stretch_ode(self, Y, h):
    """Implements a first-order ODE governing the stretch function for the density ratio.
    This function is to be passed to `scipy.odeint`."""

    log_derivative = -self.log_1m_lower_p_mass(h)

    return jnp.exp(log_derivative)

  def log_stretch_ode(self, Y, h):
    """Implements a first-order ODE governing the log of the stretch function for the density ratio.
    This function is to be passed to `scipy.odeint`."""

    log_derivative = -(Y + self.log_1m_lower_p_mass(h))

    return jnp.exp(log_derivative)

  def inv_stretch_ode(self, Y, t):
    """Implements a first-order ODE governing the inverse of the stretch function for the density ratio.
    This function is to be passed to `scipy.odeint`."""

    return jnp.exp(self.log_1m_lower_p_mass(Y))

  def inv_stretch_ode_log_time(self, Y, q):
    """Implements a first-order ODE governing the inverse of the stretch function for the density ratio.
    This function is to be passed to `scipy.odeint`.

    Same as `inv_stretch_ode`, but with the change of variables q = log t
    """

    return jnp.exp(q + self.log_1m_lower_p_mass(Y))

  def log_inv_stretch_ode_log_time(self, Y, q):
    """Implements a first-order ODE governing the log of the inverse of the stretch function for the density ratio.
    This function is to be passed to `scipy.odeint`."""

    return jnp.exp(self.log_1m_lower_p_mass(jnp.exp(Y)) + q - Y)
 

class OneDimensionalTruncatableGaussianDensityRatio:
  """Models the case when both Q and P are 1D Gaussians, and we might want to truncate Q"""

  def __init__(self, q_loc, q_scale, p_loc, p_scale):

    if not jnp.isscalar(q_scale) or not jnp.isscalar(p_scale):
      raise ValueError("The scales have to be scalars!") 

    if not jnp.isscalar(q_loc) or not jnp.isscalar(p_scale):
      raise ValueError("The locations have to be scalars!")

    if not q_scale < p_scale:
      raise ValueError("The Q scale has to be smaller than P scale!")

    # Just bookkeeping
    self.q_loc = q_loc
    self.q_scale = q_scale
    self.q_var = q_scale**2

    self.p_loc = p_loc
    self.p_scale = p_scale
    self.p_var = p_scale**2

    # Problem dimensionality
    self.dim = self.q_loc.size

    # Compute ratio statistics
    self.r_loc, self.r_var, self.logZ = gauss_ratio_stats(q_loc, q_scale**2, p_loc, p_scale**2)
    self.r_scale = jnp.sqrt(self.r_var)

    # Divergences in nats
    self.kl = gauss_kl(q_loc, q_scale, p_loc, p_scale)
    self.inf_div = gauss_inf_div(q_loc, q_scale, p_loc, p_scale)

    # Divergences in bits
    self.kl2 = self.kl / jnp.log(2.)
    self.inf_div2 = self.inf_div / jnp.log(2.)

    # Compute some constants for convenient computation later
    self.radius_sq_const = -self.dim * (jnp.log(2. * jnp.pi) + jnp.log(self.r_var)) + 2. * self.logZ
    self.radius_sq_const = self.r_var * self.radius_sq_const

    super().__init__()

  def log_ratio(self, x, lower=-jnp.inf, upper=jnp.inf):
    log_q_mass = logsubexp(normal_dist.logcdf(upper, self.q_loc, self.q_scale),
                           normal_dist.logcdf(lower, self.q_loc, self.q_scale))

    log_r = normal_dist.logpdf(x, self.q_loc, self.q_scale) - log_q_mass - normal_dist.logpdf(x, self.p_loc, self.p_scale)
    
    # mask for lower <= x <= upper.
    # It is non-negative if and only if the condition is satisfied.
    indicator_mask = jnp.minimum(x - lower, upper - x)

    return jnp.where(indicator_mask >= 0., log_r, -jnp.inf)

  def ratio(self, x):
    return jnp.exp(self.log_ratio(x))

  def log_width_p(self, h, lower=-jnp.inf, upper=jnp.inf, log=False):
    """Computes the log-P-measure of the h-superlevel set of the density ratio."""

    log_h = h if log else jnp.log(h)

    sol_down, sol_up = gauss_ratio_super_level_sols(
      self.q_loc, self.q_scale, self.p_loc, self.p_scale, log_h)

    # compute bound measures: for the parent bound we should be guaranteed intersection
    set_down = jnp.maximum(lower, sol_down)
    set_up = jnp.minimum(upper, sol_up)

    set_down = normal_dist.logcdf(set_down, self.p_loc, self.p_scale)
    set_up = normal_dist.logcdf(set_up, self.p_loc, self.p_scale)

    # `intersection_criterion`: the sign of the difference is -1 if and only if
    # intervals (lower, upper) and (sol_down, sol_up) DO NOT intersect
    log_p_measure, intersection_criterion = logsubexp(set_up, set_down, return_sign=True)

    return jnp.where(intersection_criterion <= 0, -jnp.inf, log_p_measure)

  def log_width_q(self, h, lower=-jnp.inf, upper=jnp.inf, log=False):
    """Computes the log-Q-measure of the h-superlevel set of the density ratio."""

    log_h = h if log else jnp.log(h)

    sol_down, sol_up = gauss_ratio_super_level_sols(
      self.q_loc, self.q_scale, self.p_loc, self.p_scale, log_h)

    # compute bound measures: for the parent bound we should be guaranteed intersection
    set_down = jnp.maximum(lower, sol_down)
    set_up = jnp.minimum(upper, sol_up)

    set_up = normal_dist.logcdf(set_up, self.q_loc, self.q_scale)
    set_down = normal_dist.logcdf(set_down, self.q_loc, self.q_scale)

    # `intersection_criterion`: the sign of the difference is -1 if and only if
    # intervals (lower, upper) and (sol_down, sol_up) DO NOT intersect
    log_q_measure, intersection_criterion = logsubexp(set_up, set_down, return_sign=True)

    return jnp.where(intersection_criterion <= 0, -jnp.inf, log_q_measure)

  def log_1m_lower_p_mass(self, h, lower=-jnp.inf, upper=jnp.inf, log=False):
    """Computes the log of 1 - \int_0^h width_p(y) dy."""

    log_h = h if log else jnp.log(h)

    return logsubexp(self.log_width_q(h, lower, upper, log), log_h + self.log_width_p(h, lower, upper, log))

  def log_lower_p_mass(self, h, lower=-jnp.inf, upper=jnp.inf, log=False):
    """Computes the log of \int_0^h width_p(y) dy."""
    return log1mexp(self.log_1m_lower_p_mass(h, lower, upper, log))

  def stretch_ode(self, Y, h, lower=-jnp.inf, upper=jnp.inf):
    """Implements a first-order ODE governing the stretch function for the density ratio.
    This function is to be passed to `scipy.odeint`."""

    log_derivative = -self.log_1m_lower_p_mass(h, lower, upper)

    return jnp.exp(log_derivative)

  def log_stretch_ode(self, Y, h, lower=-jnp.inf, upper=jnp.inf):
    """Implements a first-order ODE governing the log of the stretch function for the density ratio.
    This function is to be passed to `scipy.odeint`."""

    log_derivative = -(Y + self.log_1m_lower_p_mass(h, lower, upper))

    return jnp.exp(log_derivative)

  def inv_stretch_ode(self, Y, t, lower=-jnp.inf, upper=jnp.inf):
    """Implements a first-order ODE governing the inverse of the stretch function for the density ratio.
    This function is to be passed to `scipy.odeint`."""

    return jnp.exp(self.log_1m_lower_p_mass(Y, lower, upper))

  def inv_stretch_ode_log_time(self, Y, q, lower=-jnp.inf, upper=jnp.inf):
    """Implements a first-order ODE governing the inverse of the stretch function for the density ratio.
    This function is to be passed to `scipy.odeint`.

    Same as `inv_stretch_ode`, but with the change of variables q = log t
    """

    return jnp.exp(q + self.log_1m_lower_p_mass(Y, lower, upper))

  def log_inv_stretch_ode_log_time(self, Y, q, lower=-jnp.inf, upper=jnp.inf):
    """Implements a first-order ODE governing the log of the inverse of the stretch function for the density ratio.
    This function is to be passed to `scipy.odeint`."""

    return jnp.exp(self.log_1m_lower_p_mass(jnp.exp(Y), lower, upper) + q - Y)
 
class OneDimensionalTruncatableGaussianDensityRatioEqualVariance:
  """Models the case when both Q and P are 1D Gaussians with equal variance, and we might want to truncate Q"""

  def __init__(self, q_loc, p_loc, p_scale):

    if not jnp.isscalar(p_scale):
      raise ValueError("The scales have to be scalars!") 

    if not jnp.isscalar(q_loc) or not jnp.isscalar(p_scale):
      raise ValueError("The locations have to be scalars!")

    # Just bookkeeping
    self.q_loc = q_loc
    self.q_scale = p_scale
    self.q_var = p_scale**2

    self.p_loc = p_loc
    self.p_scale = p_scale
    self.p_var = p_scale**2

    # Problem dimensionality
    self.dim = self.q_loc.size

    # Divergences in nats
    self.kl = gauss_kl(q_loc, p_scale, p_loc, p_scale)
    self.inf_div = jnp.inf

    # Divergences in bits
    self.kl2 = self.kl / jnp.log(2.)
    self.inf_div2 = self.inf_div / jnp.log(2.)

    super().__init__()

  def log_ratio(self, x, lower=-jnp.inf, upper=jnp.inf):
    log_q_mass = logsubexp(normal_dist.logcdf(upper, self.q_loc, self.q_scale),
                           normal_dist.logcdf(lower, self.q_loc, self.q_scale))

    log_r = normal_dist.logpdf(x, self.q_loc, self.q_scale) - log_q_mass - normal_dist.logpdf(x, self.p_loc, self.p_scale)
    
    # mask for lower <= x <= upper.
    # It is non-negative if and only if the condition is satisfied.
    indicator_mask = jnp.minimum(x - lower, upper - x)

    return jnp.where(indicator_mask >= 0., log_r, -jnp.inf)

  def ratio(self, x):
    return jnp.exp(self.log_ratio(x))

  def log_width_p(self, h, lower=-jnp.inf, upper=jnp.inf, log=False):
    """Computes the log-P-measure of the h-superlevel set of the density ratio."""

    log_h = h if log else jnp.log(h)

    loc_sign = jnp.sign(self.q_loc - self.p_loc)

    # TODO

  def log_width_q(self, h, lower=-jnp.inf, upper=jnp.inf, log=False):
    """Computes the log-P-measure of the h-superlevel set of the density ratio."""

    log_h = h if log else jnp.log(h)

    sol_down, sol_up = gauss_ratio_super_level_sols(
      self.q_loc, self.q_scale, self.p_loc, self.p_scale, log_h)

    # compute bound measures: for the parent bound we should be guaranteed intersection
    set_down = jnp.maximum(lower, sol_down)
    set_up = jnp.minimum(upper, sol_up)

    # `intersection_criterion` is larger than 0 if and only if the
    # intervals (lower, upper) and (sol_down, sol_up) DO NOT intersect
    intersection_criterion = set_down - set_up

    log_q_upper_term = normal_dist.logcdf(set_up, self.q_loc, self.q_scale)
    log_q_lower_term = normal_dist.logcdf(set_down, self.q_loc, self.q_scale)
    log_q_measure = logsubexp(log_q_upper_term, log_q_lower_term)

    return jnp.where(intersection_criterion > 0, -jnp.inf, log_q_measure)

  def log_1m_lower_p_mass(self, h, lower=-jnp.inf, upper=jnp.inf, log=False):
    """Computes the log of 1 - \int_0^h width_p(y) dy."""

    log_h = h if log else jnp.log(h)

    return logsubexp(self.log_width_q(h, lower, upper, log), log_h + self.log_width_p(h, lower, upper, log))

  def log_lower_p_mass(self, h, lower=-jnp.inf, upper=jnp.inf, log=False):
    """Computes the log of \int_0^h width_p(y) dy."""
    return log1mexp(self.log_1m_lower_p_mass(h, lower, upper, log))

  def stretch_ode(self, Y, h, lower=-jnp.inf, upper=jnp.inf):
    """Implements a first-order ODE governing the stretch function for the density ratio.
    This function is to be passed to `scipy.odeint`."""

    log_derivative = -self.log_1m_lower_p_mass(h, lower, upper)

    return jnp.exp(log_derivative)

  def log_stretch_ode(self, Y, h, lower=-jnp.inf, upper=jnp.inf):
    """Implements a first-order ODE governing the log of the stretch function for the density ratio.
    This function is to be passed to `scipy.odeint`."""

    log_derivative = -(Y + self.log_1m_lower_p_mass(h, lower, upper))

    return jnp.exp(log_derivative)

  def inv_stretch_ode(self, Y, t, lower=-jnp.inf, upper=jnp.inf):
    """Implements a first-order ODE governing the inverse of the stretch function for the density ratio.
    This function is to be passed to `scipy.odeint`."""

    return jnp.exp(self.log_1m_lower_p_mass(Y, lower, upper))

  def inv_stretch_ode_log_time(self, Y, q, lower=-jnp.inf, upper=jnp.inf):
    """Implements a first-order ODE governing the inverse of the stretch function for the density ratio.
    This function is to be passed to `scipy.odeint`.

    Same as `inv_stretch_ode`, but with the change of variables q = log t
    """

    return jnp.exp(q + self.log_1m_lower_p_mass(Y, lower, upper))

  def log_inv_stretch_ode_log_time(self, Y, q, lower=-jnp.inf, upper=jnp.inf):
    """Implements a first-order ODE governing the log of the inverse of the stretch function for the density ratio.
    This function is to be passed to `scipy.odeint`."""

    return jnp.exp(self.log_1m_lower_p_mass(jnp.exp(Y), lower, upper) + q - Y)
