from functools import partial

import jax.numpy as jnp
from jax import random
from jax.lax import while_loop
from jax.lax import cond
from jax.lax import switch
from jax.scipy.stats import norm as norm_dist
from jax.experimental.ode import odeint

from tensorflow_probability.substrates.jax.distributions.noncentral_chi2 import NoncentralChi2

from .util import logsubexp
from .util import trunc_gumbel

from .gauss_helpers import gauss_ratio_stats
from .gauss_helpers import gauss_log_mass_of_sphere
from .gauss_helpers import gauss_inf_div

from .priority_queue import create_heap, heappop, heappush, is_empty


def slow_gprs_encoder(seed, q_loc, q_scale, max_iter):
  """Assume standard Gaussian proposal."""

  print("tracing")

  base_key = random.PRNGKey(seed)

  r_loc, r_var, logZ = gauss_ratio_stats(q_loc, q_scale**2., 0., 1.)

  radius_sq_const = -(jnp.log(2. * jnp.pi) + jnp.log(r_var)) + 2. * logZ
  radius_sq_const = r_var * radius_sq_const

  def inv_sigma_prime(level, _):
      
    log_level = jnp.log(level)
    radius_sq = -2. * r_var * log_level + radius_sq_const

    log_p_mass = gauss_log_mass_of_sphere(r_loc, radius_sq, 0., 1.)
    log_q_mass = gauss_log_mass_of_sphere(r_loc, radius_sq, q_loc, q_scale)

    log_sigma_prime = logsubexp(log_q_mass, log_level + log_p_mass)

    return jnp.exp(log_sigma_prime)

  def cond_fun(args):
    k, _, _, _, log_ratio, log_level = args
    return (k <= max_iter) & (log_ratio <= log_level)

  def body_fun(args):
    k, base_key, x, log_time, log_ratio, log_level = args

    key = random.fold_in(base_key, k)
    log_time_key, x_key = random.split(key)

    log_time = -trunc_gumbel(log_time_key, shape=(), loc=0., bound=-log_time)
    x = random.normal(x_key, shape=())

    level = odeint(inv_sigma_prime, jnp.array(0.), jnp.array([0., jnp.exp(log_time)]))[-1]
    log_level = jnp.log(level)

    log_ratio = norm_dist.logpdf(x, q_loc, q_scale) - norm_dist.logpdf(x)

    return k + 1, base_key, x, log_time, log_ratio, log_level

  k, base_key, x, log_time, log_ratio, log_level = while_loop(
   cond_fun=cond_fun,
   body_fun=body_fun,
   init_val=(0, base_key, jnp.nan, -jnp.inf, -jnp.inf, -jnp.inf)
  )

  return x, log_time, k


def sac_gprs_encoder(seed, q_loc, q_scale, max_iter):
  """Assume standard Gaussian proposal."""

  print("tracing")

  base_key = random.PRNGKey(seed)

  r_loc, r_var, logZ = gauss_ratio_stats(q_loc, q_scale**2., 0., 1.)

  radius_sq_const = -(jnp.log(2. * jnp.pi) + jnp.log(r_var)) + 2. * logZ
  radius_sq_const = r_var * radius_sq_const

  def inv_sigma_prime(level, _):
      
    log_level = jnp.log(level)
    radius_sq = -2. * r_var * log_level + radius_sq_const

    log_p_mass = gauss_log_mass_of_sphere(r_loc, radius_sq, 0., 1.)
    log_q_mass = gauss_log_mass_of_sphere(r_loc, radius_sq, q_loc, q_scale)

    log_sigma_prime = logsubexp(log_q_mass, log_level + log_p_mass)

    return jnp.exp(log_sigma_prime)

  def cond_fun(args):
    k, _, _, _, _, log_ratio, log_level, _, _ = args
    return (k <= max_iter) & (log_ratio <= log_level)

  def body_fun(args):
    k, base_key, heap_index, x, log_time, log_ratio, log_level, bound_left, bound_right = args

    key = random.fold_in(base_key, k)
    log_time_key, u_key = random.split(key)

    bound_size = bound_right - bound_left

    log_time = -trunc_gumbel(log_time_key, shape=(), loc=jnp.log(bound_size), bound=-log_time)

    u = random.uniform(u_key, shape=())
    u = bound_left + bound_size * u
    x = norm_dist.ppf(u)

    level = odeint(inv_sigma_prime, jnp.array(0.), jnp.array([0., jnp.exp(log_time)]))[-1]
    log_level = jnp.log(level)

    log_ratio = norm_dist.logpdf(x, q_loc, q_scale) - norm_dist.logpdf(x)

    bound_left = cond(jnp.all(x - r_loc < 0), lambda: u, lambda: bound_left)
    bound_right = cond(jnp.all(x - r_loc > 0), lambda: u, lambda: bound_right)

    heap_index = 2 * heap_index + jnp.all(x - r_loc < 0).astype(jnp.int32)

    # By picking the maximum, we ensure that the heap index of the first arrival is always 1.
    heap_index = jnp.maximum(heap_index, 1)

    return k + 1, base_key, heap_index, x, log_time, log_ratio, log_level, bound_left, bound_right

  k, base_key, heap_index, x, log_time, log_ratio, log_level, bound_left, bound_right = while_loop(
   cond_fun=cond_fun,
   body_fun=body_fun,
   init_val=(0, base_key, 0, jnp.nan, -jnp.inf, -jnp.inf, -jnp.inf, 0., 1.)
  )

  # code_length = jnp.floor(-jnp.log2(bound_right - bound_left)) + 2

  return x, log_time, k, heap_index #code_length


def dyadic_gprs_encoder(seed, q_loc, q_scale, max_iter):
  """Assume standard Gaussian proposal."""

  base_key = random.PRNGKey(seed)

  r_loc, r_var, logZ = gauss_ratio_stats(q_loc, q_scale**2., 0., 1.)

  radius_sq_const = -(jnp.log(2. * jnp.pi) + jnp.log(r_var)) + 2. * logZ
  radius_sq_const = r_var * radius_sq_const

  def compute_log_survival_prob(level, bound_lower, bound_upper):

    log_level = jnp.log(level)
    
    # Compute superlevel set bounds
    sol_term = 2. * (logZ - log_level) - (jnp.log(2. * jnp.pi) + jnp.log(r_var))
    sol_term = jnp.sqrt(r_var * sol_term)

    sol_upper = r_loc + sol_term
    sol_lower = r_loc - sol_term

    bound_lower = norm_dist.ppf(bound_lower)
    bound_upper = norm_dist.ppf(bound_upper)

    intersection_lower = jnp.maximum(bound_lower, sol_lower)
    intersection_upper = jnp.minimum(bound_upper, sol_upper)

    log_q_mass, intersection_criterion = logsubexp(
      norm_dist.logcdf(intersection_upper, q_loc, q_scale),
      norm_dist.logcdf(intersection_lower, q_loc, q_scale),
      return_sign=True)

    log_q_mass = jnp.where(intersection_criterion <= 0, -jnp.inf, log_q_mass)

    log_p_mass = logsubexp(
      norm_dist.logcdf(intersection_upper),
      norm_dist.logcdf(intersection_lower))

    log_p_mass = jnp.where(intersection_criterion <= 0, -jnp.inf, log_p_mass)

    log_survival_prob = logsubexp(log_q_mass, log_level + log_p_mass)

    return log_survival_prob

  def cond_fun(args):
    k, _, _, _, _, log_ratio, log_level, _, _ = args
    return (k <= max_iter) & (log_ratio <= log_level)

  def body_fun(args):
    k, base_key, heap_index, x, log_time, log_ratio, log_level, bound_left, bound_right = args

    key = random.fold_in(base_key, k)
    log_time_key, u_key, b_key = random.split(key, num=3)

    bound_size = bound_right - bound_left
    midpoint = (bound_left + bound_right) / 2.

    bound_left_options = jnp.array([bound_left, midpoint])
    bound_right_options = jnp.array([midpoint, bound_right])

    log_time = -trunc_gumbel(log_time_key, shape=(), loc=jnp.log(bound_size), bound=-log_time)

    u = random.uniform(u_key, shape=())
    u = bound_left + bound_size * u

    x = norm_dist.ppf(u)
    log_ratio = norm_dist.logpdf(x, q_loc, q_scale) - norm_dist.logpdf(x)

    inv_sigma_prime = lambda Y, _: jnp.exp(compute_log_survival_prob(Y, bound_left, bound_right))

    level = odeint(inv_sigma_prime, jnp.array(0.), jnp.array([0., jnp.exp(log_time)]))[-1]
    log_level = jnp.log(level)

    left_log_survival = compute_log_survival_prob(level, bound_left, midpoint)
    right_log_survival = compute_log_survival_prob(level, midpoint, bound_right)

    log_cond_right_prob = right_log_survival - jnp.logaddexp(left_log_survival, right_log_survival)

    b = random.bernoulli(b_key, jnp.exp(log_cond_right_prob)).astype(jnp.int32)

    bound_left = bound_left_options[b]
    bound_right = bound_right_options[b]

    # By picking the maximum, we ensure that the heap index of the first arrival is always 1.
    heap_index = jnp.maximum(2 * heap_index + b, 1)

    return k + 1, base_key, heap_index, x, log_time, log_ratio, log_level, bound_left, bound_right

  k, base_key, heap_index, x, log_time, log_ratio, log_level, bound_left, bound_right = while_loop(
   cond_fun=cond_fun,
   body_fun=body_fun,
   init_val=(0, base_key, 0, jnp.nan, -jnp.inf, -jnp.inf, -jnp.inf, 0., 1.)
  )

  return x, log_time, k, heap_index


def sac_a_star_encoder(seed, q_loc, q_scale, max_iter):

  base_key = random.PRNGKey(seed)
  r_loc, _, _ = gauss_ratio_stats(q_loc, q_scale**2., 0., 1.)
  log_ratio_bound = gauss_inf_div(q_loc, q_scale, 0., 1.)

  def cond_fun(args):
    k, _, _, _, _, lower_bound, upper_bound, _, _ = args
    return (k <= max_iter) & (lower_bound <= upper_bound)

  def body_fun(args):
    k, base_key, best_k, best_x, log_time, lower_bound, upper_bound, bound_left, bound_right = args

    key = random.fold_in(base_key, k)
    log_time_key, u_key = random.split(key)

    bound_size = bound_right - bound_left

    log_time = trunc_gumbel(log_time_key, shape=(), loc=jnp.log(bound_size), bound=log_time)

    u = random.uniform(u_key, shape=())
    u = bound_left + bound_size * u
    x = norm_dist.ppf(u)

    log_ratio = norm_dist.logpdf(x, q_loc, q_scale) - norm_dist.logpdf(x)

    best_k, best_x, lower_bound = cond(jnp.all(log_ratio + log_time > lower_bound),
                                       lambda: (k, x, log_ratio + log_time),
                                       lambda: (best_k, best_x, lower_bound))

    upper_bound = log_time + log_ratio_bound

    bound_left = cond(jnp.all(x - r_loc < 0), lambda: u, lambda: bound_left)
    bound_right = cond(jnp.all(x - r_loc > 0), lambda: u, lambda: bound_right)

    return k + 1, base_key, best_k, best_x, log_time, lower_bound, upper_bound, bound_left, bound_right

  num_iters, base_key, index, x, _, _, _, _, _ = while_loop(
   cond_fun=cond_fun,
   body_fun=body_fun,
   init_val=(1, base_key, 0, jnp.nan, jnp.inf, -jnp.inf, jnp.inf, 0., 1.)
  )

  return x, index, num_iters


def dyadic_a_star_encoder(seed, q_loc, q_scale, max_iter):

  heap = create_heap(6, 30)

  base_key = random.PRNGKey(seed)
  log_ratio_global_bound = gauss_inf_div(q_loc, q_scale, 0., 1.)

  r_loc, _, _ = gauss_ratio_stats(q_loc, q_scale**2., 0., 1.)

  def log_ratio_bound_in_interval(left, right):
    """expects bounds to be subset of [0, 1]"""
    left = norm_dist.ppf(left)
    right = norm_dist.ppf(right)

    # Switch cond takes 3 possible values:
    # 0: r_loc < left
    # 1: left < r_loc < right
    # 2: right < r_loc
    switch_cond = jnp.all(left < r_loc).astype(jnp.int32) + jnp.all(right < r_loc).astype(jnp.int32)
    return switch(switch_cond,
                  (lambda: norm_dist.logpdf(left, q_loc, q_scale) - norm_dist.logpdf(left),
                   lambda: log_ratio_global_bound,
                   lambda: norm_dist.logpdf(right, q_loc, q_scale) - norm_dist.logpdf(right)))
  
  def cond_fun(args):
    k, _, _, _, heap, _ = args

    return (k < max_iter) & ~is_empty(heap)

  def body_fun(args):
    k, base_key, best_heap_index, best_x, heap, lower_bound = args

    top_item, heap = heappop(heap)

    # top_item[0] is the upper bound, which we throw away
    x = top_item[1]
    log_time = top_item[2]
    heap_index = top_item[3].astype(jnp.int32)
    bound_left = top_item[4]
    bound_right = top_item[5]

    log_ratio = norm_dist.logpdf(x, q_loc, q_scale) - norm_dist.logpdf(x)

    # if the current sample sets a better lower bound, update stuff
    lower_bound, best_heap_index, best_x = cond(
      lower_bound < log_time + log_ratio,
      lambda: (log_time + log_ratio, heap_index, x),
      lambda: (lower_bound, best_heap_index, best_x))

    # split current bounds
    midpoint = (bound_left + bound_right) / 2.

    # add left bounds
    left_heap_index = 2 * heap_index

    key = random.fold_in(base_key, left_heap_index)
    log_time_key, u_key = random.split(key)

    left_bound_size = midpoint - bound_left
    left_log_time = trunc_gumbel(log_time_key, shape=(), loc=jnp.log(left_bound_size), bound=log_time)

    left_u = random.uniform(u_key)
    left_u = bound_left + left_bound_size * left_u

    left_x = norm_dist.ppf(left_u)

    left_log_ratio_upper_bound = log_ratio_bound_in_interval(bound_left, midpoint)
    left_upper_bound = log_time + left_log_ratio_upper_bound

    heap = cond(
      left_upper_bound > lower_bound,
      lambda: heappush(heap, jnp.array([
        left_upper_bound,  # negate the upper bound, because we are usign a min-heap implementation
        left_x,
        left_log_time,
        left_heap_index.astype(left_x.dtype),
        bound_left,
        midpoint
      ])),
      lambda: heap
    )

    # add right bounds
    right_heap_index = 2 * heap_index + 1

    key = random.fold_in(base_key, right_heap_index)
    log_time_key, u_key = random.split(key)

    right_bound_size = bound_right - midpoint
    right_log_time = trunc_gumbel(log_time_key, shape=(), loc=jnp.log(right_bound_size), bound=log_time)

    right_u = random.uniform(u_key)
    right_u = midpoint + right_bound_size * right_u

    right_x = norm_dist.ppf(right_u)

    right_log_ratio_upper_bound = log_ratio_bound_in_interval(midpoint, bound_right)
    right_upper_bound = log_time + right_log_ratio_upper_bound

    heap = cond(
      right_upper_bound > lower_bound,
      lambda: heappush(heap, jnp.array([
        -right_upper_bound,  # negate the upper bound, because we are usign a min-heap implementation
        right_x,
        right_log_time,
        right_heap_index.astype(right_x.dtype),
        midpoint,
        bound_right
      ])),
      lambda: heap
    )

    return k + 1, base_key, best_heap_index, best_x, heap, lower_bound

  key = random.fold_in(base_key, 1)
  log_time_key, x_key = random.split(key)

  log_time = trunc_gumbel(log_time_key, shape=(), loc=0., bound=jnp.inf)
  x = random.normal(x_key)

  lower_bound = log_time + norm_dist.logpdf(x, q_loc, q_scale) - norm_dist.logpdf(x)
  upper_bound = log_time + log_ratio_global_bound

  heap = heappush(heap, jnp.array([
    -upper_bound,  # negate the upper bound, because we are usign a min-heap implementation
    x,
    log_time,
    1.,
    0.,
    1.
  ]))

  k, _, heap_index, x, _, _ = while_loop(
    cond_fun,
    body_fun,
    (1, base_key, 1, x, heap, lower_bound)
  )

  return x, heap_index, k