from functools import partial
import os
import argparse
import yaml

import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

from guided_diffusion.condition_methods import get_conditioning_method
from guided_diffusion.measurements import get_noise, get_operator
from guided_diffusion.unet import create_model
from guided_diffusion.gaussian_diffusion_wavelet import create_sampler
from data.dataloader import get_dataset, get_dataloader
from util.img_utils import clear_color, mask_generator
from util.logger import get_logger


def load_yaml(file_path: str) -> dict:
    with open(file_path) as f:
        return yaml.load(f, Loader=yaml.FullLoader)


# def make_radial_filter(shape, t, T_max, device):
#     """
#     Example radial filter C_t that goes from low-pass (t≈T_max) to
#     high-pass (t≈0).
#     """
#     # shape = [B, C, H, W]
#     B, C, H, W = shape
#     # build normalized frequency grids
#     fy = torch.fft.fftfreq(H, device=device).unsqueeze(1).repeat(1, W)
#     fx = torch.fft.fftfreq(W, device=device).unsqueeze(0).repeat(H, 1)
#     radius = torch.sqrt(fx ** 2 + fy ** 2)  # [H, W]

#     # threshold shifts with t
#     frac = t / float(T_max)  # 1.0→low-pass; 0.0→high-pass
#     cutoff = frac * radius.max()
#     # low‑pass mask: 1 inside cutoff, 0 outside
#     mask_lp = (radius <= cutoff).float()
#     # high‑pass = 1−low-pass
#     mask_hp = 1.0 - mask_lp

#     # you can interpolate or pick one; here we do a simple mix:
#     C = frac * mask_lp + (1.0 - frac) * mask_hp
#     C = C.unsqueeze(0).unsqueeze(0).repeat(B, C, 1, 1)
#     return C

def make_radial_filter(shape, t, T_max, device):
    """
    生成随扩散步 t 渐变的径向滤波核 C_t：
        t ≈ T_max  →  低通
        t ≈ 0      →  高通
    返回形状 [B, C, H, W]，可直接与图像逐通道相乘。
    """
    # shape = [B, C, H, W]
    B, C_ch, H, W = shape            # <<< 避免与后面的 Tensor 重名

    # —— 1. 构造归一化频率网格 ——
    fy = torch.fft.fftfreq(H, device=device).unsqueeze(1).repeat(1, W)
    fx = torch.fft.fftfreq(W, device=device).unsqueeze(0).repeat(H, 1)
    radius = torch.sqrt(fx**2 + fy**2)            # [H, W]

    # —— 2. 随 t 线性插值的截止频率 ——
    frac   = t / float(T_max)                     # 1.0→低通, 0.0→高通
    cutoff = frac * radius.max()

    mask_lp = (radius <= cutoff).float()          # 低通掩码
    mask_hp = 1.0 - mask_lp                      # 高通掩码

    # —— 3. 按 frac 混合高/低通 ——
    C_2d = frac * mask_lp + (1.0 - frac) * mask_hp    # [H, W]

    # —— 4. 扩展至 [B, C, H, W] ——
    C_4d = C_2d.unsqueeze(0).unsqueeze(0)              # [1,1,H,W]
    C_4d = C_4d.repeat(B, C_ch, 1, 1)                 # [B,C,H,W]

    return C_4d


def phi(y: torch.Tensor, C: torch.Tensor) -> torch.Tensor:
    """
    φ(y) = iFFT( C * FFT(y) ), applied per-channel.
    """
    # FFT over spatial dims:
    Y = torch.fft.fft2(y, dim=(-2, -1))
    Yf = Y * C
    y_ifft = torch.fft.ifft2(Yf, dim=(-2, -1)).real
    return y_ifft


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_config', type=str, required=True)
    parser.add_argument('--diffusion_config', type=str, required=True)
    parser.add_argument('--task_config', type=str, required=True)
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--save_dir', type=str, default='./results-new')
    args = parser.parse_args()

    logger = get_logger()

    device_str = f"cuda:{args.gpu}" if torch.cuda.is_available() else 'cpu'
    device = torch.device(device_str)
    logger.info(f"Device set to {device_str}.")

    # load YAMLs
    model_cfg = load_yaml(args.model_config)
    diff_cfg = load_yaml(args.diffusion_config)
    task_cfg = load_yaml(args.task_config)
    T_max = diff_cfg.get('num_timesteps', 1000)  # assume this key

    # build model + sampler
    model = create_model(**model_cfg).to(device).eval()
    sampler = create_sampler(**diff_cfg)
    sample_fn = partial(sampler.p_sample_loop, model=model)

    # measurement & conditioning
    m_cfg = task_cfg['measurement']
    operator = get_operator(device=device, **m_cfg['operator'])
    noiser = get_noise(**m_cfg['noise'])
    cond_cfg = task_cfg['conditioning']
    cond_method = get_conditioning_method(cond_cfg['method'], operator, noiser, **cond_cfg['params'])
    measurement_cond = cond_method.conditioning

    # prepare outputs
    out_path = os.path.join(args.save_dir, m_cfg['operator']['name'])
    os.makedirs(out_path, exist_ok=True)
    for sub in ['input', 'recon', 'label', 'progress']:
        os.makedirs(os.path.join(out_path, sub), exist_ok=True)

    # data
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,) * 3, (0.5,) * 3)
    ])
    dataset = get_dataset(**task_cfg['data'], transforms=transform)
    loader = get_dataloader(dataset, batch_size=1, num_workers=0, train=False)

    # inpainting mask?
    if m_cfg['operator']['name'] == 'inpainting':
        mask_gen = mask_generator(**m_cfg['mask_opt'])

    # inference
    for i, ref_img in enumerate(loader):
        ref_img = ref_img.to(device)
        fname = f"{i:05d}.jpg"
        logger.info(f"Processing image {i}")

        # possibly generate mask
        if m_cfg['operator']['name'] == 'inpainting':
            mask = mask_gen(ref_img)
            measurement_cond = partial(cond_method.conditioning, mask=mask)
            y = operator.forward(ref_img, mask=mask)
        else:
            y = operator.forward(ref_img)

        # add noise
        y_n = noiser(y)

        # --- apply φ(y_n) before sampling ---
        # here we pick a timestep t = T_max//2 as an example;
        # you could loop over t if you want time‑varying filters
        #t = T_max // 2
        #C_t = make_radial_filter(y_n.shape, t, T_max, device)
        #y_phi = phi(y_n, C_t)

        # run diffusion with filtered measurement
        x0 = torch.randn_like(ref_img).requires_grad_()
        sample = sample_fn(
            x_start=x0,
            measurement=y_n,
            measurement_cond_fn=measurement_cond,
            record=True,
            save_root=out_path
        )

        # save
        plt.imsave(os.path.join(out_path, 'input', fname), clear_color(y_n))
        plt.imsave(os.path.join(out_path, 'label', fname), clear_color(ref_img))
        plt.imsave(os.path.join(out_path, 'recon', fname), clear_color(sample))


if __name__ == '__main__':
    main()
