# TODO: Fix the loss functions of all diffusion models to use different subkeys. Check Song's JAX implementation.
# TODO: Make the B matrices for GRF non-trainable using lax.stop_gradient(). See GRF function in Song's code.

import jax
import jax.numpy as jnp
import flax.linen as lnn
from jax import jit, random
from typing import Sequence
import optax
from tqdm.notebook import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
from functools import partial

# rand_key = random.PRNGKey(0)

class Diffusion(lnn.Module):
  features: Sequence[int]
  mapping_size: int
  num_dimensions: int # Dimensionality of the data vectors. Time not included.
  sigma : float # This is actually barsigma.
  grf_scale_x : float = 10.0
  grf_scale_t : float = 10.0

  @lnn.compact
  def __call__(self, x, t):
    # The random key is automatically passed as the first argument to normal().
    # This code produces the same value of B, even if it is executed multiple times.
    B_x = self.grf_scale_x * self.param('B_x', lnn.initializers.normal(), (self.mapping_size, self.num_dimensions))
    B_t = self.grf_scale_t * self.param('B_t', lnn.initializers.normal(), (self.mapping_size, 1))

    embed = self.input_mapping(t[..., None], B_t) # Convert from [batch_size,] to [batch_size, 1]
    embed = lnn.Dense(embed.shape[-1])(embed) # embed.shape[-1] = 2 * (mapping_size)
    embed = lnn.sigmoid(embed)
    pos = self.input_mapping(x, B_x)
    pos = lnn.Dense(pos.shape[-1])(pos) # This definitely helps improve learned scores.
    pos = lnn.sigmoid(pos)
    h = pos

    for feat in self.features[:-1]:
        tau = lnn.Dense(feat)(embed)
        h = lnn.Dense(feat)(h)
        h += tau
        h = lnn.LayerNorm()(h)
        h = lnn.sigmoid(h)

    # No time embedded in the last step, following Song's code.
    h = lnn.Dense(self.features[-1])(h)
    h = h / jnp.expand_dims(self.marginal_prob_std(t), -1)
    return h

  # Fourier feature mapping
  def input_mapping(self, x, B):
    if B is None:
      return x
    else:
      x_proj = (2.*jnp.pi*x) @ B.T
      return jnp.concatenate([jnp.sin(x_proj), jnp.cos(x_proj)], axis=-1)

  def marginal_prob_std(self, t):
    return jnp.sqrt((self.sigma**(2 * t) - 1.) / 2. / jnp.log(self.sigma))

  # x at input will have the shape [batch_size, 2] for 2D.
  # IMPORTANT: Use different key for each iteration.
  def loss(self, params, x, key, eps=1e-5, num_steps=1):
    x = jnp.tile(x[:, jnp.newaxis, :], (1, num_steps, 1)) # shape is [batch_size, num_steps, 2]
    random_t = random.uniform(key, x.shape[:-1]) * (1. - eps) + eps
    # random_t = 1 - jnp.sqrt(1 - random_t) # Skew the uniform distribution to a triangle.
    z = random.normal(key, x.shape)
    std = self.marginal_prob_std(random_t) # shape is [batch_size, num_steps]
    perturbed_x = x + z * std[..., None]
    score = self.apply(params, perturbed_x, random_t)
    return jnp.mean(jnp.sum((score * std[..., None] + z) ** 2, axis=-1))
    # The axis arguments are for sum; the **2 does element-wise squaring and sum adds up
    # these squares of the components for each element in the batch. The result is an array
    # [batch_size, num_steps], which we average over along both directions.


class DiffusionOU(lnn.Module):
  features: Sequence[int]
  mapping_size: int
  num_dimensions: int # Dimensionality of the data vectors. Time not included.
  # sigma: float
  beta_min: float
  beta_max: float
  grf_scale_x : float = 10.0
  grf_scale_t : float = 10.0

  @lnn.compact
  def __call__(self, x, t):
    # The random key is automatically passed as the first argument to normal().
    # This code produces the same value of B, even if it is executed multiple times.
    B_x = self.grf_scale_x * self.param('B_x', lnn.initializers.normal(), (self.mapping_size, self.num_dimensions))
    B_t = self.grf_scale_t * self.param('B_t', lnn.initializers.normal(), (self.mapping_size, 1))

    embed = self.input_mapping(t[..., None], B_t) # Convert from [batch_size,] to [batch_size, 1]
    embed = lnn.Dense(embed.shape[-1])(embed) # embed.shape[-1] = 2 * (mapping_size)
    embed = lnn.sigmoid(embed)
    pos = self.input_mapping(x, B_x)
    pos = lnn.Dense(pos.shape[-1])(pos) # This definitely helps improve learned scores.
    pos = lnn.sigmoid(pos)
    h = pos

    for feat in self.features[:-1]:
        tau = lnn.Dense(feat)(embed)
        h = lnn.Dense(feat)(h)
        h += tau
        h = lnn.LayerNorm()(h)
        h = lnn.sigmoid(h)

    # No time embedded in the last step, following Song's code.
    h = lnn.Dense(self.features[-1])(h)
    h = h / jnp.expand_dims(self.marginal_prob_std(t), -1)
    return h

  # Fourier feature mapping
  def input_mapping(self, x, B):
    if B is None:
      return x
    else:
      x_proj = (2.*jnp.pi*x) @ B.T
      return jnp.concatenate([jnp.sin(x_proj), jnp.cos(x_proj)], axis=-1)

  # # With constant coefficients
  # def sigma_at(self, t):
  #   return self.sigma # * (t ** 0) # To return the same shape

  # def beta_at(self, t):
  #   return self.beta # * (t ** 0)

  # def expintbeta(self, t):
  #   return jnp.exp(-self.beta * t) # NOTE: For time-dependent beta we need to change this.

  # VP SDE
  def beta_at(self, t):
    return self.beta_min + (self.beta_max - self.beta_min) * t

  def sigma_at(self, t):
    return jnp.sqrt(self.beta_at(t))

  def expintbeta(self, t):
    return jnp.exp(-0.5 * (self.beta_max - self.beta_min) * (t ** 2) - self.beta_min * t)

  def marginal_prob_std(self, t):
    return self.sigma_at(t) * jnp.sqrt((1-self.expintbeta(t)) / self.beta_at(t))

  # x at input will have the shape [batch_size, 2] for 2D.
  # IMPORTANT: Use different key for each iteration.
  def loss(self, params, x, key, eps=1e-5, num_steps=1):
    x = jnp.tile(x[:, jnp.newaxis, :], (1, num_steps, 1)) # shape is [batch_size, num_steps, 2]
    random_t = random.uniform(key, x.shape[:-1]) * (1. - eps) + eps
    # random_t = 1 - jnp.sqrt(1 - random_t) # Skew the uniform distribution to a triangle.
    z = random.normal(key, x.shape)
    std = self.marginal_prob_std(random_t) # shape is [batch_size, num_steps]
    perturbed_x = x * jnp.sqrt(self.expintbeta(random_t))[..., None] + z * std[..., None] # Different for OU
    score = self.apply(params, perturbed_x, random_t)
    return jnp.mean(jnp.sum((score * std[..., None] + z) ** 2, axis=-1))
    # The axis arguments are for sum; the **2 does element-wise squaring and sum adds up
    # these squares of the components for each element in the batch. The result is an array
    # [batch_size, num_steps], which we average over along both directions.


# A class that implements Song's SDE with no normalization of the final NN output.
# Introduced in FastLikelihood.ipynb.
class DiffusionUnitNorm(Diffusion):
  @lnn.compact
  def __call__(self, x, t):
    # The random key is automatically passed as the first argument to normal().
    # This code produces the same value of B, even if it is executed multiple times.
    B_x = self.grf_scale_x * self.param('B_x', lnn.initializers.normal(), (self.mapping_size, self.num_dimensions))
    B_t = self.grf_scale_t * self.param('B_t', lnn.initializers.normal(), (self.mapping_size, 1))

    embed = self.input_mapping(t[..., None], B_t) # Convert from [batch_size,] to [batch_size, 1]
    embed = lnn.Dense(embed.shape[-1])(embed) # embed.shape[-1] = 2 * (mapping_size)
    embed = lnn.sigmoid(embed)
    pos = self.input_mapping(x, B_x)
    pos = lnn.Dense(pos.shape[-1])(pos) # This definitely helps improve learned scores.
    pos = lnn.sigmoid(pos)
    h = pos

    for feat in self.features[:-1]:
        tau = lnn.Dense(feat)(embed)
        h = lnn.Dense(feat)(h)
        h += tau
        h = lnn.LayerNorm()(h)
        h = lnn.sigmoid(h)

    # No time embedded in the last step, following Song's code.
    h = lnn.Dense(self.features[-1])(h)
    return h
  
  # These functions Song's model look like an OU class, so that we can use DensityHelper on it.
  def beta_at(self, t):
    return 0 ** t

  def sigma_at(self, t):
    return self.sigma ** t

  def expintbeta(self, t):
    return 1 ** t


# An diffusion model with the VP scheme that does not normalize the network output.
# The output of the network is no longer normalized. This improves the likelihood performance.
# We could have overridden the __call__ function more succinctly, like in Backgrounds.ipynb,
# but I did not want to normalize and then remove the normalization.
class DiffusionOUUnitNorm(DiffusionOU):
  @lnn.compact
  def __call__(self, x, t):
    # The random key is automatically passed as the first argument to normal().
    # This code produces the same value of B, even if it is executed multiple times.
    B_x = self.grf_scale_x * self.param('B_x', lnn.initializers.normal(), (self.mapping_size, self.num_dimensions))
    B_t = self.grf_scale_t * self.param('B_t', lnn.initializers.normal(), (self.mapping_size, 1))

    embed = self.input_mapping(t[..., None], B_t) # Convert from [batch_size,] to [batch_size, 1]
    embed = lnn.Dense(embed.shape[-1])(embed) # embed.shape[-1] = 2 * (mapping_size)
    embed = lnn.sigmoid(embed)
    pos = self.input_mapping(x, B_x)
    pos = lnn.Dense(pos.shape[-1])(pos) # This definitely helps improve learned scores.
    pos = lnn.sigmoid(pos)
    h = pos

    for feat in self.features[:-1]:
        tau = lnn.Dense(feat)(embed)
        h = lnn.Dense(feat)(h)
        h += tau
        h = lnn.LayerNorm()(h)
        h = lnn.sigmoid(h)

    # No time embedded in the last step, following Song's code.
    h = lnn.Dense(self.features[-1])(h)

    return h


# A class that implements the entropy matching model.
# (1) The loss function has been updated for the entropy matching scheme. Ho prefactor used.
# (2) The output of the network is no longer normalized. See Backgrounds.ipynb for a discussion.
class DiffusionEM(DiffusionOUUnitNorm):
  # We override the loss function to incorporate the changes.
  def loss(self, params, x, key, eps=1e-5, num_steps=1):
    x = jnp.tile(x[:, jnp.newaxis, :], (1, num_steps, 1)) # shape is [batch_size, num_steps, 2]
    random_t = random.uniform(key, x.shape[:-1]) * (1. - eps) + eps
    z = random.normal(key, x.shape)
    std = self.marginal_prob_std(random_t) # shape is [batch_size, num_steps]
    perturbed_x = x * jnp.sqrt(self.expintbeta(random_t))[..., None] + z * std[..., None] # Different for OU
    ntheta = self.apply(params, perturbed_x, random_t)
    return jnp.mean(jnp.sum(((-perturbed_x + ntheta) * std[..., None] + z) ** 2, axis=-1)) # \lambda = 2 \Sigma^2 / \sigma^2


# A class to make a NumPy array into a TF dataset.
# Works for any dimensions. The name is a legacy artifact.
class Data2D(Dataset):
  def __init__(self, data):
    self.data = data

  def __len__(self):
    return len(self.data) # same as self.data.shape[0]

  def __getitem__(self, idx):
    return self.data[idx] # same as self.data[idx,]


# Changed this function to accept the key as its first parameter, when we put it in a module.
# Note that num_steps has to be partialed into the loss function otherwise JAX complains.
def train_diffusion(key, model, params, learning_rate, epochs, train_data, batch_size, num_steps=10):
  tx = optax.adam(learning_rate=learning_rate)
  opt_state = tx.init(params)
  updater = jit(tx.update) # Very important: jit here boosts speed by 4x!
  loss_grad_fn = jit(jax.value_and_grad(partial(model.loss, num_steps=num_steps))) # MUCH faster when jit-ed!
  applier = jit(optax.apply_updates) # Also speeds up by 80%!
  dataset = Data2D(train_data)
  data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
  losses = []

  ekey = key # ekey will change between epochs
  for i in (pbar := tqdm(range(epochs), desc='train iter', leave=True)):
    avg_loss = 0
    num_items = 0
    # ekey = key # ekey will be reset between epochs

    for batch in data_loader:
      ekey, subkey = random.split(ekey)
      x = jnp.asarray(batch)
      loss_val, grads = loss_grad_fn(params, x, ekey)
      updates, opt_state = updater(grads, opt_state)
      params = applier(params, updates)
      avg_loss += loss_val * x.shape[0]
      num_items += x.shape[0]

    losses.append(avg_loss / num_items)
    pbar.set_description("Loss: {:5f}".format(avg_loss / num_items))

  return params, losses

# ------------------------------------------------------------------------

# This is the VPx process. It becomes regular VP if kappa=1.
class DiffusionEMRedux(lnn.Module):
  features: Sequence[int]
  mapping_size: int
  num_dimensions: int # Dimensionality of the data vectors. Time not included.
  # sigma: float
  beta_min : float
  beta_max : float
  kappa : float = 1.0
  x_embed : bool = True
  maxL_prefactor : bool = False
  grf_scale_x : float = 10.0
  grf_scale_s : float = 10.0

  @lnn.compact
  def __call__(self, x, s):
    B_x = self.grf_scale_x * self.param('B_x', lnn.initializers.normal(), (self.mapping_size, self.num_dimensions))
    B_s = self.grf_scale_s * self.param('B_s', lnn.initializers.normal(), (self.mapping_size, 1))
    # Stop gradients from flowing through B_x and B_t. [NEW in Redux] 
    B_x = jax.lax.stop_gradient(B_x)
    B_s = jax.lax.stop_gradient(B_s)
    B_x = B_x if self.x_embed else None

    embed = self.input_mapping(s[..., None], B_s) # Convert from [batch_size,] to [batch_size, 1]
    embed = lnn.Dense(embed.shape[-1])(embed) # embed.shape[-1] = 2 * (mapping_size)
    embed = lnn.sigmoid(embed)
    pos = self.input_mapping(x, B_x)
    pos = lnn.Dense(pos.shape[-1])(pos) # This definitely helps improve learned scores.
    pos = lnn.sigmoid(pos)
    h = pos

    for feat in self.features[:-1]:
        tau = lnn.Dense(feat)(embed)
        h = lnn.Dense(feat)(h)
        h += tau
        h = lnn.LayerNorm()(h)
        h = lnn.sigmoid(h)

    # No time embedded in the last step, following Song's code.
    h = lnn.Dense(self.features[-1])(h)

    # Normalize the output.
    # h = h / jnp.expand_dims(self.marginal_prob_std(t), -1) #
    return h

  # Fourier feature mapping
  def input_mapping(self, x, B):
    if B is None:
      return x
    else:
      x_proj = (2.*jnp.pi*x) @ B.T
      return jnp.concatenate([jnp.sin(x_proj), jnp.cos(x_proj)], axis=-1)

  # VP SDE with kappa
  def beta_at(self, s):
    return self.beta_min + (self.beta_max - self.beta_min) * s
  
  def bplus(self, x, s):
    return - self.beta_at(s)[..., None] * x / 2

  def sigma_at(self, s):
    return self.kappa * jnp.sqrt(self.beta_at(s))

  # This is the square root of expintbeta.
  def mu(self, s):
    return jnp.exp(- 0.5 * self.beta_min * s - 0.25 * (self.beta_max - self.beta_min) * (s ** 2))

  def marginal_prob_std(self, s):
    return self.kappa * jnp.sqrt((1 - self.mu(s)**2))
  
  def grad_logp_eq(self, x, s):
    return - x / self.kappa**2

  # The entropy matching loss.
  def loss(self, params, x, key, eps=1e-5, num_steps=1):
    x = jnp.tile(x[:, jnp.newaxis, :], (1, num_steps, 1)) # shape is [batch_size, num_steps, 2]
    key, subkey = random.split(key) # [NEW in Redux]
    random_s = random.uniform(subkey, x.shape[:-1]) * (1. - eps) + eps
    key, subkey = random.split(key)
    z = random.normal(subkey, x.shape)
    std = self.marginal_prob_std(random_s) # shape is [batch_size, num_steps]
    perturbed_x = x * self.mu(random_s)[..., None] + z * std[..., None] # Different for OU
    etheta = self.apply(params, perturbed_x, random_s)
    prefactor = (0.5 * (self.sigma_at(random_s)/std) ** 2) ** int(self.maxL_prefactor) # Maximum likelihood.
    return jnp.mean(prefactor * jnp.sum(((-perturbed_x/self.kappa**2 + etheta) * std[..., None] + z) ** 2, axis=-1))


class DiffusionSL(lnn.Module):
  features: Sequence[int]
  mapping_size: int
  num_dimensions: int # Dimensionality of the data vectors. Time not included.
  Sigma_0 : float
  x_embed : bool = True
  maxL_prefactor : bool = False
  grf_scale_x : float = 10.0
  grf_scale_s : float = 10.0

  @lnn.compact
  def __call__(self, x, s):
    B_x = self.grf_scale_x * self.param('B_x', lnn.initializers.normal(), (self.mapping_size, self.num_dimensions))
    B_s = self.grf_scale_s * self.param('B_s', lnn.initializers.normal(), (self.mapping_size, 1))
    # Stop gradients from flowing through B_x and B_t. [NEW in Redux] 
    B_x = jax.lax.stop_gradient(B_x)
    B_s = jax.lax.stop_gradient(B_s)
    B_x = B_x if self.x_embed else None

    embed = self.input_mapping(s[..., None], B_s) # Convert from [batch_size,] to [batch_size, 1]
    embed = lnn.Dense(embed.shape[-1])(embed) # embed.shape[-1] = 2 * (mapping_size)
    embed = lnn.sigmoid(embed)
    pos = self.input_mapping(x, B_x)
    pos = lnn.Dense(pos.shape[-1])(pos) # This definitely helps improve learned scores.
    pos = lnn.sigmoid(pos)
    h = pos

    for feat in self.features[:-1]:
        tau = lnn.Dense(feat)(embed)
        h = lnn.Dense(feat)(h)
        h += tau
        h = lnn.LayerNorm()(h)
        h = lnn.sigmoid(h)

    # No time embedded in the last step, following Song's code.
    h = lnn.Dense(self.features[-1])(h)

    # Normalize the output.
    # h = h / jnp.expand_dims(self.marginal_prob_std(t), -1) #
    return h

  # Fourier feature mapping
  def input_mapping(self, x, B):
    if B is None:
      return x
    else:
      x_proj = (2.*jnp.pi*x) @ B.T
      return jnp.concatenate([jnp.sin(x_proj), jnp.cos(x_proj)], axis=-1)
  
  def bplus(self, x, s):
    return - x / (1 - s[..., None])

  def sigma_at(self, s):
    return self.Sigma_0 * jnp.sqrt(2/(1 - s))
  
  def mu(self, s):
    return 1-s

  def marginal_prob_std(self, s):
    return self.Sigma_0 * jnp.sqrt(1 - (1-s)**2)
  
  def grad_logp_eq(self, x, s):
    return - x / self.Sigma_0**2

  # The entropy matching loss.
  def loss(self, params, x, key, eps=1e-5, num_steps=1):
    x = jnp.tile(x[:, jnp.newaxis, :], (1, num_steps, 1)) # shape is [batch_size, num_steps, 2]
    key, subkey = random.split(key) # [NEW in Redux]
    random_s = random.uniform(subkey, x.shape[:-1]) * (1. - 2.*eps) + eps # Stop shy of T=1.
    key, subkey = random.split(key)
    z = random.normal(subkey, x.shape)
    std = self.marginal_prob_std(random_s) # shape is [batch_size, num_steps]
    perturbed_x = x * self.mu(random_s)[..., None] + z * std[..., None] # Different for OU
    etheta = self.apply(params, perturbed_x, random_s)
    prefactor = (0.5 * (self.sigma_at(random_s)/std) ** 2) ** int(self.maxL_prefactor) # Maximum likelihood.
    return jnp.mean(prefactor * jnp.sum(((-perturbed_x/self.Sigma_0**2 + etheta) * std[..., None] + z) ** 2, axis=-1))
  

# A diffusion model class that just stores the diffusion parameters and time scaling.
class DiffusionBare():
  def __init__(self, beta_min, beta_max, kappa=1.):

    self.beta_min = beta_min
    self.beta_max = beta_max
    self.kappa = kappa

  # VP SDE
  def beta_at(self, t):
    return self.beta_min + (self.beta_max - self.beta_min) * t
  
  def bplus(self, x, t):
    return - self.beta_at(t)[..., None] * x / 2

  def sigma_at(self, t):
    return self.kappa * jnp.sqrt(self.beta_at(t))

  # This is the square root of expintbeta.
  def mu(self, t):
    return jnp.exp(- 0.5 * self.beta_min * t - 0.25 * (self.beta_max - self.beta_min) * (t ** 2))

  def marginal_prob_std(self, t):
    return self.kappa * jnp.sqrt((1 - self.mu(t)**2))

  def grad_logp_eq(self, x, t):
    return - x / self.kappa**2