import math
import torch
import torch.nn as nn
import torch.nn.functional as F


def extract(v, t, x_shape):
    """
    Extract some coefficients at specified timesteps, then reshape to
    [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
    """
    out = torch.gather(v, index=t, dim=0).float()
    return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))


def vp_var(T, beta_min, beta_max, eps_small=1e-3):
    """compute the var of the marginal q(x_t|x_0) in VPSDE"""
    t = torch.arange(0, T + 1, dtype=torch.float64)
    t = t / T
    t = (1. - eps_small) * t + eps_small
    log_mean_coeff = -0.25 * t ** 2 * (beta_max - beta_min) - 0.5 * t * beta_min
    var = 1. - torch.exp(2. * log_mean_coeff)
    assert var.shape[0] == (T + 1)
    return var


def get_beta_schedule(beta_schedule, beta_1, beta_T, T):
    if beta_schedule == 'linear':
        betas = torch.linspace(beta_1, beta_T, T, dtype=torch.float64)
    elif beta_schedule == 'cosine':
        t = torch.linspace(0, T, T+1, dtype=torch.float64)
        alphas_cumprod = torch.cos(((t / T) + 0.008) / (1 + 0.008) * torch.pi * 0.5) ** 2
        alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
        betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
        betas = torch.clip(betas, 0, 0.999)
    elif beta_schedule == 'vpsde':
        beta_min, beta_max = beta_1 * 1000, beta_T * 1000
        var = vp_var(T, beta_min, beta_max)
        alpha_bars = 1. - var
        betas = 1. - alpha_bars[1:] / alpha_bars[:-1]
    else:
        raise NotImplementedError(beta_schedule)
    assert betas.shape[0] == T
    betas = betas.type(torch.float32)
    return betas


def get_reduced_beta_schedule(beta_schedule, betas, T):
    ori_T = betas.shape[0]
    aug_betas = F.pad(betas, (1, 0), value=0.)
    aug_alphas = 1. - aug_betas
    aug_alphas_bar = torch.cumprod(aug_alphas, dim=0)
    if beta_schedule == 'linear':
        idx = torch.arange(0, ori_T, math.ceil(ori_T / T))
        idx = torch.cat((idx, torch.tensor([ori_T]))).type(torch.int64)
    elif beta_schedule == 'quadratic':
        c = ori_T / (T ** 2)
        idx = torch.arange(0, T + 1)
        idx = torch.ceil(idx ** 2 * c).type(torch.int64)
    else:
        raise NotImplementedError(beta_schedule)
    rdc_alphas_bar = aug_alphas_bar[idx]
    assert rdc_alphas_bar.shape[0] == T + 1
    rdc_betas = 1. - rdc_alphas_bar[1:] / rdc_alphas_bar[:-1]
    idx = idx[1:] - 1
    return (rdc_betas, idx)


def normal_kl(mean1, logvar1, mean2, logvar2):
    return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
                  + (mean1 - mean2) ** 2 * torch.exp(-logvar2))


def noise_like(shape, device, repeat=False):
    repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
    noise = lambda: torch.randn(shape, device=device)
    return repeat_noise() if repeat else noise()


class DDPM(nn.Module):
    def __init__(self, model, betas,
                 mean_type='epsilon', var_type='fixedsmall'):
        super().__init__()
        assert isinstance(betas, torch.Tensor)
        assert all(betas > 0) and all(betas <= 1)
        assert mean_type in ['xprev' 'xstart', 'epsilon']
        assert var_type in ['fixedlarge', 'fixedsmall']
        self.model = model
        self.T = int(betas.shape[0])
        self.mean_type = mean_type
        self.var_type = var_type

        aug_betas = F.pad(betas, (1, 0), value=0.)
        aug_alphas = 1. - aug_betas
        aug_alphas_bar = torch.cumprod(aug_alphas, dim=0)

        # calculations for diffusion q(x_t | x_0), q(x_{t+1} | x_t) and q(x_{t+1}, x_t | x_0)
        self.register_buffer('betas',
                             aug_betas)
        self.register_buffer('sqrt_betas',
                             torch.sqrt(aug_betas))
        self.register_buffer('sqrt_alphas',
                             torch.sqrt(aug_alphas))
        self.register_buffer('alphas_bar',
                             aug_alphas_bar)
        self.register_buffer('sqrt_alphas_bar',
                             torch.sqrt(aug_alphas_bar))
        self.register_buffer('sqrt_one_minus_alphas_bar',
                             torch.sqrt(1. - aug_alphas_bar))

        alphas = aug_alphas[1:]
        alphas_bar = aug_alphas_bar[1:]
        alphas_bar_prev = aug_alphas_bar[:-1]

        # calculations for predicting x_0
        self.register_buffer('sqrt_recip_alphas_bar',
                             torch.sqrt(1. / alphas_bar))
        self.register_buffer('sqrt_recipm1_alphas_bar',
                             torch.sqrt(1. / alphas_bar - 1))

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        posterior_var = betas * (1. - alphas_bar_prev) / (1. - alphas_bar)
        self.register_buffer('posterior_var',
                             posterior_var)
        # below: log calculation clipped because the posterior variance is 0
        # at the beginning of the diffusion chain
        self.register_buffer('posterior_log_var_clipped',
                             torch.log(torch.cat([posterior_var[1:2], posterior_var[1:]])))
        self.register_buffer('posterior_mean_coef1',
                             betas * torch.sqrt(alphas_bar_prev) / (1. - alphas_bar))
        self.register_buffer('posterior_mean_coef2',
                             (1. - alphas_bar_prev) * torch.sqrt(alphas) / (1. - alphas_bar))

    def q_forward_marginal(self, x_0, t):
        """
        Compute the mean and variance of the forward diffusion marginal q(x_t | x_0)
        t in [0, T]
        Notice if t == 0, x_0' is a sharp Gaussian kernel density estimate of x_0,
        q(x_0' | x_0) is a Gaussian with beta = 1e-8.
        """
        mean = extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0
        var = extract(1. - self.alphas_bar, t, x_0.shape)
        return mean, var

    def q_sample(self, x_0, t, noise=None):
        """
        Sample from the forward diffusion q(x_t | x_0)
        t in [0, T]
        """
        if noise is None:
            noise = torch.randn_like(x_0)
        return (
                extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0
                + extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise
        )

    def q_conditional_reverse_kernel(self, x_0, x_t, t):
        """
        Compute the mean and variance of the reverse kernel conditioned on x_0
        q(x_{t-1} | x_t, x_0)
        t in [1, T]
        """
        assert x_0.shape == x_t.shape
        temp = t - 1
        posterior_mean = (
                extract(self.posterior_mean_coef1, temp, x_t.shape) * x_0
                + extract(self.posterior_mean_coef2, temp, x_t.shape) * x_t
        )
        posterior_var = extract(self.posterior_var, temp, x_t.shape)
        posterior_log_var_clipped = extract(self.posterior_log_var_clipped, temp, x_t.shape)
        return posterior_mean, posterior_var, posterior_log_var_clipped

    def predict_xstart_from_eps(self, x_t, t, eps):
        """
        Predict x_0 from x_t and noise
        t in [1, T]
        """
        assert x_t.shape == eps.shape
        temp = t - 1
        return (
                extract(self.sqrt_recip_alphas_bar, temp, x_t.shape) * x_t
                - extract(self.sqrt_recipm1_alphas_bar, temp, x_t.shape) * eps
        )

    def predict_xstart_from_xprev(self, x_t, t, x_tm1):
        """
        Predict x_0 from x_t and x_{t-1}
        t in [1, T]
        """
        assert x_t.shape == x_t.shape
        temp = t - 1
        return (  # (x_t - coef2*x_tp1) / coef1
                extract(1. / self.posterior_mean_coef1, temp, x_t.shape) * x_tm1
                - extract(self.posterior_mean_coef2 / self.posterior_mean_coef1, temp, x_t.shape) * x_t
        )

    def p_denoising_kernel(self, x_t, t, clip_denoised: bool):
        """
        Compute the mean and variance of the Gaussian denoise kernel p(x_{t-1} | x_t)
        t in [1, T]
        """
        temp = t - 1
        # below: only log_variance is used in the KL computations
        model_var, model_log_var = {
            # for fixedlarge, we set the initial (log-)variance like so to
            # get a better decoder log likelihood
            'fixedlarge': (self.betas[1:], torch.log(torch.cat([self.posterior_var[1:2], self.betas[2:]]))),
            'fixedsmall': (self.posterior_var, self.posterior_log_var_clipped),
        }[self.var_type]
        model_var = extract(model_var, temp, x_t.shape)
        model_log_var = extract(model_log_var, temp, x_t.shape)

        # Mean parameterization
        _maybe_clip = lambda x_: (torch.clip(x_, -1., 1.) if clip_denoised else x_)
        if self.mean_type == 'xprev':       # the model predicts x_{t-1}
            x_tm1 = self.model(x_t, temp)
            x_0 = _maybe_clip(self.predict_xstart_from_xprev(x_t, t, x_tm1=x_tm1))
            model_mean = x_tm1
        elif self.mean_type == 'xstart':    # the model predicts x_0
            x_0 = _maybe_clip(self.model(x_t, temp))
            model_mean, _, _ = self.q_conditional_reverse_kernel(x_0, x_t, t)
        elif self.mean_type == 'epsilon':   # the model predicts epsilon
            eps = self.model(x_t, temp)
            x_0 = _maybe_clip(self.predict_xstart_from_eps(x_t, t, eps=eps))
            model_mean, _, _ = self.q_conditional_reverse_kernel(x_0, x_t, t)
        else:
            raise NotImplementedError(self.mean_type)

        return model_mean, model_var, model_log_var, x_0

    # === training one step ===
    def one_step_train(self, x_0):
        # t is uniformly sampled from [1, T]
        t = torch.randint(self.T, size=(x_0.shape[0],), device=x_0.device) + 1
        noise = torch.randn_like(x_0)
        x_t = self.q_sample(x_0, t, noise)
        # the idx of t in Unet is [0, T-1]
        loss = F.mse_loss(self.model(x_t, t-1), noise, reduction='none')
        return loss

    # === sampling ===
    def ddpm_p_sample(self, x_t, t, repeat_noise=False, clip_denoised=True):
        assert x_t.shape[0] == t.shape[0]
        assert all(t >= 1) and all(t <= self.T), 't should be in [1, T]'
        noise = noise_like(x_t.shape, x_t.device, repeat_noise)
        nonzero_mask = (1 - (t == 1).float()).reshape(x_t.shape[0], *((1,) * (len(x_t.shape) - 1)))
        noise = nonzero_mask * noise
        model_mean, _, model_log_var, _ = self.p_denoising_kernel(x_t, t, clip_denoised)
        x_tm1 = model_mean + torch.exp(0.5 * model_log_var) * noise
        return x_tm1
    
    def ddpm_sample(self, x_T):
        x_t = x_T
        for time_step in reversed(range(self.T)):
            t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * (time_step + 1)
            x_t = self.ddpm_p_sample(x_t, t)
        x_0 = x_t
        return torch.clip(x_0, -1, 1)

    def forward(self, x, mode):
        if mode == 'train':
            return self.one_step_train(x)
        elif mode == 'sample':
            return self.ddpm_sample(x)
        else:
            raise NotImplementedError


class DPM(nn.Module):
    def __init__(self, M, reduced_betas, nz,
                 mean_type='epsilon', var_type='fixedsmall'):
        super().__init__()
        betas = reduced_betas[0]
        assert isinstance(betas, torch.Tensor)
        assert all(betas >= 0) and all(betas <= 1)
        assert mean_type in ['xstart', 'epsilon']
        assert var_type in ['fixedlarge', 'fixedsmall']
        self.M = M
        self.T = int(betas.shape[0])
        self.register_buffer('idx', reduced_betas[1])
        self.nz = nz
        self.mean_type = mean_type
        self.var_type = var_type

        aug_betas = F.pad(betas, (1, 0), value=0.)
        aug_alphas = 1. - aug_betas
        aug_alphas_bar = torch.cumprod(aug_alphas, dim=0)

        # calculations for diffusion q(x_t | x_0), q(x_{t+1} | x_t) and q(x_{t+1}, x_t | x_0)
        self.register_buffer('betas',
                             aug_betas)
        self.register_buffer('sqrt_betas',
                             torch.sqrt(aug_betas))
        self.register_buffer('sqrt_alphas',
                             torch.sqrt(aug_alphas))
        self.register_buffer('alphas_bar',
                             aug_alphas_bar)
        self.register_buffer('sqrt_alphas_bar',
                             torch.sqrt(aug_alphas_bar))
        self.register_buffer('sqrt_one_minus_alphas_bar',
                             torch.sqrt(1. - aug_alphas_bar))
        # self.register_buffer('sqrt')

        alphas = aug_alphas[1:]
        alphas_bar = aug_alphas_bar[1:]
        alphas_bar_prev = aug_alphas_bar[:-1]

        # calculations for predicting x_0 or noise
        self.register_buffer('sqrt_recip_alphas_bar',
                             torch.sqrt(1. / alphas_bar))
        self.register_buffer('sqrt_recipm1_alphas_bar',
                             torch.sqrt(1. / alphas_bar - 1))
        self.register_buffer('sqrt_recip_one_minus_alphas_bar',
                             torch.sqrt(1. / (1 - alphas_bar)))
        self.register_buffer('sqrt_recipm1_one_minus_alphas_bar',
                             torch.sqrt(1. / (1 - alphas_bar) - 1))

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        posterior_var = betas * (1. - alphas_bar_prev) / (1. - alphas_bar)
        self.register_buffer('posterior_var',
                             posterior_var)
        # below: log calculation clipped because the posterior variance is 0
        # at the beginning of the diffusion chain
        self.register_buffer('posterior_log_var_clipped',
                             torch.log(torch.cat([posterior_var[1:2], posterior_var[1:]])))
        self.register_buffer('posterior_mean_coef1',
                             betas * torch.sqrt(alphas_bar_prev) / (1. - alphas_bar))
        self.register_buffer('posterior_mean_coef2',
                             (1. - alphas_bar_prev) * torch.sqrt(alphas) / (1. - alphas_bar))

    def q_forward_marginal(self, x_0, t):
        """
        Compute the mean and variance of the forward diffusion marginal q(x_t | x_0)
        t in [0, T]
        Notice if t == 0, x_0' is a sharp Gaussian kernel density estimate of x_0,
        q(x_0' | x_0) is a Gaussian with beta = 1e-8.
        """
        mean = extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0
        var = extract(1. - self.alphas_bar, t, x_0.shape)
        return mean, var

    def q_sample(self, x_0, t, noise=None):
        """
        Sample from the forward diffusion q(x_t | x_0)
        t in [0, T]
        """
        if noise is None:
            noise = torch.randn_like(x_0)
        return (
                extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0
                + extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise
        )

    def q_forward_kernel(self, x_tm1, t):
        """
        Compute the mean and variance of the forward Gaussian diffusion kernel
        q(x_t | x_{t-1})
        t in [1, T]
        """
        mean = extract(self.sqrt_alphas, t, x_tm1.shape) * x_tm1
        var = extract(self.betas, t, x_tm1.shape)
        return mean, var

    def q_forward_step(self, x_tm1, t, noise=None):
        """
        Sample from the forward diffusion kernel q(x_t | x_{t-1})
        t in [1, T]
        """
        if noise is None:
            noise = torch.randn_like(x_tm1)
        return (
                extract(self.sqrt_alphas, t, x_tm1.shape) * x_tm1
                + extract(self.sqrt_betas, t, x_tm1.shape) * noise
        )

    def q_forward_pair(self, x_0, t, fixed=False):
        """
        Sample from the forward diffusion q(x_{t+1}, x_t | x_0) = q(x_t | x_0)q(x_{t+1} | x_t)
        t in [0, T -1]
        """
        if fixed:
            noise = torch.randn_like(x_0)
            T = x_0.new_ones([x_0.shape[0], ], dtype=torch.long) * self.T
            x_t = self.q_sample(x_0, t, noise)
            x_tp1 = self.q_sample(x_0, t+1, noise)
            x_T = self.q_sample(x_0, T, noise)
            return x_t, x_tp1, x_T
        elif not fixed:
            x_t = self.q_sample(x_0, t)
            x_tp1 = self.q_forward_step(x_t, t+1)
            return x_t, x_tp1
        else:
            raise NotImplementedError(fixed)

    def q_conditional_reverse_kernel(self, x_0, x_t, t):
        """
        Compute the mean and variance of the reverse kernel conditioned on x_0 in DDPM
        q(x_{t-1} | x_t, x_0)
        t in [1, T]
        """
        assert x_0.shape == x_t.shape
        temp = t - 1
        posterior_mean = (
                extract(self.posterior_mean_coef1, temp, x_t.shape) * x_0
                + extract(self.posterior_mean_coef2, temp, x_t.shape) * x_t
        )
        posterior_var = extract(self.posterior_var, temp, x_t.shape)
        posterior_log_var_clipped = extract(self.posterior_log_var_clipped, temp, x_t.shape)
        return posterior_mean, posterior_var, posterior_log_var_clipped

    def predict_xstart_from_eps(self, x_t, t, eps):
        """
        Predict x_0 from x_t and noise
        t in [1, T]
        """
        assert x_t.shape == eps.shape
        temp = t - 1
        return (
                extract(self.sqrt_recip_alphas_bar, temp, x_t.shape) * x_t
                - extract(self.sqrt_recipm1_alphas_bar, temp, x_t.shape) * eps
        )

    def predict_eps_from_xstart(self, x_t, t, x_0):
        """
        Predict noise from x_t and x_0
        t in [1, T]
        """
        assert x_t.shape == x_0.shape
        temp = t - 1
        return (
                extract(self.sqrt_recip_one_minus_alphas_bar, temp, x_t.shape) * x_t
                - extract(self.sqrt_recipm1_one_minus_alphas_bar, temp, x_t.shape) * x_0
        )

    def ddpm_denoising_kernel(self, x_t, t, clip_denoised: bool):
        """
        Compute the mean and variance of the Gaussian denoise kernel in DDPM
        p(x_{t-1} | x_t)
        t in [1, T]
        """
        temp = t - 1
        # below: only log_variance is used in the KL computations
        model_var, model_log_var = {
            # for fixedlarge, we set the initial (log-)variance like so to
            # get a better decoder log likelihood
            'fixedlarge': (self.betas[1:], torch.log(torch.cat([self.posterior_var[1:2], self.betas[2:]]))),
            'fixedsmall': (self.posterior_var, self.posterior_log_var_clipped),
        }[self.var_type]
        model_var = extract(model_var, temp, x_t.shape)
        model_log_var = extract(model_log_var, temp, x_t.shape)

        # Mean parameterization
        _maybe_clip = lambda x_: (torch.clip(x_, -1., 1.) if clip_denoised else x_)
        if self.mean_type == 'xstart':    # the model predicts x_0 with stochastic imput u
            u = noise_like([x_t.shape[0], self.nz], x_t.device)
            x_0 = _maybe_clip(self.M(x_t, t - 1, u))
            model_mean, _, _ = self.q_conditional_reverse_kernel(x_0, x_t, t)
        elif self.mean_type == 'epsilon':   # the model predicts epsilon deterministically
            eps = self.M(x_t, self.idx[temp])
            x_0 = _maybe_clip(self.predict_xstart_from_eps(x_t, t, eps=eps))
            model_mean, _, _ = self.q_conditional_reverse_kernel(x_0, x_t, t)
        else:
            raise NotImplementedError(self.mean_type)

        return model_mean, model_var, model_log_var, x_0

    def iddim_denoising_kernel(self, x_t, y_t, t, x_T, m, clip_denoised: bool):
        temp = t - 1
        _maybe_clip = lambda x_: (torch.clip(x_, -1., 1.) if clip_denoised else x_)
        if self.mean_type == 'xstart':    # the model predicts x_0 with stochastic imput u
            u = noise_like([x_t.shape[0], self.nz], x_t.device)
            x_0 = _maybe_clip(self.M(x_t, t - 1, u))
        elif self.mean_type == 'epsilon':   # the model predicts epsilon deterministically
            eps = self.M(x_t, self.idx[temp])
            x_0 = _maybe_clip(self.predict_xstart_from_eps(x_t, t, eps=eps))
        else:
            raise NotImplementedError(self.mean_type)
        # x_T = torch.randn_like(x_0)
        z_tm1 = (
            extract(self.sqrt_alphas_bar - self.sqrt_one_minus_alphas_bar * self.sqrt_recipm1_one_minus_alphas_bar[-1], temp, x_0.shape) * x_0 
            + extract(self.sqrt_one_minus_alphas_bar * self.sqrt_recip_one_minus_alphas_bar[-1], temp, x_0.shape) * x_T
            )
        if y_t is None or m == 1.0:
            return z_tm1, None, x_0
        else:
            # eps_ = torch.randn_like(x_0)
            eps_ = self.predict_eps_from_xstart(y_t, t, x_0=x_0)
            y_tm1 = (
                extract(self.sqrt_alphas_bar, temp, x_0.shape) * x_0
                + extract(self.sqrt_one_minus_alphas_bar, temp, x_0.shape) * eps_
                )
            x_tm1 = (1. - m) * y_tm1 + m * z_tm1
            return x_tm1, y_tm1, x_0

    def iddim_denoising_kernel_display(self, x_t, y_t, t, x_T, m, clip_denoised: bool):
        temp = t - 1
        _maybe_clip = lambda x_: (torch.clip(x_, -1., 1.) if clip_denoised else x_)
        if self.mean_type == 'xstart':    # the model predicts x_0 with stochastic imput u
            u = noise_like([x_t.shape[0], self.nz], x_t.device)
            x_0 = _maybe_clip(self.M(x_t, t - 1, u))
        elif self.mean_type == 'epsilon':   # the model predicts epsilon deterministically
            eps = self.M(x_t, self.idx[temp])
            x_0 = _maybe_clip(self.predict_xstart_from_eps(x_t, t, eps=eps))
        else:
            raise NotImplementedError(self.mean_type)
        # x_T = torch.randn_like(x_0)
        z_tm1 = (
            extract(self.sqrt_alphas_bar - self.sqrt_one_minus_alphas_bar * self.sqrt_recipm1_one_minus_alphas_bar[-1], temp, x_0.shape) * x_0 
            + extract(self.sqrt_one_minus_alphas_bar * self.sqrt_recip_one_minus_alphas_bar[-1], temp, x_0.shape) * x_T
            )
        if y_t is None or m == 1.0:
            return z_tm1, None, x_0, x_0
        else:
            # eps_ = torch.randn_like(x_0)
            eps_ = self.predict_eps_from_xstart(y_t, t, x_0=x_0)
            y_tm1 = (
                extract(self.sqrt_alphas_bar, temp, x_0.shape) * x_0
                + extract(self.sqrt_one_minus_alphas_bar, temp, x_0.shape) * eps_
                )
            x_tm1 = (1. - m) * y_tm1 + m * z_tm1
            x_0_ = (
                (x_tm1 - extract(self.sqrt_one_minus_alphas_bar * self.sqrt_recip_one_minus_alphas_bar[-1], temp, x_0.shape) * x_T) 
                / extract(self.sqrt_alphas_bar - self.sqrt_one_minus_alphas_bar * self.sqrt_recipm1_one_minus_alphas_bar[-1], temp, x_0.shape)
                )
            return x_tm1, y_tm1, x_0, x_0_

    # === sampling ===
    def ddpm_p_sample(self, x_t, t, repeat_noise=False, clip_denoised=True):
        assert x_t.shape[0] == t.shape[0]
        assert all(t >= 1) and all(t <= self.T), 't should be in [1, T]'
        noise = noise_like(x_t.shape, x_t.device, repeat_noise)
        nonzero_mask = (1 - (t == 1).float()).reshape(x_t.shape[0], *((1,) * (len(x_t.shape) - 1)))
        noise = nonzero_mask * noise
        with torch.no_grad():
            model_mean, _, model_log_var, _ = self.ddpm_denoising_kernel(x_t, t, clip_denoised)
        x_tm1 = model_mean + torch.exp(0.5 * model_log_var) * noise
        return x_tm1

    def ddpm_sample(self, x_T):
        x_t = x_T
        for time_step in reversed(range(self.T)):
            t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * (time_step + 1)
            x_t = self.ddpm_p_sample(x_t, t)
        x_0 = x_t
        return torch.clip(x_0, -1, 1)

    def iddim_p_sample(self, x_t, y_t, t, t_T, m, clip_denoised=True):
        assert x_t.shape[0] == t.shape[0]
        assert all(t >= 1) and all(t <= self.T), 't should be in [1, T]'
        with torch.no_grad():
            x_tm1, y_tm1, _ = self.iddim_denoising_kernel(x_t, y_t, t, t_T, m, clip_denoised)
        return x_tm1, y_tm1

    def iddim_sample(self, x_T, m):
        x_t = x_T
        y_t = x_T
        for time_step in reversed(range(self.T)):
            t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * (time_step + 1)
            x_t, y_t = self.iddim_p_sample(x_t, y_t, t, x_T, m)
        x_0 = x_t
        return torch.clip(x_0, -1, 1)

    def forward(self, x, mode, m=None):
        if mode == 'ddpm_sample':
            return self.ddpm_sample(x)
        elif mode == 'iddim_sample':
            return self.iddim_sample(x, m)
        else:
            raise NotImplementedError
