from typing import Callable, List
from typing import Optional, Union

import numpy as np
import torch
import torch as th
import torch.nn as nn
import torch.nn.functional as F

from cleandiffuser.diffusion import DiffusionModel
from cleandiffuser.nn_condition import BaseNNCondition
from cleandiffuser.nn_diffusion import BaseNNDiffusion
from cleandiffuser.utils import at_least_ndim, create_named_schedule_sampler, Args
from .newedm import ContinuousEDM


def get_weightings(weight_schedule, snrs, sigma_data):
    if weight_schedule == "uniform":
        weightings = th.ones_like(snrs)
    elif weight_schedule == "karras_weight":
        sigma = snrs ** -0.5
        weightings = (sigma ** 2 + sigma_data ** 2) / (sigma * sigma_data) ** 2
    else:
        raise NotImplementedError()
    return weightings


def compare_properties(obj1, obj2, properties: List[str]):
    differences = []
    for prop in properties:
        obj1_prop = getattr(obj1, prop)
        obj2_prop = getattr(obj2, prop)
        if isinstance(obj1_prop, torch.Tensor):
            if not torch.allclose(obj1_prop, obj2_prop):
                differences.append(prop)
        elif isinstance(obj1_prop, np.ndarray):
            if not np.allclose(obj1_prop, obj2_prop):
                differences.append(prop)
        else:
            if obj1_prop != obj2_prop:
                differences.append(prop)
    return differences


def pseudo_huber_loss(source: torch.Tensor, target: torch.Tensor, c: float = 0.0):
    return ((source - target) ** 2 + c ** 2).sqrt() - c


class CMCurriculumLogger:
    def __init__(
            self, s0: int = 10, s1: int = 1280, curriculum_cycle: int = 100_000,
            sigma_min: float = 0.002, sigma_max: float = 80., rho: float = 7.,
            P_mean: float = -1.1, P_std: float = 2.0
    ):
        self.Kprime = np.ceil(curriculum_cycle / (np.log2(np.ceil(s1 / s0)) + 1))
        self.Nk = s0
        self.s0, self.s1 = s0, s1
        self.curriculum_cycle = curriculum_cycle
        self.sigma_min, self.sigma_max, self.rho = sigma_min, sigma_max, rho
        self.P_mean, self.P_std = P_mean, P_std

        self.ceil_k_div_Kprime, self.k = None, None

        self.update_k(0)

    def update_k(self, k):
        self.k = k
        if np.ceil(k / self.Kprime) != self.ceil_k_div_Kprime:
            self.ceil_k_div_Kprime = np.ceil(k / self.Kprime)
            self.Nk = int(min(self.s0 * (2 ** self.ceil_k_div_Kprime), self.s1))

            self.sigmas = ((self.sigma_min ** (1 / self.rho) + np.arange(self.Nk + 1, dtype=np.float32)
                            / self.Nk * (self.sigma_max ** (1 / self.rho) - self.sigma_min ** (
                            1 / self.rho))) ** self.rho)

            self.p_sigmas = (erf((np.log(self.sigmas[1:]) - self.P_mean) / (self.P_std * (2 ** 0.5))) -
                             erf((np.log(self.sigmas[:-1]) - self.P_mean) / (self.P_std * (2 ** 0.5))))
            self.p_sigmas = self.p_sigmas / self.p_sigmas.sum()

    def incremental_update_k(self):
        self.update_k(self.k + 1)

    @property
    def curriculum_process(self):
        return (self.k % self.curriculum_cycle) / self.curriculum_cycle


class ContinuousConsistencyTrajecoteryModel(DiffusionModel):
    """**Continuous-time Consistency Trajectory Model**
    
    The Consistency Model defines a consistency function. 
    A consistency function has the property of self-consistency: 
    its outputs are consistent for arbitrary pairs of (x_t, t) that belong to the same PF ODE trajectory. 
    To learn such a consistency function, the Consistency Model needs to be distilled either from a pre-trained EDM 
    or learned directly through consistency training loss.
    This self-consistency property allows the Consistency Model in theory to achieve one-step generation.

    The current implementation of Consistency Model only supports continuous-time ODEs.
    The sampling steps are required to be greater than 0.

    Args:
        nn_diffusion: BaseNNDiffusion,
            The neural network backbone for the Diffusion model.
        nn_condition: Optional[BaseNNCondition],
            The neural network backbone for the condition embedding.
        
        fix_mask: Union[list, np.ndarray, torch.Tensor],
            Fix some portion of the input data, and only allow the diffusion model to complete the rest part.
            The mask should be in the shape of `x_shape`.
        loss_weight: Union[list, np.ndarray, torch.Tensor],
            Add loss weight. The weight should be in the shape of `x_shape`.
        
        classifier: Optional[BaseClassifier],
            The Consistency Model does not support classifier guidance; please set this option to `None`.
        
        grad_clip_norm: Optional[float],
            Gradient clipping norm.
        ema_rate: float,
            Exponential moving average rate.
        optim_params: Optional[dict],
            Optimizer parameters.
        
        s0: int,
            The minimum number of noise levels. Default: 10.
        s1: int,
            The maximum number of noise levels. Default: 1280.
        data_dim: int,
            The dimension of the data, which affects the `pseudo_huber_constant`.
            As suggested in `improved Consistency Models`, `pseudo_huber_constant` = 0.00054 * np.sqrt(data_dim).
            If `data_dim` is `None`, then `pseudo_huber_constant` = 0.01 will be used.
        P_mean: float,
            Hyperparameter for noise sampling during training. Default: -1.1.
        P_std: float,
            Hyperparameter for noise sampling during training. Default: 2.0.
        sigma_min: float,
            The minimum standard deviation of the noise. Default: 0.002.
        sigma_max: float,
            The maximum standard deviation of the noise. Default: 80.
        sigma_data: float,
            The standard deviation of the data. Default: 0.5.
        rho: float,
            The power of the noise schedule. Default: 7.
        curriculum_cycle: int,
            The cycle of the curriculum process.
            It is best to set `curriculum_cycle` to the number of model training iterations. Default: 100_000.
    
        x_max: Optional[torch.Tensor],
            The maximum value for the input data. `None` indicates no constraint.
        x_min: Optional[torch.Tensor],
            The minimum value for the input data. `None` indicates no constraint.
        
        device: Union[torch.device, str],
            The device to run the model.
    """

    def __init__(
            self,

            # ----------------- Neural Networks ----------------- #
            nn_diffusion: BaseNNDiffusion,
            nn_condition: Optional[BaseNNCondition] = None,

            # ----------------- Masks ----------------- #
            # Fix some portion of the input data, and only allow the diffusion model to complete the rest part.
            fix_mask: Union[list, np.ndarray, torch.Tensor] = None,  # be in the shape of `x_shape`
            # Add loss weight
            loss_weight: Union[list, np.ndarray, torch.Tensor] = None,  # be in the shape of `x_shape`

            # ------------------ Plugins ---------------- #
            # Do not support CG
            classifier=None,

            # ------------------ Training Params ---------------- #
            grad_clip_norm: Optional[float] = None,
            ema_rate: float = 0.9999,
            optim_params: Optional[dict] = None,
            d_optim_params: Optional[dict] = None,

            # ------------------- Consistency Model Params ------------------- #
            s0: int = 10,
            s1: int = 1280,
            data_dim: int = None,
            P_mean: float = -1.1,
            P_std: float = 2.0,
            sigma_min: float = 0.002,
            sigma_max: float = 80.,
            sigma_data: float = 1.0,
            rho: float = 7.0,
            diffusion_mult: float = 0.7,
            curriculum_cycle: int = 100_000,

            x_max: Optional[torch.Tensor] = None,
            x_min: Optional[torch.Tensor] = None,

            device: Union[torch.device, str] = "cpu",
            apply_adaptive_weight: bool = True,
            discriminator = None,
            g_learning_period: int = 2,
            dloss_start_itr: int = 1000,
    ):
        super().__init__(
            nn_diffusion, nn_condition, fix_mask, loss_weight, classifier, grad_clip_norm,
            0, ema_rate, optim_params, device)


        self.rho = rho
        self.sigma_data, self.sigma_max, self.sigma_min = sigma_data, sigma_max, sigma_min
        self.x_max = x_max.to(device) if isinstance(x_max, torch.Tensor) else x_max
        self.x_min = x_min.to(device) if isinstance(x_min, torch.Tensor) else x_min
        self.args = Args(sigma_max=sigma_max, sigma_min=sigma_min, rho=rho, diffusion_mult=diffusion_mult)
        self.apply_adaptive_weight = apply_adaptive_weight
        self.discriminator = discriminator
        self.d_optim = torch.optim.Adam(discriminator.parameters(), **d_optim_params) if discriminator is not None else None

        self.edm = None
        self.distillation_sigmas, self.distillation_N = None, None
        self.g_learning_period = g_learning_period
        self.step = 0
        self.dloss_start_itr = dloss_start_itr
        

    def prepare_distillation(self, edm: ContinuousEDM, start_scales: int = 18, num_heun_step: int = 17):
        checklist = [
            "sigma_data", "sigma_max", "sigma_min", "rho", "x_max", "x_min",
            "fix_mask", "loss_weight", "device"]
        differences = compare_properties(self, edm, checklist)
        if len(differences) != 0:
            raise ValueError(f"Properties {differences} are different between the EDM and the Consistency Model.")
        self.edm = edm
        # 暂时先这么用，以后可以尝试都用edm.model_ema
        # self.model.load_state_dict(edm.model.state_dict())
        # self.model_ema.load_state_dict(edm.model_ema.state_dict())

        for dst_name, dst in self.model.named_parameters():
            for src_name, src in edm.model.named_parameters(): # edm.model_ema
                if dst_name == src_name:
                    dst.data.copy_(src.data)

        for dst_name, dst in self.model_ema.named_parameters():
            for src_name, src in edm.model_ema.named_parameters():
                if dst_name == src_name:
                    dst.data.copy_(src.data)        

        self.model.train()
        self.start_scales = start_scales
        self.num_heun_step = num_heun_step
        
        self.schedule_sampler = create_named_schedule_sampler(self.args, "uniform", start_scales)
        self.diffusion_schedule_sampler = create_named_schedule_sampler(self.args, 'halflognormal', start_scales)


    @property
    def supported_solvers(self):
        return ["none"]

    @property
    def clip_pred(self):
        return (self.x_max is not None) or (self.x_min is not None)

    def training_noise_schedule(self, N):
        sigma = ((self.sigma_min ** (1 / self.rho) + np.arange(N + 1)
                  / N * (self.sigma_max ** (1 / self.rho) - self.sigma_min ** (1 / self.rho))) ** self.rho)
        return torch.tensor(sigma, device=self.device, dtype=torch.float32)

    # ===================== CM Pre-conditioning =======================

    def get_edm_scalings(self, sigma):
        c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
        c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
        return c_skip, c_out

    def get_outer_scalings(self, t, s=None):
        c_skip = s / t
        c_out = (1. - s / t)
        return c_skip, c_out

    def get_snr(self, sigmas):
        return sigmas**-2

    def get_c_in(self, sigma):
        return 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt()

    def rescaling_t(self, sigma):
        return 0.25 * sigma.log()
    
    def get_t(self, ind):
        t = self.sigma_max ** (1 / self.rho) + ind / (self.start_scales - 1) * (
                self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho)
        )
        t = t ** self.rho
        return t

    def get_denoised_and_G(self, model, x_t, t, s=None, ctm=False, teacher=False):
        rescaled_t = self.rescaling_t(t)

        c_in = at_least_ndim(self.get_c_in(t), x_t.ndim)
        if s != None:
            rescaled_s = self.rescaling_t(s)
            model_output = model["diffusion"](c_in * x_t, rescaled_t, s_noise=rescaled_s)
        else:
            rescaled_s = None
            model_output = model["diffusion"](c_in * x_t, rescaled_t, )

        if ctm:
            c_skip, c_out = [
                at_least_ndim(x, x_t.ndim)
                for x in self.get_edm_scalings(t)
            ]
            g_theta = c_out * model_output + c_skip * x_t

            denoised = g_theta
            c_skip, c_out = [
                at_least_ndim(x, x_t.ndim)
                for x in self.get_outer_scalings(t, s)
            ]
            G_theta = c_out * g_theta + c_skip * x_t
        else:
            if teacher:
                c_skip, c_out = [
                    at_least_ndim(x, x_t.ndim) for x in self.get_edm_scalings(t)
                ]
            else:
                c_skip, c_out = [
                    at_least_ndim(x, x_t.ndim)
                    for x in self.get_cm_scalings(t)
                ]
            denoised = c_out * model_output + c_skip * x_t
            G_theta = denoised
        return denoised, G_theta


    def get_ctm_estimate(self, x_t, t, s, model, target_model, x_start, ctm):
        # ========== x_t -> theta -> x_s ========== #
        _, estimate = self.get_denoised_and_G(model, x_t, t, s=s, ctm=ctm)
        estimate_s = estimate * (1. - self.fix_mask) + (x_start * self.fix_mask / at_least_ndim(self.get_c_in(s), x_t.dim())).detach()

        # ========== x_s -> theta_minus -> x_0 ========== #
        _, estimate = self.get_denoised_and_G(target_model, estimate_s, s, s=th.ones_like(s) * self.sigma_min, ctm=ctm)
        estimate = estimate * (1. - self.fix_mask) + x_start * self.fix_mask

        return estimate

    @th.no_grad()
    def get_ctm_target(self, x_t_dt, t_dt, s, model, target_model, x_start, ctm):
        with th.no_grad():
            # ========== x_u -> theta -> x_s ========== #
            _, target = self.get_denoised_and_G(model, x_t_dt, t_dt, s=s, ctm=ctm)
            target = target * (1. - self.fix_mask) + x_start * self.fix_mask / at_least_ndim(self.get_c_in(s), x_t_dt.dim())

            # ========== x_s -> theta_minus -> x_0 ========== #
            _, target = self.get_denoised_and_G(target_model, target, s, s=th.ones_like(s) * self.sigma_min, ctm=ctm)
            target = target * (1. - self.fix_mask) + x_start * self.fix_mask
            return target.detach()

    @th.no_grad()
    def heun_solver(self, x, ind, dims, x_start, num_step=1):
        with th.no_grad():
            for k in range(num_step):
                t = self.get_t(ind + k)
                denoised, _ = self.get_denoised_and_G(self.edm.model_ema, x, t, None, ctm=False, teacher=True)
                d = (x - denoised) / at_least_ndim(t, dims)
                t2 = self.get_t(ind + k + 1)
                x_phi_ODE_1st = x + d * at_least_ndim(t2 - t, dims)
                x_phi_ODE_1st = x_phi_ODE_1st * (1. - self.fix_mask) + x_start * self.fix_mask / at_least_ndim(self.get_c_in(t2), x.dim())

                denoised2, _ = self.get_denoised_and_G(self.edm.model_ema, x_phi_ODE_1st, t2, None, ctm=False, teacher=True)
                next_d = (x_phi_ODE_1st - denoised2) / at_least_ndim(t2, dims)
                x_phi_ODE_2nd = x + (d + next_d) * at_least_ndim((t2 - t) / 2, dims)
                x_phi_ODE_2nd = x_phi_ODE_2nd * (1. - self.fix_mask) + x_start * self.fix_mask / at_least_ndim(self.get_c_in(t2), x.dim())

                x = x_phi_ODE_2nd
            return x

    def get_num_heun_step(self, num_heun_step=-1, heun_step_strategy='weighted'):
        if heun_step_strategy == 'uniform':
            num_heun_step = np.random.randint(1,1+num_heun_step)
        elif heun_step_strategy == 'weighted':
            p = np.array([i ** 1.0 for i in range(1,1+num_heun_step)]) # heun_step_multiplier=1.0,
            p = p / sum(p)
            num_heun_step = np.random.choice([i+1 for i in range(len(p))], size=1, p=p)[0]
        return num_heun_step

    def calculate_adaptive_weight(self, loss1, loss2, last_layer=None):
        loss1_grad = th.autograd.grad(loss1, last_layer, retain_graph=True)[0]
        loss2_grad = th.autograd.grad(loss2, last_layer, retain_graph=True)[0]
        d_weight = th.norm(loss1_grad) / (th.norm(loss2_grad) + 1e-4)
        d_weight = th.clamp(d_weight, 0.0, 1e4).detach()
        return d_weight

    def get_DSM_loss(self, model, x_start, consistency_loss):
        sigmas, denoising_weights = self.diffusion_schedule_sampler.sample(x_start.shape[0], self.device)
        noise = th.randn_like(x_start)
        dims = x_start.ndim
        x_t = x_start + noise * at_least_ndim(sigmas, dims)
        denoised, _ = self.get_denoised_and_G(model, x_t, sigmas, s=sigmas, ctm=True, teacher=True)
        snrs = self.get_snr(sigmas)
        denoising_weights = at_least_ndim(get_weightings("karras_weight", snrs, self.sigma_data), dims)
        denoising_loss = ((1 - self.fix_mask) * denoising_weights * (denoised - x_start) ** 2).mean()
        if self.apply_adaptive_weight: # True
            last_layer = dict(model.named_parameters())["diffusion.final_layer.linear.weight"]
            balance_weight = self.calculate_adaptive_weight(consistency_loss, denoising_loss,
                                                            last_layer=last_layer)
        else:
            balance_weight = 1.
        if self.logger: self.logger['denoise_weight'] = balance_weight.item()
        denoising_loss = denoising_loss * balance_weight
        return denoising_loss

    def get_GAN_loss(self, model, fake, real, consistency_loss, learn_generator=True):
        if self.discriminator is None:
            g_loss = (((fake - real) ** 2) * (1 - self.fix_mask) * self.loss_weight).mean()
            if self.apply_adaptive_weight: # True
                last_layer = dict(self.model.named_parameters())["diffusion.final_layer.linear.weight"]
                d_weight = self.calculate_adaptive_weight(consistency_loss, g_loss, last_layer=last_layer)
                d_weight = th.clip(d_weight, 0.01, 10.)
            else:
                d_weight = 1.
            g_loss = g_loss * d_weight
            return g_loss
        else:
            if learn_generator:
                logits_fake = self.discriminator(fake)
                g_loss = -logits_fake.mean()
                if self.logger: self.logger['g_loss'] = g_loss.item()
                if self.apply_adaptive_weight: # True
                    last_layer = dict(self.model.named_parameters())["diffusion.final_layer.linear.weight"]
                    d_weight = self.calculate_adaptive_weight(consistency_loss, g_loss, last_layer=last_layer)
                    d_weight = th.clip(d_weight, 0.01, 10.)
                else:
                    d_weight = 1.
                g_loss = g_loss * d_weight
                return g_loss

            else:
                logits_fake = self.discriminator(fake.detach())
                logits_real = self.discriminator(real.detach())
                if self.logger:
                    self.logger["logits_fake"] = logits_fake.mean().item()
                    self.logger["logits_real"] = logits_real.mean().item()
                loss_Dfake = (F.relu(th.ones_like(logits_fake) + logits_fake)).mean()
                loss_Dreal = (F.relu(th.ones_like(logits_real) - logits_real)).mean()
                discriminator_loss = loss_Dreal + loss_Dfake
                return discriminator_loss


    def distillation_loss(self, x_start, condition=None, ctm=True):

        assert self.edm is not None, "Please call `prepare_distillation` before distillation."

        terms = {}
        num_heun_step = self.get_num_heun_step(num_heun_step=self.num_heun_step, heun_step_strategy='weighted') # 随机采样teacher model步数，self.num_heun_step为最大可能步数
        indices, _ = self.schedule_sampler.sample_t(x_start.shape[0], self.device, num_heun_step) # t index
        t = self.get_t(indices)
        t_dt = self.get_t(indices + num_heun_step) # u
        new_indices = self.schedule_sampler.sample_s(x_start.shape[0], self.device, indices, num_heun_step, N=self.start_scales)
        s = self.get_t(new_indices) # s

        x_t, _, _ = self.edm.add_noise(x_start, t, None)

        ctm_estimate = self.get_ctm_estimate(x_t, t, s, self.model, self.model_ema, x_start, ctm=ctm)
        # 暂时先不加discriminator
        # ========== x_t -> solver -> x_u ========== #
        x_t_dt = self.heun_solver(x_t, indices, x_start.ndim, x_start, num_step=num_heun_step).detach()

        # ctm_target = self.heun_solver(x_t, indices, x_start.ndim, x_start, num_step=num_heun_step).detach()
        ctm_target = self.get_ctm_target(x_t_dt, t_dt, s, self.model, self.model_ema, x_start, ctm=ctm)

        snrs = self.get_snr(t)
        weights = at_least_ndim(get_weightings("uniform", snrs, self.sigma_data), x_start.ndim)
        # self.loss_weight决定是否给a0加多倍weight；weights是algorithm中的lamda
        terms["consistency_loss"] = (((ctm_estimate - ctm_target) ** 2) * (1 - self.fix_mask) * weights * self.loss_weight).mean()
        terms['denoising_loss'] = self.get_DSM_loss(self.model, x_start, terms["consistency_loss"])
        
        # 暂时先用最简单的方法得到gan_fake
        
        # gan_num_heun_step = self.get_num_heun_step(num_heun_step=self.args.gan_num_heun_step, # gan_heun_step_strategy='uniform'
        #                                                           heun_step_strategy=self.args.gan_heun_step_strategy)
        # gan_x_t, gan_t, gan_t_dt, gan_s, _, _ = self.get_gan_time(x_start, noise, x_t, t, t_dt, s, indices,
        #                                                             num_heun_step, gan_num_heun_step)
        # _, gan_fake = self.get_denoised_and_G(model, x_t, t, s=th.ones_like(s) * self.args.sigma_min, ctm=ctm, **model_kwargs)
        

        if self.discriminator is None:
            terms["d_loss"] = 0#self.get_GAN_loss(self.model, gan_fake, gan_real, terms["consistency_loss"])
        else:
            gan_fake = ctm_estimate
            gan_real = x_start.detach()
            if self.step % self.g_learning_period == 0:
                if self.step >= self.dloss_start_itr:
                    terms["d_loss"] = self.get_GAN_loss(self.model, gan_fake, gan_real, terms["consistency_loss"], learn_generator=True)
                else: 
                    terms["d_loss"] = 0
            else:
                terms["d_loss"] = self.get_GAN_loss(self.model, gan_fake, gan_real, terms["consistency_loss"], learn_generator=False)

        loss = terms["consistency_loss"] + terms['denoising_loss'] + terms["d_loss"]

        return loss.mean(), None


    def update(self, x0, condition=None, update_ema=True, loss_type="distillation", log_training=True):
        """ One-step gradient update.

        Args:
            x0: torch.Tensor,
                Samples from the target distribution.
            condition: Optional,
                Condition of x0. `None` indicates no condition.
            update_ema: bool,
                Whether to update the exponential moving average model.
            loss_type: str,
                The type of loss. `training` or `distillation`.

        Returns:
            log: dict,
                The log dictionary.

        Examples:
            >>> model = ContinuousConsistencyModel(...)
            >>> x0 = torch.randn(*x_shape)
            >>> condition = torch.randn(*condition_shape)
            >>> log = model.update(x0, condition, loss_type="training")  # training
            >>> log = model.update(x0, condition, loss_type="distillation")  # distillation
        """
        self.logger = {"denoise_weight": 0, "g_loss": 0, "logits_real": 0, "logits_fake": 0} if log_training else None
        if loss_type == "training":
            loss, unweighted_loss = self.training_loss(x0, condition)
        elif loss_type == "distillation":
            loss, unweighted_loss = self.distillation_loss(x0, condition)
        else:
            raise ValueError(f"Unknown loss type: {loss_type}")

        loss.backward()
        grad_norm = nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_norm) \
            if self.grad_clip_norm else None
        self.optimizer.step()
        self.optimizer.zero_grad()
        if self.discriminator and (self.step % self.g_learning_period != 0):
            self.d_optim.step()
            self.d_optim.zero_grad()

        if update_ema:
            self.ema_update()

        if loss_type == "training":
            self.cur_logger.incremental_update_k()

        if self.step % self.g_learning_period == 0:
            log = {"loss": loss.item(), "denoise_weight": self.logger["denoise_weight"], "g_loss": self.logger["g_loss"]}
        else:
            log = {"loss": loss.item(), "denoise_weight": self.logger["denoise_weight"], "logits_real": self.logger["logits_real"], "logits_fake": self.logger["logits_fake"]}
        self.step += 1

        return log


    def get_sigmas_karras(self, n):

        ramp = th.linspace(0, 1, n)
        min_inv_rho = self.sigma_min ** (1 / self.rho)
        max_inv_rho = self.sigma_max ** (1 / self.rho)
        sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho
        sigmas = th.cat([sigmas, sigmas.new_zeros([1])])
        return sigmas.to(self.device)

    @th.no_grad()
    def sample(
            self,
            # ---------- the known fixed portion ---------- #
            prior: torch.Tensor,
            # ----------------- sampling ----------------- #
            solver: str = "none",
            n_samples: int = 1,
            sample_steps: int = 1,
            sample_step_schedule: Union[str, Callable] = "uniform",
            use_ema: bool = True,
            temperature: float = 1.0,
            # ------------------ guidance ------------------ #
            condition_cfg=None,
            mask_cfg=None,
            w_cfg: float = 0.0,
            condition_cg=None,
            w_cg: float = 0.0,
            # ----------- Diffusion-X sampling ----------
            diffusion_x_sampling_steps: int = 0,
            # ----------- Warm-Starting -----------
            warm_start_reference: Optional[torch.Tensor] = None,
            warm_start_forward_level: float = 0.3,
            # ------------------ others ------------------ #
            requires_grad: bool = False,
            preserve_history: bool = False,
            **kwargs,
    ):
        assert w_cg == 0.0 and condition_cg is None, "Consistency Distillation does not support classifier guidance."

        log = {
            "sample_history": np.empty((n_samples, sample_steps + 1, *prior.shape)) if preserve_history else None, }

        model = self.model if not use_ema else self.model_ema

        prior = prior.to(self.device)
        
        x = torch.randn_like(prior) * self.sigma_max * temperature
        x = x * (1. - self.fix_mask) + prior * self.fix_mask / at_least_ndim(self.get_c_in(torch.tensor(self.sigma_max * temperature, device=self.device)), x.dim())

        if preserve_history:
            log["sample_history"][:, 0] = x.cpu().numpy()

        # ===================== Sampling Schedule ====================
        s_in = x.new_ones([x.shape[0]])
        # if ts != [] and ts != None:
        #     sigmas = []
        #     t_max_rho = self.sigma_max ** (1 / self.rho)
        #     t_min_rho = self.sigma_min ** (1 / self.rho)
        #     s_in = x.new_ones([x.shape[0]])

        #     for i in range(len(ts)):
        #         sigmas.append((t_max_rho + ts[i] / (sample_steps - 1) * (t_min_rho - t_max_rho)) ** self.rho)
        #     sigmas = th.tensor(sigmas)
        #     sigmas = append_zero(sigmas).to(self.device)
        sigmas = self.get_sigmas_karras(sample_steps + 1)
        indices = range(len(sigmas) - 1)

        for i in indices[:-1]:
            sigma = sigmas[i]
            if sigmas[i+1] != 0:
                _, denoised = self.get_denoised_and_G(model, x, sigma * s_in, s=sigmas[i + 1] * s_in, ctm = True, teacher = False)
                x = denoised
                # if sigmas[i + 1] != self.sigma_min:
                x = x * (1. - self.fix_mask) + (prior * self.fix_mask / at_least_ndim(self.get_c_in(sigmas[i + 1]), x.dim()))
            else:
                _, denoised = self.get_denoised_and_G(model, x, sigma * s_in, s=sigmas[i + 1] * s_in, ctm = True, teacher = False)
                d = to_d(x, sigma, denoised)
                dt = sigmas[i + 1] - sigma
                x = x + d * dt

        return x, log
