"""
Based on: https://github.com/crowsonkb/k-diffusion
"""

import os
import numpy as np
import torch
from piq import LPIPS
from tqdm.auto import tqdm
import torch.distributed as dist
import torchvision
import math
from types import SimpleNamespace
from corruption import build_corruption
import torchvision
from corruption.blur import build_blur

from .nn import mean_flat, append_dims, append_zero
from .random_util import BatchedSeedGenerator

def get_env_float(name):
    float_value = os.environ.get(name)
    try:
        return float(float_value)
    except ValueError:
        raise ValueError(f"Invalid value for {name}: {float_value}")

def corrupt_image(image, corrupt_type=None, corrupt_scale=0.0):
    if corrupt_type == "void":
        # raise "`bridge_guidance_type` 'void' need prior dropout in bridge training"
        # image = torch.zeros_like(image)
        image = image ** 3
    elif corrupt_type == "noise":
        image_noise_sigma = torch.normal(
            mean=-3.0, std=0.5, size=(image.shape[0],), device=image.device)
        image_noise_sigma = torch.exp(
            image_noise_sigma).to(dtype=image.dtype)
        image = image + torch.randn_like(image) * \
            image_noise_sigma[:, None, None, None] * corrupt_scale
    elif corrupt_type == "blur":
        image = torchvision.transforms.functional.gaussian_blur(image, kernel_size=3, sigma=corrupt_scale)
    elif corrupt_type.startswith("blur"):
        ks = int(corrupt_type[4:])
        image = torchvision.transforms.functional.gaussian_blur(image, kernel_size=ks, sigma=corrupt_scale)
    elif corrupt_type.startswith("blur_"):
        ks = int(corrupt_type[5:])
        image = torchvision.transforms.functional.gaussian_blur(image, kernel_size=ks)
    elif corrupt_type.startswith("avg"):
        ks = int(corrupt_type[3:])
        pd = (ks - 1) // 2
        image = torch.nn.functional.avg_pool2d(image, kernel_size=ks, stride=1, padding=pd)
    # elif corrupt_type == "downsample":
    #     orig_size = image.shape[-2:]
    #     downsampled = F.interpolate(
    #         image, scale_factor=0.95, mode="bilinear", align_corners=True)
    #     image = F.interpolate(downsampled, size=orig_size,
    #                           mode="bilinear", align_corners=True)
    # elif corrupt_type == 'jpeg':
    #     qf = 95
    #     image = jpeg_decode(jpeg_encode(image, qf), qf)
    else:
        opt = SimpleNamespace(device=image.device, image_size=image.shape[-1])
        log = None  # No logging needed here
        image = build_corruption(opt, log, corrupt_type)(image.float())

    return image

class NoiseSchedule:
    def __init__(self):
        raise NotImplementedError

    def get_f_g2(self, t):
        raise NotImplementedError

    def get_alpha_rho(self, t):
        raise NotImplementedError

    def get_abc(self, t):
        alpha_t, alpha_bar_t, rho_t, rho_bar_t = self.get_alpha_rho(t)
        a_t, b_t, c_t = (
            (alpha_bar_t * rho_t**2) / self.rho_T**2,
            (alpha_t * rho_bar_t**2) / self.rho_T**2,
            (alpha_t * rho_bar_t * rho_t) / self.rho_T,
        )
        return a_t, b_t, c_t


class VPNoiseSchedule(NoiseSchedule):
    def __init__(self, beta_d=2, beta_min=0.1):
        self.beta_d, self.beta_min = beta_d, beta_min
        self.alpha_fn = lambda t: np.e ** (-0.5 * beta_min * t - 0.25 * beta_d * t**2)
        self.alpha_T = self.alpha_fn(1)
        self.rho_fn = lambda t: (np.e ** (beta_min * t + 0.5 * beta_d * t**2) - 1).sqrt()
        self.rho_T = self.rho_fn(torch.DoubleTensor([1])).item()

        self.f_fn = lambda t: (-0.5 * beta_min - 0.5 * beta_d * t)
        self.g2_fn = lambda t: (beta_min + beta_d * t)

    def get_f_g2(self, t):
        t = t.to(torch.float64)
        f, g2 = self.f_fn(t), self.g2_fn(t)
        return f, g2

    def get_alpha_rho(self, t):
        t = t.to(torch.float64)
        alpha_t = self.alpha_fn(t)
        alpha_bar_t = alpha_t / self.alpha_T
        rho_t = self.rho_fn(t)
        rho_bar_t = (self.rho_T**2 - rho_t**2).sqrt()
        return alpha_t, alpha_bar_t, rho_t, rho_bar_t


class VENoiseSchedule(NoiseSchedule):
    def __init__(self, sigma_max=80.0):
        self.sigma_max = sigma_max
        self.alpha_fn = lambda t: torch.ones_like(t)
        self.alpha_T = 1
        self.rho_fn = lambda t: t
        self.rho_T = sigma_max

        self.f_fn = lambda t: torch.zeros_like(t)
        self.g2_fn = lambda t: 2 * t

    def get_f_g2(self, t):
        t = t.to(torch.float64)
        f, g2 = self.f_fn(t), self.g2_fn(t)
        return f, g2

    def get_alpha_rho(self, t):
        t = t.to(torch.float64)
        alpha_t = self.alpha_fn(t)
        alpha_bar_t = alpha_t / self.alpha_T
        rho_t = self.rho_fn(t)
        rho_bar_t = (self.rho_T**2 - rho_t**2).sqrt()
        return alpha_t, alpha_bar_t, rho_t, rho_bar_t


class I2SBNoiseSchedule(NoiseSchedule):
    def __init__(self, n_timestep=1000, beta_min=0.1, beta_max=1.0):
        self.n_timestep, self.linear_start, self.linear_end = (
            n_timestep,
            beta_min / n_timestep,
            beta_max / n_timestep,
        )
        betas = (
            torch.linspace(
                self.linear_start**0.5,
                self.linear_end**0.5,
                n_timestep,
                dtype=torch.float64,
            ).cuda()
            ** 2
        )
        betas = torch.cat(
            [
                betas[: self.n_timestep // 2],
                torch.flip(betas[: self.n_timestep // 2], dims=(0,)),
            ]
        )
        std_fwd = torch.sqrt(torch.cumsum(betas, dim=0))
        std_bwd = torch.sqrt(torch.flip(torch.cumsum(torch.flip(betas, dims=(0,)), dim=0), dims=(0,)))

        self.alpha_fn = lambda t: torch.ones_like(t).float()
        self.alpha_T = 1
        self.rho_fn = lambda t: std_fwd[t]
        self.rho_T = std_fwd[-1]
        self.rho_bar_fn = lambda t: std_bwd[t]

        self.f_fn = lambda t: torch.zeros_like(t).float()
        self.g2_fn = lambda t: betas[t]

    def get_f_g2(self, t):
        t = ((self.n_timestep - 1) * t).round().long()
        f, g2 = self.f_fn(t), self.g2_fn(t)
        return f, g2

    def get_alpha_rho(self, t):
        t = ((self.n_timestep - 1) * t).round().long()
        alpha_t = self.alpha_fn(t)
        alpha_bar_t = alpha_t / self.alpha_T
        rho_t = self.rho_fn(t)
        rho_bar_t = self.rho_bar_fn(t)
        return alpha_t, alpha_bar_t, rho_t, rho_bar_t


class PreCond:
    def __init__(self, ns):
        raise NotImplementedError

    def _get_scalings_and_weightings(self, t):
        raise NotImplementedError

    def get_scalings_and_weightings(self, t, ndim):
        c_skip, c_in, c_out, c_noise, weightings = self._get_scalings_and_weightings(t)
        c_skip, c_in, c_out, weightings = [append_dims(item, ndim) for item in [c_skip, c_in, c_out, weightings]]
        return c_skip, c_in, c_out, c_noise, weightings


class I2SBPreCond(PreCond):
    def __init__(self, ns, n_timestep=1000, t0=1e-4, T=1.0):
        self.ns = ns
        self.n_timestep = n_timestep
        self.noise_levels = torch.linspace(t0, T, n_timestep).cuda() * n_timestep

    def _get_scalings_and_weightings(self, t):
        _, _, rho_t, _ = self.ns.get_alpha_rho(t)
        c_skip = torch.ones_like(t)
        c_in = torch.ones_like(t)
        c_out = -rho_t
        c_noise = self.noise_levels[((self.n_timestep - 1) * t).round().long()]
        weightings = 1 / c_out**2
        return c_skip, c_in, c_out, c_noise, weightings


class DDBMPreCond(PreCond):
    def __init__(self, ns, sigma_data, cov_xy):
        self.ns, self.sigma_data, self.cov_xy = ns, sigma_data, cov_xy
        self.sigma_data_end = sigma_data

    def _get_scalings_and_weightings(self, t):
        a_t, b_t, c_t = self.ns.get_abc(t)
        A = a_t**2 * self.sigma_data_end**2 + b_t**2 * self.sigma_data**2 + 2 * a_t * b_t * self.cov_xy + c_t**2
        c_in = 1 / (A) ** 0.5
        c_skip = (b_t * self.sigma_data**2 + a_t * self.cov_xy) / A
        c_out = (
            a_t**2 * (self.sigma_data_end**2 * self.sigma_data**2 - self.cov_xy**2) + self.sigma_data**2 * c_t**2
        ) ** 0.5 * c_in
        c_noise = 1000 * 0.25 * torch.log(t + 1e-44)
        weightings = 1 / c_out**2
        return c_skip, c_in, c_out, c_noise, weightings


class KarrasDenoiser:
    def __init__(
        self,
        noise_schedule,
        precond,
        t_max=1.0,
        t_min=0.0001,
        loss_norm="lpips",
    ):

        self.t_max = t_max
        self.t_min = t_min

        self.noise_schedule = noise_schedule
        self.precond = precond

        self.loss_norm = loss_norm
        if loss_norm == "lpips":
            self.lpips_loss = LPIPS(replace_pooling=True, reduction="none")

    def bridge_sample(self, x0, xT, t, noise):
        a_t, b_t, c_t = [append_dims(item, x0.ndim) for item in self.noise_schedule.get_abc(t)]
        samples = a_t * xT + b_t * x0 + c_t * noise
        return samples

    def denoise(self, model, x_t, t, **model_kwargs):
        c_skip, c_in, c_out, c_noise, weightings = self.precond.get_scalings_and_weightings(t, x_t.ndim)
        model_output = model(c_in * x_t, c_noise, **model_kwargs)
        denoised = c_out * model_output + c_skip * x_t
        return model_output, denoised, weightings

    def training_bridge_losses(self, model, x_start, t, model_kwargs=None, noise=None):
        assert model_kwargs is not None
        xT = model_kwargs["xT"]
        mask = model_kwargs.pop("mask", None)
        if noise is None:
            noise = torch.randn_like(x_start)
        t = torch.minimum(t, torch.ones_like(t) * self.t_max)
        terms = {}

        x_t = self.bridge_sample(x_start, xT, t, noise)

        _, denoised, weights = self.denoise(model, x_t, t, **model_kwargs)

        if mask is not None:
            terms["xs_mse"] = mean_flat(mask * (denoised - x_start) ** 2)
            terms["mse"] = mean_flat(weights * mask * (denoised - x_start) ** 2)
        else:
            terms["xs_mse"] = mean_flat((denoised - x_start) ** 2)
            terms["mse"] = mean_flat(weights * (denoised - x_start) ** 2)

        terms["loss"] = terms["mse"]

        return terms


def karras_sample(
    diffusion,
    model,
    x_T,
    x_0,
    steps,
    mask=None,
    clip_denoised=True,
    model_kwargs=None,
    device=None,
    rho=7.0,
    sampler="heun",
    churn_step_ratio=0.0,
    eta=0.0,
    order=2,
    seed=None,
):
    assert sampler in [
        "heun",
        "ground_truth",
        "dbim",
        "dbimt",
        "dbim_high_order",
    ], "only these sampler is supported currently"

    if sampler == "heun":
        ts = get_sigmas_karras(steps, diffusion.t_min, diffusion.t_max - 1e-4, rho, device=device)
    else:
        ts = get_sigmas_uniform(steps, diffusion.t_min, diffusion.t_max - 1e-3, device=device)

    sample_fn = {
        "heun": sample_heun,
        "ground_truth": sample_ground_truth,
        "dbim": sample_dbim,
        "dbimt": sample_dbimt, # corrupt xt by linear decomposition at t instead of t-1
        "dbim_high_order": sample_dbim_high_order,
    }[sampler]

    sampler_args = dict(churn_step_ratio=churn_step_ratio, mask=mask, eta=eta, x_0=x_0, order=order, seed=seed)

    def denoiser(x_t, sigma):
        _, denoised, _ = diffusion.denoise(model, x_t, sigma, **model_kwargs)
        if clip_denoised:
            denoised = denoised.clamp(-1, 1)
        return denoised

    x_0, path, nfe, pred_x0, sigmas, noise = sample_fn(
        denoiser,
        diffusion,
        x_T,
        ts,
        **sampler_args,
    )
    if dist.get_rank() == 0:
        print("nfe:", nfe)
    return (
        x_0.clamp(-1, 1),
        [x.clamp(-1, 1) for x in path],
        nfe,
        [x.clamp(-1, 1) for x in pred_x0],
        sigmas,
        noise,
    )


def get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, device="cpu"):
    """Constructs the noise schedule of Karras et al. (2022)."""
    ramp = torch.linspace(0, 1, n)
    min_inv_rho = sigma_min ** (1 / rho)
    max_inv_rho = sigma_max ** (1 / rho)
    sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
    return append_zero(sigmas).to(device)


def get_sigmas_uniform(n, t_min, t_max, device="cpu"):
    return torch.linspace(t_max, t_min, n + 1).to(device)


@torch.no_grad()
def sample_dbim_high_order(
    denoiser,
    diffusion,
    x,
    ts,
    mask=None,
    order=2,
    lower_order_final=True,
    seed=None,
    **kwargs,
):
    if order not in [2, 3]:
        raise NotImplementedError("Not supported")
    x_T = x
    path = []
    pred_x0 = []

    ones = x.new_ones([x.shape[0]])
    indices = range(len(ts) - 1)
    indices = tqdm(indices, disable=(dist.get_rank() != 0))

    nfe = 0
    x0_hat = denoiser(x, diffusion.t_max * ones)
    generator = BatchedSeedGenerator(seed)
    noise = generator.randn_like(x0_hat)
    first_noise = noise
    if mask is not None:
        x0_hat = x0_hat * mask + x_T * (1 - mask)
    x = diffusion.bridge_sample(x0_hat, x_T, ts[0] * ones, noise)
    path.append(x.detach().cpu())
    pred_x0.append(x0_hat.detach().cpu())
    nfe += 1

    u = diffusion.t_max
    if u == 1.0:
        u -= 5e-5
    u = [u for _ in range(order - 1)]
    xu_hat = [x0_hat.detach().clone() for _ in range(order - 1)]

    for _, i in enumerate(indices):
        s = ts[i]
        t = ts[i + 1]

        # First Order Update, t < s
        if (lower_order_final and i + 1 == len(ts) - 1) or (i == 0):
            if dist.get_rank() == 0:
                print("Step order 1")
            a_s, b_s, c_s = [append_dims(item, x0_hat.ndim) for item in diffusion.noise_schedule.get_abc(s * ones)]
            a_t, b_t, c_t = [append_dims(item, x0_hat.ndim) for item in diffusion.noise_schedule.get_abc(t * ones)]

            tmp_var = c_t / c_s
            coeff_xs = tmp_var
            coeff_x0_hat = b_t - tmp_var * b_s
            coeff_xT = a_t - tmp_var * a_s

            x0_hat = denoiser(x, s * ones)
            if mask is not None:
                x0_hat = x0_hat * mask + x_T * (1 - mask)
            nfe += 1
            x_old = x
            x = coeff_xs * x_old + coeff_x0_hat * x0_hat + coeff_xT * x_T

        # Second Order Update, t < s < u
        elif order == 2 or i == 1:
            if dist.get_rank() == 0:
                print("Step order 2")
            a_u, b_u, c_u = [append_dims(item, x0_hat.ndim) for item in diffusion.noise_schedule.get_abc(u[-1] * ones)]
            a_s, b_s, c_s = [append_dims(item, x0_hat.ndim) for item in diffusion.noise_schedule.get_abc(s * ones)]
            a_t, b_t, c_t = [append_dims(item, x0_hat.ndim) for item in diffusion.noise_schedule.get_abc(t * ones)]
            lambda_u, lambda_s, lambda_t = (
                torch.log(b_u / c_u),
                torch.log(b_s / c_s),
                torch.log(b_t / c_t),
            )

            x0_hat = denoiser(x, s * ones)
            if mask is not None:
                x0_hat = x0_hat * mask + x_T * (1 - mask)
            nfe += 1
            h = lambda_t - lambda_s
            h2 = lambda_s - lambda_u
            integral = torch.exp(lambda_t) * (
                (1 - torch.exp(-h)) * x0_hat + (torch.exp(-h) + h - 1) * (x0_hat - xu_hat[-1]) / h2
            )
            x_old = x
            x = x_old * (c_t / c_s) + x_T * (a_t - a_s * (c_t / c_s)) + c_t * integral

        elif order == 3:
            if dist.get_rank() == 0:
                print("Step order 3")
            a_u1, b_u1, c_u1 = [
                append_dims(item, x0_hat.ndim) for item in diffusion.noise_schedule.get_abc(u[-1] * ones)
            ]
            a_u2, b_u2, c_u2 = [
                append_dims(item, x0_hat.ndim) for item in diffusion.noise_schedule.get_abc(u[-2] * ones)
            ]
            a_s, b_s, c_s = [append_dims(item, x0_hat.ndim) for item in diffusion.noise_schedule.get_abc(s * ones)]
            a_t, b_t, c_t = [append_dims(item, x0_hat.ndim) for item in diffusion.noise_schedule.get_abc(t * ones)]
            lambda_u2, lambda_u1, lambda_s, lambda_t = (
                torch.log(b_u2 / c_u2),
                torch.log(b_u1 / c_u1),
                torch.log(b_s / c_s),
                torch.log(b_t / c_t),
            )
            x0_hat = denoiser(x, s * ones)
            if mask is not None:
                x0_hat = x0_hat * mask + x_T * (1 - mask)
            nfe += 1

            h = lambda_t - lambda_s
            h1 = lambda_s - lambda_u1
            h2 = lambda_u1 - lambda_u2
            dx0_hat = ((x0_hat - xu_hat[-1]) * (2 * h1 + h2) / h1 - (xu_hat[-1] - xu_hat[-2]) * h1 / h2) / (h1 + h2)
            d2x0_hat = 2 * ((x0_hat - xu_hat[-1]) / h1 - (xu_hat[-1] - xu_hat[-2]) / h2) / (h1 + h2)
            integral = torch.exp(lambda_t) * (
                (1 - torch.exp(-h)) * x0_hat
                + (torch.exp(-h) + h - 1) * dx0_hat
                + (h**2 / 2 - h + 1 - torch.exp(-h)) * d2x0_hat
            )
            x_old = x
            x = x_old * (c_t / c_s) + x_T * (a_t - a_s * (c_t / c_s)) + c_t * integral

        u.append(s)
        u.pop(0)
        xu_hat.append(x0_hat)
        xu_hat.pop(0)

        path.append(x.detach().cpu())
        pred_x0.append(x0_hat.detach().cpu())

    return x, path, nfe, pred_x0, ts, first_noise


import os
import math
import torch
import torch.distributed as dist
from tqdm import tqdm

from .random_util import BatchedSeedGenerator
from .karras_diffusion import corrupt_image, get_env_float



import torch

def frequency_cfg(
    x0_prior: torch.Tensor,
    x0_like: torch.Tensor,
    step_idx: int,
    num_steps: int,
    mask: torch.Tensor = None,
    base_scale: float = 1.0,      # 低频 scale
    hf_extra: float = 1.0,        # 高频 scale
    cutoff: float = 0.7,          # 频率分界中心 (0.0 ~ 1.414)
    transition_width: float = 0.15, # 过渡带宽度
    x_T: torch.Tensor = None,     # 用于 mask 外的观测/条件
):
    B, C, H, W = x0_prior.shape
    device = x0_prior.device
    dtype = x0_prior.dtype

    if mask is None:
        raise ValueError("frequency_cfg for inpainting requires mask")
    if x_T is None:
        raise ValueError("frequency_cfg for inpainting requires x_T (known pixels outside mask)")

    # 处理 mask 维度
    if mask.dim() == 3:
        mask = mask.unsqueeze(1)
    if mask.shape[1] != C:
        mask = mask.expand(B, C, H, W)
    mask = mask.to(dtype)

    # 1) 【关键步骤】先把 prior/like 在 mask 外强制对齐到 x_T
    #    确保 Δ 在 mask 外严格为 0，防止 FFT 时背景噪声泄露干扰频谱
    x0_prior = x0_prior * mask + x_T * (1 - mask)
    x0_like  = x0_like  * mask + x_T * (1 - mask)

    # 2) 计算 Δ (只取 mask 内)
    Δ = (x0_like - x0_prior) * mask

    # ---- FFT ----
    Δf = torch.fft.fft2(Δ, dim=(-2, -1))
    Δf = torch.fft.fftshift(Δf, dim=(-2, -1))

    # ---- Frequency Radius ----
    yy, xx = torch.meshgrid(
        torch.linspace(-1.0, 1.0, H, device=device, dtype=dtype),
        torch.linspace(-1.0, 1.0, W, device=device, dtype=dtype),
        indexing="ij",
    )
    # 半径矩阵 0 ~ sqrt(2) ≈ 1.414
    rr = (xx**2 + yy**2).sqrt()

    # ---- Smooth Band Mask (Cutoff Logic) ----
    # 计算过渡区间的起点和终点
    r0 = cutoff - transition_width * 0.5
    r1 = cutoff + transition_width * 0.5

    # 线性插值位置 0.0 -> 1.0
    x_val = ((rr - r0) / (r1 - r0)).clamp(0.0, 1.0)
    
    # Smoothstep: 3x^2 - 2x^3 (平滑的 S 型曲线，比线性过渡更自然)
    # hp (High Pass) 在低频为 0，高频为 1
    hp_mask = 3 * x_val**2 - 2 * x_val**3
    
    # lp (Low Pass) 在低频为 1，高频为 0
    lp_mask = 1.0 - hp_mask

    # 调整维度以匹配广播
    hp_mask = hp_mask.view(1, 1, H, W)
    lp_mask = lp_mask.view(1, 1, H, W)

    # ---- Split & Scale in Frequency Domain ----
    # 频域直接加权合成，避免多次 iFFT
    # Low Freq part * base_scale + High Freq part * hf_extra
    Δ_new_f = Δf * (lp_mask * base_scale + hp_mask * hf_extra)

    # ---- iFFT ----
    Δ_new_f = torch.fft.ifftshift(Δ_new_f, dim=(-2, -1))
    Δ_new = torch.fft.ifft2(Δ_new_f, dim=(-2, -1)).real

    # 3) x0_hat = prior + Δ_new
    x0_hat = x0_prior + Δ_new
    
    # 4) 【最后保险】再次强制回填背景
    # 因为频域操作(Scaling)是全局的，会导致 mask 外产生微小的振铃/泄露噪声，必须切除
    if mask is not None:
        x0_hat = x0_hat * mask + x_T * (1 - mask)

    return x0_hat
@torch.no_grad()
def sample_dbimt(
    denoiser,
    diffusion,
    x,
    ts,
    eta=1.0,
    mask=None,
    seed=None,
    **kwargs,
):
    x_T = x # 对于 Inpainting 任务，初始输入的 x 通常就是原图（作为条件）
    path = []
    pred_x0 = []
    
    # ---- 1. 获取环境变量参数 ----
    corrupt_type = os.environ.get("corrupt_type", "none")
    corrupt_scale = get_env_float("corrupt_scale")
    corrupt_fn = lambda image: corrupt_image(image, corrupt_type, corrupt_scale)
    
    guidance_type = os.environ.get("guidance_type", "none")
    guidance_scale = get_env_float("guidance_scale")
    
    # ⭐新增：获取高频部分的 scale 参数 (如果没设置，默认和 guidance_scale 一样，即不分频)
    hf_scale = os.environ.get("hf_scale")
    hf_scale = float(hf_scale) if hf_scale is not None else guidance_scale

    ones = x.new_ones([x.shape[0]])
    indices = range(len(ts) - 1)
    indices = tqdm(indices, disable=(dist.get_rank() != 0))

    nfe = 0
    
    # ---- 2. 首步预测 (Initialization) ----
    # 计算 x0_good (当前状态的预测)
    x0_good = denoiser(x, diffusion.t_max * ones)
    
    if guidance_type in ['x0',  'none']:
        x0_hat = x0_good
    elif guidance_type in ['x1', 'xt', 'mean']:
        # 构造 Bad Input
        x_bad = corrupt_fn(x)
        # 计算 x0_bad
        x0_bad = denoiser(x_bad, diffusion.t_max * ones)
        
        # ⭐修改：使用 frequency_cfg 替代 lerp
        # 原逻辑: x0_bad.lerp(x0_good, scale) => x0_bad + scale * (x0_good - x0_bad)
        # frequency_cfg 逻辑: prior + delta => prior + scale * (like - prior)
        # 所以: prior=x0_bad, like=x0_good
        if mask is not None:
            x0_hat = frequency_cfg(
                x0_prior=x0_bad,
                x0_like=x0_good,
                step_idx=0,
                num_steps=len(ts),
                mask=mask,
                base_scale=guidance_scale, # 低频使用基础 scale
                hf_extra=hf_scale,         # 高频使用额外 scale
                x_T=x_T
            )
        else:
            # 如果没有 mask，退回普通 lerp (或者报错，取决于具体需求)
            x0_hat = x0_bad.lerp(x0_good, guidance_scale)
    else:
        raise ValueError(f"Unknown guidance type: {guidance_type}")

    generator = BatchedSeedGenerator(seed)
    noise = generator.randn_like(x0_hat)
    first_noise = noise
    
    # 强制背景对齐 (虽然 frequency_cfg 做过了，但如果是 none 类型则没做)
    if mask is not None:
        x0_hat = x0_hat * mask + x_T * (1 - mask)
    
    if len(ts) == 1:
        x = x0_hat
    else:
        x = diffusion.bridge_sample(x0_hat, x_T, ts[0] * ones, noise)
        
    path.append(x.detach().cpu())
    pred_x0.append(x0_hat.detach().cpu())
    nfe += 1

    # ---- 3. 采样循环 (Loop) ----
    for _, i in enumerate(indices):
        s = ts[i]
        t = ts[i + 1]

        # A. 计算 Good Prediction
        x0_good = denoiser(x, s * ones)
        # 确保 Good Prediction 的背景是正确的 (这也作为 frequency_cfg 里的输入基础)
        if mask is not None:
            x0_good = x0_good * mask + x_T * (1 - mask)

        a_s, b_s, c_s = [append_dims(item, x0_good.ndim) for item in diffusion.noise_schedule.get_abc(s * ones)]
        a_t, b_t, c_t = [append_dims(item, x0_good.ndim) for item in diffusion.noise_schedule.get_abc(t * ones)]
        
        # B. 构造 x_bad (根据不同的策略)
        if guidance_type == "none":
            x_bad = x # 实际上不需要计算
        elif guidance_type == "x0":
            x_bad = x + b_s * (corrupt_fn(x0_good) - x0_good)
        elif guidance_type == "x1":
            x_bad = x + a_s * (corrupt_fn(x_T) - x_T)
        elif guidance_type == "xt":
            x_bad = corrupt_fn(x)
        elif guidance_type == "mean":
            x_bad = x + b_s * (corrupt_fn(x0_good) - x0_good) + a_s * (corrupt_fn(x_T) - x_T)
        else:
            raise ValueError(f"Unknown guidance type: {guidance_type}")

        # C. 应用 Guidance
        if guidance_type == "none":
            x0_hat = x0_good
        else:
            # 计算 Bad Prediction
            x0_bad = denoiser(x_bad, s * ones)
            
            # ⭐修改：使用 frequency_cfg
            if mask is not None:
                x0_hat = frequency_cfg(
                    x0_prior=x0_bad,           # Bad prediction (Unconditional-ish)
                    x0_like=x0_good,           # Good prediction (Conditional)
                    step_idx=i,
                    num_steps=len(indices),
                    mask=mask,
                    base_scale=guidance_scale, # 低频权重 (结构)
                    hf_extra=hf_scale,         # 高频权重 (纹理/细节)
                    x_T=x_T                    # 背景参考图
                )
            else:
                x0_hat = x0_bad.lerp(x0_good, guidance_scale)

        # 再次确保背景正确 (双重保险，防止频域计算残留背景噪声)
        if mask is not None:
            x0_hat = x0_hat * mask + x_T * (1 - mask)

        # D. 状态更新 (Bridge Sampling Step)
        _, _, rho_s, _ = [append_dims(item, x0_hat.ndim) for item in diffusion.noise_schedule.get_alpha_rho(s * ones)]
        alpha_t, _, rho_t, _ = [
            append_dims(item, x0_hat.ndim) for item in diffusion.noise_schedule.get_alpha_rho(t * ones)
        ]

        omega_st = eta * (alpha_t * rho_t) * (1 - rho_t**2 / rho_s**2).sqrt()
        tmp_var = (c_t**2 - omega_st**2).sqrt() / c_s
        coeff_xs = tmp_var
        coeff_x0_hat = b_t - tmp_var * b_s
        coeff_xT = a_t - tmp_var * a_s

        noise = generator.randn_like(x0_hat)

        # 更新 x
        x = coeff_x0_hat * x0_hat + coeff_xT * x_T + coeff_xs * x + (1 if i != len(ts) - 2 else 0) * omega_st * noise
        
        # ⭐建议：这里也可以加上一步 x 的背景回填，确保迭代过程中 x 的背景不漂移
        # if mask is not None:
        #     x = x * mask + x_T * (1 - mask)  # 取决于你的 Bridge 定义，通常加上更稳

        path.append(x.detach().cpu())
        pred_x0.append(x0_hat.detach().cpu())
        nfe += 1

    return x, path, nfe, pred_x0, ts, first_noise

@torch.no_grad()
def sample_dbim(
    denoiser,
    diffusion,
    x,
    ts,
    eta=1.0,
    mask=None,
    seed=None,
    **kwargs,
):
    corrupt_type = os.environ.get("corrupt_type", "none")
    corrupt_scale = get_env_float("corrupt_scale")
    corrupt_fn = lambda image: corrupt_image(image, corrupt_type, corrupt_scale)
    guidance_type = os.environ.get("guidance_type", "none")
    guidance_scale = get_env_float("guidance_scale")

    # 这三个参数从脚本里 export，没给就用默认值
    hf_extra = float(os.environ.get("hf_extra", "0.6"))

    x_T = x
    path = []
    pred_x0 = []

    ones = x.new_ones([x.shape[0]])
    indices = range(len(ts) - 1)
    indices = tqdm(indices, disable=(dist.get_rank() != 0))

    nfe = 0
    num_steps = len(ts)  # 用来归一化 step_idx

    # ---------- 初始化 x_bad ----------
    if guidance_type == "x0":
        x_bad = x
    elif guidance_type in ["x1", "xt", "mean", "mx1"]:
        x_bad = corrupt_fn(x)
    elif guidance_type == "mx0":
        x_bad = x
    else:
        x_bad = x
        assert guidance_type == "none"

    # ---------- 第一步：t_max ----------
    if guidance_type == "none":
        x0_hat = denoiser(x, diffusion.t_max * ones)
    else:
        x0_good = denoiser(x,  diffusion.t_max * ones)
        x0_bad  = denoiser(corrupt_fn(x), diffusion.t_max * ones)

        x0_hat = frequency_cfg(
            x0_prior=x0_bad,
            x0_like=x0_good,                       # ✅ 正确来源
            guidance_scale=guidance_scale,
            guidance_scale_hf=get_env_float("hf_extra"),
            cutoff=get_env_float("cutoff"),
        )

    generator = BatchedSeedGenerator(seed)
    noise = generator.randn_like(x0_hat)
    first_noise = noise

    if mask is not None:
        x0_hat = x0_hat * mask + x_T * (1 - mask)

    # bridge 到第一个 ts[0]
    x = diffusion.bridge_sample(x0_hat, x_T, ts[0] * ones, noise)

    # 根据 guidance_type 更新 x_bad
    if guidance_type == "x0":
        x_bad = diffusion.bridge_sample(corrupt_fn(x0_hat), x_T, ts[0] * ones, noise)
    elif guidance_type == "x1":
        x_bad = diffusion.bridge_sample(x0_hat, corrupt_fn(x_T), ts[0] * ones, noise)
    elif guidance_type == "xt":
        x_bad = corrupt_fn(x)
    elif guidance_type == "mean":
        x_bad = diffusion.bridge_sample(
            corrupt_fn(x0_hat), corrupt_fn(x_T), ts[0] * ones, noise
        )
    elif guidance_type == "mx0":
        x_bad = diffusion.bridge_sample(corrupt_fn(x0_hat), x_T, ts[0] * ones, noise)
    elif guidance_type == "mx1":
        x_bad = diffusion.bridge_sample(x0_hat, corrupt_fn(x_T), ts[0] * ones, noise)
    else:
        x_bad = x
        assert guidance_type == "none"

    path.append(x.detach().cpu())
    pred_x0.append(x0_hat.detach().cpu())
    nfe += 1

    # ---------- 后续各步 s→t ----------
    for _, i in enumerate(indices):
        s = ts[i]
        t = ts[i + 1]

        # guidance：先算 prior / like，然后 frequency CFG
        if guidance_type == "none":
            x0_hat = denoiser(x, s * ones)
        else:
            x0_good = denoiser(x, s * ones)
            x0_bad  = denoiser(corrupt_fn(x), s * ones)

            x0_hat = frequency_cfg(
                x0_prior=x0_bad,
                x0_like=x0_good,                      # ✅ 正确来源
                guidance_scale=guidance_scale,
                guidance_scale_hf=get_env_float("hf_extra"),
                cutoff=get_env_float("cutoff"),
            )

        if mask is not None:
            x0_hat = x0_hat * mask + x_T * (1 - mask)

        # ----- 原来的 bridge + Doob correction 全保留 -----
        a_s, b_s, c_s = [
            append_dims(item, x0_hat.ndim)
            for item in diffusion.noise_schedule.get_abc(s * ones)
        ]
        a_t, b_t, c_t = [
            append_dims(item, x0_hat.ndim)
            for item in diffusion.noise_schedule.get_abc(t * ones)
        ]

        _, _, rho_s, _ = [
            append_dims(item, x0_hat.ndim)
            for item in diffusion.noise_schedule.get_alpha_rho(s * ones)
        ]
        alpha_t, _, rho_t, _ = [
            append_dims(item, x0_hat.ndim)
            for item in diffusion.noise_schedule.get_alpha_rho(t * ones)
        ]

        omega_st = eta * (alpha_t * rho_t) * (1 - rho_t**2 / rho_s**2).sqrt()
        tmp_var = (c_t**2 - omega_st**2).sqrt() / c_s
        coeff_xs = tmp_var
        coeff_x0_hat = b_t - tmp_var * b_s
        coeff_xT = a_t - tmp_var * a_s

        noise = generator.randn_like(x0_hat)

        doob_scale = get_env_float("doob_scale")

        if guidance_type == "x0":
            x_bad = (
                coeff_x0_hat * corrupt_fn(x0_hat)
                + coeff_xT * x_T
                + coeff_xs * x
                + (1 if i != len(ts) - 2 else 0) * omega_st * noise
            )
        elif guidance_type == "x1":
            x_bad = (
                coeff_x0_hat * x0_hat
                + coeff_xT * corrupt_fn(x_T)
                + coeff_xs * x
                + (1 if i != len(ts) - 2 else 0) * omega_st * noise
            )
        elif guidance_type == "xt":
            x_bad = corrupt_fn(
                coeff_x0_hat * x0_hat
                + coeff_xT * x_T
                + coeff_xs * x
                + (1 if i != len(ts) - 2 else 0) * omega_st * noise
            )
        elif guidance_type == "mean":
            x_bad = (
                b_t * corrupt_fn(x0_hat)
                + a_t * corrupt_fn(x_T)
                + coeff_xs * (x - a_s * x_T - b_s * x0_hat)
                + (1 if i != len(ts) - 2 else 0) * omega_st * noise
            )
        elif guidance_type == "mx0":
            x_bad = (
                b_t * corrupt_fn(x0_hat)
                + a_t * x_T
                + coeff_xs * (x - a_s * x_T - b_s * x0_hat)
                + (1 if i != len(ts) - 2 else 0) * omega_st * noise
            )
        elif guidance_type == "mx1":
            x_bad = (
                b_t * x0_hat
                + a_t * corrupt_fn(x_T)
                + coeff_xs * (x - a_s * x_T - b_s * x0_hat)
                + (1 if i != len(ts) - 2 else 0) * omega_st * noise
            )
        else:
            x_bad = x
            assert guidance_type == "none"

        f_s, g2_s = [
            append_dims(item, x.ndim)
            for item in diffusion.noise_schedule.get_f_g2(s * ones)
        ]
        alpha_s, alpha_bar_s, _, rho_bar_s = [
            append_dims(item, x.ndim)
            for item in diffusion.noise_schedule.get_alpha_rho(s * ones)
        ]
        grad_logpxTlxs = -(x - alpha_bar_s * x_T) / (alpha_s**2 * rho_bar_s**2)

        x = (
            coeff_x0_hat * x0_hat
            + coeff_xT * x_T
            + coeff_xs * x
            + (1 if i != len(ts) - 2 else 0) * omega_st * noise
        )
        x += (t - s) * g2_s * (doob_scale - 1) * grad_logpxTlxs
        x_bad += (t - s) * g2_s * (doob_scale - 1) * grad_logpxTlxs

        path.append(x.detach().cpu())
        pred_x0.append(x0_hat.detach().cpu())
        nfe += 1

    return x, path, nfe, pred_x0, ts, first_noise




@torch.no_grad()
def sample_ground_truth(
    denoiser,
    diffusion,
    x,
    ts,
    x0=None,
    **kwargs,
):
    assert x0 is not None
    x_T = x
    path = []
    pred_x0 = []

    ones = x.new_ones([x.shape[0]])
    indices = range(len(ts) - 1)
    indices = tqdm(indices, disable=(dist.get_rank() != 0))

    nfe = 0
    x0_hat = denoiser(x, diffusion.t_max * ones)
    noise = torch.randn_like(x0)
    first_noise = noise
    x = diffusion.bridge_sample(x0_hat, x_T, ts[0] * ones, noise)
    path.append(x.detach().cpu())
    pred_x0.append(x0_hat.detach().cpu())
    nfe += 1

    for _, i in enumerate(indices):
        s = ts[i]
        t = ts[i + 1]

        x0_hat = denoiser(x, s * ones)
        noise = torch.randn_like(x0)
        x = diffusion.bridge_sample(x0, x_T, t * ones, noise)

        path.append(x.detach().cpu())
        pred_x0.append(x0_hat.detach().cpu())
        nfe += 1

    return x, path, nfe, pred_x0, ts, first_noise


def get_d(denoiser, noise_schedule, x, x_T, t, stochastic):
    ones = x.new_ones([x.shape[0]])
    f_t, g2_t = [append_dims(item, x.ndim) for item in noise_schedule.get_f_g2(t * ones)]
    alpha_t, alpha_bar_t, _, rho_bar_t = [append_dims(item, x.ndim) for item in noise_schedule.get_alpha_rho(t * ones)]
    a_t, b_t, c_t = [append_dims(item, x.ndim) for item in noise_schedule.get_abc(t * ones)]
    denoised = denoiser(x, t * ones)
    grad_logq = -(x - (a_t * x_T + b_t * denoised)) / c_t**2
    grad_logpxTlxt = -(x - alpha_bar_t * x_T) / (alpha_t**2 * rho_bar_t**2)

    doob_scale = get_env_float("doob_scale")
    d = f_t * x - g2_t * ((1 if stochastic else 0.5) * grad_logq - doob_scale * grad_logpxTlxt)
    return d, g2_t, denoised

def ddbm_simulate(denoiser, noise_schedule, x, x_T, t_cur, t_next, stochastic, second_order=False):
    dt = t_next - t_cur
    if isinstance(noise_schedule, I2SBNoiseSchedule):
        dt = dt * (noise_schedule.n_timestep - 1)
    d, g2_t, pred_x0 = get_d(denoiser, noise_schedule, x, x_T, t_cur, stochastic)
    noise = (1 if stochastic else 0) * torch.randn_like(x) * ((dt).abs() ** 0.5) * g2_t.sqrt()
    x_new = x + d * dt + noise
    
    
    
    ones = x.new_ones([x.shape[0]])
    f_t, g2_t = [append_dims(item, x.ndim) for item in noise_schedule.get_f_g2(t_cur * ones)]
    alpha_t, alpha_bar_t, _, rho_bar_t = [append_dims(item, x.ndim) for item in noise_schedule.get_alpha_rho(t_cur * ones)]
    a_t, b_t, c_t = [append_dims(item, x.ndim) for item in noise_schedule.get_abc(t_cur * ones)]
    denoised = denoiser(x, t_cur * ones)
    grad_logq = -(x - (a_t * x_T + b_t * denoised)) / c_t**2
    grad_logpxTlxt = -(x - alpha_bar_t * x_T) / (alpha_t**2 * rho_bar_t**2)
    doob_scale = get_env_float("doob_scale")
    d = f_t * x - g2_t * ((1 if stochastic else 0.5) * grad_logq - doob_scale * grad_logpxTlxt)
    x_new_ = x + d * dt + noise
    
    coeff_x = 1 + dt * f_t + dt * g2_t * (1 if stochastic else 0.5) / c_t**2 - dt * g2_t * doob_scale / (alpha_t**2 * rho_bar_t**2)
    coeff_x0 = -dt * g2_t * (1 if stochastic else 0.5) * b_t / c_t**2
    coeff_xT = -dt * g2_t * (1 if stochastic else 0.5) * a_t / c_t**2 + dt * g2_t * doob_scale * alpha_bar_t / (alpha_t**2 * rho_bar_t**2)
    x_new__ = coeff_x * x + coeff_x0 * denoised + coeff_xT * x_T + noise
    
    # print(torch.abs(x_new - x_new_).mean())
    # print(torch.abs(x_new - x_new__).mean())

    if second_order:
        d_2, _, pred_x0 = get_d(denoiser, noise_schedule, x_new, x_T, t_next, stochastic)
        d_prime = (d + d_2) / 2
        x_new = (
            x + d_prime * dt + (1 if stochastic else 0) * torch.randn_like(x) * ((dt).abs() ** 0.5) * g2_t.sqrt()
        )
    return x_new, pred_x0

def frequency_cfg_lowpass_only(
    x0_prior: torch.Tensor,
    x0_like: torch.Tensor,
    t: torch.Tensor,                   # (B,)
    t_max: float,
    guidance_scale: float = 1.0,        # 低频 scale
    guidance_scale_hf: float = 1.0,     # 高频 scale
    cutoff: float = 0.3,               # 频率中心（不是硬切）
    transition_width: float = 0.15,    # 过渡带宽
    hf_start_ratio: float = 0.5,        # 高频开启时间
):
    B, C, H, W = x0_prior.shape
    device = x0_prior.device

    # --------------------------------------------------
    # Δ
    # --------------------------------------------------
    Δ = x0_like - x0_prior

    # --------------------------------------------------
    # FFT
    # --------------------------------------------------
    Δf = torch.fft.fft2(Δ, dim=(-2, -1))
    Δf = torch.fft.fftshift(Δf, dim=(-2, -1))

    # --------------------------------------------------
    # ✅ 平滑频域 low-pass（没有任何 hard cutoff）
    # --------------------------------------------------
    yy, xx = torch.meshgrid(
        torch.linspace(-1.0, 1.0, H, device=device),
        torch.linspace(-1.0, 1.0, W, device=device),
        indexing="ij",
    )
    rr = torch.sqrt(xx**2 + yy**2)

    r0 = cutoff - transition_width * 0.5
    r1 = cutoff + transition_width * 0.5

    x = ((rr - r0) / (r1 - r0)).clamp(0, 1)
    lp = 1.0 - (3 * x**2 - 2 * x**3)   # smoothstep
    lp = lp.view(1, 1, H, W)

    # --------------------------------------------------
    # Low / High split
    # --------------------------------------------------
    Δ_low_f = Δf * lp
    Δ_low = torch.fft.ifft2(
        torch.fft.ifftshift(Δ_low_f), dim=(-2, -1)
    ).real

    Δ_high = Δ - Δ_low

    # --------------------------------------------------
    # ✅ 时间维度：高频延迟开启（hard gate）
    # --------------------------------------------------
    t_norm = (t / t_max).view(B, 1, 1, 1)
    hf_mask = (t_norm <= hf_start_ratio).float()

    # --------------------------------------------------
    # 合成
    # --------------------------------------------------
    Δ_used = (
        guidance_scale * Δ_low
        + guidance_scale_hf *  Δ_high
    )

    return x0_prior + Δ_used
    
@torch.no_grad()
def sample_heun(
    denoiser,
    diffusion,
    x,
    ts,
    churn_step_ratio=0.0,
    **kwargs,
):
    x_T = x
    path = []
    pred_x0 = []

    corrupt_type  = os.environ.get("corrupt_type", "none")
    corrupt_scale = get_env_float("corrupt_scale")
    guidance_scale = get_env_float("guidance_scale")
    corrupt_fn = lambda img: corrupt_image(img, corrupt_type, corrupt_scale)

    indices = range(len(ts) - 1)
    indices = tqdm(indices, disable=(dist.get_rank() != 0))
    nfe = 0
    t_max=ts[0]
    for _, i in enumerate(indices):

        # --------------------------------------------------
        # 1) churn step (NO guidance)
        # --------------------------------------------------
        if churn_step_ratio > 0:
            t_hat = (ts[i+1] - ts[i]) * churn_step_ratio + ts[i]

            x, _pred_x0 = ddbm_simulate(
                denoiser,                     # ← 原始 denoiser
                diffusion.noise_schedule,
                x,
                x_T,
                ts[i],
                t_hat,
                stochastic=True,
            )
            nfe += 1
            path.append(x.detach().cpu())
            pred_x0.append(_pred_x0.detach().cpu())
        else:
            t_hat = ts[i]

        # --------------------------------------------------
        # 2) Heun step (INLINE CFG)
        # --------------------------------------------------
        def guided_denoiser(x_t, t_vec):
            x0_good = denoiser(x_t, t_vec)
            x0_bad  = denoiser(corrupt_fn(x_t), t_vec)

            x0_hat = frequency_cfg_lowpass_only(
                x0_prior=x0_bad,
                x0_like=x0_good,
                t=t_vec,
                t_max=t_max,                         # ✅ 正确来源
                guidance_scale=guidance_scale,
                guidance_scale_hf=get_env_float("hf_extra"),
            )
            return x0_hat

        if ts[i + 1] == 0:
            x, _pred_x0 = ddbm_simulate(
                guided_denoiser,
                diffusion.noise_schedule,
                x,
                x_T,
                t_hat,
                ts[i + 1],
                stochastic=False,
            )
            nfe += 1
        else:
            x, _pred_x0 = ddbm_simulate(
                guided_denoiser,
                diffusion.noise_schedule,
                x,
                x_T,
                t_hat,
                ts[i + 1],
                stochastic=False,
                second_order=True,
            )
            nfe += 2

        path.append(x.detach().cpu())
        pred_x0.append(_pred_x0.detach().cpu())

    return x, path, nfe, pred_x0, ts, None
@torch.no_grad()
def sample_dbimt(
    denoiser,
    diffusion,
    x,
    ts,
    eta=1.0,
    mask=None,
    seed=None,
    **kwargs,
):
    x_T = x
    path = []
    pred_x0 = []
    
    corrupt_type = os.environ.get("corrupt_type", "none")
    corrupt_scale = get_env_float("corrupt_scale")
    corrupt_fn = lambda image: corrupt_image(image, corrupt_type, corrupt_scale)
    guidance_type = os.environ.get("guidance_type", "none")
    guidance_scale = get_env_float("guidance_scale")
    hf_extra = get_env_float("hf_extra")
    ones = x.new_ones([x.shape[0]])
    indices = range(len(ts) - 1)
    indices = tqdm(indices, disable=(dist.get_rank() != 0))

    nfe = 0
    if guidance_type in ['x0',  'none']:
        x0_hat = denoiser(x, diffusion.t_max * ones)
    elif guidance_type in ['x1', 'xt', 'mean']:
        x_bad = corrupt_fn(x)
        x0_prior = denoiser(x_bad, diffusion.t_max * ones)
        x0_like  = denoiser(x,     diffusion.t_max * ones)
        x0_hat = frequency_cfgi(
            x0_prior=x0_prior,
            x0_like=x0_like,
            mask=mask,
            base_scale=guidance_scale,
            hf_extra=hf_extra,
            x_T=x_T,
        )
    else:
        raise 
    generator = BatchedSeedGenerator(seed)
    noise = generator.randn_like(x0_hat)
    first_noise = noise
    if mask is not None:
        x0_hat = x0_hat * mask + x_T * (1 - mask)
    
    if len(ts) == 1:
        x = x0_hat
    else:
        x = diffusion.bridge_sample(x0_hat, x_T, ts[0] * ones, noise)
        
    path.append(x.detach().cpu())
    pred_x0.append(x0_hat.detach().cpu())
    nfe += 1

    for _, i in enumerate(indices):
        s = ts[i]
        t = ts[i + 1]

        x0_hat = denoiser(x, s * ones)
        if mask is not None:
            x0_hat = x0_hat * mask + x_T * (1 - mask)

        a_s, b_s, c_s = [append_dims(item, x0_hat.ndim) for item in diffusion.noise_schedule.get_abc(s * ones)]
        a_t, b_t, c_t = [append_dims(item, x0_hat.ndim) for item in diffusion.noise_schedule.get_abc(t * ones)]
        
        if guidance_type == "x0":
            x_bad = x + b_s * (corrupt_fn(x0_hat) - x0_hat)
        elif guidance_type == "x1":
            x_bad = x + a_s * (corrupt_fn(x_T) - x_T)
        elif guidance_type == "xt":
            x_bad = corrupt_fn(x)
        elif guidance_type == "mean":
            x_bad = x + b_s * (corrupt_fn(x0_hat) - x0_hat) + a_s * (corrupt_fn(x_T) - x_T)
        else:
            x_bad = x
            assert guidance_type == "none"
        x0_prior = denoiser(x_bad, s * ones)
        x0_like  = denoiser(x,     s * ones)
        x0_hat = frequency_cfgi(
            x0_prior=x0_prior,
            x0_like=x0_like,
            mask=mask,
            base_scale=guidance_scale,
            hf_extra=hf_extra,
            x_T=x_T,
        )
        if mask is not None:
            x0_hat = x0_hat * mask + x_T * (1 - mask)

        _, _, rho_s, _ = [append_dims(item, x0_hat.ndim) for item in diffusion.noise_schedule.get_alpha_rho(s * ones)]
        alpha_t, _, rho_t, _ = [
            append_dims(item, x0_hat.ndim) for item in diffusion.noise_schedule.get_alpha_rho(t * ones)
        ]

        omega_st = eta * (alpha_t * rho_t) * (1 - rho_t**2 / rho_s**2).sqrt()
        tmp_var = (c_t**2 - omega_st**2).sqrt() / c_s
        coeff_xs = tmp_var
        coeff_x0_hat = b_t - tmp_var * b_s
        coeff_xT = a_t - tmp_var * a_s

        noise = generator.randn_like(x0_hat)

        x = coeff_x0_hat * x0_hat + coeff_xT * x_T + coeff_xs * x + (1 if i != len(ts) - 2 else 0) * omega_st * noise

        path.append(x.detach().cpu())
        pred_x0.append(x0_hat.detach().cpu())
        nfe += 1

    return x, path, nfe, pred_x0, ts, first_noise
    @torch.no_grad()
def sample_dbim(
    denoiser,
    diffusion,
    x,
    ts,
    eta=1.0,
    mask=None,
    seed=None,
    **kwargs,
):
    corrupt_type = os.environ.get("corrupt_type", "none")
    corrupt_scale = get_env_float("corrupt_scale")
    corrupt_fn = lambda image: corrupt_image(image, corrupt_type, corrupt_scale)
    guidance_type = os.environ.get("guidance_type", "none")
    guidance_scale = get_env_float("guidance_scale")
    hf_extra = get_env_float("hf_extra")

    x_T = x
    path = []
    pred_x0 = []

    ones = x.new_ones([x.shape[0]])
    indices = range(len(ts) - 1)
    indices = tqdm(indices, disable=(dist.get_rank() != 0))

    nfe = 0
    num_steps = len(ts)  # 用来归一化 step_idx

    # ---------- 初始化 x_bad ----------
    if guidance_type == "x0":
        x_bad = x
    elif guidance_type in ["x1", "xt", "mean", "mx1"]:
        x_bad = corrupt_fn(x)
    elif guidance_type == "mx0":
        x_bad = x
    else:
        x_bad = x
        assert guidance_type == "none"

    # ---------- 第一步：t_max ----------
    if guidance_type == "none" or guidance_scale == 1.0:
        x0_hat = denoiser(x, diffusion.t_max * ones)
    else:
        x0_good = denoiser(x,  diffusion.t_max * ones)
        x0_bad  = denoiser(corrupt_fn(x), diffusion.t_max * ones)
        x0_hat = frequency_cfge(
            x0_prior=x0_bad,
            x0_like=x0_good,
            guidance_scale=guidance_scale,
            guidance_scale_hf=hf_extra,        # 这里传入处理好的变量
            cutoff=get_env_float("cutoff"),
        )

    generator = BatchedSeedGenerator(seed)
    noise = generator.randn_like(x0_hat)
    first_noise = noise

    if mask is not None:
        x0_hat = x0_hat * mask + x_T * (1 - mask)

    # bridge 到第一个 ts[0]
    x = diffusion.bridge_sample(x0_hat, x_T, ts[0] * ones, noise)

    # 根据 guidance_type 更新 x_bad
    if guidance_type == "x0":
        x_bad = diffusion.bridge_sample(corrupt_fn(x0_hat), x_T, ts[0] * ones, noise)
    elif guidance_type == "x1":
        x_bad = diffusion.bridge_sample(x0_hat, corrupt_fn(x_T), ts[0] * ones, noise)
    elif guidance_type == "xt":
        x_bad = corrupt_fn(x)
    elif guidance_type == "mean":
        x_bad = diffusion.bridge_sample(
            corrupt_fn(x0_hat), corrupt_fn(x_T), ts[0] * ones, noise
        )
    elif guidance_type == "mx0":
        x_bad = diffusion.bridge_sample(corrupt_fn(x0_hat), x_T, ts[0] * ones, noise)
    elif guidance_type == "mx1":
        x_bad = diffusion.bridge_sample(x0_hat, corrupt_fn(x_T), ts[0] * ones, noise)
    else:
        x_bad = x
        assert guidance_type == "none"

    path.append(x.detach().cpu())
    pred_x0.append(x0_hat.detach().cpu())
    nfe += 1
    for _, i in enumerate(indices):
        s = ts[i]
        t = ts[i + 1]

        # guidance：先算 prior / like，然后 frequency CFG
        if guidance_type == "none" or guidance_scale == 1.0:
            x0_hat = denoiser(x, s * ones)
        else:
            x0_good = denoiser(x,  s * ones)
            x0_bad  = denoiser(corrupt_fn(x), s * ones)
            x0_hat = frequency_cfge(
                x0_prior=x0_bad,
                x0_like=x0_good,
                guidance_scale=guidance_scale,
                guidance_scale_hf=hf_extra,          
                cutoff=get_env_float("cutoff"),
            )

        if mask is not None:
            x0_hat = x0_hat * mask + x_T * (1 - mask)

        # ----- 原来的 bridge + Doob correction 全保留 -----
        a_s, b_s, c_s = [
            append_dims(item, x0_hat.ndim)
            for item in diffusion.noise_schedule.get_abc(s * ones)
        ]
        a_t, b_t, c_t = [
            append_dims(item, x0_hat.ndim)
            for item in diffusion.noise_schedule.get_abc(t * ones)
        ]

        _, _, rho_s, _ = [
            append_dims(item, x0_hat.ndim)
            for item in diffusion.noise_schedule.get_alpha_rho(s * ones)
        ]
        alpha_t, _, rho_t, _ = [
            append_dims(item, x0_hat.ndim)
            for item in diffusion.noise_schedule.get_alpha_rho(t * ones)
        ]

        omega_st = eta * (alpha_t * rho_t) * (1 - rho_t**2 / rho_s**2).sqrt()
        tmp_var = (c_t**2 - omega_st**2).sqrt() / c_s
        coeff_xs = tmp_var
        coeff_x0_hat = b_t - tmp_var * b_s
        coeff_xT = a_t - tmp_var * a_s

        noise = generator.randn_like(x0_hat)

        doob_scale = get_env_float("doob_scale")

        if guidance_type == "x0":
            x_bad = (
                coeff_x0_hat * corrupt_fn(x0_hat)
                + coeff_xT * x_T
                + coeff_xs * x
                + (1 if i != len(ts) - 2 else 0) * omega_st * noise
            )
        elif guidance_type == "x1":
            x_bad = (
                coeff_x0_hat * x0_hat
                + coeff_xT * corrupt_fn(x_T)
                + coeff_xs * x
                + (1 if i != len(ts) - 2 else 0) * omega_st * noise
            )
        elif guidance_type == "xt":
            x_bad = corrupt_fn(
                coeff_x0_hat * x0_hat
                + coeff_xT * x_T
                + coeff_xs * x
                + (1 if i != len(ts) - 2 else 0) * omega_st * noise
            )
        elif guidance_type == "mean":
            x_bad = (
                b_t * corrupt_fn(x0_hat)
                + a_t * corrupt_fn(x_T)
                + coeff_xs * (x - a_s * x_T - b_s * x0_hat)
                + (1 if i != len(ts) - 2 else 0) * omega_st * noise
            )
        elif guidance_type == "mx0":
            x_bad = (
                b_t * corrupt_fn(x0_hat)
                + a_t * x_T
                + coeff_xs * (x - a_s * x_T - b_s * x0_hat)
                + (1 if i != len(ts) - 2 else 0) * omega_st * noise
            )
        elif guidance_type == "mx1":
            x_bad = (
                b_t * x0_hat
                + a_t * corrupt_fn(x_T)
                + coeff_xs * (x - a_s * x_T - b_s * x0_hat)
                + (1 if i != len(ts) - 2 else 0) * omega_st * noise
            )
        else:
            x_bad = x
            assert guidance_type == "none"

        f_s, g2_s = [
            append_dims(item, x.ndim)
            for item in diffusion.noise_schedule.get_f_g2(s * ones)
        ]
        alpha_s, alpha_bar_s, _, rho_bar_s = [
            append_dims(item, x.ndim)
            for item in diffusion.noise_schedule.get_alpha_rho(s * ones)
        ]
        grad_logpxTlxs = -(x - alpha_bar_s * x_T) / (alpha_s**2 * rho_bar_s**2)

        x = (
            coeff_x0_hat * x0_hat
            + coeff_xT * x_T
            + coeff_xs * x
            + (1 if i != len(ts) - 2 else 0) * omega_st * noise
        )
        x += (t - s) * g2_s * (doob_scale - 1) * grad_logpxTlxs
        x_bad += (t - s) * g2_s * (doob_scale - 1) * grad_logpxTlxs

        path.append(x.detach().cpu())
        pred_x0.append(x0_hat.detach().cpu())
        nfe += 1

    return x, path, nfe, pred_x0, ts, first_noise
