import math
import os
from functools import partial
import matplotlib.pyplot as plt
import numpy as np
import torch
from tqdm.auto import tqdm
import pywt  # <-- added for wavelet transforms

from util.img_utils import clear_color
from .posterior_mean_variance import get_mean_processor, get_var_processor

import cv2
import numpy as np
import pywt
#from .frequency_pass_filter import apply_radial_filter

def compute_wavelet_strength(x: float, alpha: float = 0.5, beta: float = 1000.0) -> float:
    """
    Compute the regularized wavelet_strength parameter that depends on diffusion step x:
        y = 1 / (7 + exp(x * (alpha / beta)))

    Args:
        x:     Current diffusion step (or any positive scalar).
        alpha: Multiplier in the exponent (default 0.5).
        beta:  Divider in the exponent (default 1000.0).

    Returns:
        A float representing wavelet strength, e.g. ~0.12499 when x=1.
    """
    return 1.0 / (7.0 + math.exp(x * (alpha / beta)))

def process_image_with_wavelet(input_tensor, wavelet_strength=None, step=None):
    """
    处理输入的PyTorch张量，应用小波变换并重构，并可以根据 diffusion step 动态调整影响程度。

    Args:
        input_tensor (torch.Tensor): 输入张量，形状为 [1, 3, H, W].
        wavelet_strength (float, optional): 直接指定影响强度，若为 None 则使用 compute_wavelet_strength(step).
        step (float, optional): 当 wavelet_strength=None 时，必须提供当前 diffusion step，用于动态计算 strength.

    Returns:
        torch.Tensor: 处理后的张量，形状同输入，数据类型为 float32，范围 [0, 1].
    """
    # Determine strength
    if wavelet_strength is None:
        if step is None:
            raise ValueError("必须提供 step 参数以动态计算 wavelet_strength, 或直接指定 wavelet_strength")
        wavelet_strength = compute_wavelet_strength(step)

    if input_tensor.dim() != 4 or input_tensor.size(0) != 1:
        raise ValueError("输入张量必须是形状为 [1, 3, H, W] 的4D张量")

    if input_tensor.dtype == torch.float32:
        np_img = (input_tensor.squeeze(0).permute(1, 2, 0).detach().cpu().numpy() * 255).astype(np.uint8)
    elif input_tensor.dtype == torch.uint8:
        np_img = input_tensor.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
    else:
        raise TypeError(f"不支持的张量数据类型: {input_tensor.dtype}")

    b_channel, g_channel, r_channel = cv2.split(np_img)

    def wavelet_process(channel):
        coeffs = pywt.dwt2(channel, 'haar')
        cA, (cH, cV, cD) = coeffs
        reconstructed = pywt.idwt2((cA, (cH, cV, cD)), 'haar')
        return np.clip(reconstructed * wavelet_strength, 0, 255).astype(np.uint8)

    reconstructed_b = wavelet_process(b_channel)
    reconstructed_g = wavelet_process(g_channel)
    reconstructed_r = wavelet_process(r_channel)

    reconstructed_img = cv2.merge([reconstructed_b, reconstructed_g, reconstructed_r])

    reconstructed_tensor = torch.from_numpy(reconstructed_img) \
        .permute(2, 0, 1) \
        .unsqueeze(0) \
        .float() / 255

    return reconstructed_tensor.to(input_tensor.device)

# def process_image_with_wavelet(input_tensor, wavelet_strength=0.125):
#     """
#     处理输入的PyTorch张量，应用小波变换并重构，整体控制小波变换的影响。

#     Args:
#         input_tensor (torch.Tensor): 输入张量，形状为 [1, 3, H, W]
#         wavelet_strength (float): 控制小波变换整体影响的参数，值越小，影响越小。

#     Returns:
#         torch.Tensor: 处理后的张量，形状同输入，数据类型为 float32，范围 [0, 1]
#     """
#     # 确保输入是4D张量
#     if input_tensor.dim() != 4 or input_tensor.size(0) != 1:
#         raise ValueError("输入张量必须是形状为 [1, 3, H, W] 的4D张量")

#     # 将张量移动到CPU并转换为NumPy数组
#     if input_tensor.dtype == torch.float32:
#         np_img = (input_tensor.squeeze(0).permute(1, 2, 0).detach().cpu().numpy() * 255).astype(np.uint8)
#     elif input_tensor.dtype == torch.uint8:
#         np_img = input_tensor.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
#     else:
#         raise TypeError("不支持的张量数据类型: {}".format(input_tensor.dtype))

#     # 分离颜色通道
#     b_channel, g_channel, r_channel = cv2.split(np_img)

#     # 定义处理单通道的小波变换函数
#     def wavelet_process(channel):
#         coeffs = pywt.dwt2(channel, 'haar')
#         cA, (cH, cV, cD) = coeffs

#         # 直接减少整个图像的小波变换影响
#         reconstructed = pywt.idwt2((cA, (cH, cV, cD)), 'haar')
#         return np.clip(reconstructed * wavelet_strength, 0, 255).astype(np.uint8)

#     # 处理每个通道
#     reconstructed_b = wavelet_process(b_channel)
#     reconstructed_g = wavelet_process(g_channel)
#     reconstructed_r = wavelet_process(r_channel)

#     # 合并通道
#     reconstructed_img = cv2.merge([reconstructed_b, reconstructed_g, reconstructed_r])

#     # 转换回张量，调整维度顺序为 CxHxW，并归一化到 [0, 1]
#     reconstructed_tensor = torch.from_numpy(reconstructed_img) \
#                                .permute(2, 0, 1) \
#                                .unsqueeze(0) \
#                                .float() / 255

#     return reconstructed_tensor


# ============== Wavelet Functions ==============
import pywt
import numpy as np


def wavelet_transform(x, wavelet='db1', wavelet_strength=0.125):
    """
    计算一阶 2D 离散小波变换，并返回字典，包含 'LL', 'LH', 'HL', 'HH' 子带。
    对每个子带，返回其振幅和相位（通过调整 wavelet_strength 来减少高频细节的增强）。
    假定输入 x 为 2D numpy 数组。

    Args:
        x (np.ndarray): 输入图像。
        wavelet (str): 使用的小波类型，默认为 'db1'。
        wavelet_strength (float): 控制小波变换整体影响的参数，值越小，影响越小。

    Returns:
        dict: 包含 'LL', 'LH', 'HL', 'HH' 四个子带的振幅和相位。
    """
    coeffs = pywt.dwt2(x, wavelet)
    LL, (LH, HL, HH) = coeffs

    def amp_phase(coef):
        amplitude = np.abs(coef)
        phase = np.angle(coef)
        return amplitude, phase

    # 直接减少小波变换的整体影响
    LL = LL * wavelet_strength  # 控制低频部分的影响
    LH = LH * wavelet_strength  # 控制水平高频细节
    HL = HL * wavelet_strength  # 控制垂直高频细节
    HH = HH * wavelet_strength  # 控制对角高频细节

    return {
        'LL': amp_phase(LL),
        'LH': amp_phase(LH),
        'HL': amp_phase(HL),
        'HH': amp_phase(HH)
    }


def inverse_wavelet_transform(coeffs_dict, wavelet='db1', wavelet_strength=0.125):
    """
    使用小波系数重建图像。对高频部分（'LH', 'HL', 'HH'）应用 wavelet_strength 来减小其影响。

    Args:
        coeffs_dict (dict): 包含 'LL', 'LH', 'HL', 'HH' 四个子带的字典。
        wavelet (str): 使用的小波类型，默认为 'db1'。
        wavelet_strength (float): 控制高频部分的增强力度。

    Returns:
        np.ndarray: 重建后的图像。
    """
    LL = np.real(coeffs_dict['LL']) * wavelet_strength  # 调低 LL 子带的影响
    LH = np.real(coeffs_dict['LH']) * wavelet_strength  # 调低 LH 子带的影响
    HL = np.real(coeffs_dict['HL']) * wavelet_strength  # 调低 HL 子带的影响
    HH = np.real(coeffs_dict['HH']) * wavelet_strength  # 调低 HH 子带的影响
    return pywt.idwt2((LL, (LH, HL, HH)), wavelet)


def inverse_wavelet_transform_combined(A, P, wavelet='db1', wavelet_strength=0.125):
    """
    逆小波变换，使用组合的振幅 A 和相位 P，调整高频部分的影响。

    Args:
        A (np.ndarray): 组合振幅。
        P (np.ndarray): 组合相位。
        wavelet (str): 使用的小波类型，默认为 'db1'。
        wavelet_strength (float): 控制高频部分的增强力度。

    Returns:
        np.ndarray: 重建后的图像。
    """
    combined_coef = A * np.exp(1j * P)
    zeros = np.zeros_like(combined_coef)

    # 控制高频部分的影响
    combined_coef *= wavelet_strength  # 对复合系数应用 wavelet_strength 调整

    return pywt.idwt2((combined_coef, (zeros, zeros, zeros)), wavelet)


# ============== End Wavelet Functions ==============

__SAMPLER__ = {}


def register_sampler(name: str):
    def wrapper(cls):
        if __SAMPLER__.get(name, None):
            raise NameError(f"Name {name} is already registered!")
        __SAMPLER__[name] = cls
        return cls

    return wrapper


def get_sampler(name: str):
    if __SAMPLER__.get(name, None) is None:
        raise NameError(f"Name {name} is not defined!")
    return __SAMPLER__[name]


def create_sampler(sampler,
                   steps,
                   noise_schedule,
                   model_mean_type,
                   model_var_type,
                   dynamic_threshold,
                   clip_denoised,
                   rescale_timesteps,
                   timestep_respacing=""):
    sampler = get_sampler(name=sampler)

    betas = get_named_beta_schedule(noise_schedule, steps)
    if not timestep_respacing:
        timestep_respacing = [steps]

    return sampler(use_timesteps=space_timesteps(steps, timestep_respacing),
                   betas=betas,
                   model_mean_type=model_mean_type,
                   model_var_type=model_var_type,
                   dynamic_threshold=dynamic_threshold,
                   clip_denoised=clip_denoised,
                   rescale_timesteps=rescale_timesteps)


class GaussianDiffusion:
    def __init__(self,
                 betas,
                 model_mean_type,
                 model_var_type,
                 dynamic_threshold,
                 clip_denoised,
                 rescale_timesteps):
        # use float64 for accuracy.
        betas = np.array(betas, dtype=np.float64)
        self.betas = betas
        assert self.betas.ndim == 1, "betas must be 1-D"
        assert (0 < self.betas).all() and (self.betas <= 1).all(), "betas must be in (0..1]"

        self.num_timesteps = int(self.betas.shape[0])
        self.rescale_timesteps = rescale_timesteps

        alphas = 1.0 - self.betas
        self.alphas_cumprod = np.cumprod(alphas, axis=0)
        self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
        self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
        assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
        self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
        self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
        self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        self.posterior_variance = (
                betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
        self.posterior_log_variance_clipped = np.log(
            np.append(self.posterior_variance[1], self.posterior_variance[1:])
        )
        self.posterior_mean_coef1 = (
                betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
        self.posterior_mean_coef2 = (
                (1.0 - self.alphas_cumprod_prev)
                * np.sqrt(alphas)
                / (1.0 - self.alphas_cumprod)
        )

        self.mean_processor = get_mean_processor(model_mean_type,
                                                 betas=betas,
                                                 dynamic_threshold=dynamic_threshold,
                                                 clip_denoised=clip_denoised)

        self.var_processor = get_var_processor(model_var_type,
                                               betas=betas)

    def q_mean_variance(self, x_start, t):
        """
        Get the distribution q(x_t | x_0).

        :param x_start: the [N x C x ...] tensor of noiseless inputs.
        :param t: the diffusion step index (0 means one step).
        :return: A tuple (mean, variance, log_variance) with the same shape as x_start.
        """
        mean = extract_and_expand(self.sqrt_alphas_cumprod, t, x_start) * x_start
        variance = extract_and_expand(1.0 - self.alphas_cumprod, t, x_start)
        log_variance = extract_and_expand(self.log_one_minus_alphas_cumprod, t, x_start)
        return mean, variance, log_variance

    def q_sample(self, x_start, t):
        """
        Diffuse the data for a given number of diffusion steps.
        Sample from q(x_t | x_0).
        """
        noise = torch.randn_like(x_start)
        assert noise.shape == x_start.shape

        coef1 = extract_and_expand(self.sqrt_alphas_cumprod, t, x_start)
        coef2 = extract_and_expand(self.sqrt_one_minus_alphas_cumprod, t, x_start)
        return coef1 * x_start + coef2 * noise

    def q_posterior_mean_variance(self, x_start, x_t, t):
        """
        Compute the mean and variance of the diffusion posterior:
            q(x_{t-1} | x_t, x_0)
        """
        assert x_start.shape == x_t.shape
        coef1 = extract_and_expand(self.posterior_mean_coef1, t, x_start)
        coef2 = extract_and_expand(self.posterior_mean_coef2, t, x_t)
        posterior_mean = coef1 * x_start + coef2 * x_t
        posterior_variance = extract_and_expand(self.posterior_variance, t, x_t)
        posterior_log_variance_clipped = extract_and_expand(self.posterior_log_variance_clipped, t, x_t)
        assert (
                posterior_mean.shape[0]
                == posterior_variance.shape[0]
                == posterior_log_variance_clipped.shape[0]
                == x_start.shape[0]
        )
        return posterior_mean, posterior_variance, posterior_log_variance_clipped
      
        
    def p_sample_loop(
        self,
        model,
        x_start,
        measurement,
        measurement_cond_fn,
        record,
        save_root,
        ):
        """
        Sampling loop with:
          1) diffusion posterior update
          2) measurement conditioning
          3) radial‐filter on x_start
          4) optional wavelet fusion
        """
        img    = x_start
        device = x_start.device
        T_max  = self.num_timesteps

        pbar = tqdm(range(T_max - 1, -1, -1), desc='sampling')
        for idx in pbar:
            time = torch.full((img.shape[0],), idx, device=device, dtype=torch.long)

            # 1) one diffusion posterior step
            img = img.requires_grad_()
            # print(type(img))
            out = self.p_sample(x=img, t=time, model=model)

            # 2) measurement conditioning
            noisy_measurement = self.q_sample(measurement, t=time)
            img, distance = measurement_cond_fn(
                x_t=out['sample'],
                measurement=measurement,
                noisy_measurement=noisy_measurement,
                x_prev=img,
                x_0_hat=out['pred_xstart']
            )

            # 3) radial‐filter on img (x_start)
            # img = apply_radial_filter(img, t=idx, T_max=T_max)

            # 4) optional wavelet fusion for the very last steps
            if idx > T_max - 3:   # e.g. only last two steps
                # This assumes process_image_with_wavelet returns a tensor
                img = process_image_with_wavelet(img, step=idx)  

            # detach & record
            img = img.detach()
            pbar.set_postfix({'distance': f'{distance:.4f}'}, refresh=False)

            if record and idx % 10 == 0:
                path = os.path.join(save_root, f"progress/x_{idx:04d}.jpg")
                plt.imsave(path, clear_color(img))

        return img


        
    # def p_sample_loop(self,
#                       model,
#                       x_start,
#                       measurement,
#                       measurement_cond_fn,
#                       record,
#                       save_root,
#                       wavelet_N=1.0,
#                       wavelet_w=0.5):
#         """
#         The function used for sampling from noise.
#         Integrated with wavelet-domain fusion.

#         At each timestep t, we transform the predicted x_start (output from the model)
#         into the wavelet domain, then compute the fused coefficients as:

#             A_combined = A_LL + (A_LH + A_HL + A_HH) * (N * exp(-w * t))
#             P_combined = P_LL + (P_LH + P_HL + P_HH) * (N * exp(-w * t))

#         Finally, we reconstruct the sample via the inverse wavelet transform.
#         """
#         img = x_start
#         device = x_start.device

#         pbar = tqdm(list(range(self.num_timesteps))[::-1])
#         for idx in pbar:
#             time = torch.tensor([idx] * img.shape[0], device=device)
#             # print(img.shape)

#             img = img.requires_grad_()
#             out = self.p_sample(x=img, t=time, model=model)

#             # ---- Wavelet-domain fusion integration ----
#             # current_t = time[0].item()  # assume same timestep for the batch
#             # scaling = wavelet_N * np.exp(-wavelet_w * current_t)
#             # img_np = img.detach().cpu().numpy()
#             # new_img = []
#             # for i in range(img_np.shape[0]):
#             #     wt_pred = wavelet_transform(img_np[i])
#             #     A_ll, P_ll = wt_pred['LL']
#             #     A_lh, P_lh = wt_pred['LH']
#             #     A_hl, P_hl = wt_pred['HL']
#             #     A_hh, P_hh = wt_pred['HH']
#             #     A_combined = A_ll + (A_lh + A_hl + A_hh) * scaling
#             #     P_combined = P_ll + (P_lh + P_hl + P_hh) * scaling
#             #     pred_mod = inverse_wavelet_transform_combined(A_combined, P_combined)
#             #     new_img.append(pred_mod)
#             # new_img = abs(np.stack(new_img, axis=0))
#             # img = torch.from_numpy(new_img).to(device)
#             # ---- End wavelet fusion integration ----

#             # Give condition.
#             noisy_measurement = self.q_sample(measurement, t=time)
#             img, distance = measurement_cond_fn(x_t=out['sample'],
#                                                 measurement=measurement,
#                                                 noisy_measurement=noisy_measurement,
#                                                 x_prev=img,
#                                                 x_0_hat=out['pred_xstart'])

#             # ---- Wavelet-domain fusion integration ----
#             if idx > 998:
#                 reconstructed_img = process_image_with_wavelet(img, step = idx ).detach().cpu().numpy()
#                 img = torch.from_numpy(reconstructed_img).to(device)
#             # print(reconstructed_img.shape)
#             # ---- End wavelet fusion integration ----

#             img = img.detach_()
#             pbar.set_postfix({'distance': distance.item()}, refresh=False)
#             if record:
#                 if idx % 10 == 0:
#                     file_path = os.path.join(save_root, f"progress/x_{str(idx).zfill(4)}.jpg")
#                     plt.imsave(file_path, clear_color(img))
#         return img

    def p_sample(self, model, x, t):
        raise NotImplementedError

    def p_mean_variance(self, model, x, t):
                
        model_output = model(x, self._scale_timesteps(t))
        if model_output.shape[1] == 2 * x.shape[1]:
            model_output, model_var_values = torch.split(model_output, x.shape[1], dim=1)
        else:
            model_var_values = model_output

        model_mean, pred_xstart = self.mean_processor.get_mean_and_xstart(x, t, model_output)
        model_variance, model_log_variance = self.var_processor.get_variance(model_var_values, t)
        assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape

        return {'mean': model_mean,
                'variance': model_variance,
                'log_variance': model_log_variance,
                'pred_xstart': pred_xstart}

    def _scale_timesteps(self, t):
        if self.rescale_timesteps:
            return t.float() * (1000.0 / self.num_timesteps)
        return t


def space_timesteps(num_timesteps, section_counts):
    """
    Create a set of timesteps from the original diffusion process.
    """
    if isinstance(section_counts, str):
        if section_counts.startswith("ddim"):
            desired_count = int(section_counts[len("ddim"):])
            for i in range(1, num_timesteps):
                if len(range(0, num_timesteps, i)) == desired_count:
                    return set(range(0, num_timesteps, i))
            raise ValueError(
                f"cannot create exactly {num_timesteps} steps with an integer stride"
            )
        section_counts = [int(x) for x in section_counts.split(",")]
    elif isinstance(section_counts, int):
        section_counts = [section_counts]

    size_per = num_timesteps // len(section_counts)
    extra = num_timesteps % len(section_counts)
    start_idx = 0
    all_steps = []
    for i, section_count in enumerate(section_counts):
        size = size_per + (1 if i < extra else 0)
        if size < section_count:
            raise ValueError(
                f"cannot divide section of {size} steps into {section_count}"
            )
        frac_stride = 1 if section_count <= 1 else (size - 1) / (section_count - 1)
        cur_idx = 0.0
        taken_steps = []
        for _ in range(section_count):
            taken_steps.append(start_idx + round(cur_idx))
            cur_idx += frac_stride
        all_steps += taken_steps
        start_idx += size
    return set(all_steps)


class SpacedDiffusion(GaussianDiffusion):
    """
    A diffusion process which can skip steps in a base diffusion process.
    """

    def __init__(self, use_timesteps, **kwargs):
        self.use_timesteps = set(use_timesteps)
        self.timestep_map = []
        self.original_num_steps = len(kwargs["betas"])

        base_diffusion = GaussianDiffusion(**kwargs)
        last_alpha_cumprod = 1.0
        new_betas = []
        for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
            if i in self.use_timesteps:
                new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
                last_alpha_cumprod = alpha_cumprod
                self.timestep_map.append(i)
        kwargs["betas"] = np.array(new_betas)
        super().__init__(**kwargs)

    def p_mean_variance(self, model, *args, **kwargs):
        return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)

    def training_losses(self, model, *args, **kwargs):
        return super().training_losses(self._wrap_model(model), *args, **kwargs)

    def condition_mean(self, cond_fn, *args, **kwargs):
        return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)

    def condition_score(self, cond_fn, *args, **kwargs):
        return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)

    def _wrap_model(self, model):
        if isinstance(model, _WrappedModel):
            return model
        return _WrappedModel(
            model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
        )

    def _scale_timesteps(self, t):
        return t


class _WrappedModel:
    def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
        self.model = model
        self.timestep_map = timestep_map
        self.rescale_timesteps = rescale_timesteps
        self.original_num_steps = original_num_steps

    def __call__(self, x, ts, **kwargs):
        map_tensor = torch.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
        new_ts = map_tensor[ts]
        if self.rescale_timesteps:
            new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
        return self.model(x, new_ts, **kwargs)


@register_sampler(name='ddpm')
class DDPM(SpacedDiffusion):
    def p_sample(self, model, x, t):
        out = self.p_mean_variance(model, x, t)
        sample = out['mean']
        noise = torch.randn_like(x)
        if t != 0:  # no noise when t == 0
            sample += torch.exp(0.5 * out['log_variance']) * noise
        return {'sample': sample, 'pred_xstart': out['pred_xstart']}


@register_sampler(name='ddim')
class DDIM(SpacedDiffusion):
    def p_sample(self, model, x, t, eta=0.0):
        out = self.p_mean_variance(model, x, t)
        eps = self.predict_eps_from_x_start(x, t, out['pred_xstart'])
        alpha_bar = extract_and_expand(self.alphas_cumprod, t, x)
        alpha_bar_prev = extract_and_expand(self.alphas_cumprod_prev, t, x)
        sigma = (
                eta
                * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
                * torch.sqrt(1 - alpha_bar / alpha_bar_prev)
        )
        noise = torch.randn_like(x)
        mean_pred = (
                out["pred_xstart"] * torch.sqrt(alpha_bar_prev)
                + torch.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
        )
        sample = mean_pred
        if t != 0:
            sample += sigma * noise
        return {"sample": sample, "pred_xstart": out["pred_xstart"]}

    def predict_eps_from_x_start(self, x_t, t, pred_xstart):
        coef1 = extract_and_expand(self.sqrt_recip_alphas_cumprod, t, x_t)
        coef2 = extract_and_expand(self.sqrt_recipm1_alphas_cumprod, t, x_t)
        return (coef1 * x_t - pred_xstart) / coef2


def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
    """
    Get a pre-defined beta schedule for the given name.
    """
    if schedule_name == "linear":
        scale = 1000 / num_diffusion_timesteps
        beta_start = scale * 0.0001
        beta_end = scale * 0.02
        return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
    elif schedule_name == "cosine":
        return betas_for_alpha_bar(
            num_diffusion_timesteps,
            lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
        )
    else:
        raise NotImplementedError(f"unknown beta schedule: {schedule_name}")


def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
    """
    Create a beta schedule that discretizes the given alpha_bar function.
    """
    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
    return np.array(betas)


def extract_and_expand(array, time, target):
    array = torch.from_numpy(array).to(target.device)[time].float()
    while array.ndim < target.ndim:
        array = array.unsqueeze(-1)
    return array.expand_as(target)


def expand_as(array, target):
    if isinstance(array, np.ndarray):
        array = torch.from_numpy(array)
    elif isinstance(array, np.float):
        array = torch.tensor([array])
    while array.ndim < target.ndim:
        array = array.unsqueeze(-1)
    return array.expand_as(target).to(target.device)


def _extract_into_tensor(arr, timesteps, broadcast_shape):
    res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
    while len(res.shape) < len(broadcast_shape):
        res = res[..., None]
    return res.expand(broadcast_shape)
