"""Abstract SDE classes, Reverse SDE, and VE/VP SDEs."""

import abc
import torch
import numpy as np

class SDE(abc.ABC):
  """SDE abstract class. Functions are designed for a mini-batch of inputs."""

  def __init__(self, N):
    """Construct an SDE.
    Args:
      N: number of discretization time steps.
    """
    super().__init__()
    self.N = N

  @property
  @abc.abstractmethod
  def T(self):
    """End time of the SDE."""
    pass

  @abc.abstractmethod
  def get_sde_coefficients(self, x, t):
    pass

  @abc.abstractmethod
  def marginal_prob_terms(self, x, t):
    """Parameters to determine the marginal distribution of the SDE, $p_t(x)$."""
    pass

  def perturb(self, x_0, t):
    """Retruns x from pt(x|x_0)"""
    z = torch.randn_like(x_0)
    mean, std = self.marginal_prob(x_0, t)
    perturbed_data = mean + std[(...,) + (None,) * len(x_0.shape[1:])] * z
    return perturbed_data

  @abc.abstractmethod
  def prior_sampling(self, shape):
    """Generate one sample from the prior distribution, $p_T(x)$."""
    pass

  @abc.abstractmethod
  def prior_logp(self, z):
    """Compute log-density of the prior distribution.
    Useful for computing the log-likelihood via probability flow ODE.
    Args:
      z: latent code
    Returns:
      log probability density
    """
    pass

class PinnedBrownSDE(SDE):
  def __init__(self, beta_const=0.01, miu=0, N=3000):
    """ Pinned Brownian motion-based gen modelling.
    Args:
    beta_min: value of beta(0)
    beta_max: value of beta(1)
    N: number of discretization steps
    """
    super().__init__(N)
    self.beta = beta_const
    self.N = N
    self.miu = miu

  @property
  def T(self):
    return 1

  def get_sde_coefficients(self, x, t):
    """Coficients of forward process pinned brownian"""
    beta_t = self.beta
    t_fraction = 1.0 / (self.T - t) 

    drift = (self.miu - x) * t_fraction[(...,)+(None,)*len(x.shape[1:])] #beta_t[(...,)+(None,)*len(x.shape[1:])]
    diffusion = beta_t * torch.ones(t.size())

    return drift, diffusion

  def init_score_fn(self, score_fn):
    self.score_fn = score_fn

  def get_reverse_sde_coefficients(self, x, t, probability_flow=False):
    """Returns the drift and diffusion terms of the reverse SDE (ODE), that is [-0.5*beta*x_t - beta*score*const] and sqrt(beta) / 0 (sde/ode)."""
    drift, diffusion = self.get_sde_coefficients(x, t) # forward time coeffs
    score = self.score_fn(x, t)

    drift = drift - diffusion[(...,)+(None,)*len(x.shape[1:])] ** 2 * score * (0.5 if probability_flow else 1.)
    # Set the diffusion function to zero for ODEs.
    diffusion = 0. if probability_flow else diffusion
    return drift, diffusion
    
  def marginal_prob_terms(self, x_0, t): # what is the shape of t?
    """ Calculates the mean and std of the probability p(x_T|x_0) as in Ho et al. (2020) discretisation.
    Args:
      x_0: batch of original samples [n, d]
      t: time steps [n]
      Returns:
      mean: mean of the probability p(x_T|x_0) [n, d]
      std: std of the probability p(x_T|x_0) [n]"""
    assert t.shape[0] == x_0.shape[0], "t and x_0 must have the same batch size for mean and std calculation"
    beta_t = self.beta  # constant schedule, for now.
    t_fraction = (self.T - t) / self.T
    mean = t_fraction[(...,)+(None,)*len(x_0.shape[1:])] * x_0
    std = torch.sqrt(t_fraction * t) * beta_t # [n], DOUBLE CHECK THIS IS ELEMENTWISE!
    return mean, std

  def prior_sampling(self, shape):
    return torch.zeros(*shape)
  
  def prior_logp(self, z): # currently will work with 2d toy only
    print("no")
    return 0


class GenPinnedBrownSDE(SDE):
  def __init__(self, beta_min=0.1, beta_max=20, miu=0, N=3000):
    """ Pinned Brownian motion-based gen modelling.
    Args:
    beta_min: value of beta(0)
    beta_max: value of beta(1)
    N: number of discretization steps
    """
    super().__init__(N)
    self.beta_0 = beta_min
    self.beta_1 = beta_max
    self.N = N
    self.miu = miu

  @property
  def T(self):
    return 1
  
  def get_alpha(self, t):
    return self.beta_0 * t + 0.5 * t**2  * (self.beta_1 - self.beta_0) / self.T

  def get_sde_coefficients(self, x, t):
    """Coficients of forward process pinned brownian"""
    beta_t_sq = self.beta_0 + t * (self.beta_1 - self.beta_0) / self.T
    beta_t = torch.sqrt(beta_t_sq)

    alpha_t =  self.get_alpha(t)
    alpha_T = self.get_alpha(self.T)
    
    t_fraction = beta_t_sq  / (alpha_T - alpha_t) 

    drift = (self.miu - x) * t_fraction[(...,)+(None,)*len(x.shape[1:])] #beta_t[(...,)+(None,)*len(x.shape[1:])]
    diffusion = beta_t * torch.ones(t.size())

    return drift, diffusion

  def init_score_fn(self, score_fn):
    self.score_fn = score_fn

  def get_reverse_sde_coefficients(self, x, t, probability_flow=False):
    """Returns the drift and diffusion terms of the reverse SDE (ODE), that is [-0.5*beta*x_t - beta*score*const] and sqrt(beta) / 0 (sde/ode)."""
    drift, diffusion = self.get_sde_coefficients(x, t) # forward time coeffs
    score = self.score_fn(x, t)

    drift = drift - diffusion[(...,)+(None,)*len(x.shape[1:])] ** 2 * score * (0.5 if probability_flow else 1.)
    # Set the diffusion function to zero for ODEs.
    diffusion = 0. if probability_flow else diffusion
    return drift, diffusion
    
  def marginal_prob_terms(self, x_0, t): # what is the shape of t?
    """ Calculates the mean and std of the probability p(x_T|x_0) as in Ho et al. (2020) discretisation.
    Args:
      x_0: batch of original samples [n, d]
      t: time steps [n]
      Returns:
      mean: mean of the probability p(x_T|x_0) [n, d]
      std: std of the probability p(x_T|x_0) [n]"""
    assert t.shape[0] == x_0.shape[0], "t and x_0 must have the same batch size for mean and std calculation"
    beta_t_sq = self.beta_0 + t * (self.beta_1 - self.beta_0) / self.T
    alpha_t =  self.get_alpha(t)
    alpha_0 =  self.get_alpha(torch.zeros_like(t))
    alpha_T = self.get_alpha(self.T)
    
    t_fraction = (alpha_T - alpha_t) / (alpha_T - alpha_0)
    mean = t_fraction[(...,)+(None,)*len(x_0.shape[1:])] * x_0

    std =  torch.sqrt(t_fraction * (alpha_t - alpha_0))
    return mean, std

  def prior_sampling(self, shape):
    return torch.zeros(*shape)
  
  def prior_logp(self, z): # currently will work with 2d toy only
    print("no")
    return 0

  
class VPSDE(SDE):
  def __init__(self, beta_min=0.1, beta_max=20, N=1000):
    """Construct a Variance Preserving SDE.
    Args:
    beta_min: value of beta(0)
    beta_max: value of beta(1)
    N: number of discretization steps
    """
    super().__init__(N)
    self.beta_0 = beta_min
    self.beta_1 = beta_max
    self.N = N
    self.discrete_betas = torch.linspace(beta_min / N, beta_max / N, N)
    self.alphas = 1. - self.discrete_betas
    self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
    self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
    self.sqrt_1m_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)

  @property
  def T(self):
    return 1

  def perturbation_coefficients(self, t):
    log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
    a_t = torch.exp(log_mean_coeff)
    sigma_t = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
    return a_t, sigma_t 

  def snr(self, t):
    log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
    alpha_t = torch.exp(log_mean_coeff)
    std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
    return alpha_t**2/std**2

  def get_sde_coefficients(self, x, t):
    """Returns the drift and diffusion terms of the forward SDE, that is -0.5*beta*x_t and sqrt(beta)."""
    beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0)
    drift = -0.5 * beta_t[(...,)+(None,)*len(x.shape[1:])] * x
    diffusion = torch.sqrt(beta_t)
    return drift, diffusion
  
  # def init_reverse_sde(self, score_fn, probability_flow):
  #   self.score_fn = score_fn
  #   self.probability_flow = probability_flow

  def init_score_fn(self, score_fn):
    self.score_fn = score_fn

  def get_reverse_sde_coefficients(self, x, t, probability_flow):
    """Returns the drift and diffusion terms of the reverse SDE (ODE), that is [-0.5*beta*x_t - beta*score*const] and sqrt(beta) / 0 (sde/ode)."""
    drift, diffusion = self.get_sde_coefficients(x, t) # forward time coeffs
    score = self.score_fn(x, t)

    drift = drift - diffusion[(..., ) + (None, ) * len(x.shape[1:])] ** 2 * score * (0.5 if probability_flow else 1.)
    # Set the diffusion function to zero for ODEs.
    diffusion = 0. if probability_flow else diffusion
    return drift, diffusion
    
  def marginal_prob_terms(self, x_0, t): # what is the shape of t?
    """Calculates the mean and std of the probability p(x_T|x_0) as in Ho et al. (2020) discretisation.
    Args:
      x_0: batch of original samples [n, d]
      t: time steps [n]
      Returns:
      mean: mean of the probability p(x_T|x_0) [n, d]
      std: std of the probability p(x_T|x_0) [n]"""
    assert t.shape[0] == x_0.shape[0], "t and x_0 must have the same batch size for mean and std calculation"
    log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
    mean = torch.exp(log_mean_coeff[(...,)+(None,)*len(x_0.shape[1:])]) * x_0 # probs will break
    std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff)) # [n]
    return mean, std

  def prior_sampling(self, shape):
    return torch.randn(*shape)

  def prior_logp(self, z): # currently will work with 2d toy only
    shape = z.shape
    # N = np.prod(shape[1:])
    N = shape[-1]
    logps = - N / 2. * np.log(2 * np.pi) - torch.sum(z ** 2, dim=-1) / 2. # probs shapes are broken
    return logps
  

class VESDE(SDE):
  def __init__(self, sigma_min=0.01, sigma_max=4, N=1000):
    """Construct a Variance Exploding SDE.
    Args:

    N: number of discretization steps
    """
    super().__init__(N)
    self.sigma_min = sigma_min
    self.sigma_max = sigma_max
    self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N))
    self.N = N

  @property
  def T(self):
    return 1

  # def perturbation_coefficients(self, t):
  #   log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
  #   a_t = torch.exp(log_mean_coeff)
  #   sigma_t = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
  #   return a_t, sigma_t 
  
  def get_sde_coefficients(self, x, t):
    sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
    drift = torch.zeros_like(x)
    diffusion = sigma * torch.sqrt(torch.tensor(2 * (np.log(self.sigma_max) - np.log(self.sigma_min))).type_as(t))
    return drift, diffusion

  def marginal_prob_terms(self, x, t): #perturbation kernel P(X(t)|X(0)) parameters
    sigma_min = torch.tensor(self.sigma_min).type_as(t)
    sigma_max = torch.tensor(self.sigma_max).type_as(t)
    std = sigma_min * (sigma_max / sigma_min) ** t
    mean = x
    return mean, std
  # def init_reverse_sde(self, score_fn, probability_flow):
  #   self.score_fn = score_fn
  #   self.probability_flow = probability_flow

  def init_score_fn(self, score_fn):
    self.score_fn = score_fn

  def get_reverse_sde_coefficients(self, x, t, probability_flow):
    """Returns the drift and diffusion terms of the reverse SDE (ODE), that is [-0.5*beta*x_t - beta*score*const] and sqrt(beta) / 0 (sde/ode)."""
    drift, diffusion = self.get_sde_coefficients(x, t) # forward time coeffs
    score = self.score_fn(x, t)

    # forward drift is zero
    drift = - diffusion[(..., ) + (None, ) * len(x.shape[1:])] ** 2 * score * (0.5 if probability_flow else 1.)
    # Set the diffusion function to zero for ODEs.
    diffusion = 0. if probability_flow else diffusion
    return drift, diffusion
    
  # def marginal_prob_terms(self, x_0, t): # what is the shape of t?
  #   """Calculates the mean and std of the probability p(x_T|x_0) as in Ho et al. (2020) discretisation.
  #   Args:
  #     x_0: batch of original samples [n, d]
  #     t: time steps [n]
  #     Returns:
  #     mean: mean of the probability p(x_T|x_0) [n, d]
  #     std: std of the probability p(x_T|x_0) [n]"""
  #   assert t.shape[0] == x_0.shape[0], "t and x_0 must have the same batch size for mean and std calculation"
  #   log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
  #   mean = torch.exp(log_mean_coeff[(...,)+(None,)*len(x_0.shape[1:])]) * x_0 # probs will break
  #   std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff)) # [n]
  #   return mean, std

  def prior_sampling(self, shape):
    return torch.randn(*shape) * self.sigma_max

  def prior_logp(self, z):
    shape = z.shape
    N = np.prod(shape[1:])
    dims_to_reduce=tuple(range(len(z.shape))[1:])
    return -N / 2. * np.log(2 * np.pi * self.sigma_max ** 2) - torch.sum(z ** 2, dim=dims_to_reduce) / (2 * self.sigma_max ** 2)
 
