"""
Based on: https://github.com/crowsonkb/k-diffusion
"""
# import random

import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from .nn import mean_flat, append_dims, append_zero

import os
import tools.torch_tools as torch_tools

def normalize_tensor(in_feat,eps=1e-10):
    norm_factor = th.sqrt(th.sum(in_feat**2, dim=1, keepdim=True))
    return in_feat/(norm_factor+eps)

def get_weightings(weight_schedule, snrs, sigma_data, t, s):
    if weight_schedule == "snr":
        weightings = snrs
    elif weight_schedule == "sq-snr":
        weightings = snrs**0.5
    elif weight_schedule == "snr+1":
        weightings = snrs + 1
    elif weight_schedule == "karras":
        weightings = snrs + 1.0 / sigma_data**2
    elif weight_schedule == "truncated-snr":
        weightings = th.clamp(snrs, min=1.0)
    elif weight_schedule == "uniform":
        weightings = th.ones_like(snrs)
    elif weight_schedule == "uniform_g":
        return 1./(1. - s / t)
    elif weight_schedule == "karras_weight":
        sigma = snrs ** -0.5
        weightings = (sigma ** 2 + sigma_data ** 2) / (sigma * sigma_data) ** 2
    elif weight_schedule == "sq-t-inverse":
        weightings = 1. / snrs ** 0.25
    else:
        raise NotImplementedError()
    return weightings

def hinge_d_loss(logits_real, logits_fake):
    loss_real = th.mean(F.relu(1. - logits_real))
    loss_fake = th.mean(F.relu(1. + logits_fake))
    d_loss = 0.5 * (loss_real + loss_fake)
    return d_loss


def vanilla_d_loss(logits_real, logits_fake):
    d_loss = 0.5 * (
        th.mean(th.nn.functional.softplus(-logits_real)) +
        th.mean(th.nn.functional.softplus(logits_fake)))
    return d_loss

class KarrasDenoiser:
    def __init__(
        self,
        args,
        schedule_sampler,
        diffusion_schedule_sampler,
        feature_networks=None,
    ):
        self.args = args
        self.schedule_sampler = schedule_sampler
        self.diffusion_schedule_sampler = diffusion_schedule_sampler
        # self.omega_sampler = OmegaSampler(args.omega_min, args.omega_max)
        self.feature_networks = feature_networks
        self.num_timesteps = args.start_scales
        self.dist = nn.MSELoss(reduction='none')

    def get_snr(self, sigmas):
        return sigmas**-2

    def get_sigmas(self, sigmas):
        return sigmas

    def get_c_in(self, sigma):
        return 1 / (sigma**2 + self.args.sigma_data**2) ** 0.5

    def get_scalings(self, sigma):
        c_skip = self.args.sigma_data**2 / (sigma**2 + self.args.sigma_data**2)
        c_out = sigma * self.args.sigma_data / (sigma**2 + self.args.sigma_data**2) ** 0.5
        return c_skip, c_out

    def get_scalings_t(self, t, s): # TODO: check what this is later
        c_skip = th.zeros_like(t)
        c_out = ((t ** 2 + self.args.sigma_data ** 2) / (s ** 2 + self.args.sigma_data ** 2)) ** 0.5
        return c_skip, c_out

    def get_scalings_for_generalized_boundary_condition(self, t, s):
        if self.args.parametrization.lower() == 'euler':
            c_skip = s / t
        elif self.args.parametrization.lower() == 'variance':
            c_skip = (((s - self.args.sigma_min) ** 2 + self.args.sigma_data ** 2) / ((t - self.args.sigma_min) ** 2 + self.args.sigma_data ** 2)).sqrt()
        elif self.args.parametrization.lower() == 'euler_variance_mixed':
            c_skip = s / (t + 1.) + \
                     (((s - self.args.sigma_min) ** 2 + self.args.sigma_data ** 2) /
                      ((t - self.args.sigma_min) ** 2 + self.args.sigma_data ** 2)).sqrt() / (t + 1.)
        c_out = (1. - s / t)
        return c_skip, c_out

    def get_scalings_for_boundary_condition(self, sigma):
        c_skip = self.args.sigma_data**2 / (
            (sigma - self.args.sigma_min) ** 2 + self.args.sigma_data**2
        )
        c_out = (
            (sigma - self.args.sigma_min)
            * self.args.sigma_data
            / (sigma**2 + self.args.sigma_data**2) ** 0.5 
        )
        return c_skip, c_out

    def calculate_adaptive_weight(self, loss1, loss2, last_layer=None, allow_unused=False):
        loss1_grad = th.autograd.grad(loss1, last_layer, retain_graph=True, allow_unused=allow_unused)[0]
        loss2_grad = th.autograd.grad(loss2, last_layer, retain_graph=True)[0]
        d_weight = th.norm(loss1_grad) / (th.norm(loss2_grad) + 1e-8)
        #print("consistency gradient: ", th.norm(loss1_grad))
        #print("denoising gradient: ", th.norm(loss2_grad))
        #print("weight: ", d_weight)
        d_weight = th.clamp(d_weight, 0.0, 1e3).detach()
        return d_weight

    def adopt_weight(self, weight, global_step, threshold=0, value=0.):
        if global_step < threshold:
            weight = value
        return weight

    def rescaling_t(self, t):
        rescaled_t = 1000 * 0.25 * th.log(t + 1e-44)
        return rescaled_t

    def get_t(self, ind): 
        if self.args.time_continuous:
            t = self.args.sigma_max ** (1 / self.args.rho) + ind * (
                    self.args.sigma_min ** (1 / self.args.rho) - self.args.sigma_max ** (1 / self.args.rho)
            )
            t = t ** self.args.rho
        else: # Same as EDM's eq.(5)
            t = self.args.sigma_max ** (1 / self.args.rho) + ind / (self.args.start_scales - 1) * (
                    self.args.sigma_min ** (1 / self.args.rho) - self.args.sigma_max ** (1 / self.args.rho)
            )
            t = t ** self.args.rho
        return t

    def get_num_heun_step(self, step):
        if self.args.num_heun_step_random:
            #if step % self.args.g_learning_period == 0:
            if self.args.time_continuous:
                num_heun_step = np.random.rand() * self.args.num_heun_step / self.args.start_scales
            else:
                if self.args.heun_step_strategy == 'uniform':
                    num_heun_step = np.random.randint(1, 1+self.args.num_heun_step)
                elif self.args.heun_step_strategy == 'weighted':
                    p = np.array([i ** self.args.heun_step_multiplier for i in range(1, 1+self.args.num_heun_step)])
                    p = p / sum(p)
                    num_heun_step = np.random.choice([i+1 for i in range(len(p))], size=1, p=p)[0]
            # else:
            #    if self.args.time_continuous:
            #        num_heun_step = np.random.rand() / self.args.d_learning_period +\
            #                        (self.args.d_learning_period - 1) / self.args.d_learning_period
            #    else:
                    #num_heun_step = np.random.randint((self.args.d_learning_period - 1) * self.args.start_scales //
                    #                                  self.args.d_learning_period, 1+self.args.num_heun_step)
            #        num_heun_step = self.args.num_heun_step
        else:
            if self.args.time_continuous:
                num_heun_step = self.args.num_heun_step / self.args.start_scales
            else:
                num_heun_step = self.args.num_heun_step
        return num_heun_step

    @th.no_grad()
    def heun_solver(self, x, ind, teacher_model, dims, cond, num_step=1):
        for k in range(num_step):
            t = self.get_t(ind + k) 
            denoiser = self.denoise_fn(teacher_model, x, t, cond=cond, s=None, ctm=False, teacher=True) # D_{\theta}
            d = (x - denoiser) / append_dims(t, dims)
            
            t2 = self.get_t(ind + k + 1) 
            x_phi_ODE_1st = x + d * append_dims(t2 - t, dims) 
            denoiser_2 = self.denoise_fn(teacher_model, x_phi_ODE_1st, t2, cond=cond, s=None, ctm=False, teacher=True)            
            next_d = (x_phi_ODE_1st - denoiser_2) / append_dims(t2, dims)
            x_phi_ODE_2nd = x + (d + next_d) * append_dims((t2 - t) / 2, dims)
            x = x_phi_ODE_2nd
        return x

    @th.no_grad()
    def heun_solver_cfg(self, x, ind, guidance_scale, teacher_model, dims, cond, num_step=1):
        for k in range(num_step):
            t = self.get_t(ind + k)
            t_in = th.cat([t] * 2)
            model_input = th.cat([x] * 2) 
            cond_cfg = cond + ([""] * len(cond))
            denoiser = self.denoise_fn(teacher_model, model_input, t_in, cond=cond_cfg, s=None, ctm=False, teacher=True) # D_{\theta}
            denoised_text, denoised_uncond = denoiser.chunk(2)
            denoised = denoised_uncond + append_dims(guidance_scale, dims) * (denoised_text - denoised_uncond)
            d = (x - denoised) / append_dims(t, dims)
            
            t2 = self.get_t(ind + k + 1) 
            t2_in = th.cat([t2] * 2)
            x_phi_ODE_1st = x + d * append_dims(t2 - t, dims) 
            model_input_2 = th.cat([x_phi_ODE_1st] * 2)

            denoiser_2 = self.denoise_fn(teacher_model, model_input_2, t2_in, cond=cond_cfg, s=None, ctm=False, teacher=True)
            denoised_text, denoised_uncond = denoiser_2.chunk(2)
            denoised_2 = denoised_uncond + append_dims(guidance_scale, dims) * (denoised_text - denoised_uncond)
            
                        
            next_d = (x_phi_ODE_1st - denoised_2) / append_dims(t2, dims)
            x_phi_ODE_2nd = x + (d + next_d) * append_dims((t2 - t) / 2, dims)
            x = x_phi_ODE_2nd
        return x

    def get_gan_estimate(self, estimate, step, x_t, t, t_dt, s, model, target_model, ctm, cond):
        if self.args.gan_estimate_type == 'consistency':
            # NOTE: If we use different timestep for gan, then use here.
            estimate = self.denoise_fn(model, x_t, t, cond=cond, s=th.ones_like(s) * self.args.sigma_min, ctm=ctm)
        elif self.args.gan_estimate_type == 'enable_grad':
            if self.args.auxiliary_type == 'enable_grad':
                estimate = estimate
            else:
                estimate = self.get_estimate(step, x_t, t, t_dt, s, model, target_model, ctm=ctm,
                                             auxiliary_type='enable_grad')
        elif self.args.gan_estimate_type == 'only_high_freq':
            estimate = self.get_estimate(step, x_t, t, t_dt, s, model, target_model, ctm=ctm,
                                         type='stop_grad', auxiliary_type='enable_grad')
        elif self.args.gan_estimate_type == 'same':
            estimate = estimate
        return estimate

    def get_estimate(self, step, x_t, t, t_dt, s, model, target_model, ctm, cfg=None, cond=None, type=None, auxiliary_type=None):
        print('cfg', cfg)
        distiller = self.denoise_fn(model, x_t, t, cond=cond, s=s, ctm=ctm, cfg=cfg)
        if self.args.match_point == 'zs':
            return distiller
        else:
            distiller = self.denoise_fn(target_model, distiller, s, cond=cond, s=th.ones_like(s) * self.args.sigma_min, ctm=ctm, cfg=cfg)
            return distiller

    @th.no_grad()
    def get_target(self, step, x_t_dt, t_dt, s, model, target_model, ctm, cond, cfg=None):
        with th.no_grad():
            distiller_target = self.denoise_fn(target_model, x_t_dt, t_dt, cond=cond, s=s, ctm=ctm, cfg=cfg)
            if self.args.match_point == 'zs':
                return distiller_target.detach()
            else:
                distiller_target = self.denoise_fn(target_model, distiller_target, s, cond=cond, s=th.ones_like(s) * self.args.sigma_min, ctm=ctm, cfg=cfg)
                return distiller_target.detach()

    def denoise_fn(self, model, x, t, cond, s, ctm=False, cfg=None, teacher=False):

        return self.denoise(model, x, t, cond=cond, s=s, ctm=ctm, teacher=teacher, cfg=cfg)[1]

    def denoise(self, model, x_t, t, cond=None, s=None, ctm=False, teacher=False, cfg=None):
        # NOTE: We don't need rescaling stuff if the model trained with EDM's VESDE
        # rescaled_t = self.rescaling_t(t)
        # if s != None:
        #     rescaled_s = self.rescaling_t(s)
        # else:
        #     rescaled_s = None
            

        c_in = append_dims(self.get_c_in(t), x_t.ndim)
        model_output = model(c_in * x_t, t, prompt=cond, s_tinmesteps=s, teacher=teacher, cfg=cfg)
        if ctm:
            if self.args.inner_parametrization == 'edm':
                c_skip, c_out = [
                    append_dims(x, x_t.ndim)
                    for x in self.get_scalings(t)
                ]

                model_output = c_out * model_output + c_skip * x_t # g_{\theta}, Same as EDM's eq.(7)
            elif self.args.inner_parametrization == 'scale':
                c_skip, c_out = [
                    append_dims(x, x_t.ndim)
                    for x in self.get_scalings_t(t, s)
                ]
                #print("c_skip, c_out: ", c_skip.reshape(-1), c_out.reshape(-1))
                model_output = c_out * model_output + c_skip * x_t
            elif self.args.inner_parametrization == 'no':
                model_output = model_output
            if teacher:
                if self.args.parametrization.lower() == 'euler': # NOTE: Normally, do here.
                    denoised = model_output
                elif self.args.parametrization.lower() == 'variance':
                    denoised = model_output + append_dims((self.args.sigma_min ** 2 + self.args.sigma_data ** 2
                                                        - self.args.sigma_min * t) / \
                            ((t - self.args.sigma_min) ** 2 + self.args.sigma_data ** 2), x_t.ndim) * x_t
                elif self.args.parametrization.lower() == 'euler_variance_mixed':
                    denoised = model_output + x_t - append_dims(t / (t + 1.) * (1. + (t - self.args.sigma_min) /
                                                                        ((t - self.args.sigma_min) ** 2 + self.args.sigma_data ** 2)), x_t.ndim) * x_t
                return model_output, denoised
            else:
                assert s != None
                c_skip, c_out = [
                    append_dims(x, x_t.ndim)
                    for x in self.get_scalings_for_generalized_boundary_condition(t, s, )
                ]
                denoised = c_out * model_output + c_skip * x_t # G_{\theta} Last eq of Lemma 1 on CTM paper.
        else:
            if teacher:
                c_skip, c_out = [
                    append_dims(x, x_t.ndim) for x in self.get_scalings(t)
                ]
            else:
                c_skip, c_out = [
                    append_dims(x, x_t.ndim)
                    for x in self.get_scalings_for_boundary_condition(t) # CM's boundary condition when smallestr time instant. See Appendix.C of CM paper.
                ]
            denoised = c_out * model_output + c_skip * x_t
        
        return model_output, denoised

    def get_consistency_loss(self, estimate, target, weights, loss_domain='latent', loss_norm='', teacher_model=None, s=None, prompt=None):
        
        estimate_out = estimate
        target_out = target

        if loss_norm == 'l2':
            if loss_domain == 'latent':
                consistency_loss = weights * mean_flat((estimate_out - target_out) ** 2)
            
            elif loss_domain == 'mel':
                raise NotImplementedError
            elif loss_domain == 'waveform':
                raise NotImplementedError

                
        elif loss_norm == 'l1': 
            if loss_domain == 'latent':
                consistency_loss = weights * mean_flat(th.abs(estimate_out - target_out)) 
            
            elif loss_domain == 'mel':
                raise NotImplementedError

            elif loss_domain == 'waveform':
                raise NotImplementedError

        elif loss_norm == 'ictm': # Psuedo-Huber loss 
            if loss_domain == 'latent':
                c = 0.00054 * th.sqrt(th.tensor(self.args.latent_channels*self.args.latent_f_size*self.args.latent_t_size))
                consistency_loss = weights * mean_flat(th.sqrt((estimate_out - target_out) ** 2 + c ** 2) - c) 
            
            elif loss_domain == 'mel':
                raise NotImplementedError

                
            elif loss_domain == 'waveform':
                raise NotImplementedError
        elif loss_norm == 'feature_space':
            if self.args.match_point == 'z0':
                c_in = append_dims(self.get_c_in(th.ones_like(s) * self.args.sigma_min), estimate.ndim)
                estimate_out = teacher_model.extract_feature_space(estimate * c_in, timesteps=th.ones_like(s) * self.args.sigma_min, prompt=prompt, unet_mode = self.args.unet_mode)
                target_out = teacher_model.extract_feature_space(target * c_in, timesteps=th.ones_like(s) * self.args.sigma_min, prompt=prompt, unet_mode = self.args.unet_mode)
                
            elif self.args.match_point == 'zs':
                print("s", s)
                c_in = append_dims(self.get_c_in(s), estimate.ndim)
                estimate_out = teacher_model.extract_feature_space(estimate * c_in, timesteps=s, prompt=prompt, unet_mode = self.args.unet_mode)
                target_out = teacher_model.extract_feature_space(target * c_in, timesteps=s, prompt=prompt, unet_mode = self.args.unet_mode)
            
            consistency_loss = 0.
            k = 0
            if self.args.loss_distance == 'l2':
                for est_feature, tgt_feature in zip(estimate_out, target_out):
                    if th.isnan(mean_flat((normalize_tensor(est_feature) - normalize_tensor(tgt_feature)) ** 2).mean()): 
                        consistency_loss += self.null(mean_flat((normalize_tensor(est_feature) - normalize_tensor(tgt_feature)) ** 2))
                    else:
                        consistency_loss += mean_flat((normalize_tensor(est_feature) - normalize_tensor(tgt_feature)) ** 2)
                        # consistency_loss += mean_flat((est_feature - tgt_feature) ** 2)
                    k += 1
                    print("consistency_loss_{}".format(k), consistency_loss.mean())
            elif self.args.loss_distance == 'l1':
                for est_feature, tgt_feature in zip(estimate_out, target_out):
                    if th.isnan(F.l1_loss(normalize_tensor(est_feature), normalize_tensor(tgt_feature)).mean()): 
                        consistency_loss += self.null(F.l1_loss(normalize_tensor(est_feature), normalize_tensor(tgt_feature)).mean())
                    else:
                        consistency_loss += F.l1_loss(normalize_tensor(est_feature), normalize_tensor(tgt_feature))
                        # consistency_loss += F.l1_loss(est_feature, tgt_feature)
            consistency_loss = weights * consistency_loss
            
        else:
            raise NotImplementedError

        if th.isnan(consistency_loss.mean()):
            consistency_loss = self.null(consistency_loss)
        return consistency_loss

    def get_denoising_loss(self, model, x_start, consistency_loss, step, cond, loss_target):
        dsm_null_flg = False
        sigmas, denoising_weights = self.diffusion_schedule_sampler.sample(x_start.shape[0], x_start.device)
        #print("diffusion sigmas: ", sigmas)
        noise = th.randn_like(x_start)
        dims = x_start.ndim
        x_t = x_start + noise * append_dims(sigmas, dims)
        if self.args.unform_sampled_cfg_distill:
            model_estimate = self.denoise(model, x_t, sigmas, cond=cond, s=sigmas, ctm=True, teacher=False, cfg=self.sampled_cfg)[0] # g_{\theta}(z_t, cond, t, t, omega)
        else:
            model_estimate = self.denoise(model, x_t, sigmas, cond=cond, s=sigmas, ctm=True, teacher=False)[0] # g_{\theta}(z_t, cond, t, t)
        snrs = self.get_snr(sigmas)
        denoising_weights = append_dims(get_weightings(self.args.diffusion_weight_schedule, snrs, self.args.sigma_data, None, None), dims)
        # denoising_loss = mean_flat(denoising_weights * (model_estimate - x_start) ** 2)
        denoising_loss = mean_flat(denoising_weights * (model_estimate - loss_target) ** 2)
        if th.isnan(denoising_loss.mean()): 
            denoising_loss = self.null(denoising_loss)
            dsm_null_flg = True
        if not dsm_null_flg:
            if self.args.apply_adaptive_weight:
                try:
                    balance_weight = self.calculate_adaptive_weight(consistency_loss.mean(), denoising_loss.mean(),
                                                            last_layer=model.ctm_unet.conv_out.weight)
                except:
                    balance_weight = self.calculate_adaptive_weight(consistency_loss.mean(), denoising_loss.mean(),
                                                            last_layer=model.module.ctm_unet.conv_out.weight)
        else:
            balance_weight = 1.

        balance_weight = self.adopt_weight(balance_weight, step, threshold=0, value=1.)
        denoising_loss = denoising_loss * balance_weight
        return denoising_loss

    def get_discriminator_loss(self, estimate=None, target=None, discriminator=None, step=0):
        #print("loss_norm: ", self.args.loss_norm, learn_generator, self.args.d_architecture, self.args.discriminator_fix)


        d_real = discriminator(target.unsqueeze(1))
        d_fake = discriminator(estimate.unsqueeze(1).clone().detach())
        loss_d = 0
        k=0
        for x_fake, x_real in zip(d_fake, d_real):
            if th.isnan(th.mean(x_fake[-1] ** 2)): 
                loss_d += self.null(th.mean(x_fake[-1] ** 2))
            else:
                loss_d += th.mean(x_fake[-1] ** 2)
            # loss_d += th.mean(x_fake[-1] ** 2)
            if th.isnan(th.mean((1 - x_real[-1]) ** 2)): 
                loss_d += self.null(th.mean((1 - x_real[-1]) ** 2))
            else:
                loss_d += th.mean((1 - x_real[-1]) ** 2)
            # loss_d += th.mean((1 - x_real[-1]) ** 2)
            # print(f'global_step_{step}_loss_d_k_{k}_{loss_d.mean()}_device_{target.device}')
            k += 1
        discriminator_loss = self.adopt_weight(1.0, step, threshold=self.args.discriminator_start_itr) * loss_d
        if self.args.r1_reg_enable:
            reg = 0
            # j_gamma = 0
            target.requires_grad = True
            d_real = discriminator(target.unsqueeze(1))
            for x_real in d_real:
                reg_grad = th.autograd.grad(x_real[-1].mean(), target, retain_graph=True)[0]
                reg += reg_grad * self.args.reg_gamma
                # print(reg_grad)
                # for j in range(len(x_real)-1):
                #     j_gamma = x_real[j]**2
                # gamma = th.mean(j_gamma) * self.args.reg_gamma
                # print(gamma)
                # reg += gamma * th.mean(reg_grad.pow(2)) 
            r1_reg_term = reg.mean() 
                           
        
        return discriminator_loss
    def get_conditional_discriminator_loss(self, model, cond=None, estimate=None, target=None, discriminator=None, step=0):

        d_real = discriminator(target.unsqueeze(1), cond, model)
        d_fake = discriminator(estimate.unsqueeze(1).clone().detach(), cond, model)
        loss_d = 0
        k=0
        for x_fake, x_real in zip(d_fake, d_real):
            if th.isnan(th.mean(x_fake[-1] ** 2)):
                loss_d += self.null(th.mean(x_fake[-1] ** 2))
            else:
                loss_d += th.mean(x_fake[-1] ** 2)
            # loss_d += th.mean(x_fake[-1] ** 2)
            if th.isnan(th.mean((1 - x_real[-1]) ** 2)):
                loss_d += self.null(th.mean((1 - x_real[-1]) ** 2))
            else:
                loss_d += th.mean((1 - x_real[-1]) ** 2)
            # loss_d += th.mean((1 - x_real[-1]) ** 2)
            # print(f'global_step_{step}_loss_d_k_{k}_{loss_d.mean()}_device_{target.device}')
            k += 1
        # print(f'global_step_{step}_loss_d_{loss_d.mean()}_device_{estimate.device}')
        discriminator_loss = self.adopt_weight(1.0, step, threshold=self.args.discriminator_start_itr) * loss_d
        if self.args.r1_reg_enable:
            reg = 0
            # j_gamma = 0
            target.requires_grad = True
            d_real = discriminator(target.unsqueeze(1), cond, model)
            for x_real in d_real:
                reg_grad = th.autograd.grad(x_real[-1].mean(), target, retain_graph=True)[0]
                reg += reg_grad * self.args.reg_gamma
                # print(reg_grad)
                # for j in range(len(x_real)-1):
                #     j_gamma = x_real[j]**2
                # gamma = th.mean(j_gamma) * self.args.reg_gamma
                # print(gamma)
                # reg += gamma * th.mean(reg_grad.pow(2)) 
            r1_reg_term = reg.mean() 
                        
        
        return discriminator_loss


    def get_latent_vqgan_discriminator_loss(self, estimate=None, target=None, discriminator=None, step=0):
        #print("loss_norm: ", self.args.loss_norm, learn_generator, self.args.d_architecture, self.args.discriminator_fix)

        d_real = discriminator(target.detach())
        d_fake = discriminator(estimate.clone().detach())
        if th.isnan(hinge_d_loss(d_real, d_fake)):
            loss_d = self.null(hinge_d_loss(d_real, d_fake))
        else:
            loss_d = hinge_d_loss(d_real, d_fake)
        
        
        discriminator_loss = self.adopt_weight(1.0, step, threshold=self.args.discriminator_start_itr) * loss_d
        if self.args.r1_reg_enable:
            reg = 0
            # j_gamma = 0
            target.requires_grad = True
            d_real = discriminator(target.unsqueeze(1))
            for x_real in d_real:
                reg_grad = th.autograd.grad(x_real[-1].mean(), target, retain_graph=True)[0]
                reg += reg_grad * self.args.reg_gamma
                # print(reg_grad)
                # for j in range(len(x_real)-1):
                #     j_gamma = x_real[j]**2
                # gamma = th.mean(j_gamma) * self.args.reg_gamma
                # print(gamma)
                # reg += gamma * th.mean(reg_grad.pow(2)) 
            r1_reg_term = reg.mean() 
                           
        return discriminator_loss

    def get_latent_cvqgan_discriminator_loss(self, model, cond=None, estimate=None, target=None, discriminator=None, step=0):
        #print("loss_norm: ", self.args.loss_norm, learn_generator, self.args.d_architecture, self.args.discriminator_fix)

        d_real = discriminator(target.detach(), cond, model)
        d_fake = discriminator(estimate.clone().detach(), cond, model)
        if th.isnan(hinge_d_loss(d_real, d_fake)):
            loss_d = self.null(hinge_d_loss(d_real, d_fake))
        else:
            loss_d = hinge_d_loss(d_real, d_fake)
        
        
        discriminator_loss = self.adopt_weight(1.0, step, threshold=self.args.discriminator_start_itr) * loss_d
        if self.args.r1_reg_enable:
            reg = 0
            # j_gamma = 0
            target.requires_grad = True
            d_real = discriminator(target.unsqueeze(1))
            for x_real in d_real:
                reg_grad = th.autograd.grad(x_real[-1].mean(), target, retain_graph=True)[0]
                reg += reg_grad * self.args.reg_gamma
                # print(reg_grad)
                # for j in range(len(x_real)-1):
                #     j_gamma = x_real[j]**2
                # gamma = th.mean(j_gamma) * self.args.reg_gamma
                # print(gamma)
                # reg += gamma * th.mean(reg_grad.pow(2)) 
            r1_reg_term = reg.mean() 
                           
        # print('discriminator_loss', discriminator_loss)
        return discriminator_loss

    
    def get_generator_loss(self, model, estimate=None, target=None, consistency_loss=None, discriminator=None,
                               step=0):

        # Borrowed from DAC's codes https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/loss.py
        d_real = discriminator(target.unsqueeze(1))
        d_fake = discriminator(estimate.unsqueeze(1))
        loss_g = 0
        loss_feature = 0
        k=0
        g_null_flg = False
        fm_null_flg = False
        for x_fake in d_fake:
            if th.isnan(th.mean((1 - x_fake[-1]) ** 2)):
                loss_g += self.null(th.mean((1 - x_fake[-1]) ** 2))
                g_null_flg = True
            else:
                loss_g += th.mean((1 - x_fake[-1]) ** 2)
            # print(f'global_step_{step}_loss_g_k_{k}_{loss_g.mean()}_device_{estimate.device}')
            k += 1
            # loss_g += th.mean((1 - x_fake[-1]) ** 2)
            
            
        for i in range(len(d_fake)):
            for j in range(len(d_fake[i]) - 1):
                if th.isnan(F.l1_loss(d_fake[i][j], d_real[i][j].detach()).mean()): 
                    loss_feature += self.null(F.l1_loss(d_fake[i][j], d_real[i][j].detach()))
                    fm_null_flg = True
                else:
                    loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())

        
        if not g_null_flg:
            if self.args.d_apply_adaptive_weight:
                try:
                    balance_weight = self.calculate_adaptive_weight(consistency_loss.mean(), loss_g.mean(),
                                                            last_layer=model.ctm_unet.conv_out.weight)
                except:
                    balance_weight = self.calculate_adaptive_weight(consistency_loss.mean(), loss_g.mean(),
                                                            last_layer=model.module.ctm_unet.conv_out.weight)
            g_weight = balance_weight
        else:
            g_weight = 1.
            
        if not fm_null_flg:
            if self.args.fm_apply_adaptive_weight:
                try:
                    balance_weight = self.calculate_adaptive_weight(consistency_loss.mean(), loss_feature.mean(),
                                                            last_layer=model.ctm_unet.conv_out.weight)
                except:
                    balance_weight = self.calculate_adaptive_weight(consistency_loss.mean(), loss_feature.mean(),
                                                            last_layer=model.module.ctm_unet.conv_out.weight)

            fm_weight = balance_weight
        else:
            fm_weight = 1.

        gen_loss = self.adopt_weight(g_weight, step, threshold=self.args.discriminator_start_itr) * loss_g # generator loss
        loss_feature = self.adopt_weight(fm_weight, step, threshold=self.args.discriminator_start_itr) * loss_feature # fm loss      
        
        return gen_loss, loss_feature

    def get_conditional_generator_loss(self, model, cond=None, estimate=None, target=None, consistency_loss=None, discriminator=None, step=0,):
        # Borrowed from DAC's codes https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/loss.py
        d_real = discriminator(target.unsqueeze(1), cond, model)
        d_fake = discriminator(estimate.unsqueeze(1), cond, model)
        loss_g = 0
        loss_feature = 0
        k=0
        g_null_flg = False
        fm_null_flg = False
        for x_fake in d_fake:
            if th.isnan(th.mean((1 - x_fake[-1]) ** 2)): 
                loss_g += self.null(th.mean((1 - x_fake[-1]) ** 2))
                g_null_flg = True
            else:
                loss_g += th.mean((1 - x_fake[-1]) ** 2)
            # loss_g += th.mean((1 - x_fake[-1]) ** 2)
            # print(f'global_step_{step}_loss_g_k_{k}_{loss_g.mean()}_device_{estimate.device}')
            k += 1


        for i in range(len(d_fake)):
            for j in range(len(d_fake[i]) - 1):
                if th.isnan(F.l1_loss(d_fake[i][j], d_real[i][j].detach()).mean()): 
                    loss_feature += self.null(F.l1_loss(d_fake[i][j], d_real[i][j].detach()))
                    fm_null_flg = True
                else:
                    loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
        
        if not g_null_flg:
            if self.args.d_apply_adaptive_weight:
                try:
                    balance_weight = self.calculate_adaptive_weight(consistency_loss.mean(), loss_g.mean(),
                                                            last_layer=model.ctm_unet.conv_out.weight)
                except:
                    balance_weight = self.calculate_adaptive_weight(consistency_loss.mean(), loss_g.mean(),
                                                            last_layer=model.module.ctm_unet.conv_out.weight)
            g_weight = balance_weight
        else:
            g_weight = 1.
        if not fm_null_flg:
            if self.args.fm_apply_adaptive_weight:
                try:
                    balance_weight = self.calculate_adaptive_weight(consistency_loss.mean(), loss_feature.mean(),
                                                            last_layer=model.ctm_unet.conv_out.weight)
                except:
                    balance_weight = self.calculate_adaptive_weight(consistency_loss.mean(), loss_feature.mean(),
                                                            last_layer=model.module.ctm_unet.conv_out.weight)
            fm_weight = balance_weight
        else:
            fm_weight = 1.
        gen_loss = self.adopt_weight(g_weight, step, threshold=self.args.discriminator_start_itr) * loss_g # generator loss
        loss_feature = self.adopt_weight(fm_weight, step, threshold=self.args.discriminator_start_itr) * loss_feature # fm loss      
        
        return gen_loss, loss_feature

    def get_latent_vqgan_generator_loss(self, model, estimate=None, target=None, consistency_loss=None, discriminator=None, step=0):

        # Borrowed from DAC's codes https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/loss.py
        # d_real = discriminator(target)
        d_fake = discriminator(estimate)
        g_null_flg = False
        
        if th.isnan(th.mean(d_fake)):
            loss_g = self.null(th.mean(d_fake))
            g_null_flg = True
        else:
            loss_g = -th.mean(d_fake)
        
        
        if not g_null_flg:
            if self.args.d_apply_adaptive_weight:
                try:
                    balance_weight = self.calculate_adaptive_weight(consistency_loss.mean(), loss_g.mean(),
                                                            last_layer=model.ctm_unet.conv_out.weight)
                except:
                    balance_weight = self.calculate_adaptive_weight(consistency_loss.mean(), loss_g.mean(),
                                                            last_layer=model.module.ctm_unet.conv_out.weight)
            g_weight = balance_weight
        else:
            g_weight = 1.

        gen_loss = self.adopt_weight(g_weight, step, threshold=self.args.discriminator_start_itr) * loss_g # generator loss
        # loss_feature = self.adopt_weight(fm_weight, step, threshold=self.args.discriminator_start_itr) * loss_feature # fm loss
        loss_feature = th.zeros_like(gen_loss, device=gen_loss.device)

        return gen_loss, loss_feature


    def get_latent_cvqgan_generator_loss(
        self, 
        model, 
        estimate=None, 
        target=None, 
        cond=None, 
        consistency_loss=None, 
        discriminator=None,
        step=0):

        # Borrowed from DAC's codes https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/loss.py
        # d_real = discriminator(target)
        d_fake = discriminator(estimate, cond, model)
        g_null_flg = False
        
        if th.isnan(th.mean(d_fake)):
            loss_g = self.null(th.mean(d_fake))
            g_null_flg = True
        else:
            loss_g = -th.mean(d_fake)
        
        
        if not g_null_flg:
            if self.args.d_apply_adaptive_weight:
                try:
                    balance_weight = self.calculate_adaptive_weight(consistency_loss.mean(), loss_g.mean(),
                                                            last_layer=model.ctm_unet.conv_out.weight)
                except:
                    balance_weight = self.calculate_adaptive_weight(consistency_loss.mean(), loss_g.mean(),
                                                            last_layer=model.module.ctm_unet.conv_out.weight)
            g_weight = balance_weight
        else:
            g_weight = 1.
            
        gen_loss = self.adopt_weight(g_weight, step, threshold=self.args.discriminator_start_itr) * loss_g # generator loss
        # loss_feature = self.adopt_weight(fm_weight, step, threshold=self.args.discriminator_start_itr) * loss_feature # fm loss
        loss_feature = th.zeros_like(gen_loss, device=gen_loss.device)
        # print('gen_loss', gen_loss)
        return gen_loss, loss_feature

    def check_isnan(self, loss):
        if th.isnan(loss.mean()):
            loss = th.zeros_like(loss)
            loss.requires_grad_(True)
        return loss
    
    def null(self, x_start):
        loss = th.zeros_like(x_start, device=x_start.device)
        loss.requires_grad_(True)
        return loss
    
    
    def get_samples(
        self,
        step,
        model,
        wavs,
        cond=None,
        model_kwargs=None,
        target_model=None,
        teacher_model=None,
        stage1_model=None,
        stft=None,
        accelerator=None,
        noise=None,
        # init_step=0,
        ctm=True,
    ):
        
        # Prepare latent representation of mel through stage1 model
        target_length = int(self.args.duration * 102.4) 
        with th.no_grad():
            # mel, _, waveform = torch_tools.wav_to_fbank(wavs, target_length, stft).to(dtype=self.args.weight_dtype) # mel: [batch, 1024, 64], waveform: [batch, 163840]
            mel, _, waveform = torch_tools.wav_to_fbank(wavs, target_length, stft) # mel: [batch, 1024, 64], waveform: [batch, 163840]
            mel = mel.unsqueeze(1).to(accelerator.device)
            waveform = waveform.to(accelerator.device)
            prompt = list(cond)
            # TODO: Implement data augmentation of TANGO later.
            if self.args.tango_data_augment and len(cond) > 1:
                # the last batch of the training data may have only one instance
                # we check the length here so that the augmentation function doesn't throw an error
                mixed_mel, _, mixed_waveform, mixed_captions = torch_tools.augment_wav_to_fbank(wavs, cond, self.args.augment_num, target_length, stft)
                mixed_mel = mixed_mel.unsqueeze(1).to(accelerator.device)
                mixed_waveform = mixed_waveform.to(accelerator.device)
                mel = th.cat([mel, mixed_mel], 0)
                waveform = th.cat([waveform, mixed_waveform], 0)
                prompt += mixed_captions
            x_start = stage1_model.get_first_stage_encoding(stage1_model.encode_first_stage(mel)) # z_{0} [batch, 8, 256, 16]
        
        th.cuda.empty_cache()
        
        if noise is None:
            noise = th.randn_like(x_start)
        dims = x_start.ndim
        s = None
        terms = {}
        assert self.args.consistency_weight > 0.
        num_heun_step = [self.get_num_heun_step(step)] 
        num_heun_step = num_heun_step[0]

        indices, _ = self.schedule_sampler.sample_t(x_start.shape[0], x_start.device, num_heun_step, self.args.time_continuous)
        t = self.get_t(indices)
        t_dt = self.get_t(indices + num_heun_step)
        if ctm:
            new_indices = self.schedule_sampler.sample_s(self.args, x_start.shape[0], x_start.device, indices,
                                                        num_heun_step, self.args.time_continuous,
                                                        N=self.args.start_scales)
            s = self.get_t(new_indices)
        x_t = x_start + noise * append_dims(t, dims) # z_t
        if self.args.unform_sampled_cfg_distill:
            self.sampled_cfg = (self.args.w_max - self.args.w_min) * th.rand((noise.shape[0],), device=accelerator.device) + self.args.w_min
            estimate = self.get_estimate(step, x_t, t, t_dt, s, model, target_model, ctm=ctm, cond=prompt, cfg=self.sampled_cfg)
        else:
            estimate = self.get_estimate(step, x_t, t, t_dt, s, model, target_model, ctm=ctm, cond=prompt)
            # z_{est}(z_{t}, cond, t, s) 
            # = G_{\sg(\theta)}(G_{\theta}(z_t, cond, t, s), cond, t=s, s=sigma_min)
        th.cuda.empty_cache()

        if teacher_model:
            if self.args.cfg_distill:
                x_t_dt = self.heun_solver_cfg(x_t, indices, self.args.target_cfg, teacher_model, dims, cond=prompt, num_step=num_heun_step)
                # Solver(z_t, cond, t, u, \omega; \phi)
            
            elif self.args.unform_sampled_cfg_distill:
                x_t_dt = self.heun_solver_cfg(x_t, indices, self.sampled_cfg, teacher_model, dims, cond=prompt, num_step=num_heun_step)
            
            else:
                x_t_dt = self.heun_solver(x_t, indices, teacher_model, dims, cond=prompt, num_step=num_heun_step)
                # Solver(z_t, cond, t, u; \phi)
        
        else:
            with th.no_grad():
                x_t_dt = self.denoise_fn(target_model, x_t, t, cond=prompt, s=t_dt, ctm=ctm) # NOTE: we don't use this so far.
        if self.args.unform_sampled_cfg_distill:
            target = self.get_target(step, x_t_dt, t_dt, s, model, target_model, ctm=ctm, cond=prompt, cfg=self.sampled_cfg) 
        
        else:
            target = self.get_target(step, x_t_dt, t_dt, s, model, target_model, ctm=ctm, cond=prompt) 
            # z_{target}(z_{t}, cond, t, t_dt, s)

        th.cuda.empty_cache()
        return estimate, target, x_start, mel, waveform, prompt, t, s
    
    def get_gen_loss(
        self,
        step,
        model,
        estimate,
        target,
        x_start,
        mel,
        waveform,
        prompt,
        t,
        s,
        teacher_model,
        stage1_model,
        accelerator,
        discriminator,
        model_kwargs,
    ):
        terms = {}
        snrs = self.get_snr(t)
        weights = get_weightings(self.args.weight_schedule, snrs, self.args.sigma_data, t, s)
        terms["consistency_loss"] = self.get_consistency_loss(estimate, target, weights, 
                                                                loss_domain=self.args.loss_domain, 
                                                                loss_norm=self.args.loss_norm, 
                                                                teacher_model=teacher_model,
                                                                s=s, prompt=prompt)
        th.cuda.empty_cache()
        if self.args.diffusion_training:
            if self.args.dsm_loss_target == 'z_0':
                terms['denoising_loss'] = self.get_denoising_loss(model, x_start,
                                                                terms["consistency_loss"],
                                                                step, cond=prompt,
                                                                loss_target=x_start)
            elif self.args.dsm_loss_target == 'z_target': # NOTE: Bassically, we don't use this anymore.
                terms['denoising_loss'] = self.get_denoising_loss(model, x_start,
                                                                terms["consistency_loss"],
                                                                step, cond=prompt,
                                                                loss_target=target)
        th.cuda.empty_cache()
        # if self.args.discriminator_training and step - init_step >= self.args.discriminator_start_itr:
        if self.args.discriminator_training:
            if step >= self.args.discriminator_start_itr:
                gan_estimate = estimate
                # breakpoint() # 40768MiB
                th.cuda.empty_cache()
                if self.args.d_architecture in ['DAC_GAN', 'DAC_SAN']:
                    estimated_mel = stage1_model.decode_first_stage(gan_estimate) # convert latent to mel
                    estimated_waveform = stage1_model.decode_to_waveform(estimated_mel) # convert mel to waveform
                    
                    if self.args.gan_target == 'z_target':
                        target_mel = stage1_model.decode_first_stage(target) # convert latent to mel
                        target_waveform = stage1_model.decode_to_waveform(target_mel) # convert mel to waveform
                        
                    elif self.args.gan_target == 'z_0':
                        target_waveform = waveform.to(accelerator.device)
                        length_min = min(target_waveform.shape[-1], estimated_waveform.shape[-1])
                        target_waveform = target_waveform[:, :length_min]
                        estimated_waveform = estimated_waveform[:, :length_min]

                    # breakpoint() 41450MiB
                    terms['g_loss'], terms['fm_loss'] = self.get_generator_loss(
                        model, estimate=estimated_waveform, 
                        target=target_waveform, 
                        consistency_loss=terms["consistency_loss"], 
                        discriminator=discriminator, 
                        step=step,
                    )
                    th.cuda.empty_cache()
                elif self.args.d_architecture in ['DAC_CGAN', 'DAC_CSAN']:
                    estimated_mel = stage1_model.decode_first_stage(gan_estimate) # convert latent to mel
                    estimated_waveform = stage1_model.decode_to_waveform(estimated_mel) # convert mel to waveform
                    
                    if self.args.gan_target == 'z_target': # NOTE: Bassically, we don't use this anymore.
                        target_mel = stage1_model.decode_first_stage(target) # convert latent to mel
                        target_waveform = stage1_model.decode_to_waveform(target_mel) # convert mel to waveform
                        
                    elif self.args.gan_target == 'z_0':
                        target_waveform = waveform.to(accelerator.device)
                        length_min = min(target_waveform.shape[-1], estimated_waveform.shape[-1])
                        target_waveform = target_waveform[:, :length_min]
                        estimated_waveform = estimated_waveform[:, :length_min]

                    terms['g_loss'], terms['fm_loss'] = self.get_conditional_generator_loss( 
                                                                model, cond=prompt, 
                                                                estimate=estimated_waveform,
                                                                target=target_waveform,
                                                                consistency_loss=terms["consistency_loss"],
                                                                discriminator=discriminator,
                                                                step=step)
                    th.cuda.empty_cache()
                    
                elif self.args.d_architecture in ['MEL_VQGAN', 'MBDisc']:
                    estimated_mel = stage1_model.decode_first_stage(gan_estimate) # convert latent to mel
                    
                    if self.args.gan_target == 'z_target': # NOTE: Bassically, we don't use this anymore.
                        target_mel = stage1_model.decode_first_stage(target) # convert latent to mel
                    elif self.args.gan_target == 'z_0':
                        target_mel = mel.to(accelerator.device)
                    
                    terms['g_loss'], terms['fm_loss'] = self.get_latent_vqgan_generator_loss(
                        model, estimate=estimated_mel, 
                        target=target_mel, 
                        consistency_loss=terms["consistency_loss"], 
                        discriminator=discriminator, 
                        step=step,
                    )
                    th.cuda.empty_cache()
                
                elif self.args.d_architecture in ['MEL_CVQGAN', 'CMBDisc']:
                    estimated_mel = stage1_model.decode_first_stage(gan_estimate) # convert latent to mel
                    
                    if self.args.gan_target == 'z_target': # NOTE: Bassically, we don't use this anymore.
                        target_mel = stage1_model.decode_first_stage(target) # convert latent to mel
                    elif self.args.gan_target == 'z_0':
                        target_mel = mel.to(accelerator.device)
                    
                    terms['g_loss'], terms['fm_loss'] = self.get_latent_cvqgan_generator_loss(
                        model, estimate=estimated_mel, 
                        target=target_mel,
                        cond=prompt, 
                        consistency_loss=terms["consistency_loss"], 
                        discriminator=discriminator, 
                        step=step,
                    )
                    th.cuda.empty_cache()
                
                
                elif self.args.d_architecture in ['L_VQGAN']:
                    terms['g_loss'], terms['fm_loss'] = self.get_latent_vqgan_generator_loss(
                        model, estimate=gan_estimate, 
                        target=x_start, 
                        consistency_loss=terms["consistency_loss"], 
                        discriminator=discriminator, 
                        step=step,
                    )
                    th.cuda.empty_cache()

                elif self.args.d_architecture in ['L_CVQGAN']:    
                    terms['g_loss'], terms['fm_loss'] = self.get_latent_cvqgan_generator_loss(
                        model, estimate=gan_estimate, 
                        target=x_start,
                        cond=prompt, 
                        consistency_loss=terms["consistency_loss"], 
                        discriminator=discriminator, 
                        step=step,
                    )
                    th.cuda.empty_cache()
        return terms
    
    def get_disc_loss(
        self,
        step,
        model,
        estimate,
        target,
        x_start,
        mel,
        waveform,
        prompt,
        stage1_model,
        accelerator,
        discriminator,
        ):
        terms = {}
        assert self.args.discriminator_training
        if step >= self.args.discriminator_start_itr:
            gan_estimate = estimate
            # breakpoint() 27064MiB
            
            if self.args.d_architecture in ['DAC_GAN', 'DAC_SAN']:
                estimated_mel = stage1_model.decode_first_stage(gan_estimate) # convert latent to mel
                estimated_waveform = stage1_model.decode_to_waveform(estimated_mel) # convert mel to waveform
                        
                if self.args.gan_target == 'z_target':
                    target_mel = stage1_model.decode_first_stage(target) # convert latent to mel
                    target_waveform = stage1_model.decode_to_waveform(target_mel) # convert mel to waveform
                elif self.args.gan_target == 'z_0':
                    target_waveform = waveform.to(accelerator.device)
                    length_min = min(target_waveform.shape[-1], estimated_waveform.shape[-1])
                    target_waveform = target_waveform[:, :length_min]
                    estimated_waveform = estimated_waveform[:, :length_min]
                terms['d_loss'] = self.get_discriminator_loss(estimate=estimated_waveform, 
                                                                target=target_waveform, 
                                                                discriminator=discriminator, step=step)
            elif self.args.d_architecture in ['DAC_CGAN', 'DAC_CSAN']:
                
                estimated_mel = stage1_model.decode_first_stage(gan_estimate) # convert latent to mel
                estimated_waveform = stage1_model.decode_to_waveform(estimated_mel) # convert mel to waveform
                        
                if self.args.gan_target == 'z_target':
                    target_mel = stage1_model.decode_first_stage(target) # convert latent to mel
                    target_waveform = stage1_model.decode_to_waveform(target_mel) # convert mel to waveform
                elif self.args.gan_target == 'z_0':
                    target_waveform = waveform.to(accelerator.device)
                    length_min = min(target_waveform.shape[-1], estimated_waveform.shape[-1])
                    target_waveform = target_waveform[:, :length_min]
                    estimated_waveform = estimated_waveform[:, :length_min]
                
                terms['d_loss'] = self.get_conditional_discriminator_loss(model, prompt, 
                                                                estimate=estimated_waveform, 
                                                                target=target_waveform, 
                                                                discriminator=discriminator, step=step)
            
            elif self.args.d_architecture in ['MEL_VQGAN', 'MBDisc']:
                estimated_mel = stage1_model.decode_first_stage(gan_estimate) # convert latent to mel
                if self.args.gan_target == 'z_target':
                    target_mel = stage1_model.decode_first_stage(target) # convert latent to mel
                elif self.args.gan_target == 'z_0':
                    target_mel = mel.to(accelerator.device)
                    
                terms['d_loss'] = self.get_latent_vqgan_discriminator_loss(
                    estimate=estimated_mel, 
                    target=target_mel, 
                    discriminator=discriminator, 
                    step=step
                    )
            
            elif self.args.d_architecture in ['MEL_CVQGAN', 'CMBDisc']:
                estimated_mel = stage1_model.decode_first_stage(gan_estimate) # convert latent to mel
                if self.args.gan_target == 'z_target':
                    target_mel = stage1_model.decode_first_stage(target) # convert latent to mel
                elif self.args.gan_target == 'z_0':
                    target_mel = mel.to(accelerator.device)
                    
                terms['d_loss'] = self.get_latent_cvqgan_discriminator_loss(
                    model, prompt,
                    estimate=estimated_mel, 
                    target=target_mel, 
                    discriminator=discriminator, 
                    step=step
                    )
            
            
            elif self.args.d_architecture in ['L_VQGAN']:
                terms['d_loss'] = self.get_latent_vqgan_discriminator_loss(
                    estimate=gan_estimate, 
                    target=x_start, 
                    discriminator=discriminator, 
                    step=step
                    )
            elif self.args.d_architecture in ['L_CVQGAN']:
                terms['d_loss'] = self.get_latent_cvqgan_discriminator_loss(
                    model, prompt,
                    estimate=gan_estimate, 
                    target=x_start, 
                    discriminator=discriminator, 
                    step=step
                    )
        return terms      
    
    
    
    def ctm_losses(
        self,
        step,
        model,
        wavs,
        cond=None,
        model_kwargs=None,
        target_model=None,
        teacher_model=None,
        stage1_model=None,
        stft=None,
        accelerator=None,
        noise=None,
        discriminator=None,
        # init_step=0,
        ctm=True,
        gen_backword=False,
    ):
        
        # Prepare latent representation of mel through stage1 model
        target_length = int(self.args.duration * 102.4) 
        with th.no_grad():
            # mel, _, waveform = torch_tools.wav_to_fbank(wavs, target_length, stft).to(dtype=self.args.weight_dtype) # mel: [batch, 1024, 64], waveform: [batch, 163840]
            mel, _, waveform = torch_tools.wav_to_fbank(wavs, target_length, stft) # mel: [batch, 1024, 64], waveform: [batch, 163840]
            mel = mel.unsqueeze(1).to(accelerator.device)
            waveform = waveform.to(accelerator.device)
            prompt = list(cond)
            # TODO: Implement data augmentation of TANGO later.
            if self.args.tango_data_augment and len(cond) > 1:
                # the last batch of the training data may have only one instance
                # we check the length here so that the augmentation function doesn't throw an error
                mixed_mel, _, mixed_waveform, mixed_captions = torch_tools.augment_wav_to_fbank(wavs, cond, self.args.augment_num, target_length, stft)
                mixed_mel = mixed_mel.unsqueeze(1).to(accelerator.device)
                mixed_waveform = mixed_waveform.to(accelerator.device)
                mel = th.cat([mel, mixed_mel], 0)
                waveform = th.cat([waveform, mixed_waveform], 0)
                prompt += mixed_captions
            x_start = stage1_model.get_first_stage_encoding(stage1_model.encode_first_stage(mel)) # z_{0} [batch, 8, 256, 16]
        
        th.cuda.empty_cache()
        
        if noise is None:
            noise = th.randn_like(x_start)
        dims = x_start.ndim
        s = None
        terms = {}
        assert self.args.consistency_weight > 0.
        num_heun_step = [self.get_num_heun_step(step)] 
        num_heun_step = num_heun_step[0]

        indices, _ = self.schedule_sampler.sample_t(x_start.shape[0], x_start.device, num_heun_step, self.args.time_continuous)
        t = self.get_t(indices)
        t_dt = self.get_t(indices + num_heun_step)
        if ctm:
            new_indices = self.schedule_sampler.sample_s(self.args, x_start.shape[0], x_start.device, indices,
                                                        num_heun_step, self.args.time_continuous,
                                                        N=self.args.start_scales)
            s = self.get_t(new_indices)

        x_t = x_start + noise * append_dims(t, dims) # z_t

        estimate = self.get_estimate(step, x_t, t, t_dt, s, model, target_model, ctm=ctm, cond=prompt)
        # z_{est}(z_{t}, cond, t, s) 
        # = G_{\sg(\theta)}(G_{\theta}(z_t, cond, t, s), cond, t=s, s=sigma_min)
        th.cuda.empty_cache()

        if teacher_model:
            if self.args.cfg_distill:
                x_t_dt = self.heun_solver_cfg(x_t, indices, self.args.target_cfg, teacher_model, dims, cond=prompt, num_step=num_heun_step)
                # Solver(z_t, cond, t, u, \omega; \phi)
            
            else:
                x_t_dt = self.heun_solver(x_t, indices, teacher_model, dims, cond=prompt, num_step=num_heun_step,)
                # Solver(z_t, cond, t, u; \phi)
        
        else:
            with th.no_grad():
                x_t_dt = self.denoise_fn(target_model, x_t, t, cond=prompt, s=t_dt, ctm=ctm) # NOTE: we don't use this so far.

        target = self.get_target(step, x_t_dt, t_dt, s, 
                                    model, target_model, ctm=ctm, cond=prompt) 
        # z_{target}(z_{t}, cond, t, t_dt, s)

        th.cuda.empty_cache()
        if gen_backword:
            snrs = self.get_snr(t)
            weights = get_weightings(self.args.weight_schedule, snrs, self.args.sigma_data, t, s)
            terms["consistency_loss"] = self.get_consistency_loss(estimate, target, weights, 
                                                                    loss_domain=self.args.loss_domain, 
                                                                    loss_norm=self.args.loss_norm)
            th.cuda.empty_cache()
            if self.args.diffusion_training:
                if self.args.dsm_loss_target == 'z_0':
                    terms['denoising_loss'] = self.get_denoising_loss(model, x_start,
                                                                    terms["consistency_loss"],
                                                                    step, cond=prompt,
                                                                    loss_target=x_start)
                elif self.args.dsm_loss_target == 'z_target': # NOTE: Bassically, we don't use this anymore.
                    terms['denoising_loss'] = self.get_denoising_loss(model, x_start,
                                                                    terms["consistency_loss"],
                                                                    step, cond=prompt,
                                                                    loss_target=target)
            th.cuda.empty_cache()
            # if self.args.discriminator_training and step - init_step >= self.args.discriminator_start_itr:
            if self.args.discriminator_training:
                if step >= self.args.discriminator_start_itr:
                    gan_estimate = self.get_gan_estimate(estimate, step, x_t, t, t_dt, s, 
                                                            model, target_model, ctm=ctm, cond=prompt)
                    # breakpoint() # 40768MiB
                    th.cuda.empty_cache()
                    if self.args.d_architecture in ['DAC_GAN', 'DAC_SAN']:
                        estimated_mel = stage1_model.decode_first_stage(gan_estimate) # convert latent to mel
                        estimated_waveform = stage1_model.decode_to_waveform(estimated_mel) # convert mel to waveform
                        
                        if self.args.gan_target == 'z_target':
                            target_mel = stage1_model.decode_first_stage(target) # convert latent to mel
                            target_waveform = stage1_model.decode_to_waveform(target_mel) # convert mel to waveform
                            
                        elif self.args.gan_target == 'z_0':
                            target_waveform = waveform.to(accelerator.device)
                            length_min = min(target_waveform.shape[-1], estimated_waveform.shape[-1])
                            target_waveform = target_waveform[:, :length_min]
                            estimated_waveform = estimated_waveform[:, :length_min]

                        # breakpoint() 41450MiB
                        terms['g_loss'], terms['fm_loss'] = self.get_generator_loss(
                            model, estimate=estimated_waveform, 
                            target=target_waveform, 
                            consistency_loss=terms["consistency_loss"], 
                            discriminator=discriminator, 
                            step=step,
                        )
                        th.cuda.empty_cache()
                    elif self.args.d_architecture in ['DAC_CGAN', 'DAC_CSAN']:
                        estimated_mel = stage1_model.decode_first_stage(gan_estimate) # convert latent to mel
                        estimated_waveform = stage1_model.decode_to_waveform(estimated_mel) # convert mel to waveform
                        
                        if self.args.gan_target == 'z_target': # NOTE: Bassically, we don't use this anymore.
                            target_mel = stage1_model.decode_first_stage(target) # convert latent to mel
                            target_waveform = stage1_model.decode_to_waveform(target_mel) # convert mel to waveform
                            
                        elif self.args.gan_target == 'z_0':
                            target_waveform = waveform.to(accelerator.device)
                            length_min = min(target_waveform.shape[-1], estimated_waveform.shape[-1])
                            target_waveform = target_waveform[:, :length_min]
                            estimated_waveform = estimated_waveform[:, :length_min]

                        terms['g_loss'], terms['fm_loss'] = self.get_conditional_generator_loss( 
                                                                    model, cond=prompt, 
                                                                    estimate=estimated_waveform,
                                                                    target=target_waveform,
                                                                    consistency_loss=terms["consistency_loss"],
                                                                    discriminator=discriminator,
                                                                    step=step)
                        th.cuda.empty_cache()
                        
                    elif self.args.d_architecture in ['MEL_VQGAN']:
                        estimated_mel = stage1_model.decode_first_stage(gan_estimate) # convert latent to mel
                        
                        if self.args.gan_target == 'z_target': # NOTE: Bassically, we don't use this anymore.
                            target_mel = stage1_model.decode_first_stage(target) # convert latent to mel
                        elif self.args.gan_target == 'z_0':
                            target_mel = mel.to(accelerator.device)
                        
                        terms['g_loss'], terms['fm_loss'] = self.get_latent_vqgan_generator_loss(
                            model, estimate=estimated_mel, 
                            target=target_mel, 
                            consistency_loss=terms["consistency_loss"], 
                            discriminator=discriminator, 
                            step=step,
                        )
                        th.cuda.empty_cache()
                    
                    elif self.args.d_architecture in ['MEL_CVQGAN']:
                        estimated_mel = stage1_model.decode_first_stage(gan_estimate) # convert latent to mel
                        
                        if self.args.gan_target == 'z_target': # NOTE: Bassically, we don't use this anymore.
                            target_mel = stage1_model.decode_first_stage(target) # convert latent to mel
                        elif self.args.gan_target == 'z_0':
                            target_mel = mel.to(accelerator.device)
                        
                        terms['g_loss'], terms['fm_loss'] = self.get_latent_cvqgan_generator_loss(
                            model, estimate=estimated_mel, 
                            target=target_mel,
                            cond=prompt, 
                            consistency_loss=terms["consistency_loss"], 
                            discriminator=discriminator, 
                            step=step,
                        )
                        th.cuda.empty_cache()
                    
                    
                    elif self.args.d_architecture in ['L_VQGAN']:
                        
                        terms['g_loss'], terms['fm_loss'] = self.get_latent_vqgan_generator_loss(
                            model, estimate=gan_estimate, 
                            target=x_start, 
                            consistency_loss=terms["consistency_loss"], 
                            discriminator=discriminator, 
                            step=step,
                        )
                        th.cuda.empty_cache()
                    elif self.args.d_architecture in ['L_CVQGAN']:    
                        terms['g_loss'], terms['fm_loss'] = self.get_latent_cvqgan_generator_loss(
                            model, estimate=gan_estimate, 
                            target=x_start,
                            cond=prompt, 
                            consistency_loss=terms["consistency_loss"], 
                            discriminator=discriminator, 
                            step=step,
                        )
                        th.cuda.empty_cache()

        else:
            assert self.args.discriminator_training
            if step >= self.args.discriminator_start_itr:
                gan_estimate = self.get_gan_estimate(estimate, step, x_t, t, t_dt, s, 
                                                    model, target_model, ctm=ctm, cond=prompt)
                # breakpoint() 27064MiB
                if not self.args.d_architecture in ['L_VQGAN', 'L_CVQGAN', 'MEL_VQGAN', 'MEL_CVQGAN']:
                    estimated_mel = stage1_model.decode_first_stage(gan_estimate) # convert latent to mel
                    estimated_waveform = stage1_model.decode_to_waveform(estimated_mel) # convert mel to waveform
                            
                    if self.args.gan_target == 'z_target':
                        target_mel = stage1_model.decode_first_stage(target) # convert latent to mel
                        target_waveform = stage1_model.decode_to_waveform(target_mel) # convert mel to waveform
                    elif self.args.gan_target == 'z_0':
                        target_waveform = waveform.to(accelerator.device)
                        length_min = min(target_waveform.shape[-1], estimated_waveform.shape[-1])
                        target_waveform = target_waveform[:, :length_min]
                        estimated_waveform = estimated_waveform[:, :length_min]

                
                if self.args.d_architecture in ['DAC_GAN', 'DAC_SAN']:
                
                    terms['d_loss'] = self.get_discriminator_loss(estimate=estimated_waveform, 
                                                                    target=target_waveform, 
                                                                    discriminator=discriminator, step=step)
                elif self.args.d_architecture in ['DAC_CGAN', 'DAC_CSAN']:
                    terms['d_loss'] = self.get_conditional_discriminator_loss(model, prompt, 
                                                                    estimate=estimated_waveform, 
                                                                    target=target_waveform, 
                                                                    discriminator=discriminator, step=step)
                
                elif self.args.d_architecture in ['MEL_VQGAN']:
                    estimated_mel = stage1_model.decode_first_stage(gan_estimate) # convert latent to mel
                    if self.args.gan_target == 'z_target':
                        target_mel = stage1_model.decode_first_stage(target) # convert latent to mel
                    elif self.args.gan_target == 'z_0':
                        target_mel = mel.to(accelerator.device)
                        
                    terms['d_loss'] = self.get_latent_vqgan_discriminator_loss(
                        estimate=estimated_mel, 
                        target=target_mel, 
                        discriminator=discriminator, 
                        step=step
                        )
                
                elif self.args.d_architecture in ['MEL_CVQGAN']:
                    estimated_mel = stage1_model.decode_first_stage(gan_estimate) # convert latent to mel
                    if self.args.gan_target == 'z_target':
                        target_mel = stage1_model.decode_first_stage(target) # convert latent to mel
                    elif self.args.gan_target == 'z_0':
                        target_mel = mel.to(accelerator.device)
                        
                    terms['d_loss'] = self.get_latent_cvqgan_discriminator_loss(
                        model, prompt,
                        estimate=estimated_mel, 
                        target=target_mel, 
                        discriminator=discriminator, 
                        step=step
                        )
                
                
                elif self.args.d_architecture in ['L_VQGAN']:
                    terms['d_loss'] = self.get_latent_vqgan_discriminator_loss(
                        estimate=gan_estimate, 
                        target=x_start, 
                        discriminator=discriminator, 
                        step=step
                        )
                elif self.args.d_architecture in ['L_CVQGAN']:
                    terms['d_loss'] = self.get_latent_cvqgan_discriminator_loss(
                        model, prompt,
                        estimate=gan_estimate, 
                        target=x_start, 
                        discriminator=discriminator, 
                        step=step
                        )
            

        return terms

def karras_sample(
    diffusion,
    model,
    shape,
    steps,
    cond,
    guidance_scale,
    gamma=0.9,
    cfg_aug=None,
    clip_denoised=True,
    progress=False,
    callback=None,
    model_kwargs=None,
    device=None,
    sigma_min=0.002,
    sigma_max=80, 
    rho=7.0,
    sampler="heun",
    s_churn=0.0,
    s_tmin=0.0,
    s_tmax=float("inf"),
    s_noise=1.0,
    # generator=None,
    ts=None,
    x_T=None,
    ctm=False,
    teacher=False,
    clip_output=False,
    train=False,
):
    # if generator is None:
    #     generator = get_generator("dummy")

    if sampler in ["progdist", 'euler', 'exact', 'exact_cfg', 'cm_multistep_cfg', 'gamma_multistep_cfg']:
        sigmas = get_sigmas_karras(steps + 1, sigma_min, sigma_max, rho, device=device)
    else:
        sigmas = get_sigmas_karras(steps, sigma_min, sigma_max, rho, device=device)

    if x_T == None:
        # x_T = generator.randn(*shape, device=device) * sigma_max
        x_T = th.randn(*shape, device=device) * sigma_max

    sample_fn = {
        "heun": sample_heun,
        "heun_cfg": sample_heun_cfg,
        "dpm": sample_dpm,
        "ancestral": sample_euler_ancestral,
        "onestep": sample_onestep,
        "exact": sample_exact,
        "exact_cfg": sample_exact_cfg,
        "progdist": sample_progdist,
        "euler": sample_euler,
        "multistep": stochastic_iterative_sampler,
        "cm_multistep_cfg": sample_multistep_cfg,
        "gamma_multistep_cfg": sample_gamma_multistep_cfg,
    }[sampler]

    if sampler in ["heun", "dpm", "heun_cfg"]:
        sampler_args = dict(
            s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise
        )
    elif sampler in ["multistep", "exact", "exact_cfg", "cm_multistep_cfg"]:
        sampler_args = dict(
            ts=ts, t_min=sigma_min, t_max=sigma_max, rho=rho, steps=steps
        )
    elif sampler in ["gamma_multistep_cfg"]:
        sampler_args = dict(
            ts=ts, t_min=sigma_min, t_max=sigma_max, rho=rho, steps=steps, gamma=gamma,
        )
    else:
        sampler_args = {}
    if sampler in ['heun', "heun_cfg"]:
        sampler_args['teacher'] = False if train else teacher
        sampler_args['ctm'] = ctm
    def denoiser(x_t, t, s=th.ones(x_T.shape[0], device=device), cond=None, ctm=False, teacher=False, cfg=None):
        denoised = diffusion.denoise_fn(model, x_t, t, cond=cond, s=s, ctm=ctm, teacher=teacher, cfg=cfg)
        if clip_denoised:
            denoised = denoised.clamp(-1, 1)
        return denoised

    x_0 = sample_fn(
        denoiser,
        x_T,
        sigmas,
        cond=cond,
        ctm=ctm,
        teacher=teacher,
        guidance_scale=guidance_scale,
        cfg_aug=cfg_aug,
        # generator=generator,
        progress=progress,
        callback=callback,
        **sampler_args,
    )

    if clip_output:
        print("clip output")
        return x_0.clamp(-1, 1)
    return x_0


def get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, device="cpu"):
    """Constructs the noise schedule of Karras et al. (2022)."""
    ramp = th.linspace(0, 1, n)
    min_inv_rho = sigma_min ** (1 / rho)
    max_inv_rho = sigma_max ** (1 / rho)
    sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
    return append_zero(sigmas).to(device)


def to_d(x, sigma, denoised):
    """Converts a denoiser output to a Karras ODE derivative."""
    return (x - denoised) / append_dims(sigma, x.ndim)


def get_ancestral_step(sigma_from, sigma_to):
    """Calculates the noise level (sigma_down) to step down to and the amount
    of noise to add (sigma_up) when doing an ancestral sampling step."""
    sigma_up = (
        sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2
    ) ** 0.5
    sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
    return sigma_down, sigma_up


@th.no_grad()
def sample_euler_ancestral(
    model, 
    x, 
    sigmas, 
    # generator, 
    progress=False, 
    callback=None):
    """Ancestral sampling with Euler method steps."""
    s_in = x.new_ones([x.shape[0]])
    indices = range(len(sigmas) - 1)
    if progress:
        from tqdm.auto import tqdm

        indices = tqdm(indices)

    for i in indices:
        denoised = model(x, sigmas[i] * s_in)
        sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
        if callback is not None:
            callback(
                {
                    "x": x,
                    "i": i,
                    "sigma": sigmas[i],
                    "sigma_hat": sigmas[i],
                    "denoised": denoised,
                }
            )
        d = to_d(x, sigmas[i], denoised)
        # Euler method
        dt = sigma_down - sigmas[i]
        x = x + d * dt
        x = x + th.randn_like(x) * sigma_up
        # x = x + generator.randn_like(x) * sigma_up
    return x


@th.no_grad()
def sample_midpoint_ancestral(
    model, 
    x, 
    ts, 
    # generator, 
    progress=False, 
    callback=None):
    """Ancestral sampling with midpoint method steps."""
    s_in = x.new_ones([x.shape[0]])
    step_size = 1 / len(ts)
    if progress:
        from tqdm.auto import tqdm

        ts = tqdm(ts)

    for tn in ts:
        dn = model(x, tn * s_in)
        dn_2 = model(x + (step_size / 2) * dn, (tn + step_size / 2) * s_in)
        x = x + step_size * dn_2
        if callback is not None:
            callback({"x": x, "tn": tn, "dn": dn, "dn_2": dn_2})
    return x

@th.no_grad()
def sample_exact(
    denoiser,
    x,
    sigmas,
    cond,
    ctm,
    teacher,
    guidance_scale=None,
    # generator,
    progress=False,
    callback=None,
    ts=[],
    t_min=0.002,
    t_max=80.0,
    rho=7.0,
    steps=40,
):
    """Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
    s_in = x.new_ones([x.shape[0]])
    if ts != [] and ts != None:
        sigmas = []
        t_max_rho = t_max ** (1 / rho)
        t_min_rho = t_min ** (1 / rho)
        s_in = x.new_ones([x.shape[0]])

        for i in range(len(ts)):
            sigmas.append((t_max_rho + ts[i] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho)
        sigmas = th.tensor(sigmas)
        sigmas = append_zero(sigmas).to(x.device)
    indices = range(len(sigmas) - 1)
    if progress:
        from tqdm.auto import tqdm

        indices = tqdm(indices)

    for i in indices[:-1]:
        sigma = sigmas[i]
        print(sigma, sigmas[i+1])
        if sigmas[i+1] != 0:
            denoised = denoiser(x, sigma * s_in, cond=cond, s=sigmas[i + 1] * s_in, ctm=ctm, teacher=teacher)
            x = denoised
        else:
            denoised = denoiser(x, sigma * s_in, cond=cond, s=sigma * s_in, ctm=ctm, teacher=teacher)
            d = to_d(x, sigma, denoised)
            dt = sigmas[i + 1] - sigma
            x = x + d * dt
        #else:
        #    denoised = denoiser(x, sigma * s_in)

        if callback is not None:
            callback(
                {
                    "x": x,
                    "i": i,
                    "sigma": sigmas[i],
                    "denoised": denoised,
                }
            )
        #x = denoised
    return x

@th.no_grad()
def sample_exact_cfg(
    denoiser,
    x,
    sigmas,
    cond,
    guidance_scale,
    cfg_aug=None,
    ctm=False,
    teacher=False,
    # generator,
    progress=False,
    callback=None,
    ts=[],
    t_min=0.002,
    t_max=80.0,
    rho=7.0,
    steps=40,
):
    """Implements CTM's \gamma-sampling when \gamma=0)."""
    dims = x.ndim
    s_in = x.new_ones([x.shape[0]])
    if cfg_aug is not None:
        cfg_aug = cfg_aug * th.ones((x.shape[0],), device=x.device)
    if ts != [] and ts != None:
        sigmas = []
        t_max_rho = t_max ** (1 / rho)
        t_min_rho = t_min ** (1 / rho)
        s_in = x.new_ones([x.shape[0]])

        for i in range(len(ts)):
            sigmas.append((t_max_rho + ts[i] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho)
        sigmas = th.tensor(sigmas)
        sigmas = append_zero(sigmas).to(x.device)
    indices = range(len(sigmas) - 1)
    if progress:
        from tqdm.auto import tqdm

        indices = tqdm(indices)

    for i in indices[:-1]:
        sigma = sigmas[i]
        # print(sigma, sigmas[i+1])
        if sigmas[i+1] != 0:
            x_in = th.cat([x] * 2)
            cond_cfg = cond + ([""] * len(cond))
            denoised = denoiser(x_in, th.cat([sigma * s_in] * 2), cond=cond_cfg, 
                                s=th.cat([sigmas[i + 1] * s_in] * 2), ctm=ctm, teacher=teacher, cfg=cfg_aug)
            denoised_text, denoised_uncond = denoised.chunk(2)
            denoised = denoised_uncond + guidance_scale * (denoised_text - denoised_uncond)
            x = denoised
        else:
            x_in = th.cat([x] * 2)
            cond_cfg = cond + ([""] * len(cond))
            denoised = denoiser(x_in, th.cat([sigma * s_in]*2), cond=cond_cfg, 
                                s=th.cat([sigma * s_in]*2), ctm=ctm, teacher=teacher, cfg=cfg_aug)
            denoised_text, denoised_uncond = denoised.chunk(2)
            denoised = denoised_uncond + guidance_scale * (denoised_text - denoised_uncond)
            d = to_d(x, sigma, denoised)
            dt = sigmas[i + 1] - sigma
            x = x + d * dt
        #else:
        #    denoised = denoiser(x, sigma * s_in)

        if callback is not None:
            callback(
                {
                    "x": x,
                    "i": i,
                    "sigma": sigmas[i],
                    "denoised": denoised,
                }
            )
        #x = denoised
    return x


@th.no_grad()
def sample_heun(
    denoiser,
    x,
    sigmas,
    cond,
    guidance_scale=None,
    # generator,
    progress=False,
    callback=None,
    s_churn=0.0,
    s_tmin=0.0,
    s_tmax=float("inf"),
    s_noise=1.0,
    teacher=False,
    ctm=False,
    
):
    """Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
    s_in = x.new_ones([x.shape[0]])
    indices = range(len(sigmas) - 1)
    if progress:
        from tqdm.auto import tqdm

        indices = tqdm(indices)

    for i in indices:
        print("sigmas: ", sigmas[i])
        gamma = (
            min(s_churn / (len(sigmas) - 1), 2**0.5 - 1)
            if s_tmin <= sigmas[i] <= s_tmax
            else 0.0
        )
        # eps = generator.randn_like(x) * s_noise
        eps = th.randn_like(x) * s_noise
        sigma_hat = sigmas[i] * (gamma + 1)
        if gamma > 0:
            x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
        if ctm:
            denoised = denoiser(x, sigma_hat * s_in, cond=cond, s=sigma_hat * s_in, 
                                ctm=ctm, teacher=teacher)
        else:
            if teacher:
                denoised = denoiser(x, sigma_hat * s_in, cond=cond, s=None, 
                                ctm=ctm, teacher=teacher)
            else:
                denoised = denoiser(x, sigma_hat * s_in, cond=cond, s=sigma_hat * s_in, 
                                ctm=ctm, teacher=teacher)
        d = to_d(x, sigma_hat, denoised)
        if callback is not None:
            callback(
                {
                    "x": x,
                    "i": i,
                    "sigma": sigmas[i],
                    "sigma_hat": sigma_hat,
                    "denoised": denoised,
                }
            )
        dt = sigmas[i + 1] - sigma_hat
        if sigmas[i + 1] == 0:
            # Euler method
            x = x + d * dt
        else:
            # Heun's method
            x_2 = x + d * dt
            if ctm:
                denoised_2 = denoiser(x_2, sigmas[i + 1] * s_in, cond=cond, 
                                      s=sigmas[i + 1] * s_in, 
                                      ctm=ctm, teacher=teacher)
            else:
                if teacher:
                    denoised_2 = denoiser(x_2, sigmas[i + 1] * s_in, cond=cond, 
                                      s=None, 
                                      ctm=ctm, teacher=teacher)
                else:
                    denoised_2 = denoiser(x_2, sigmas[i + 1] * s_in, cond=cond, 
                                      s=sigmas[i + 1] * s_in, 
                                      ctm=ctm, teacher=teacher)
            d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
            d_prime = (d + d_2) / 2
            x = x + d_prime * dt
    return x

@th.no_grad()
def sample_heun_cfg(
    denoiser,
    x,
    sigmas,
    cond,
    guidance_scale=None,
    # generator,
    progress=False,
    callback=None,
    s_churn=0.0,
    s_tmin=0.0,
    s_tmax=float("inf"),
    s_noise=1.0,
    teacher=False,
    ctm=False,
    
):
    """Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
    dims = x.ndim
    s_in = x.new_ones([x.shape[0]])
    indices = range(len(sigmas) - 1)
    if progress:
        from tqdm.auto import tqdm

        indices = tqdm(indices)

    for i in indices:
        print("sigmas: ", sigmas[i])
        gamma = (
            min(s_churn / (len(sigmas) - 1), 2**0.5 - 1)
            if s_tmin <= sigmas[i] <= s_tmax
            else 0.0
        )
        # eps = generator.randn_like(x) * s_noise
        eps = th.randn_like(x) * s_noise
        sigma_hat = sigmas[i] * (gamma + 1)
        if gamma > 0:
            x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
        x_in = th.cat([x] * 2)
        cond_cfg = cond + ([""] * len(cond))
        if ctm:
            denoised = denoiser(x_in, th.cat([sigma_hat * s_in] * 2), cond=cond_cfg, 
                                s=th.cat([sigma_hat * s_in] * 2), 
                                ctm=ctm, teacher=teacher)
            denoised_text, denoised_uncond = denoised.chunk(2)
            # denoised = denoised_uncond + append_dims(guidance_scale, dims) * (denoised_text - denoised_uncond)
            denoised = denoised_uncond + guidance_scale * (denoised_text - denoised_uncond)
        else:
            if teacher:
                denoised = denoiser(x_in, th.cat([sigma_hat * s_in]*2), cond=cond_cfg, s=None, 
                                ctm=ctm, teacher=teacher)
                denoised_text, denoised_uncond = denoised.chunk(2)
                # denoised = denoised_uncond + append_dims(guidance_scale, dims) * (denoised_text - denoised_uncond)
                denoised = denoised_uncond + guidance_scale * (denoised_text - denoised_uncond)
            else:
                denoised = denoiser(x_in, th.cat([sigma_hat * s_in]*2), cond=cond_cfg, s=th.cat([sigma_hat * s_in]*2), 
                                ctm=ctm, teacher=teacher)
                denoised_text, denoised_uncond = denoised.chunk(2)
                # denoised = denoised_uncond + append_dims(guidance_scale, dims) * (denoised_text - denoised_uncond)
                denoised = denoised_uncond + guidance_scale * (denoised_text - denoised_uncond)
        
        d = to_d(x, sigma_hat, denoised)
        if callback is not None:
            callback(
                {
                    "x": x,
                    "i": i,
                    "sigma": sigmas[i],
                    "sigma_hat": sigma_hat,
                    "denoised": denoised,
                }
            )
        dt = sigmas[i + 1] - sigma_hat
        if sigmas[i + 1] == 0:
            # Euler method
            x = x + d * dt
        else:
            # Heun's method
            x_2 = x + d * dt
            x_2_in = th.cat([x_2] * 2)
            cond_cfg = cond + ([""] * len(cond))
            if ctm:
                denoised_2 = denoiser(x_2_in, th.cat([sigmas[i + 1] * s_in]*2), cond=cond_cfg, 
                                      s=th.cat([sigmas[i + 1] * s_in]*2), 
                                      ctm=ctm, teacher=teacher)
                denoised_text, denoised_uncond = denoised_2.chunk(2)
                # denoised_2 = denoised_uncond + append_dims(guidance_scale, dims) * (denoised_text - denoised_uncond)
                denoised_2 = denoised_uncond + guidance_scale * (denoised_text - denoised_uncond)
            else:
                if teacher:
                    denoised_2 = denoiser(x_2_in, th.cat([sigmas[i + 1] * s_in]*2), cond=cond_cfg, 
                                      s=None, 
                                      ctm=ctm, teacher=teacher)
                    denoised_text, denoised_uncond = denoised_2.chunk(2)
                    # denoised_2 = denoised_uncond + append_dims(guidance_scale, dims) * (denoised_text - denoised_uncond)
                    denoised_2 = denoised_uncond + guidance_scale * (denoised_text - denoised_uncond)
                else:
                    denoised_2 = denoiser(x_2_in, th.cat([sigmas[i + 1] * s_in]*2), cond=cond_cfg, 
                                      s=th.cat([sigmas[i + 1] * s_in]*2), 
                                      ctm=ctm, teacher=teacher)
                    denoised_text, denoised_uncond = denoised_2.chunk(2)
                    # denoised_2 = denoised_uncond + append_dims(guidance_scale, dims) * (denoised_text - denoised_uncond)
                    denoised_2 = denoised_uncond + guidance_scale * (denoised_text - denoised_uncond)
                
            d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
            d_prime = (d + d_2) / 2
            x = x + d_prime * dt
    return x


@th.no_grad()
def sample_multistep_cfg(
    denoiser,
    x,
    sigmas,
    cond,
    guidance_scale,
    ctm=False,
    teacher=False,
    progress=False,
    callback=None,
    ts=[],
    t_min=0.002,
    t_max=80.0,
    rho=7.0,
    steps=40,
):
    """Implements CM's multistep sampling (CTM's \gamma-sampling when \gamma=1)."""
    s_in = x.new_ones([x.shape[0]])
    if ts != [] and ts != None:
        sigmas = []
        t_max_rho = t_max ** (1 / rho)
        t_min_rho = t_min ** (1 / rho)
        s_in = x.new_ones([x.shape[0]])

        for i in range(len(ts)):
            sigmas.append((t_max_rho + ts[i] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho)
        sigmas = th.tensor(sigmas)
        sigmas = append_zero(sigmas).to(x.device)
    indices = range(len(sigmas) - 1)
    if progress:
        from tqdm.auto import tqdm

        indices = tqdm(indices)
    for i in indices[:-1]:
        sigma = sigmas[i]
        print(i, sigma, sigmas[i+1])
        #print(0.002 * s_in)
        x_in = th.cat([x] * 2)
        cond_cfg = cond + ([""] * len(cond))
        # denoised = denoiser(x, sigma * s_in, s=t_min * s_in)
        denoised = denoiser(x_in, th.cat([sigma * s_in] * 2), cond=cond_cfg, 
                            s=th.cat([t_min * s_in] * 2), ctm=ctm, teacher=teacher)
        denoised_text, denoised_uncond = denoised.chunk(2)
        denoised = denoised_uncond + guidance_scale * (denoised_text - denoised_uncond)
        if i < len(indices) - 2:
            print(th.sqrt(sigmas[i+1] ** 2 - t_min ** 2).item())
            x = denoised + th.sqrt(sigmas[i+1] ** 2 - t_min ** 2) * th.randn_like(denoised)
        else:
            x = denoised

        if callback is not None:
            callback(
                {
                    "x": x,
                    "i": i,
                    "sigma": sigmas[i],
                    "denoised": denoised,
                }
            )
        #x = denoised
    return x

@th.no_grad()
def sample_gamma_multistep_cfg(
    denoiser,
    x,
    sigmas,
    cond,
    guidance_scale,
    ctm=False,
    teacher=False,
    progress=False,
    callback=None,
    ts=[],
    t_min=0.002,
    t_max=80.0,
    rho=7.0,
    steps=40,
    gamma=0.9,
):
    """ Implements CTM's \gamma-sampling """
    s_in = x.new_ones([x.shape[0]])
    if ts != [] and ts != None:
        sigmas = []
        t_max_rho = t_max ** (1 / rho)
        t_min_rho = t_min ** (1 / rho)
        s_in = x.new_ones([x.shape[0]])

        for i in range(len(ts)):
            sigmas.append((t_max_rho + ts[i] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho)
        sigmas = th.tensor(sigmas)
        sigmas = append_zero(sigmas).to(x.device)
    indices = range(len(sigmas) - 1)
    if progress:
        from tqdm.auto import tqdm

        indices = tqdm(indices)

    assert gamma != 0.0 and gamma != 1.0
    for i in indices[:-1]:
        sigma = sigmas[i]
        print(sigma, sigmas[i+1], gamma)
        s = (np.sqrt(1. - gamma ** 2) * (sigmas[i + 1] - t_min) + t_min)
        
        x_in = th.cat([x] * 2)
        cond_cfg = cond + ([""] * len(cond))
        denoised = denoiser(x_in, th.cat([sigma * s_in] * 2), cond=cond_cfg, 
                            s=th.cat([s * s_in] * 2), ctm=ctm, teacher=teacher)
        denoised_text, denoised_uncond = denoised.chunk(2)
        denoised = denoised_uncond + guidance_scale * (denoised_text - denoised_uncond)
        if i < len(indices) - 2:
            std = th.sqrt(sigmas[i + 1] ** 2 - s ** 2)
            x = denoised + std * th.randn_like(denoised)
        else:
            x = denoised

    return x

@th.no_grad()
def sample_progdist(
    denoiser,
    x,
    sigmas,
    generator=None,
    progress=False,
    callback=None,
):
    s_in = x.new_ones([x.shape[0]])
    sigmas = sigmas[:-1]  # skip the zero sigma

    indices = range(len(sigmas) - 1)
    if progress:
        from tqdm.auto import tqdm

        indices = tqdm(indices)

    for i in indices:
        sigma = sigmas[i]
        denoised = denoiser(x, sigma * s_in)
        d = to_d(x, sigma, denoised)
        if callback is not None:
            callback(
                {
                    "x": x,
                    "i": i,
                    "sigma": sigma,
                    "denoised": denoised,
                }
            )
        dt = sigmas[i + 1] - sigma
        x = x + d * dt

    return x

@th.no_grad()
def sample_euler(
    denoiser,
    x,
    sigmas,
    # generator,
    progress=False,
    callback=None,
):
    """Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
    s_in = x.new_ones([x.shape[0]])
    indices = range(len(sigmas) - 1)
    if progress:
        from tqdm.auto import tqdm

        indices = tqdm(indices)

    for i in indices:
        sigma = sigmas[i]
        denoised = denoiser(x, sigma * s_in)
        d = to_d(x, sigma, denoised)
        if callback is not None:
            callback(
                {
                    "x": x,
                    "i": i,
                    "sigma": sigmas[i],
                    "denoised": denoised,
                }
            )
        dt = sigmas[i + 1] - sigma
        x = x + d * dt
    return x


@th.no_grad()
def sample_dpm(
    denoiser,
    x,
    sigmas,
    # generator,
    progress=False,
    callback=None,
    s_churn=0.0,
    s_tmin=0.0,
    s_tmax=float("inf"),
    s_noise=1.0,
):
    """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
    s_in = x.new_ones([x.shape[0]])
    indices = range(len(sigmas) - 1)
    if progress:
        from tqdm.auto import tqdm

        indices = tqdm(indices)

    for i in indices:
        gamma = (
            min(s_churn / (len(sigmas) - 1), 2**0.5 - 1)
            if s_tmin <= sigmas[i] <= s_tmax
            else 0.0
        )
        # eps = generator.randn_like(x) * s_noise
        eps = th.randn_like(x) * s_noise
        sigma_hat = sigmas[i] * (gamma + 1)
        if gamma > 0:
            x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
        denoised = denoiser(x, sigma_hat * s_in)
        d = to_d(x, sigma_hat, denoised)
        if callback is not None:
            callback(
                {
                    "x": x,
                    "i": i,
                    "sigma": sigmas[i],
                    "sigma_hat": sigma_hat,
                    "denoised": denoised,
                }
            )
        # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
        sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3
        dt_1 = sigma_mid - sigma_hat
        dt_2 = sigmas[i + 1] - sigma_hat
        x_2 = x + d * dt_1
        denoised_2 = denoiser(x_2, sigma_mid * s_in)
        d_2 = to_d(x_2, sigma_mid, denoised_2)
        x = x + d_2 * dt_2
    return x


@th.no_grad()
def sample_onestep(
    distiller,
    x,
    sigmas,
    # generator=None,
    progress=False,
    callback=None,
):
    """Single-step generation from a distilled model."""
    s_in = x.new_ones([x.shape[0]])
    return distiller(x, sigmas[0] * s_in, None)


@th.no_grad()
def stochastic_iterative_sampler(
    distiller,
    x,
    sigmas,
    # generator,
    ts,
    progress=False,
    callback=None,
    t_min=0.002,
    t_max=80.0,
    rho=7.0,
    steps=40,
):
    t_max_rho = t_max ** (1 / rho)
    t_min_rho = t_min ** (1 / rho)
    s_in = x.new_ones([x.shape[0]])

    for i in range(len(ts) - 1):
        t = (t_max_rho + ts[i] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
        x0 = distiller(x, t * s_in, None)
        next_t = (t_max_rho + ts[i + 1] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
        next_t = np.clip(next_t, t_min, t_max)
        # x = x0 + generator.randn_like(x) * np.sqrt(next_t**2 - t_min**2)
        x = x0 + th.randn_like(x) * np.sqrt(next_t**2 - t_min**2)

    return x
