"""Functions for computing probability lower bounds."""
import functools
import equinox
import jax
import jax.numpy as jnp

import utils
from probability_flow import get_hutchinson_div_fn, draw_epsilon
from score_flow.bound_likelihood import get_likelihood_offset_fn
from score_flow import sde_lib

def get_div_drift_fn(sde):
  def div_drift_fn(x, t):
    """Returns the divergence of SDE f(x, t) with respect to x."""
    dim = x.shape[1] * x.shape[2] * x.shape[3]
    if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
      drift_coeff = -0.5 * (sde.beta_0 + t * (sde.beta_1 - sde.beta_0))
    elif isinstance(sde, sde_lib.VESDE):
      drift_coeff = jnp.zeros_like(t)
    else:
      raise NotImplementedError(
        f'div(f(x, t)) not implemented for SDE of type {type(sde)}')
    return drift_coeff * dim
  return div_drift_fn


def get_likelihood_bound_fn(
  score_fn,
  sde,
  eps=1e-3,
  nt=1000,
  importance_weighting=True,
  dsm=True,
  eps_offset=True,
  n_trace_estimates=1,
  hutchinson_type='rademacher'):
  """Returns the function for estimating lower bound of logp(x)."""

  offset_fn = get_likelihood_offset_fn(sde, score_fn, eps)
  div_drift_fn = get_div_drift_fn(sde)
  div_score_fn = get_hutchinson_div_fn(score_fn)
  
  def likelihood_bound_fn(rng, data):
    rng, step_rng = jax.random.split(rng)
    if importance_weighting:
      time_samples = sde.sample_importance_weighted_time_for_likelihood(
        step_rng, (nt, data.shape[0]), eps=eps)
      Z = sde.likelihood_importance_cum_weight(sde.T, eps=eps)
    else:
      time_samples = jax.random.uniform(step_rng, (nt, data.shape[0]), minval=eps, maxval=sde.T)
      Z = 1
    
    shape = data.shape  # (b, h, w, c)
    if not dsm:
      @functools.partial(jax.vmap, in_axes=(None, 0))
      def integrand_per_time(rng, t):
        rng, step_rng = jax.random.split(rng)
        epsilon = draw_epsilon(
          step_rng, (n_trace_estimates, *data[0].shape), hutchinson_type)

        rng, step_rng = jax.random.split(rng)
        noise = jax.random.normal(step_rng, shape)
        mean, std = sde.marginal_prob(data, t)
        noisy_data = mean + utils.batch_mul(std, noise)
        score_val = score_fn(noisy_data, t)
        score_div = div_score_fn(noisy_data, t, epsilon)
        score_norm = jnp.square(score_val.reshape((score_val.shape[0], -1))).sum(axis=-1)
        drift_div = div_drift_fn(noisy_data, t, epsilon)
        f, g = sde.sde(noisy_data, t)
        integrand = utils.batch_mul(g ** 2, 2 * score_div + score_norm) - 2 * drift_div
        if importance_weighting:
          integrand = utils.batch_mul(std ** 2 / g ** 2 * Z, integrand)
        return integrand
    else:
      @functools.partial(jax.vmap, in_axes=(None, 0))
      def integrand_per_time(rng, t):
        """Monte-Carlo approximation of the time integral.
        
        Args:
          t: Vmapped batch of time samples, of shape (nt, b).
        
        Output:
          integrand: Vmapped evaluation of integrand for each time sample,
            of shape (nt, b).
        """
        rng, step_rng = jax.random.split(rng)
        noise = jax.random.normal(step_rng, shape)
        mean, std = sde.marginal_prob(data, t)
        noisy_data = mean + utils.batch_mul(std, noise)

        drift_div = div_drift_fn(noisy_data, t)
        score_val = score_fn(noisy_data, t)
        grad = utils.batch_mul(-(noisy_data - mean), 1 / std ** 2)
        diff1 = score_val - grad
        diff1 = jnp.square(diff1.reshape((diff1.shape[0], -1))).sum(axis=-1)
        diff2 = jnp.square(grad.reshape((grad.shape[0], -1))).sum(axis=-1)
        _, g = sde.sde(noisy_data, t)
        integrand = utils.batch_mul(g ** 2, diff1 - diff2) - 2 * drift_div
        if importance_weighting:
          integrand = utils.batch_mul(std ** 2 / g ** 2 * Z, integrand)
        return integrand
      
    rng, step_rng = jax.random.split(rng)
    integrands = integrand_per_time(step_rng, time_samples)
    integral = jnp.mean(integrands, axis=0)

    mean, std = sde.marginal_prob(data, jnp.ones((data.shape[0],)) * sde.T)
    rng, step_rng = jax.random.split(rng)
    noise = jax.random.normal(step_rng, shape)
    neg_prior_logp = -sde.prior_logp(mean + utils.batch_mul(std, noise))
    nlogp = neg_prior_logp + 0.5 * integral

    if eps_offset:
      # Offset to account for not integrating exactly to 0.
      rng, step_rng = jax.random.split(rng)
      nlogp = nlogp + offset_fn(step_rng, data)

    logp = -nlogp
    return logp

  return likelihood_bound_fn

# def get_likelihood_bound_fn(
#   score_fn,
#   sde,
#   eps=1e-3,
#   nt=1000,
#   importance_weighting=True,
#   dsm=True,
#   eps_offset=True,
#   n_trace_estimates=1,
#   hutchinson_type='rademacher'):
#   """Returns the function for estimating lower bound of logp(x)."""

#   offset_fn = get_likelihood_offset_fn(sde, score_fn, eps)
#   div_drift_fn = get_div_drift_fn(sde)
#   div_score_fn = get_hutchinson_div_fn(score_fn)
  
#   def likelihood_bound_fn(rng, data):
#     rng, step_rng = jax.random.split(rng)
#     if importance_weighting:
#       time_samples = sde.sample_importance_weighted_time_for_likelihood(
#         step_rng, (nt, data.shape[0]), eps=eps)
#       Z = sde.likelihood_importance_cum_weight(sde.T, eps=eps)
#     else:
#       time_samples = jax.random.uniform(step_rng, (nt, data.shape[0]), minval=eps, maxval=sde.T)
#       Z = 1
    
#     shape = data.shape  # (b, h, w, c)
#     if not dsm:
#       def scan_fn(carry, vec_time):
#         rng, value = carry
#         rng, step_rng = jax.random.split(rng)
#         epsilon = draw_epsilon(
#           step_rng, (n_trace_estimates, *data[0].shape), hutchinson_type)

#         rng, step_rng = jax.random.split(rng)
#         noise = jax.random.normal(step_rng, shape)
#         mean, std = sde.marginal_prob(data, vec_time)
#         noisy_data = mean + utils.batch_mul(std, noise)
#         score_val = score_fn(noisy_data, vec_time)
#         score_div = div_score_fn(noisy_data, vec_time, epsilon)
#         score_norm = jnp.square(score_val.reshape((score_val.shape[0], -1))).sum(axis=-1)
#         drift_div = div_drift_fn(noisy_data, vec_time, epsilon)
#         f, g = sde.sde(noisy_data, vec_time)
#         integrand = utils.batch_mul(g ** 2, 2 * score_div + score_norm) - 2 * drift_div
#         if importance_weighting:
#           integrand = utils.batch_mul(std ** 2 / g ** 2 * Z, integrand)
#         return (rng, value + integrand), integrand
#     else:
#       def scan_fn(carry, vec_time):
#         rng, value = carry

#         # Perturb data.
#         rng, step_rng = jax.random.split(rng)
#         noise = jax.random.normal(step_rng, shape)
#         mean, std = sde.marginal_prob(data, vec_time)
#         noisy_data = mean + utils.batch_mul(std, noise)

#         drift_div = div_drift_fn(noisy_data, vec_time)
#         score_val = score_fn(noisy_data, vec_time)
#         grad = utils.batch_mul(-(noisy_data - mean), 1 / std ** 2)
#         diff1 = score_val - grad
#         diff1 = jnp.square(diff1.reshape((diff1.shape[0], -1))).sum(axis=-1)
#         diff2 = jnp.square(grad.reshape((grad.shape[0], -1))).sum(axis=-1)
#         _, g = sde.sde(noisy_data, vec_time)
#         integrand = utils.batch_mul(g ** 2, diff1 - diff2) - 2 * drift_div
#         if importance_weighting:
#           integrand = utils.batch_mul(std ** 2 / g ** 2 * Z, integrand)
#         return (rng, value + integrand), integrand
    
#     (rng, integral), _ = jax.lax.scan(
#       scan_fn, (rng, jnp.zeros((shape[0],))), time_samples)
#     integral = integral / nt
#     mean, std = sde.marginal_prob(data, jnp.ones((data.shape[0],)) * sde.T)
#     rng, step_rng = jax.random.split(rng)
#     noise = jax.random.normal(step_rng, shape)
#     neg_prior_logp = -sde.prior_logp(mean + utils.batch_mul(std, noise))
#     nlogp = neg_prior_logp + 0.5 * integral

#     if eps_offset:
#       # Offset to account for not integrating exactly to 0.
#       rng, step_rng = jax.random.split(rng)
#       nlogp = nlogp + offset_fn(step_rng, data)

#     logp = -nlogp
#     return logp

#   return likelihood_bound_fn


def get_likelihood_bound_loss_fn(
  score_fn,
  sde,
  eps=1e-3,
  nt=128,
  importance_weighting=True,
  dsm=True,
  eps_offset=True,
  n_trace_estimates=1,
  hutchinson_type='rademacher'):
  """Returns the function for estimating upper bound of -logp(x)."""

  offset_fn = get_likelihood_offset_fn(sde, score_fn, eps)
  div_drift_fn = get_div_drift_fn(sde)
  div_score_fn = get_hutchinson_div_fn(score_fn)

  logp_bound_fn = get_likelihood_bound_fn(
    score_fn, sde, eps, nt, importance_weighting, dsm, eps_offset,
    n_trace_estimates, hutchinson_type)

  def loss_fn(rng, data):
    logp_bound = logp_bound_fn(rng, data)
    return -logp_bound

  return loss_fn


# def get_likelihood_bound_loss_fn(
#   score_fn,
#   sde,
#   eps=1e-3,
#   importance_weighting=True,
#   dsm=True,
#   eps_offset=True,
#   n_trace_estimates=1,
#   hutchinson_type='rademacher'):
#   """Returns the function for estimating upper bound of -logp(x)."""

#   offset_fn = get_likelihood_offset_fn(sde, score_fn, eps)
#   div_drift_fn = get_div_drift_fn(sde)
#   div_score_fn = get_hutchinson_div_fn(score_fn)

#   def loss_fn(rng, data):
#     shape = data.shape  # (b, h, w, c)
    
#     # x(T)
#     mean, std = sde.marginal_prob(data, jnp.ones((shape[0],)) * sde.T)
#     rng, step_rng = jax.random.split(rng)
#     z = jax.random.normal(step_rng, shape)
#     neg_prior_logp = -sde.prior_logp(mean + utils.batch_mul(std, z))

#     # x(t)
#     rng, step_rng = jax.random.split(rng)
#     if importance_weighting:
#       t = sde.sample_importance_weighted_time_for_likelihood(
#         step_rng, (shape[0],), eps=eps)
#       Z = sde.likelihood_importance_cum_weight(sde.T, eps=eps)
#     else:
#       t = jax.random.uniform(
#         step_rng, (shape[0],), minval=eps, maxval=sde.T)
    
#     rng, step_rng = jax.random.split(rng)
#     z = jax.random.normal(step_rng, shape)
#     mean, std = sde.marginal_prob(data, t)
#     perturbed_data = mean + utils.batch_mul(std, z)
#     score = score_fn(perturbed_data, t)

#     if not dsm:
#       # Estimate div(score(xt, t)).
#       rng, step_rng = jax.random.split(rng)
#       epsilon = draw_epsilon(
#         step_rng, (n_trace_estimates, *data[0].shape), hutchinson_type)
#       score_div = div_score_fn(perturbed_data, t, epsilon)
  
#       score_norm2 = jnp.square(score.reshape((score.shape[0], -1))).sum(axis=-1)
#       f, g = sde.sde(perturbed_data, t)
      
#       # div(f(xt, t)).
#       drift_div = div_drift_fn(perturbed_data, t)

#       losses = (
#         utils.batch_mul(g ** 2, 2 * score_div + score_norm2)
#         - 2 * drift_div)
#       if importance_weighting:
#         losses = utils.batch_mul(std ** 2 / g ** 2 * Z, losses)
#       losses = 0.5 * losses

#     else:
#       if importance_weighting:
#         losses = jnp.square(utils.batch_mul(score, std) + z)
#         losses = jnp.sum(losses.reshape((losses.shape[0], -1)), axis=-1)
#         grad_norm = jnp.square(z).reshape((z.shape[0], -1)).sum(axis=-1)
#         losses = (losses - grad_norm) * Z
#       else:
#         g2 = sde.sde(jnp.zeros_like(data), t)[1] ** 2
#         losses = jnp.square(score + utils.batch_mul(z, 1. / std))
#         losses = jnp.sum(losses.reshape((losses.shape[0], -1)), axis=-1) * g2
#         grad_norm = jnp.square(z).reshape((z.shape[0], -1)).sum(axis=-1)
#         grad_norm = grad_norm * g2 / (std ** 2)
#         losses = losses - grad_norm
    
#       rng, step_rng = jax.random.split(rng)
#       z = jax.random.normal(step_rng, shape)
#       rng, step_rng = jax.random.split(rng)
#       t = jax.random.uniform(step_rng, (shape[0],), minval=eps, maxval=sde.T)
#       mean, std = sde.marginal_prob(data, t)
#       noisy_data = mean + utils.batch_mul(std, z)
#       drift_div = div_drift_fn(noisy_data, t)

#       losses = neg_prior_logp + 0.5 * (losses - 2 * drift_div)
#     if eps_offset:
#       offset_fn = get_likelihood_offset_fn(sde, score_fn, eps)
#       rng, step_rng = jax.random.split(rng)
#       losses = losses + offset_fn(step_rng, data)
#     return losses

#   return loss_fn

def get_likelihood_bound_fn_with_scan(
  score_fn,
  sde,
  eps=1e-3,
  nt=1000,
  nz=1,
  importance_weighting=True,
  dsm=True,
  eps_offset=True,
  n_trace_estimates=1,
  hutchinson_type='rademacher'):
  """Returns the function for estimating lower bound of logp(x)."""

  offset_fn = get_likelihood_offset_fn(sde, score_fn, eps)
  div_drift_fn = get_div_drift_fn(sde)
  div_score_fn = get_hutchinson_div_fn(score_fn)
  
  def likelihood_bound_fn(rng, data):
    rng, step_rng = jax.random.split(rng)
    if importance_weighting:
      time_samples = sde.sample_importance_weighted_time_for_likelihood(
        step_rng, (nt, data.shape[0]), eps=eps)
      Z = sde.likelihood_importance_cum_weight(sde.T, eps=eps)
    else:
      time_samples = jax.random.uniform(step_rng, (nt, data.shape[0]), minval=eps, maxval=sde.T)
      Z = 1
    
    shape = data.shape  # (b, h, w, c)
    if not dsm:
      def time_scan_fn(carry, vec_time):
        rng, value = carry
        mean, std = sde.marginal_prob(data, vec_time)
        _, g = sde.sde(data, vec_time)

        def noise_scan_fn(rng_integrand, i):
          noise_rng, integrand_sum = rng_integrand
          noise_rng, step_rng = jax.random.split(noise_rng)
          epsilon = draw_epsilon(
            step_rng, (n_trace_estimates, *data[0].shape), hutchinson_type)

          noise_rng, step_rng = jax.random.split(noise_rng)
          noise = jax.random.normal(step_rng, shape)
          noisy_data = mean + utils.batch_mul(std, noise)
          score_val = score_fn(noisy_data, vec_time)
          score_div = div_score_fn(noisy_data, vec_time, epsilon)
          score_norm = jnp.square(score_val.reshape((score_val.shape[0], -1))).sum(axis=-1)
          drift_div = div_drift_fn(noisy_data, vec_time, epsilon)

          item = utils.batch_mul(g ** 2, 2 * score_div + score_norm) - 2 * drift_div
          return (noise_rng, integrand_sum + item), item
        
        rng, noise_rng = jax.random.split(rng)
        (_, integrand), _ = equinox.internal.scan(
          noise_scan_fn,
          (noise_rng, jnp.zeros((shape[0],))),
          np.arange(nz),
          kind='lax')
        integrand = integrand / nz
        if importance_weighting:
          integrand = utils.batch_mul(std ** 2 / g ** 2 * Z, integrand)
        return (rng, value + integrand), integrand

    else:
      def time_scan_fn(carry, vec_time):
        rng, value = carry
        mean, std = sde.marginal_prob(data, vec_time)
        _, g = sde.sde(data, vec_time)

        def noise_scan_fn(rng_integrand, i):
          noise_rng, integrand_sum = rng_integrand
          # Perturb data.
          noise_rng, step_rng = jax.random.split(noise_rng)
          noise = jax.random.normal(step_rng, shape)
          noisy_data = mean + utils.batch_mul(std, noise)

          drift_div = div_drift_fn(noisy_data, vec_time)
          score_val = score_fn(noisy_data, vec_time)
          grad = utils.batch_mul(-(noisy_data - mean), 1 / std ** 2)
          diff1 = score_val - grad
          diff1 = jnp.square(diff1.reshape((diff1.shape[0], -1))).sum(axis=-1)
          diff2 = jnp.square(grad.reshape((grad.shape[0], -1))).sum(axis=-1)

          item = utils.batch_mul(g ** 2, diff1 - diff2) - 2 * drift_div
          return (noise_rng, integrand_sum + item), item
        
        rng, noise_rng = jax.random.split(rng)
        (_, integrand), _ = equinox.internal.scan(
          noise_scan_fn,
          (noise_rng, jnp.zeros((shape[0],))),
          jnp.arange(nz),
          kind='checkpointed')
        integrand = integrand / nz
        if importance_weighting:
          integrand = utils.batch_mul(std ** 2 / g ** 2 * Z, integrand)
        return (rng, value + integrand), integrand

    (rng, integral), _ = equinox.internal.scan(
      time_scan_fn, (rng, jnp.zeros((shape[0],))), time_samples,
      kind='checkpointed')
    integral = integral / nt

    mean, std = sde.marginal_prob(data, jnp.ones((data.shape[0],)) * sde.T)
    rng, step_rng = jax.random.split(rng)
    noise = jax.random.normal(step_rng, shape)
    neg_prior_logp = -sde.prior_logp(mean + utils.batch_mul(std, noise))
    nlogp = neg_prior_logp + 0.5 * integral

    if eps_offset:
      # Offset to account for not integrating exactly to 0.
      rng, step_rng = jax.random.split(rng)
      nlogp = nlogp + offset_fn(step_rng, data)

    logp = -nlogp
    return logp

  return likelihood_bound_fn