"""
Based on: https://github.com/crowsonkb/k-diffusion
"""

import os
import numpy as np
import torch
from piq import LPIPS
from tqdm.auto import tqdm
import torch.distributed as dist
import torchvision
import math
from types import SimpleNamespace
from corruption import build_corruption
import torchvision
from corruption.blur import build_blur

from .nn import mean_flat, append_dims, append_zero
from .random_util import BatchedSeedGenerator

def get_env_float(name):
    float_value = os.environ.get(name)
    try:
        return float(float_value)
    except ValueError:
        raise ValueError(f"Invalid value for {name}: {float_value}")

def corrupt_image(image, corrupt_type=None, corrupt_scale=0.0):
    if corrupt_type == "void":
        # raise "`bridge_guidance_type` 'void' need prior dropout in bridge training"
        # image = torch.zeros_like(image)
        image = image ** 3
    elif corrupt_type == "noise":
        image_noise_sigma = torch.normal(
            mean=-3.0, std=0.5, size=(image.shape[0],), device=image.device)
        image_noise_sigma = torch.exp(
            image_noise_sigma).to(dtype=image.dtype)
        image = image + torch.randn_like(image) * \
            image_noise_sigma[:, None, None, None] * corrupt_scale
    elif corrupt_type == "blur":
        image = torchvision.transforms.functional.gaussian_blur(image, kernel_size=3, sigma=corrupt_scale)
    elif corrupt_type.startswith("blur"):
        ks = int(corrupt_type[4:])
        image = torchvision.transforms.functional.gaussian_blur(image, kernel_size=ks, sigma=corrupt_scale)
    elif corrupt_type.startswith("blur_"):
        ks = int(corrupt_type[5:])
        image = torchvision.transforms.functional.gaussian_blur(image, kernel_size=ks)
    elif corrupt_type.startswith("avg"):
        ks = int(corrupt_type[3:])
        pd = (ks - 1) // 2
        image = torch.nn.functional.avg_pool2d(image, kernel_size=ks, stride=1, padding=pd)
    # elif corrupt_type == "downsample":
    #     orig_size = image.shape[-2:]
    #     downsampled = F.interpolate(
    #         image, scale_factor=0.95, mode="bilinear", align_corners=True)
    #     image = F.interpolate(downsampled, size=orig_size,
    #                           mode="bilinear", align_corners=True)
    # elif corrupt_type == 'jpeg':
    #     qf = 95
    #     image = jpeg_decode(jpeg_encode(image, qf), qf)
    else:
        opt = SimpleNamespace(device=image.device, image_size=image.shape[-1])
        log = None  # No logging needed here
        image = build_corruption(opt, log, corrupt_type)(image.float())

    return image

class NoiseSchedule:
    def __init__(self):
        raise NotImplementedError

    def get_f_g2(self, t):
        raise NotImplementedError

    def get_alpha_rho(self, t):
        raise NotImplementedError

    def get_abc(self, t):
        alpha_t, alpha_bar_t, rho_t, rho_bar_t = self.get_alpha_rho(t)
        a_t, b_t, c_t = (
            (alpha_bar_t * rho_t**2) / self.rho_T**2,
            (alpha_t * rho_bar_t**2) / self.rho_T**2,
            (alpha_t * rho_bar_t * rho_t) / self.rho_T,
        )
        return a_t, b_t, c_t


class VPNoiseSchedule(NoiseSchedule):
    def __init__(self, beta_d=2, beta_min=0.1):
        self.beta_d, self.beta_min = beta_d, beta_min
        self.alpha_fn = lambda t: np.e ** (-0.5 * beta_min * t - 0.25 * beta_d * t**2)
        self.alpha_T = self.alpha_fn(1)
        self.rho_fn = lambda t: (np.e ** (beta_min * t + 0.5 * beta_d * t**2) - 1).sqrt()
        self.rho_T = self.rho_fn(torch.DoubleTensor([1])).item()

        self.f_fn = lambda t: (-0.5 * beta_min - 0.5 * beta_d * t)
        self.g2_fn = lambda t: (beta_min + beta_d * t)

    def get_f_g2(self, t):
        t = t.to(torch.float64)
        f, g2 = self.f_fn(t), self.g2_fn(t)
        return f, g2

    def get_alpha_rho(self, t):
        t = t.to(torch.float64)
        alpha_t = self.alpha_fn(t)
        alpha_bar_t = alpha_t / self.alpha_T
        rho_t = self.rho_fn(t)
        rho_bar_t = (self.rho_T**2 - rho_t**2).sqrt()
        return alpha_t, alpha_bar_t, rho_t, rho_bar_t


class VENoiseSchedule(NoiseSchedule):
    def __init__(self, sigma_max=80.0):
        self.sigma_max = sigma_max
        self.alpha_fn = lambda t: torch.ones_like(t)
        self.alpha_T = 1
        self.rho_fn = lambda t: t
        self.rho_T = sigma_max

        self.f_fn = lambda t: torch.zeros_like(t)
        self.g2_fn = lambda t: 2 * t

    def get_f_g2(self, t):
        t = t.to(torch.float64)
        f, g2 = self.f_fn(t), self.g2_fn(t)
        return f, g2

    def get_alpha_rho(self, t):
        t = t.to(torch.float64)
        alpha_t = self.alpha_fn(t)
        alpha_bar_t = alpha_t / self.alpha_T
        rho_t = self.rho_fn(t)
        rho_bar_t = (self.rho_T**2 - rho_t**2).sqrt()
        return alpha_t, alpha_bar_t, rho_t, rho_bar_t


class I2SBNoiseSchedule(NoiseSchedule):
    def __init__(self, n_timestep=1000, beta_min=0.1, beta_max=1.0):
        self.n_timestep, self.linear_start, self.linear_end = (
            n_timestep,
            beta_min / n_timestep,
            beta_max / n_timestep,
        )
        betas = (
            torch.linspace(
                self.linear_start**0.5,
                self.linear_end**0.5,
                n_timestep,
                dtype=torch.float64,
            ).cuda()
            ** 2
        )
        betas = torch.cat(
            [
                betas[: self.n_timestep // 2],
                torch.flip(betas[: self.n_timestep // 2], dims=(0,)),
            ]
        )
        std_fwd = torch.sqrt(torch.cumsum(betas, dim=0))
        std_bwd = torch.sqrt(torch.flip(torch.cumsum(torch.flip(betas, dims=(0,)), dim=0), dims=(0,)))

        self.alpha_fn = lambda t: torch.ones_like(t).float()
        self.alpha_T = 1
        self.rho_fn = lambda t: std_fwd[t]
        self.rho_T = std_fwd[-1]
        self.rho_bar_fn = lambda t: std_bwd[t]

        self.f_fn = lambda t: torch.zeros_like(t).float()
        self.g2_fn = lambda t: betas[t]

    def get_f_g2(self, t):
        t = ((self.n_timestep - 1) * t).round().long()
        f, g2 = self.f_fn(t), self.g2_fn(t)
        return f, g2

    def get_alpha_rho(self, t):
        t = ((self.n_timestep - 1) * t).round().long()
        alpha_t = self.alpha_fn(t)
        alpha_bar_t = alpha_t / self.alpha_T
        rho_t = self.rho_fn(t)
        rho_bar_t = self.rho_bar_fn(t)
        return alpha_t, alpha_bar_t, rho_t, rho_bar_t


class PreCond:
    def __init__(self, ns):
        raise NotImplementedError

    def _get_scalings_and_weightings(self, t):
        raise NotImplementedError

    def get_scalings_and_weightings(self, t, ndim):
        c_skip, c_in, c_out, c_noise, weightings = self._get_scalings_and_weightings(t)
        c_skip, c_in, c_out, weightings = [append_dims(item, ndim) for item in [c_skip, c_in, c_out, weightings]]
        return c_skip, c_in, c_out, c_noise, weightings


class I2SBPreCond(PreCond):
    def __init__(self, ns, n_timestep=1000, t0=1e-4, T=1.0):
        self.ns = ns
        self.n_timestep = n_timestep
        self.noise_levels = torch.linspace(t0, T, n_timestep).cuda() * n_timestep

    def _get_scalings_and_weightings(self, t):
        _, _, rho_t, _ = self.ns.get_alpha_rho(t)
        c_skip = torch.ones_like(t)
        c_in = torch.ones_like(t)
        c_out = -rho_t
        c_noise = self.noise_levels[((self.n_timestep - 1) * t).round().long()]
        weightings = 1 / c_out**2
        return c_skip, c_in, c_out, c_noise, weightings


class DDBMPreCond(PreCond):
    def __init__(self, ns, sigma_data, cov_xy):
        self.ns, self.sigma_data, self.cov_xy = ns, sigma_data, cov_xy
        self.sigma_data_end = sigma_data

    def _get_scalings_and_weightings(self, t):
        a_t, b_t, c_t = self.ns.get_abc(t)
        A = a_t**2 * self.sigma_data_end**2 + b_t**2 * self.sigma_data**2 + 2 * a_t * b_t * self.cov_xy + c_t**2
        c_in = 1 / (A) ** 0.5
        c_skip = (b_t * self.sigma_data**2 + a_t * self.cov_xy) / A
        c_out = (
            a_t**2 * (self.sigma_data_end**2 * self.sigma_data**2 - self.cov_xy**2) + self.sigma_data**2 * c_t**2
        ) ** 0.5 * c_in
        c_noise = 1000 * 0.25 * torch.log(t + 1e-44)
        weightings = 1 / c_out**2
        return c_skip, c_in, c_out, c_noise, weightings


class KarrasDenoiser:
    def __init__(
        self,
        noise_schedule,
        precond,
        t_max=1.0,
        t_min=0.0001,
        loss_norm="lpips",
    ):

        self.t_max = t_max
        self.t_min = t_min

        self.noise_schedule = noise_schedule
        self.precond = precond

        self.loss_norm = loss_norm
        if loss_norm == "lpips":
            self.lpips_loss = LPIPS(replace_pooling=True, reduction="none")

    def bridge_sample(self, x0, xT, t, noise):
        a_t, b_t, c_t = [append_dims(item, x0.ndim) for item in self.noise_schedule.get_abc(t)]
        samples = a_t * xT + b_t * x0 + c_t * noise
        return samples

    def denoise(self, model, x_t, t, **model_kwargs):
        c_skip, c_in, c_out, c_noise, weightings = self.precond.get_scalings_and_weightings(t, x_t.ndim)
        model_output = model(c_in * x_t, c_noise, **model_kwargs)
        denoised = c_out * model_output + c_skip * x_t
        return model_output, denoised, weightings

    def training_bridge_losses(self, model, x_start, t, model_kwargs=None, noise=None):
        assert model_kwargs is not None
        xT = model_kwargs["xT"]
        mask = model_kwargs.pop("mask", None)
        if noise is None:
            noise = torch.randn_like(x_start)
        t = torch.minimum(t, torch.ones_like(t) * self.t_max)
        terms = {}

        x_t = self.bridge_sample(x_start, xT, t, noise)

        _, denoised, weights = self.denoise(model, x_t, t, **model_kwargs)

        if mask is not None:
            terms["xs_mse"] = mean_flat(mask * (denoised - x_start) ** 2)
            terms["mse"] = mean_flat(weights * mask * (denoised - x_start) ** 2)
        else:
            terms["xs_mse"] = mean_flat((denoised - x_start) ** 2)
            terms["mse"] = mean_flat(weights * (denoised - x_start) ** 2)

        terms["loss"] = terms["mse"]

        return terms


def karras_sample(
    diffusion,
    model,
    x_T,
    x_0,
    steps,
    mask=None,
    clip_denoised=True,
    model_kwargs=None,
    device=None,
    rho=7.0,
    sampler="heun",
    churn_step_ratio=0.0,
    eta=0.0,
    order=2,
    seed=None,
):
    assert sampler in [
        "heun",
        "ground_truth",
        "dbim",
        "dbimt",
        "dbim_high_order",
    ], "only these sampler is supported currently"

    if sampler == "heun":
        ts = get_sigmas_karras(steps, diffusion.t_min, diffusion.t_max - 1e-4, rho, device=device)
    else:
        ts = get_sigmas_uniform(steps, diffusion.t_min, diffusion.t_max - 1e-3, device=device)

    sample_fn = {
        "heun": sample_heun,
        "ground_truth": sample_ground_truth,
        "dbim": sample_dbim,
        "dbimt": sample_dbimt, # corrupt xt by linear decomposition at t instead of t-1
        "dbim_high_order": sample_dbim_high_order,
    }[sampler]

    sampler_args = dict(churn_step_ratio=churn_step_ratio, mask=mask, eta=eta, x_0=x_0, order=order, seed=seed)

    def denoiser(x_t, sigma, **kwargs):
        # 1. 复制一份默认参数
        args = model_kwargs.copy() if model_kwargs else {}
        
        # 2. 如果传入了新参数 (例如 y=1000)，覆盖默认参数
        if kwargs:
            args.update(kwargs)
            
        _, denoised, _ = diffusion.denoise(model, x_t, sigma, **args)
        
        if clip_denoised:
            denoised = denoised.clamp(-1, 1)
        return denoised


    x_0, path, nfe, pred_x0, sigmas, noise = sample_fn(
        denoiser,
        diffusion,
        x_T,
        ts,
        # ⚠️ 注意：要把 model_kwargs 里的 y 和 xT 透传给 sample_fn
        # 这样 sample_fn 才能拿到 y 去构造空标签
        y=model_kwargs.get("y", None) if model_kwargs else None,
        xT=model_kwargs.get("xT", None) if model_kwargs else x_T,
        **sampler_args,
    )
    if dist.get_rank() == 0:
        print("nfe:", nfe)


    return (
        x_0.clamp(-1, 1),
        [x.clamp(-1, 1) for x in path],
        nfe,
        [x.clamp(-1, 1) for x in pred_x0],
        sigmas,
        noise,
    )


def get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, device="cpu"):
    """Constructs the noise schedule of Karras et al. (2022)."""
    ramp = torch.linspace(0, 1, n)
    min_inv_rho = sigma_min ** (1 / rho)
    max_inv_rho = sigma_max ** (1 / rho)
    sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
    return append_zero(sigmas).to(device)


def get_sigmas_uniform(n, t_min, t_max, device="cpu"):
    return torch.linspace(t_max, t_min, n + 1).to(device)

def dynamic_thresholding(img, p=0.995):
    flat = img.abs().flatten(2)
    s = torch.quantile(flat, p, dim=2)
    s = s.max(dim=1, keepdim=True)[0]
    s = torch.clamp(s, min=1.0)
    s = s.unsqueeze(-1).unsqueeze(-1)
    img = torch.clamp(img, -s, s)
    img = img / s
    return img

        
import torch
import torch.distributed as dist
import torch.fft
from tqdm import tqdm
import os
import torchvision.transforms.functional as TF

# ⭐ 1. 频域处理辅助函数 (保持不变)
def apply_frequency_scaling(delta, scale_lf, scale_hf, cutoff=0.7, transition=0.15):
    B, C, H, W = delta.shape
    device = delta.device
    dtype = delta.dtype

    # FFT
    delta_f = torch.fft.fft2(delta.float(), dim=(-2, -1))
    delta_f = torch.fft.fftshift(delta_f, dim=(-2, -1))

    # 频率网格
    yy = torch.linspace(-1.0, 1.0, H, device=device)
    xx = torch.linspace(-1.0, 1.0, W, device=device)
    grid_y, grid_x = torch.meshgrid(yy, xx, indexing='ij')
    radius = torch.sqrt(grid_x**2 + grid_y**2)

    # 低通 Mask
    r0 = cutoff - transition * 0.5
    r1 = cutoff + transition * 0.5
    x_val = ((radius - r0) / (r1 - r0)).clamp(0, 1)
    lp = 1.0 - (3 * x_val**2 - 2 * x_val**3)
    lp = lp.view(1, 1, H, W)

    # 分频缩放
    delta_low_f  = delta_f * lp * scale_lf
    delta_high_f = delta_f * (1.0 - lp) * scale_hf

    # iFFT
    delta_new_f = delta_low_f + delta_high_f
    delta_new_f = torch.fft.ifftshift(delta_new_f, dim=(-2, -1))
    delta_new = torch.fft.ifft2(delta_new_f, dim=(-2, -1)).real.to(dtype)

    return delta_new

@torch.no_grad()
def sample_dbimt(
    denoiser,
    diffusion,
    x,
    ts,
    eta=1.0, 
    mask=None,
    seed=None,
    **kwargs,
):
    # --- 1. 参数与环境准备 ---
    x_T = kwargs.get("xT", x)
    y = kwargs.get("y", None)
    
    y_null = None
    if y is not None:
        y_null = torch.full_like(y, 1000)

    def get_float_safe(key, default):
        val = os.environ.get(key)
        if val is None:
            return default
        return float(val)

    # Scale 参数
    w1 = get_float_safe("guidance_scale", 0.0) 
    w2 = get_float_safe("auto_scale", 0.0)     
    w_auto_lf = get_float_safe("auto_scale_lf", 0.0)
    
    # 频率参数
    freq_cutoff = get_float_safe("frequency_cutoff", 0.7)
    freq_width  = get_float_safe("transition_width", 0.15)
    
    corrupt_type = os.environ.get("corrupt_type", "blur")
    corrupt_scale = get_float_safe("corrupt_scale", 2.0)
    corrupt_fn = lambda image: corrupt_image(image, corrupt_type, corrupt_scale)

    # ============================================================
    # ⭐ 新增：Soft Masking 准备 ⭐
    # ============================================================
    soft_mask = None
    if mask is not None:
        # 对硬 Mask (0/1) 进行高斯模糊，生成软 Mask
        # kernel_size=15, sigma=3.0 可以让边缘有约 5-10 像素的平滑过渡
        soft_mask = TF.gaussian_blur(mask, kernel_size=[15, 15], sigma=[3.0, 3.0])
        # 确保 soft_mask 在 0-1 之间
        soft_mask = soft_mask.clamp(0.0, 1.0)
        
    ag_start_ratio = get_float_safe("ag_start_ratio", 0.5)
    t_max_val = diffusion.t_max
    
    path = []
    pred_x0 = []
    ones = x.new_ones([x.shape[0]])
    indices = range(len(ts) - 1)
    indices = tqdm(indices, disable=(dist.get_rank() != 0))
    nfe = 0

    # ============================================================
    # ⭐ 核心逻辑：含 Delay Start & Soft Masking ⭐
    # ============================================================
    def calc_multiguidance_x0(x_in, t_in):
        current_t = t_in[0].item()
        out_cond = denoiser(x_in, t_in, y=y)
        if w1 <= 0.0 and w2 <= 0.0 and w_auto_lf <= 0.0:
            return out_cond

        delta_cfg = 0.0
        is_ag_active = (current_t / t_max_val) <= ag_start_ratio
        if dist.get_rank() == 0:
            print(f"[Debug] Step: {current_t:.6f} | t_max: {t_max_val} | Active: {is_ag_active}")
        if w1 > 0.0 and (not is_ag_active):
            out_uncond = denoiser(x_in, t_in, y=y_null)
            delta_cfg = out_cond - out_uncond
            if mask is not None:
                delta_cfg = delta_cfg * mask

        # 3. Auto-Guidance 分支
        delta_auto = 0.0
        if is_ag_active and (w2 > 0.0 or w_auto_lf > 0.0):
            x_bad_in = corrupt_fn(x_in)
            out_bad = denoiser(x_bad_in, t_in, y=y)
            
            raw_delta_auto = out_cond - out_bad
            
            # ⭐【Soft Masking 应用】⭐
            # 使用软 Mask 进行预处理，而不是硬 Mask
            # 这会让边界的 raw_delta 平滑降为 0，防止 FFT 看到阶跃信号
            if soft_mask is not None:
                raw_delta_auto = raw_delta_auto * soft_mask
            
            # 频域处理
            delta_auto = apply_frequency_scaling(
                raw_delta_auto, 
                scale_lf=w_auto_lf, 
                scale_hf=w2, 
                cutoff=freq_cutoff,
                transition=freq_width
            )
            
            # FFT 后再次 Mask (清理边缘残余，这里可以用回硬 Mask 或者是 Soft Mask)
            # 推荐继续用 Soft Mask 保持一致性，或者用 Hard Mask 强制截断背景
            if mask is not None:
                delta_auto = delta_auto * mask

        # 4. 线性叠加 (注意：delta_auto 只有在 is_ag_active=True 时才有值)
        x0_combined = out_cond + (w1 * delta_cfg) + delta_auto
        
        # 5. Rescale 修正
        rescale_factor = get_float_safe("guidance_rescale", 0.7)
        if rescale_factor > 0.0:
            std_cond = out_cond.std(dim=(1, 2, 3), keepdim=True)
            std_combined = x0_combined.std(dim=(1, 2, 3), keepdim=True)
            factor = std_cond / (std_combined + 1e-8)
            x0_rescaled = x0_combined * factor
            x0_hat = x0_combined * (1 - rescale_factor) + x0_rescaled * rescale_factor
        else:
            x0_hat = x0_combined

        # 6. Dynamic Thresholding
        try:
            x0_hat = dynamic_thresholding(x0_hat, p=0.995)
        except NameError:
            pass

        # 7. 最终 Mask 融合
        if mask is not None:
            x0_hat = x0_hat * mask + x_T * (1 - mask)

        return x0_hat

    x0_hat = calc_multiguidance_x0(x, diffusion.t_max * ones)
    
    generator = BatchedSeedGenerator(seed)
    noise = generator.randn_like(x0_hat)
    first_noise = noise

    if len(ts) == 1:
        x = x0_hat
    else:
        x = diffusion.bridge_sample(x0_hat, x_T, ts[0] * ones, noise)
        
    path.append(x.detach().cpu())
    pred_x0.append(x0_hat.detach().cpu())
    nfe += 1 
    for idx, i in enumerate(indices):
        s = ts[i]
        t = ts[i + 1]

        x0_hat = calc_multiguidance_x0(x, s * ones)

        # DDBM 更新逻辑
        a_s, b_s, c_s = [append_dims(item, x0_hat.ndim) for item in diffusion.noise_schedule.get_abc(s * ones)]
        a_t, b_t, c_t = [append_dims(item, x0_hat.ndim) for item in diffusion.noise_schedule.get_abc(t * ones)]
        _, _, rho_s, _ = [append_dims(item, x0_hat.ndim) for item in diffusion.noise_schedule.get_alpha_rho(s * ones)]
        alpha_t, _, rho_t, _ = [
            append_dims(item, x0_hat.ndim) for item in diffusion.noise_schedule.get_alpha_rho(t * ones)
        ]

        omega_st = eta * (alpha_t * rho_t) * (1 - rho_t**2 / rho_s**2).sqrt()
        tmp_var = (c_t**2 - omega_st**2).sqrt() / c_s
        coeff_xs = tmp_var
        coeff_x0_hat = b_t - tmp_var * b_s
        coeff_xT = a_t - tmp_var * a_s

        noise = generator.randn_like(x0_hat)
        x = coeff_x0_hat * x0_hat + coeff_xT * x_T + coeff_xs * x + (1 if i != len(ts) - 2 else 0) * omega_st * noise

        path.append(x.detach().cpu())
        pred_x0.append(x0_hat.detach().cpu())
        nfe += 1 

    return x, path, nfe, pred_x0, ts, first_noise

@torch.no_grad()
def sample_dbim_high_order(
    denoiser,
    diffusion,
    x,
    ts,
    mask=None,
    order=2,
    lower_order_final=True,
    seed=None,
    **kwargs,
):
    if order not in [2, 3]:
        raise NotImplementedError("Not supported")
    x_T = x
    path = []
    pred_x0 = []

    ones = x.new_ones([x.shape[0]])
    indices = range(len(ts) - 1)
    indices = tqdm(indices, disable=(dist.get_rank() != 0))

    nfe = 0
    x0_hat = denoiser(x, diffusion.t_max * ones)
    generator = BatchedSeedGenerator(seed)
    noise = generator.randn_like(x0_hat)
    first_noise = noise
    if mask is not None:
        x0_hat = x0_hat * mask + x_T * (1 - mask)
    x = diffusion.bridge_sample(x0_hat, x_T, ts[0] * ones, noise)
    path.append(x.detach().cpu())
    pred_x0.append(x0_hat.detach().cpu())
    nfe += 1

    u = diffusion.t_max
    if u == 1.0:
        u -= 5e-5
    u = [u for _ in range(order - 1)]
    xu_hat = [x0_hat.detach().clone() for _ in range(order - 1)]

    for _, i in enumerate(indices):
        s = ts[i]
        t = ts[i + 1]

        # First Order Update, t < s
        if (lower_order_final and i + 1 == len(ts) - 1) or (i == 0):
            if dist.get_rank() == 0:
                print("Step order 1")
            a_s, b_s, c_s = [append_dims(item, x0_hat.ndim) for item in diffusion.noise_schedule.get_abc(s * ones)]
            a_t, b_t, c_t = [append_dims(item, x0_hat.ndim) for item in diffusion.noise_schedule.get_abc(t * ones)]

            tmp_var = c_t / c_s
            coeff_xs = tmp_var
            coeff_x0_hat = b_t - tmp_var * b_s
            coeff_xT = a_t - tmp_var * a_s

            x0_hat = denoiser(x, s * ones)
            if mask is not None:
                x0_hat = x0_hat * mask + x_T * (1 - mask)
            nfe += 1
            x_old = x
            x = coeff_xs * x_old + coeff_x0_hat * x0_hat + coeff_xT * x_T

        # Second Order Update, t < s < u
        elif order == 2 or i == 1:
            if dist.get_rank() == 0:
                print("Step order 2")
            a_u, b_u, c_u = [append_dims(item, x0_hat.ndim) for item in diffusion.noise_schedule.get_abc(u[-1] * ones)]
            a_s, b_s, c_s = [append_dims(item, x0_hat.ndim) for item in diffusion.noise_schedule.get_abc(s * ones)]
            a_t, b_t, c_t = [append_dims(item, x0_hat.ndim) for item in diffusion.noise_schedule.get_abc(t * ones)]
            lambda_u, lambda_s, lambda_t = (
                torch.log(b_u / c_u),
                torch.log(b_s / c_s),
                torch.log(b_t / c_t),
            )

            x0_hat = denoiser(x, s * ones)
            if mask is not None:
                x0_hat = x0_hat * mask + x_T * (1 - mask)
            nfe += 1
            h = lambda_t - lambda_s
            h2 = lambda_s - lambda_u
            integral = torch.exp(lambda_t) * (
                (1 - torch.exp(-h)) * x0_hat + (torch.exp(-h) + h - 1) * (x0_hat - xu_hat[-1]) / h2
            )
            x_old = x
            x = x_old * (c_t / c_s) + x_T * (a_t - a_s * (c_t / c_s)) + c_t * integral

        elif order == 3:
            if dist.get_rank() == 0:
                print("Step order 3")
            a_u1, b_u1, c_u1 = [
                append_dims(item, x0_hat.ndim) for item in diffusion.noise_schedule.get_abc(u[-1] * ones)
            ]
            a_u2, b_u2, c_u2 = [
                append_dims(item, x0_hat.ndim) for item in diffusion.noise_schedule.get_abc(u[-2] * ones)
            ]
            a_s, b_s, c_s = [append_dims(item, x0_hat.ndim) for item in diffusion.noise_schedule.get_abc(s * ones)]
            a_t, b_t, c_t = [append_dims(item, x0_hat.ndim) for item in diffusion.noise_schedule.get_abc(t * ones)]
            lambda_u2, lambda_u1, lambda_s, lambda_t = (
                torch.log(b_u2 / c_u2),
                torch.log(b_u1 / c_u1),
                torch.log(b_s / c_s),
                torch.log(b_t / c_t),
            )
            x0_hat = denoiser(x, s * ones)
            if mask is not None:
                x0_hat = x0_hat * mask + x_T * (1 - mask)
            nfe += 1

            h = lambda_t - lambda_s
            h1 = lambda_s - lambda_u1
            h2 = lambda_u1 - lambda_u2
            dx0_hat = ((x0_hat - xu_hat[-1]) * (2 * h1 + h2) / h1 - (xu_hat[-1] - xu_hat[-2]) * h1 / h2) / (h1 + h2)
            d2x0_hat = 2 * ((x0_hat - xu_hat[-1]) / h1 - (xu_hat[-1] - xu_hat[-2]) / h2) / (h1 + h2)
            integral = torch.exp(lambda_t) * (
                (1 - torch.exp(-h)) * x0_hat
                + (torch.exp(-h) + h - 1) * dx0_hat
                + (h**2 / 2 - h + 1 - torch.exp(-h)) * d2x0_hat
            )
            x_old = x
            x = x_old * (c_t / c_s) + x_T * (a_t - a_s * (c_t / c_s)) + c_t * integral

        u.append(s)
        u.pop(0)
        xu_hat.append(x0_hat)
        xu_hat.pop(0)

        path.append(x.detach().cpu())
        pred_x0.append(x0_hat.detach().cpu())

    return x, path, nfe, pred_x0, ts, first_noise


import os
import torch
import torch.distributed as dist
import matplotlib
# 强制使用非交互式后端，防止服务器报错
matplotlib.use('Agg') 
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

# =================================================================
# ⭐ 修正版：折中参数，既不过曝也不黑
# =================================================================
def _debug_save_spectrum(delta, name_prefix, step_idx):
    if not dist.is_initialized() or dist.get_rank() == 0:
        env_sample_dir = os.environ.get("sample_dir")
        if env_sample_dir:
            save_dir = os.path.join(env_sample_dir, "debug_freq")
        else:
            save_dir = "debug_freq_fallback"
        os.makedirs(save_dir, exist_ok=True)
        
        Δf = torch.fft.fft2(delta, dim=(-2, -1))
        Δf = torch.fft.fftshift(Δf, dim=(-2, -1))
        magnitude = torch.log(torch.abs(Δf[0]) + 1e-8).mean(dim=0).cpu().numpy()
        
        # ⭐ 核心修改：折中参数
        if "xt" in name_prefix:
            # 噪声图：
            # vmax=2.0 -> 太亮
            # vmax=5.0 -> 太暗
            # 现在的 3.5 应该刚好让噪声呈现出清晰的橙红色纹理
            my_vmin, my_vmax = -2.5, 3.5
        else:
            # 信号残差图 (x0)：保持原样，这个已经是最佳状态
            my_vmin, my_vmax = -4.0, 2.0
            
        plt.figure(figsize=(6, 6))
        plt.imshow(magnitude, cmap='inferno', vmin=my_vmin, vmax=my_vmax)
        plt.colorbar(label='Log Magnitude')
        plt.title(f'{name_prefix} Step {step_idx:03d}')
        plt.axis('off')
        
        filename = f"{name_prefix}_{step_idx:03d}.png"
        save_path = os.path.join(save_dir, filename)
        
        plt.savefig(save_path, bbox_inches='tight', pad_inches=0.1)
        plt.close()

def frequency_cfge(
    x0_prior: torch.Tensor,
    x0_like: torch.Tensor,
    guidance_scale: float = 1.0,        # 低频 scale
    guidance_scale_hf: float = 1.0,     # 高频 scale
    cutoff: float = 0.7,               # 频率中心
    transition_width: float = 0.15,    # 过渡带宽
    enable_vis: bool = False,          # ⭐ 新增：控制开关
):
    B, C, H, W = x0_prior.shape
    device = x0_prior.device

    Δ = x0_like - x0_prior

    Δf = torch.fft.fft2(Δ, dim=(-2, -1))
    Δf = torch.fft.fftshift(Δf, dim=(-2, -1))
    
    # =================================================================
    # ⭐ 新增：可视化频谱模块 (只在 enable_vis=True 时触发)
    # =================================================================
    if enable_vis and (not dist.is_initialized() or dist.get_rank() == 0):
        # 使用函数属性作为静态计数器，确保文件名累加
        if not hasattr(frequency_cfge, "frame_count"):
            frequency_cfge.frame_count = 0
            
        # 1. 确定保存路径 (优先环境变量，没有则用保底路径)
        env_sample_dir = os.environ.get("sample_dir")
        if env_sample_dir:
            save_dir = os.path.join(env_sample_dir, "debug_freq")
        else:
            save_dir = "debug_freq_fallback" # 保底路径，防止静默失败
            
        os.makedirs(save_dir, exist_ok=True)
        
        # 2. 计算幅度谱 (取对数以便观察)
        # 取第一张图 [0]，平均通道 dim=0
        magnitude = torch.log(torch.abs(Δf[0]) + 1e-8).mean(dim=0).cpu().numpy()
        
        # 3. 绘图
        plt.figure(figsize=(6, 6))
        plt.imshow(magnitude, cmap='inferno') # inferno 配色对比度高
        plt.colorbar(label='Log Magnitude')
        plt.title(f'Freq Spectrum Step {frequency_cfge.frame_count:03d}')
        plt.axis('off')
        
        filename = f"freq_dist_{frequency_cfge.frame_count:03d}.png"
        save_path = os.path.join(save_dir, filename)
        
        plt.savefig(save_path, bbox_inches='tight', pad_inches=0.1)
        plt.close()
        
        # 4. 打印日志 (确保你知道它保存了)
        print(f"[Vis] Spectrum saved to: {save_path}")
        
        frequency_cfge.frame_count += 1
    # =================================================================

    yy, xx = torch.meshgrid(
        torch.linspace(-1.0, 1.0, H, device=device),
        torch.linspace(-1.0, 1.0, W, device=device),
        indexing="ij",
    )
    rr = torch.sqrt(xx**2 + yy**2)

    r0 = cutoff - transition_width * 0.5
    r1 = cutoff + transition_width * 0.5

    x = ((rr - r0) / (r1 - r0)).clamp(0, 1)
    lp = 1.0 - (3 * x**2 - 2 * x**3)   # smoothstep
    lp = lp.view(1, 1, H, W)

    # --------------------------------------------------
    # Low / High split
    # --------------------------------------------------
    Δ_low_f = Δf * lp
    Δ_low = torch.fft.ifft2(
        torch.fft.ifftshift(Δ_low_f), dim=(-2, -1)
    ).real

    Δ_high = Δ - Δ_low

    Δ_used = (
        guidance_scale * Δ_low
        + guidance_scale_hf * Δ_high
    )

    return x0_prior + Δ_used
    

@torch.no_grad()
def sample_dbim(
    denoiser,
    diffusion,
    x,
    ts,
    eta=1.0,
    mask=None,
    seed=None,
    **kwargs,
):
    # ==========================================================
    # Batch 计数器逻辑
    # ==========================================================
    if not hasattr(sample_dbim, "batches_processed"):
        sample_dbim.batches_processed = 0
    
    # 只有指定的 batch 才会触发保存可视化
    is_first_batch = (sample_dbim.batches_processed == 0) 

    corrupt_type = os.environ.get("corrupt_type", "none")
    corrupt_scale = get_env_float("corrupt_scale")
    corrupt_fn = lambda image: corrupt_image(image, corrupt_type, corrupt_scale)
    guidance_type = os.environ.get("guidance_type", "none")
    guidance_scale = get_env_float("guidance_scale")  
    hf_extra = get_env_float("hf_extra")             
    ag_start_ratio = get_env_float("ag_start_ratio")
    t_max_val = diffusion.t_max 
    
    x_T = x
    path = []
    pred_x0 = []

    ones = x.new_ones([x.shape[0]])
    indices = range(len(ts) - 1)
    indices = tqdm(indices, disable=(dist.get_rank() != 0))

    nfe = 0
    num_steps = len(ts)

    def get_multiscale_weights(current_t_val):
        ratio = current_t_val / t_max_val
        t_high = max(ag_start_ratio, 1.0 - ag_start_ratio) 
        t_low  = min(ag_start_ratio, 1.0 - ag_start_ratio) 
        gs_low = guidance_scale 
        if ratio > t_high:
            start_ceiling = (guidance_scale + hf_extra) / 2.0 
            progress = (ratio - t_high) / (1.0 - t_high + 1e-6)
            gs_high = hf_extra + (start_ceiling - hf_extra) * (progress ** 2)
        elif ratio < t_low:
            progress = (t_low - ratio) / (t_low + 1e-6)
            gs_high = hf_extra + (guidance_scale - hf_extra) * (progress ** 2)
        else:
            gs_high = hf_extra
        return gs_low, gs_high

    # ==================== 第1步采样逻辑 ====================
    if guidance_type == "none" or guidance_scale == 1.0:
        x0_hat = denoiser(x, diffusion.t_max * ones)
    else:
        x0_good = denoiser(x,  diffusion.t_max * ones)
        x0_bad  = denoiser(corrupt_fn(x), diffusion.t_max * ones)
        gs_low_init, gs_high_init = get_multiscale_weights(diffusion.t_max)
        x0_hat = frequency_cfge(
            x0_prior=x0_bad,
            x0_like=x0_good,
            guidance_scale=gs_low_init, 
            guidance_scale_hf=gs_high_init,        
            cutoff=get_env_float("cutoff"),
            enable_vis=is_first_batch,          
        )

    generator = BatchedSeedGenerator(seed)
    noise = generator.randn_like(x0_hat)
    first_noise = noise

    if mask is not None:
        x0_hat = x0_hat * mask + x_T * (1 - mask)

    # bridge 到第一个 ts[0]
    x = diffusion.bridge_sample(x0_hat, x_T, ts[0] * ones, noise)

    # 根据 guidance 类型计算第一步的 x_bad
    if guidance_type == "x0":
        x_bad = diffusion.bridge_sample(corrupt_fn(x0_hat), x_T, ts[0] * ones, noise)
    elif guidance_type == "x1":
        x_bad = diffusion.bridge_sample(x0_hat, corrupt_fn(x_T), ts[0] * ones, noise)
    elif guidance_type == "xt":
        x_bad = corrupt_fn(x)
    elif guidance_type == "mean":
        x_bad = diffusion.bridge_sample(corrupt_fn(x0_hat), corrupt_fn(x_T), ts[0] * ones, noise)
    elif guidance_type == "mx0":
        x_bad = diffusion.bridge_sample(corrupt_fn(x0_hat), x_T, ts[0] * ones, noise)
    elif guidance_type == "mx1":
        x_bad = diffusion.bridge_sample(x0_hat, corrupt_fn(x_T), ts[0] * ones, noise)
    else:
        x_bad = x
        assert guidance_type == "none"

    # ==================================================================
    # ⭐ 插入点 1：第一步 x_bad 生成后立即保存拼接图
    # ==================================================================
    if is_first_batch and (not dist.is_initialized() or dist.get_rank() == 0):
        env_sample_dir = os.environ.get("sample_dir")
        save_dir = env_sample_dir if env_sample_dir else "."
        os.makedirs(save_dir, exist_ok=True)
        
        # 归一化 & 可以选择只保存前几张以防图片太大
        limit = min(8, x.shape[0]) 
        vis_xt = torch.clamp((x[:limit].detach().cpu() + 1) * 0.5, 0, 1)
        vis_xbad = torch.clamp((x_bad[:limit].detach().cpu() + 1) * 0.5, 0, 1)
        
        # 宽度方向拼接：左 xt，右 xbad
        vis_concat = torch.cat([vis_xt, vis_xbad], dim=2)
        
        save_path = os.path.join(save_dir, f"debug_concat_step_{nfe:03d}.png")
        # nrow=4 表示最终的大图一行放4组对比图
        torchvision.utils.save_image(vis_concat, save_path, nrow=4)
    # ==================================================================
    
    path.append(x.detach().cpu())
    pred_x0.append(x0_hat.detach().cpu())
    nfe += 1
    
    # ---------- 循环采样 ----------
    for _, i in enumerate(indices):
        s = ts[i]
        t = ts[i + 1]

        if guidance_type == "none" or guidance_scale == 1.0:
            x0_hat = denoiser(x, s * ones)
        else:
            x0_good = denoiser(x,  s * ones)
            x0_bad  = denoiser(corrupt_fn(x), s * ones)
            current_t_val = s.item() if isinstance(s, torch.Tensor) else s
            gs_low_step, gs_high_step = get_multiscale_weights(current_t_val)
            x0_hat = frequency_cfge(
                x0_prior=x0_bad,
                x0_like=x0_good,
                guidance_scale=gs_low_step,
                guidance_scale_hf=gs_high_step,
                cutoff=get_env_float("cutoff"),
                enable_vis=is_first_batch,     
            )

        if mask is not None:
            x0_hat = x0_hat * mask + x_T * (1 - mask)

        # ----- DDBM 更新逻辑 -----
        a_s, b_s, c_s = [append_dims(item, x0_hat.ndim) for item in diffusion.noise_schedule.get_abc(s * ones)]
        a_t, b_t, c_t = [append_dims(item, x0_hat.ndim) for item in diffusion.noise_schedule.get_abc(t * ones)]
        _, _, rho_s, _ = [append_dims(item, x0_hat.ndim) for item in diffusion.noise_schedule.get_alpha_rho(s * ones)]
        alpha_t, _, rho_t, _ = [append_dims(item, x0_hat.ndim) for item in diffusion.noise_schedule.get_alpha_rho(t * ones)]

        omega_st = eta * (alpha_t * rho_t) * (1 - rho_t**2 / rho_s**2).sqrt()
        tmp_var = (c_t**2 - omega_st**2).sqrt() / c_s
        coeff_xs = tmp_var
        coeff_x0_hat = b_t - tmp_var * b_s
        coeff_xT = a_t - tmp_var * a_s

        noise = generator.randn_like(x0_hat)
        doob_scale = get_env_float("doob_scale")

        # 更新 x_bad 基准
        if guidance_type == "x0":
            x_bad = (coeff_x0_hat * corrupt_fn(x0_hat) + coeff_xT * x_T + coeff_xs * x + (1 if i != len(ts) - 2 else 0) * omega_st * noise)
        elif guidance_type == "x1":
            x_bad = (coeff_x0_hat * x0_hat + coeff_xT * corrupt_fn(x_T) + coeff_xs * x + (1 if i != len(ts) - 2 else 0) * omega_st * noise)
        elif guidance_type == "xt":
            x_bad = corrupt_fn(coeff_x0_hat * x0_hat + coeff_xT * x_T + coeff_xs * x + (1 if i != len(ts) - 2 else 0) * omega_st * noise)
        elif guidance_type == "mean":
            x_bad = (b_t * corrupt_fn(x0_hat) + a_t * corrupt_fn(x_T) + coeff_xs * (x - a_s * x_T - b_s * x0_hat) + (1 if i != len(ts) - 2 else 0) * omega_st * noise)
        elif guidance_type == "mx0":
            x_bad = (b_t * corrupt_fn(x0_hat) + a_t * x_T + coeff_xs * (x - a_s * x_T - b_s * x0_hat) + (1 if i != len(ts) - 2 else 0) * omega_st * noise)
        elif guidance_type == "mx1":
            x_bad = (b_t * x0_hat + a_t * corrupt_fn(x_T) + coeff_xs * (x - a_s * x_T - b_s * x0_hat) + (1 if i != len(ts) - 2 else 0) * omega_st * noise)
        else:
            x_bad = x

        f_s, g2_s = [append_dims(item, x.ndim) for item in diffusion.noise_schedule.get_f_g2(s * ones)]
        alpha_s, alpha_bar_s, _, rho_bar_s = [append_dims(item, x.ndim) for item in diffusion.noise_schedule.get_alpha_rho(s * ones)]
        grad_logpxTlxs = -(x - alpha_bar_s * x_T) / (alpha_s**2 * rho_bar_s**2)

        # Doob step 更新 x 和 x_bad
        x = (coeff_x0_hat * x0_hat + coeff_xT * x_T + coeff_xs * x + (1 if i != len(ts) - 2 else 0) * omega_st * noise)
        x += (t - s) * g2_s * (doob_scale - 1) * grad_logpxTlxs
        x_bad += (t - s) * g2_s * (doob_scale - 1) * grad_logpxTlxs

        # ==================================================================
        # ⭐ 插入点 2：循环中每次更新完 x 和 x_bad 后立即保存拼接图
        # ==================================================================
        if is_first_batch and (not dist.is_initialized() or dist.get_rank() == 0):
            env_sample_dir = os.environ.get("sample_dir")
            save_dir = env_sample_dir if env_sample_dir else "."
            os.makedirs(save_dir, exist_ok=True)
            
            limit = min(8, x.shape[0])
            vis_xt = torch.clamp((x[:limit].detach().cpu() + 1) * 0.5, 0, 1)
            vis_xbad = torch.clamp((x_bad[:limit].detach().cpu() + 1) * 0.5, 0, 1)
            
            # 宽度方向拼接：左 xt，右 xbad
            vis_concat = torch.cat([vis_xt, vis_xbad], dim=2)
            
            # 文件名带上当前的 nfe
            save_path = os.path.join(save_dir, f"debug_concat_step_{nfe:03d}.png")
            torchvision.utils.save_image(vis_concat, save_path, nrow=4)
        # ==================================================================

        path.append(x.detach().cpu())
        pred_x0.append(x0_hat.detach().cpu())
        nfe += 1
        
    sample_dbim.batches_processed += 1

    return x, path, nfe, pred_x0, ts, first_noise
@torch.no_grad()
def sample_ground_truth(
    denoiser,
    diffusion,
    x,
    ts,
    x0=None,
    **kwargs,
):
    assert x0 is not None
    x_T = x
    path = []
    pred_x0 = []

    ones = x.new_ones([x.shape[0]])
    indices = range(len(ts) - 1)
    indices = tqdm(indices, disable=(dist.get_rank() != 0))

    nfe = 0
    x0_hat = denoiser(x, diffusion.t_max * ones)
    noise = torch.randn_like(x0)
    first_noise = noise
    x = diffusion.bridge_sample(x0_hat, x_T, ts[0] * ones, noise)
    path.append(x.detach().cpu())
    pred_x0.append(x0_hat.detach().cpu())
    nfe += 1

    for _, i in enumerate(indices):
        s = ts[i]
        t = ts[i + 1]

        x0_hat = denoiser(x, s * ones)
        noise = torch.randn_like(x0)
        x = diffusion.bridge_sample(x0, x_T, t * ones, noise)

        path.append(x.detach().cpu())
        pred_x0.append(x0_hat.detach().cpu())
        nfe += 1

    return x, path, nfe, pred_x0, ts, first_noise


def get_d(denoiser, noise_schedule, x, x_T, t, stochastic):
    ones = x.new_ones([x.shape[0]])
    f_t, g2_t = [append_dims(item, x.ndim) for item in noise_schedule.get_f_g2(t * ones)]
    alpha_t, alpha_bar_t, _, rho_bar_t = [append_dims(item, x.ndim) for item in noise_schedule.get_alpha_rho(t * ones)]
    a_t, b_t, c_t = [append_dims(item, x.ndim) for item in noise_schedule.get_abc(t * ones)]
    denoised = denoiser(x, t * ones)
    grad_logq = -(x - (a_t * x_T + b_t * denoised)) / c_t**2
    grad_logpxTlxt = -(x - alpha_bar_t * x_T) / (alpha_t**2 * rho_bar_t**2)
    d = f_t * x - g2_t * ((0.5 if not stochastic else 1) * grad_logq - grad_logpxTlxt)
    return d, g2_t, denoised


def ddbm_simulate(denoiser, noise_schedule, x, x_T, t_cur, t_next, stochastic, second_order=False):
    dt = t_next - t_cur
    if isinstance(noise_schedule, I2SBNoiseSchedule):
        dt = dt * (noise_schedule.n_timestep - 1)
    d, g2_t, pred_x0 = get_d(denoiser, noise_schedule, x, x_T, t_cur, stochastic)
    x_new = x + d * dt + (0 if not stochastic else 1) * torch.randn_like(x) * ((dt).abs() ** 0.5) * g2_t.sqrt()
    if second_order:
        d_2, _, pred_x0 = get_d(denoiser, noise_schedule, x_new, x_T, t_next, stochastic)
        d_prime = (d + d_2) / 2
        x_new = (
            x + d_prime * dt + (0 if not stochastic else 1) * torch.randn_like(x) * ((dt).abs() ** 0.5) * g2_t.sqrt()
        )
    return x_new, pred_x0


@torch.no_grad()
def sample_heun(
    denoiser,
    diffusion,
    x,
    ts,
    churn_step_ratio=0.0,
    **kwargs,
):
    # ============================================================
    # 【新增功能】打印完整的时间步列表，验证是否为 Karras 调度
    # ============================================================
    if dist.get_rank() == 0:
        print(f"\n[Debug] Time Schedule (ts) | Total Steps: {len(ts)-1}")
        # 转为 list 打印，方便直接复制观察数值分布
        print(ts.cpu().tolist())
        print("="*60)
    # ============================================================

    x_T = x
    path = []
    pred_x0 = []
    corrupt_type  = os.environ.get("corrupt_type", "none")
    corrupt_scale = get_env_float("corrupt_scale")
    guidance_scale = get_env_float("guidance_scale")
    
    # --- 变量命名与 sample_dbimt 保持一致，统一使用 get_env_float ---
    ag_start_ratio = get_env_float("ag_start_ratio")
    t_max_val = diffusion.t_max
    
    corrupt_fn = lambda img: corrupt_image(img, corrupt_type, corrupt_scale)
    
    indices = range(len(ts) - 1)
    indices = tqdm(indices, disable=(dist.get_rank() != 0))
    nfe = 0
    for _, i in enumerate(indices):
        def guided_denoiser(x_t, t):
            current_t = t[0].item() if isinstance(t, torch.Tensor) else t
            is_ag_active = (current_t / t_max_val) <= ag_start_ratio
            x0_good = denoiser(x_t, t)
            if not is_ag_active or guidance_scale == 1.0:
                return x0_good
            x0_bad = denoiser(corrupt_fn(x_t),t)
            x0_hat = frequency_cfge(
                x0_prior=x0_bad,
                x0_like=x0_good,
                guidance_scale=guidance_scale,
                guidance_scale_hf=guidance_scale,
                cutoff=get_env_float("cutoff"),
            )
            return x0_hat


        if churn_step_ratio > 0:
            t_hat = (ts[i+1] - ts[i]) * churn_step_ratio + ts[i]

            x, _pred_x0 = ddbm_simulate(
                guided_denoiser,                     
                diffusion.noise_schedule,
                x,
                x_T,
                ts[i],
                t_hat,
                stochastic=True,
            )
            nfe += 1
            # path.append(x.detach().cpu())
            # pred_x0.append(_pred_x0.detach().cpu())
        else:
            t_hat = ts[i]

        # --------------------------------------------------
        # 2) Heun step (INLINE CFG)
        # --------------------------------------------------
        if ts[i + 1] == 0:
            x, _pred_x0 = ddbm_simulate(
                guided_denoiser,
                diffusion.noise_schedule,
                x,
                x_T,
                t_hat,
                ts[i + 1],
                stochastic=False,
            )
            nfe += 1
        else:
            x, _pred_x0 = ddbm_simulate(
                guided_denoiser,
                diffusion.noise_schedule,
                x,
                x_T,
                t_hat,
                ts[i + 1],
                stochastic=False,
                second_order=True,
            )
            nfe += 2

        path.append(x.detach().cpu())
        pred_x0.append(_pred_x0.detach().cpu())

    return x, path, nfe, pred_x0, ts, None