import math
import torch
import torch.nn as nn
from tqdm import tqdm

def cosine_beta_schedule(timesteps, s=0.008):
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5)**2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0, 0.999)

def extract(a, t, x_shape):
    """Get coefficients at given timesteps and reshape to [batch_size, 1, ...]."""
    assert len(t.size()) == 1
    bs, = t.shape; out = a[t]
    assert x_shape[0] == bs and out.shape[0] == bs
    return out.reshape([bs] + (len(x_shape) - 1) * [1])

class PortableDiffusionModel(nn.Module):
  """Basic Diffusion Model."""
  def __init__(self, dim, n_steps, net, loss_type='simple', mc_loss=True, var_type='learned', samples_per_step=1, name=None):
    super().__init__()
    assert var_type in ('beta_forward', 'beta_reverse', 'learned')
    self.net = net
    self._var_type = var_type
    self._n_steps = n_steps
    self._dim = dim
    self._loss_type = loss_type
    self._mc_loss = mc_loss
    self._samples_per_step = samples_per_step
    self._betas = cosine_beta_schedule(n_steps)
    self._alphas = 1. - self._betas
    self._log_alphas = torch.log(self._alphas)

    alphas = 1. - self._betas
    self._sqrt_alphas = torch.sqrt(alphas)
    self._sqrt_recip_alphas = 1. / torch.sqrt(alphas)
    self._alphas_cumprod = torch.cumprod(self._alphas, dim=0)

    self._alphas_cumprod_prev = torch.ones(len(self._alphas_cumprod))
    self._alphas_cumprod_prev[1:].copy_(self._alphas_cumprod[:-1])
    self._sqrt_alphas_cumprod = torch.sqrt(self._alphas_cumprod)
    self._sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - self._alphas_cumprod)
    self._log_one_minus_alphas_cumprod = torch.log(1 - self._alphas_cumprod)

    self._sqrt_recip_alphas_cumprod = torch.rsqrt(self._alphas_cumprod)
    self._sqrt_recipm1_alphas_cumprod = torch.sqrt(1 / self._alphas_cumprod - 1)
    self._sqrt_recipm1_alphas_cumprod_custom = torch.sqrt(1. / (1 - self._alphas_cumprod))

    # calculations for posterior q(x_{t-1} | x_t, x_0)
    self._posterior_variance = self._betas * (
        1. - self._alphas_cumprod_prev) / (1. - self._alphas_cumprod)

    self._posterior_log_variance_clipped = torch.log(
        torch.clip(self._posterior_variance, min=torch.min(self._betas)))
    self._posterior_mean_coef1 = self._betas * torch.sqrt(
        self._alphas_cumprod_prev) / (1 - self._alphas_cumprod)
    self._posterior_mean_coef2 = (1 - self._alphas_cumprod_prev) * torch.sqrt(
        self._alphas) / (1 - self._alphas_cumprod)
    self._out_logvar = torch.log(self._betas).clone()

  def energy_scale(self, t):
    return self._sqrt_recipm1_alphas_cumprod[t]

  def data_scale(self, t):
    return self._sqrt_recip_alphas_cumprod[t]

  def forward(self, x, t):
    """Get mu_t-1 given x_t."""
    x = torch.atleast_2d(x)
    t = torch.atleast_1d(t)
    outs = self.net(x, t)

    assert outs.shape == x.shape
    assert x.shape[1:] == (self._dim,) and len(t.size()) == 1
    return outs

  def forward_no_grad(self, x, t):
    """Get mu_t-1 given x_t."""
    x = torch.atleast_2d(x)
    t = torch.atleast_1d(t)
    outs = self.net.forward_no_grad(x, t)

    assert outs.shape == x.shape
    assert x.shape[1:] == (self._dim,) and len(t.size()) == 1
    return outs
  
  def stats(self):
    """Returns static variables for computing variances."""
    return {
        'betas': self._betas,
        'alphas': self._alphas,
        'alphas_cumprod': self._alphas_cumprod,
        'alphas_cumprod_prev': self._alphas_cumprod_prev,
        'sqrt_alphas_cumprod': self._sqrt_alphas_cumprod,
        'sqrt_one_minus_alphas_cumprod': self._sqrt_one_minus_alphas_cumprod,
        'log_one_minus_alphas_cumprod': self._log_one_minus_alphas_cumprod,
        'sqrt_recip_alphas_cumprod': self._sqrt_recip_alphas_cumprod,
        'sqrt_recipm1_alphas_cumprod': self._sqrt_recipm1_alphas_cumprod,
        'posterior_variance': self._posterior_variance,
        'posterior_log_variace_clipped': self._posterior_log_variance_clipped
    }

  def q_mean_variance(self, x_0, t):
    """Returns parameters of q(x_t | x_0)."""
    mean = extract(self._sqrt_alphas_cumprod, t, x_0.shape) * x_0
    variance = extract(1. - self._alphas_cumprod, t, x_0.shape)
    log_variance = extract(self._log_one_minus_alphas_cumprod, t, x_0.shape)
    return mean, variance, log_variance

  def q_sample(self, x_0, t, noise=None):
    """Sample from q(x_t | x_0)."""
    assert x_0.shape[1:] == (self._dim,) and len(t.size()) == 1
    if noise is None:
      noise = torch.randn(x_0.shape)

    x_t = extract(self._sqrt_alphas_cumprod, t, x_0.shape) * x_0 + extract(
        self._sqrt_one_minus_alphas_cumprod, t, x_0.shape) * noise
    assert x_t.shape == x_0.shape
    return x_t

  def p_loss_simple(self, x_0, t):
    """Training loss for given x_0 and t."""
    assert x_0.shape[1:] == (self._dim,) and len(t.size()) == 1
    noise = torch.randn(x_0.shape)

    x_noise = self.q_sample(x_0, t, noise)
    noise_recon = self.forward(x_noise, t)
    assert noise_recon.shape == x_0.shape
    mse = torch.square(noise_recon - noise)
    
    assert mse.shape == (t.shape[0], self._dim)
    mse = torch.mean(mse, dim=1)  # avg over the output dimension
    assert mse.shape == t.shape
    return mse

  def p_loss_kl(self, x_0, t):
    """Training loss for given x_0 and t (KL-weighted)."""
    assert x_0.shape[1:] == (self._dim,) and len(t.size()) == 1

    x_t = self.q_sample(x_0, t)
    q_mean, _, q_log_variance = self.q_posterior(x_0, x_t, t)
    p_mean, _, p_log_variance = self.p_mean_variance(x_t, t)

    dist_q = torch.distributions.normal.Normal(q_mean, torch.exp(0.5 * q_log_variance))
    def _loss(pmu, plogvar):
      dist_p = torch.distributions.normal.Normal(pmu, torch.exp(0.5 * plogvar))
      kl = torch.distributions.kl.kl_divergence(dist_q, dist_p).mean(-1)
      nll = -dist_p.log_prob(x_0).mean(-1)
      return kl, nll, torch.where(t == 0, nll, kl)

    kl, nll, loss = _loss(p_mean, p_log_variance)
    assert nll.shape == kl.shape and loss.shape == (t.shape[0],)
    return loss

  def q_posterior(self, x_0, x_t, t):
    """Obtain parameters of q(x_{t-1} | x_0, x_t)."""
    assert x_0.shape[1:] == (self._dim,) and len(t.size()) == 1
    mean = (
        extract(self._posterior_mean_coef1, t, x_t.shape) * x_0
        + extract(self._posterior_mean_coef2, t, x_t.shape) * x_t
    )
    var = extract(self._posterior_variance, t, x_t.shape)
    log_var_clipped = extract(self._posterior_log_variance_clipped,
                              t, x_t.shape)
    assert var.shape == log_var_clipped.shape
    assert x_0.shape == x_t.shape and x_0.shape == mean.shape
    return mean, var, log_var_clipped

  def predict_start_from_noise(self, x_t, t, noise):
    """Predict x_0 from x_t."""
    assert x_t.shape[1:] == (self._dim,) and len(t.size()) == 1
    x_0 = (
        extract(self._sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
        - extract(self._sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
    )
    assert x_0.shape == x_t.shape
    return x_0

  def p_mean_variance(self, x, t, clip=math.inf):
    """Parameters of p(x_{t-1} | x_t)."""
    assert x.shape[1:] == (self._dim,) and len(t.size()) == 1

    x_recon = torch.clip(
        self.predict_start_from_noise(x, t, noise=self.forward(x, t)), -clip,
        clip)

    mean, var, log_var = self.q_posterior(x_recon, x, t)

    assert mean.shape == x.shape
    assert var.shape == (x.shape[0], 1) and var.shape == log_var.shape

    if self._var_type == 'beta_reverse':
      pass
    elif self._var_type == 'beta_forward':
      var = extract(self._betas, t, x.shape)
      log_var = torch.log(var)
    elif self._var_type == 'learned':
      log_var = extract(self._out_logvar, t, x.shape)
      var = torch.exp(log_var)
    else:
      raise ValueError(f'{self._var_type} not recognised.')

    assert var.shape == (x.shape[0], 1) and var.shape == log_var.shape
    assert mean.shape == (x.shape[0], x.shape[1])
    return mean, var, log_var

  def p_sample(self, x, t, rng_key=None, clip=math.inf):
    """Sample from p(x_{t-1} | x_t)."""
    assert x.shape[1:] == (self._dim,) and len(t.size()) == 1
    mean, _, log_var = self.p_mean_variance(x, t, clip=clip)
    noise = torch.randn(x.shape)
    x_tm1 = mean + torch.exp(0.5 * log_var) * noise
    assert x.shape == x_tm1.shape; return x_tm1

  def _prior_kl(self, x_0):
    """KL(q_T(x) || p(x))."""
    t = torch.ones((x_0.shape[0],), dtype=torch.int32) * (self._n_steps - 1)
    qt_mean, _, qt_log_variance = self.q_mean_variance(x_0, t)
    qt_dist = torch.distributions.normal.Normal(qt_mean, torch.exp(0.5 * qt_log_variance))
    p_dist = torch.distributions.normal.Normal(torch.zeros_like(qt_mean), torch.ones_like(qt_mean))
    kl = torch.distributions.kl.kl_divergence(qt_dist, p_dist).mean(-1)
    assert kl.shape == (x_0.shape[0],); return kl

  def logpx(self, x_0):
    """Full elbo estimate of model."""
    e = self._prior_kl(x_0)
    assert e.shape == (x_0.shape[0],)
    n_repeats = self._n_steps * self._samples_per_step
    e = e.repeat_interleave(n_repeats, dim=0) / n_repeats
    kls = self.loss_all_t(x_0, loss_type='kl')
    logpx = -(kls + e) * self._dim * self._n_steps
    return {'logpx': logpx}

  def sample(self, n, clip=math.inf):
    """Sample from p(x)."""
    x = torch.randn(n, self._dim)
    def body_fn(i, inputs):
      x = inputs
      j = self._n_steps - 1 - i
      t = torch.ones((n,), dtype=torch.int32) * j
      x = self.p_sample(x, t, rng_key=None, clip=clip)
      return x

    for i in tqdm(range(self._n_steps), desc='RD'):
      x = body_fn(i, x)
    assert x.shape == (n, self._dim); return x

  def loss(self, x):
    if self._mc_loss:
      return self.loss_mc(x, loss_type=self._loss_type)
    else:
      return self.loss_all_t(x, loss_type=self._loss_type)

  def loss_mc(self, x, loss_type=None):
    """Compute training loss, uniformly sampling t's."""
    assert x.shape[1:] == (self._dim,)
    t = torch.randint(0, self._n_steps, (x.shape[0],))

    if loss_type == 'simple':
      loss = self.p_loss_simple(x, t)
      #loss = self.p_loss_simple_cv(x, t)
    elif loss_type == 'kl':
      loss = self.p_loss_kl(x, t)
    else:
      raise ValueError(f'Unrecognized loss type: {loss_type}')

    assert loss.shape == (x.shape[0],); return loss

  def loss_all_t(self, x, loss_type=None):
    """Compute training loss enumerated and averaged over all t's."""
    assert x.shape[1:] == (self._dim,)
    t = torch.cat([torch.arange(0, self._n_steps)] * x.shape[0])
    t = torch.tile(t, (self._samples_per_step,)).reshape(-1)
    x_r = torch.tile(x, (self._n_steps * self._samples_per_step,)).reshape(-1, *x.shape[1:])
    assert x_r.shape[:1] == t.shape[:1]

    if loss_type == 'simple':
      loss = self.p_loss_simple(x_r, t)
    elif loss_type == 'kl':
      loss = self.p_loss_kl(x_r, t)
    else:
      raise ValueError(f'Unrecognized loss type: {loss_type}')
    return loss

  def p_gradient(self, x, t, clip=math.inf):
    """Compute mean and variance of Gaussian reverse model p(x_{t-1} | x_t)."""
    # b = x.shape[0]
    # chex.assert_axis_dimension(t, 0, b)
    gradient = self.forward_no_grad(x, t)
    gradient = gradient * extract(self._sqrt_recipm1_alphas_cumprod_custom, t, gradient.shape)
    return gradient

  def p_energy(self, x, t, clip=math.inf):
    """Compute mean and variance of Gaussian reverse model p(x_{t-1} | x_t)."""
    # b = x.shape[0]
    # chex.assert_axis_dimension(t, 0, b)
    x = torch.atleast_2d(x)
    t = torch.atleast_1d(t)
    assert x.shape[1:] == (self._dim,) and len(t.size()) == 1
    energy = self.net.neg_logp_unnorm(x, t)
    energy = energy * extract(self._sqrt_recipm1_alphas_cumprod_custom, t, energy.shape)
    return energy