"""Abstract SDE classes, Reverse SDE, and VE/VP SDEs."""
import abc
import torch
import numpy as np


class SDE(abc.ABC):
  """SDE abstract class. Functions are designed for a mini-batch of inputs."""

  def __init__(self, N):
    """Construct an SDE.

    Args:
      N: number of discretization time steps.
    """
    super().__init__()
    self.N = N

  @property
  @abc.abstractmethod
  def T(self):
    """End time of the SDE."""
    pass

  @abc.abstractmethod
  def sde(self, x, t):
    pass

  @abc.abstractmethod
  def marginal_prob(self, x, t):
    """Parameters to determine the marginal distribution of the SDE, $p_t(x)$."""
    pass

  @abc.abstractmethod
  def prior_sampling(self, shape):
    """Generate one sample from the prior distribution, $p_T(x)$."""
    pass

  @abc.abstractmethod
  def prior_logp(self, z):
    """Compute log-density of the prior distribution.

    Useful for computing the log-likelihood via probability flow ODE.

    Args:
      z: latent code
    Returns:
      log probability density
    """
    pass

  def discretize(self, x, t):
    """Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.

    Useful for reverse diffusion sampling and probabiliy flow sampling.
    Defaults to Euler-Maruyama discretization.

    Args:
      x: a torch tensor
      t: a torch float representing the time step (from 0 to `self.T`)

    Returns:
      f, G
    """
    dt = 1 / self.N
    drift, diffusion = self.sde(x, t)
    f = drift * dt
    G = diffusion * torch.sqrt(torch.tensor(dt, device=t.device))
    return f, G

  def reverse(self, net_fn, probability_flow=False):
    """Create the reverse-time SDE/ODE.

    Args:
      net_fn: a z-dependent PFGM that takes x and z and returns the normalized Poisson field.
        Or a time-dependent score-based model that takes x and t and returns the score.
      probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling.
    """
    N = self.N
    T = self.T
    sde_fn = self.sde
    discretize_fn = self.discretize

    # Build the class for reverse-time SDE.
    class RSDE(self.__class__):
      def __init__(self):
        self.N = N
        self.probability_flow = probability_flow

      @property
      def T(self):
        return T

      def sde(self, x, t):
        """Create the drift and diffusion functions for the reverse SDE/ODE."""

        drift, diffusion = sde_fn(x, t)
        score = net_fn(x.float(), t.float())
        drift = drift - diffusion[:, None, None, None] ** 2 * score * (0.5 if self.probability_flow else 1.)
        # Set the diffusion function to zero for ODEs.
        diffusion = torch.zeros_like(diffusion) if self.probability_flow else diffusion
        return drift, diffusion

      def discretize(self, x, t):
        """Create discretized iteration rules for the reverse diffusion sampler."""
        f, G = discretize_fn(x, t)
        rev_f = f - G[:, None, None, None] ** 2 * net_fn(x, t) * (0.5 if self.probability_flow else 1.)
        rev_G = torch.zeros_like(G) if self.probability_flow else G
        return rev_f, rev_G

    return RSDE()


class VPSDE(SDE):
  def __init__(self, config, beta_min=0.1, beta_max=20, N=1000):
    """Construct a Variance Preserving SDE.

    Args:
      beta_min: value of beta(0)
      beta_max: value of beta(1)
      N: number of discretization steps
    """
    super().__init__(N)
    self.beta_0 = beta_min
    self.beta_1 = beta_max
    self.N = N
    self.discrete_betas = torch.linspace(beta_min / N, beta_max / N, N)
    self.alphas = 1. - self.discrete_betas
    self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
    self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
    self.sqrt_1m_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
    self.config = config

  @property
  def T(self):
    return 1

  def sde(self, x, t):
    beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0)
    drift = -0.5 * beta_t[:, None, None, None] * x
    diffusion = torch.sqrt(beta_t)
    return drift, diffusion

  def marginal_prob(self, x, t):
    log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
    mean = torch.exp(log_mean_coeff[:, None, None, None]) * x
    std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
    return mean, std

  def prior_sampling(self, shape):
    return torch.randn(*shape)

  def prior_logp(self, z):
    shape = z.shape
    N = np.prod(shape[1:])
    logps = -N / 2. * np.log(2 * np.pi) - torch.sum(z ** 2, dim=(1, 2, 3)) / 2.
    return logps

  def discretize(self, x, t):
    """DDPM discretization."""
    timestep = (t * (self.N - 1) / self.T).long()
    beta = self.discrete_betas.to(x.device)[timestep]
    alpha = self.alphas.to(x.device)[timestep]
    sqrt_beta = torch.sqrt(beta)
    f = torch.sqrt(alpha)[:, None, None, None] * x - x
    G = sqrt_beta
    return f, G


class subVPSDE(SDE):
  def __init__(self, config, beta_min=0.1, beta_max=20, N=1000):
    """Construct the sub-VP SDE that excels at likelihoods.

    Args:
      beta_min: value of beta(0)
      beta_max: value of beta(1)
      N: number of discretization steps
    """
    super().__init__(N)
    self.beta_0 = beta_min
    self.beta_1 = beta_max
    self.N = N
    self.config = config

  @property
  def T(self):
    return 1

  def sde(self, x, t):
    beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0)
    drift = -0.5 * beta_t[:, None, None, None] * x
    discount = 1. - torch.exp(-2 * self.beta_0 * t - (self.beta_1 - self.beta_0) * t ** 2)
    diffusion = torch.sqrt(beta_t * discount)
    return drift, diffusion

  def marginal_prob(self, x, t):
    log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
    mean = torch.exp(log_mean_coeff)[:, None, None, None] * x
    std = 1 - torch.exp(2. * log_mean_coeff)
    return mean, std

  def prior_sampling(self, shape):
    return torch.randn(*shape)

  def prior_logp(self, z):
    shape = z.shape
    N = np.prod(shape[1:])
    return -N / 2. * np.log(2 * np.pi) - torch.sum(z ** 2, dim=(1, 2, 3)) / 2.


class VESDE(SDE):
  def __init__(self, config, sigma_min=0.01, sigma_max=50, N=1000):
    """Construct a Variance Exploding SDE.

    Args:
      sigma_min: smallest sigma.
      sigma_max: largest sigma.
      N: number of discretization steps
    """
    super().__init__(N)
    self.sigma_min = sigma_min
    self.sigma_max = sigma_max
    self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N))
    self.N = N
    self.config = config

  @property
  def T(self):
    return 1

  def sde(self, x, t):
    sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
    drift = torch.zeros_like(x)
    diffusion = sigma * torch.sqrt(torch.tensor(2 * (np.log(self.sigma_max) - np.log(self.sigma_min)),
                                                device=t.device))
    return drift, diffusion

  def marginal_prob(self, x, t):
    std = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
    mean = x
    return mean, std

  def prior_sampling(self, shape):
    return torch.randn(*shape) * self.sigma_max

  def prior_logp(self, z):
    shape = z.shape
    N = np.prod(shape[1:])
    return -N / 2. * np.log(2 * np.pi * self.sigma_max ** 2) - torch.sum(z ** 2, dim=(1, 2, 3)) / (2 * self.sigma_max ** 2)

  def discretize(self, x, t):
    """SMLD(NCSN) discretization."""
    timestep = (t * (self.N - 1) / self.T).long()
    sigma = self.discrete_sigmas.to(t.device)[timestep]
    adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t),
                                 self.discrete_sigmas[timestep - 1].to(t.device))
    f = torch.zeros_like(x)
    G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2)
    return f, G


class Poisson():
  def __init__(self, config):
    """Construct a PFGM.

    Args:
      config: configurations
    """
    self.config = config
    self.N = config.sampling.N

  @property
  def M(self):
    return self.config.training.M

  def prior_sampling(self, shape):
    """
    Sampling initial data from p_prior on z=z_max hyperplane.
    See Section 3.3 in PFGM paper
    """

    # Sample the radius from p_radius (details in Appendix A.4 in the PFGM paper)
    # max_z = self.config.sampling.z_max
    N = self.config.data.channels * self.config.data.image_size * self.config.data.image_size + 1
    # Sampling from inverse-beta distribution
    # samples_norm = np.random.beta(a=N / 2. - 0.5, b=0.5, size=shape[0])
    # inverse_beta = samples_norm / (1 - samples_norm)
    # Sampling from p_radius(R) by change-of-variable
    # samples_norm = np.sqrt(max_z ** 2 * inverse_beta)
    # clip the sample norm (radius)
    # samples_norm = np.clip(samples_norm, 1, self.config.sampling.upper_norm)
    # samples_norm = torch.from_numpy(samples_norm).cuda().view(len(samples_norm), -1)

    # Uniformly sample the angle direction
    gaussian = torch.randn(shape[0], N - 1).cuda()
    # unit_gaussian = gaussian / torch.norm(gaussian, p=2, dim=1, keepdim=True)

    # Radius times the angle direction
    init_samples = gaussian

    return init_samples.float().view(len(init_samples), self.config.data.num_channels,
                                self.config.data.image_size, self.config.data.image_size)

  def ode(self, net_fn, x, t):

    z = np.exp(t.mean().cpu())
    if self.config.sampling.vs:
      print(z)
    x_drift, z_drift = net_fn(x, torch.ones((len(x))).cuda() * z)
    x_drift = x_drift.view(len(x_drift), -1)

    # Substitute the predicted z with the ground-truth
    # Please see Appendix B.2.3 in PFGM paper (https://arxiv.org/abs/2209.11178) for details
    z_exp = self.config.sampling.z_exp
    if z < z_exp and self.config.training.gamma > 0:
      data_dim = self.config.data.image_size * self.config.data.image_size * self.config.data.channels
      sqrt_dim = np.sqrt(data_dim)
      norm_1 = x_drift.norm(p=2, dim=1) / sqrt_dim
      x_norm = self.config.training.gamma * norm_1 / (1 -norm_1)
      x_norm = torch.sqrt(x_norm ** 2 + z ** 2)
      z_drift = -sqrt_dim * torch.ones_like(z_drift) * z / (x_norm + self.config.training.gamma)

    # Predicted normalized Poisson field
    v = torch.cat([x_drift, z_drift[:, None]], dim=1)

    dt_dz = 1 / (v[:, -1] + 1e-5)
    dx_dt = v[:, :-1].view(len(x), self.config.data.num_channels,
                      self.config.data.image_size, self.config.data.image_size)
    dx_dz = dx_dt * dt_dz.view(-1, *([1] * len(x.size()[1:])))
    # dx/dt_prime =  z * dx/dz
    dx_dt_prime = z * dx_dz
    return dx_dt_prime
  
class Homotopy():
  def __init__(self, config):
    """Construct a PFGM.

    Args:
      config: configurations
    """
    self.config = config