"""Diffusion Posterior Sampling (Chung & Kim et al. 2023)."""
import functools

import jax
import jax.numpy as jnp


def get_dps_sampler(sde, score_fn, y, likelihood, shape, inverse_scaler, eps):
  """Get Chung & Kim's DPS sampling function.

  Args:
    sde: An `sde_lib.SDE` object.
    score_fn: The score function s(x, t).
    y: Measurement vector, of shape (m,).
    likelihood: A `likelihood_lib.IndependentGaussianLikelihood` object.
    shape: Batch sampling shape, i.e., (b, h, w, c).
    inverse_scaler: Data inverse scaling function.
    eps: Epsilon value for t0.

  Returns:
    DPS sampling function.
  """
  alphas_cumprod = jnp.cumprod(sde.alphas, axis=0)
  alphas_cumprod_prev = jnp.append(1.0, alphas_cumprod[:-1])

  sqrt_recip_alphas_cumprod = jnp.sqrt(1.0 / alphas_cumprod)
  sqrt_recipm1_alphas_cumprod = jnp.sqrt(1.0 / alphas_cumprod - 1)  # pylint:disable=unused-variable
  posterior_mean_coef1 = (
      sde.discrete_betas * jnp.sqrt(alphas_cumprod_prev) /
      (1.0 - alphas_cumprod)
  )
  posterior_mean_coef2 = (
      (1.0 - alphas_cumprod_prev) * jnp.sqrt(sde.alphas) /
      (1.0 - alphas_cumprod)
  )

  posterior_variance = (
      sde.discrete_betas * (1.0 - alphas_cumprod_prev) /
      (1.0 - alphas_cumprod)
  )
  model_variance = jnp.append(posterior_variance[1], sde.discrete_betas[1:])
  model_log_variance = jnp.log(model_variance)

  @functools.partial(jax.vmap, in_axes=(0, 0, None))
  @functools.partial(jax.value_and_grad, has_aux=True)
  def val_and_grad_fn(x, t, t_idx):
    score = score_fn(x[None, :], jnp.ones(1) * t)[0]

    # Predict x_0 | x_t.
    coef1 = sqrt_recip_alphas_cumprod[t_idx]
    # coef2 = -sqrt_recipm1_alphas_cumprod[t_idx]
    coef2 = jnp.multiply(sqrt_recip_alphas_cumprod, (1 - alphas_cumprod))[t_idx]
    x0_hat = coef1 * x + coef2 * score

    # Get likelihood score.
    diff = y - likelihood.apply_forward_operator(x0_hat[None, ...])[0]
    # Official DPS experiments divide loss by `norm2(diff)`.
    neg_log_llh = jnp.linalg.norm(diff)
    # neg_log_llh = jnp.linalg.norm(diff)**2
    return neg_log_llh, (x0_hat,)

  def step_fn(rng, xt, t_batch, t_idx, scale):
    # Predict x_0 | x_t.
    (_, (x0_hat,)), gradient = val_and_grad_fn(xt, t_batch, t_idx)

    # Sample from q(x_{t-1} | x_t, x_0).
    coef1 = posterior_mean_coef1[t_idx]
    coef2 = posterior_mean_coef2[t_idx]
    xt_prime = coef1 * x0_hat + coef2 * xt

    log_variance = model_log_variance[t_idx]
    noise = jax.random.normal(rng, xt.shape)
    xt_prime += jnp.exp(0.5 * log_variance) * noise

    # Apply gradient.
    xt = xt_prime - scale * gradient
    return xt

  def dps_sampler(rng, scale):
    timesteps = jnp.linspace(sde.T, eps, sde.N)

    # Initial sample.
    rng, step_rng = jax.random.split(rng)
    x = sde.prior_sampling(step_rng, shape)

    def loop_body(carry, i):
      rng, x = carry
      t = timesteps[i]
      idx = (t * (sde.N - 1) / sde.T).astype(jnp.int32)
      vec_t = jnp.ones(shape[0]) * t

      rng, step_rng = jax.random.split(rng)
      x = step_fn(step_rng, x, vec_t, idx, scale)

      return (rng, x), x

    _, all_samples = jax.lax.scan(
        loop_body, (rng, x), jnp.arange(0, sde.N), length=sde.N)
    output = all_samples[-1]

    return inverse_scaler(output)

  return dps_sampler