import math
import os
from functools import partial
import matplotlib.pyplot as plt
import numpy as np
import torch
from tqdm.auto import tqdm
import pywt  # <-- added for wavelet transforms

from util.img_utils import clear_color
from .posterior_mean_variance import get_mean_processor, get_var_processor

import cv2
import numpy as np
import pywt

def process_image_with_wavelet(input_tensor):
    """
    处理输入的PyTorch张量，应用小波变换并重构。

    Args:
        input_tensor (torch.Tensor): 输入张量，形状为 [1, 3, H, W]

    Returns:
        torch.Tensor: 处理后的张量，形状同输入，数据类型为 float32，范围 [0, 1]
    """
    # 确保输入是4D张量
    if input_tensor.dim() != 4 or input_tensor.size(0) != 1:
        raise ValueError("输入张量必须是形状为 [1, 3, H, W] 的4D张量")

    # 将张量移动到CPU并转换为NumPy数组
    # 处理数据类型和范围
    if input_tensor.dtype == torch.float32:
        # 假设值在 [0, 1] 之间，转换为 uint8
        np_img = (input_tensor.squeeze(0).permute(1, 2, 0).detach().cpu().numpy() * 255).astype(np.uint8)
    elif input_tensor.dtype == torch.uint8:
        np_img = input_tensor.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
    else:
        raise TypeError("不支持的张量数据类型: {}".format(input_tensor.dtype))

    # 颜色转换（根据需要取消注释）
    # np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)
    cv2.imwrite("np_img.jpg", np_img) 

    # 分离颜色通道
    b_channel, g_channel, r_channel = cv2.split(np_img)

    # 定义处理单通道的小波变换函数
    def wavelet_process(channel):
        coeffs = pywt.dwt2(channel, 'haar')
        cA, (cH, cV, cD) = coeffs
        reconstructed = pywt.idwt2((cA, (cH, cV, cD)), 'haar')
        return np.clip(reconstructed, 0, 255).astype(np.uint8)

    # 处理每个通道
    reconstructed_b = wavelet_process(b_channel)
    reconstructed_g = wavelet_process(g_channel)
    reconstructed_r = wavelet_process(r_channel)

    # 合并通道
    reconstructed_img = cv2.merge([reconstructed_b, reconstructed_g, reconstructed_r])
    cv2.imwrite("reconstructed_img.jpg", reconstructed_img) 

    # 转换回张量，调整维度顺序为 CxHxW，并归一化到 [0, 1]
    reconstructed_tensor = torch.from_numpy(reconstructed_img)\
                            .permute(2, 0, 1)\
                            .unsqueeze(0)\
                            .float() / 255

    return reconstructed_tensor



# ============== Wavelet Functions ==============
def wavelet_transform(x, wavelet='db1'):
    """
    Compute a one-level 2D discrete wavelet transform on x.
    Returns a dictionary with keys 'LL', 'LH', 'HL', 'HH'.
    For each subband, we return a tuple (amplitude, phase)
    computed from the (possibly complex) coefficients.
    Assumes x is a 2D numpy array.
    """
    coeffs = pywt.dwt2(x, wavelet)
    LL, (LH, HL, HH) = coeffs

    def amp_phase(coef):
        amplitude = np.abs(coef)
        phase = np.angle(coef)
        return amplitude, phase

    return {
        'LL': amp_phase(LL),
        'LH': amp_phase(LH),
        'HL': amp_phase(HL),
        'HH': amp_phase(HH)
    }


def inverse_wavelet_transform(coeffs_dict, wavelet='db1'):
    """
    Reconstruct an image from wavelet-domain coefficients.
    The input coeffs_dict should have keys 'LL', 'LH', 'HL', 'HH'.
    Since we are combining amplitudes with phases to produce complex numbers,
    we take the real part before inverse transforming.
    """
    LL = np.real(coeffs_dict['LL'])
    LH = np.real(coeffs_dict['LH'])
    HL = np.real(coeffs_dict['HL'])
    HH = np.real(coeffs_dict['HH'])
    return pywt.idwt2((LL, (LH, HL, HH)), wavelet)


def inverse_wavelet_transform_combined(A, P, wavelet='db1'):
    """
    Inverse Wavelet Domain Transform (IWDT) using combined amplitude and phase.

    Given a combined amplitude A and phase P, we form a complex coefficient and then
    use it as the approximation coefficient (LL) while setting the detail coefficients
    (LH, HL, HH) to zero.

    Args:
        A (np.ndarray): Combined amplitude.
        P (np.ndarray): Combined phase.
        wavelet (str): Wavelet name to use (default 'db1').

    Returns:
        np.ndarray: The reconstructed image/signal.
    """
    combined_coef = A * np.exp(1j * P)
    zeros = np.zeros_like(combined_coef)
    return pywt.idwt2((combined_coef, (zeros, zeros, zeros)), wavelet)


# ============== End Wavelet Functions ==============

__SAMPLER__ = {}


def register_sampler(name: str):
    def wrapper(cls):
        if __SAMPLER__.get(name, None):
            raise NameError(f"Name {name} is already registered!")
        __SAMPLER__[name] = cls
        return cls

    return wrapper


def get_sampler(name: str):
    if __SAMPLER__.get(name, None) is None:
        raise NameError(f"Name {name} is not defined!")
    return __SAMPLER__[name]


def create_sampler(sampler,
                   steps,
                   noise_schedule,
                   model_mean_type,
                   model_var_type,
                   dynamic_threshold,
                   clip_denoised,
                   rescale_timesteps,
                   timestep_respacing=""):
    sampler = get_sampler(name=sampler)

    betas = get_named_beta_schedule(noise_schedule, steps)
    if not timestep_respacing:
        timestep_respacing = [steps]

    return sampler(use_timesteps=space_timesteps(steps, timestep_respacing),
                   betas=betas,
                   model_mean_type=model_mean_type,
                   model_var_type=model_var_type,
                   dynamic_threshold=dynamic_threshold,
                   clip_denoised=clip_denoised,
                   rescale_timesteps=rescale_timesteps)


class GaussianDiffusion:
    def __init__(self,
                 betas,
                 model_mean_type,
                 model_var_type,
                 dynamic_threshold,
                 clip_denoised,
                 rescale_timesteps):
        # use float64 for accuracy.
        betas = np.array(betas, dtype=np.float64)
        self.betas = betas
        assert self.betas.ndim == 1, "betas must be 1-D"
        assert (0 < self.betas).all() and (self.betas <= 1).all(), "betas must be in (0..1]"

        self.num_timesteps = int(self.betas.shape[0])
        self.rescale_timesteps = rescale_timesteps

        alphas = 1.0 - self.betas
        self.alphas_cumprod = np.cumprod(alphas, axis=0)
        self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
        self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
        assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
        self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
        self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
        self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        self.posterior_variance = (
            betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
        self.posterior_log_variance_clipped = np.log(
            np.append(self.posterior_variance[1], self.posterior_variance[1:])
        )
        self.posterior_mean_coef1 = (
            betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
        self.posterior_mean_coef2 = (
            (1.0 - self.alphas_cumprod_prev)
            * np.sqrt(alphas)
            / (1.0 - self.alphas_cumprod)
        )

        self.mean_processor = get_mean_processor(model_mean_type,
                                                 betas=betas,
                                                 dynamic_threshold=dynamic_threshold,
                                                 clip_denoised=clip_denoised)

        self.var_processor = get_var_processor(model_var_type,
                                               betas=betas)

    def q_mean_variance(self, x_start, t):
        """
        Get the distribution q(x_t | x_0).

        :param x_start: the [N x C x ...] tensor of noiseless inputs.
        :param t: the diffusion step index (0 means one step).
        :return: A tuple (mean, variance, log_variance) with the same shape as x_start.
        """
        mean = extract_and_expand(self.sqrt_alphas_cumprod, t, x_start) * x_start
        variance = extract_and_expand(1.0 - self.alphas_cumprod, t, x_start)
        log_variance = extract_and_expand(self.log_one_minus_alphas_cumprod, t, x_start)
        return mean, variance, log_variance

    def q_sample(self, x_start, t):
        """
        Diffuse the data for a given number of diffusion steps.
        Sample from q(x_t | x_0).
        """
        noise = torch.randn_like(x_start)
        assert noise.shape == x_start.shape

        coef1 = extract_and_expand(self.sqrt_alphas_cumprod, t, x_start)
        coef2 = extract_and_expand(self.sqrt_one_minus_alphas_cumprod, t, x_start)
        return coef1 * x_start + coef2 * noise

    def q_posterior_mean_variance(self, x_start, x_t, t):
        """
        Compute the mean and variance of the diffusion posterior:
            q(x_{t-1} | x_t, x_0)
        """
        assert x_start.shape == x_t.shape
        coef1 = extract_and_expand(self.posterior_mean_coef1, t, x_start)
        coef2 = extract_and_expand(self.posterior_mean_coef2, t, x_t)
        posterior_mean = coef1 * x_start + coef2 * x_t
        posterior_variance = extract_and_expand(self.posterior_variance, t, x_t)
        posterior_log_variance_clipped = extract_and_expand(self.posterior_log_variance_clipped, t, x_t)
        assert (
            posterior_mean.shape[0]
            == posterior_variance.shape[0]
            == posterior_log_variance_clipped.shape[0]
            == x_start.shape[0]
        )
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    def p_sample_loop(self,
                      model,
                      x_start,
                      measurement,
                      measurement_cond_fn,
                      record,
                      save_root,
                      wavelet_N=1.0,
                      wavelet_w=0.5):
        """
        The function used for sampling from noise.
        Integrated with wavelet-domain fusion.

        At each timestep t, we transform the predicted x_start (output from the model)
        into the wavelet domain, then compute the fused coefficients as:

            A_combined = A_LL + (A_LH + A_HL + A_HH) * (N * exp(-w * t))
            P_combined = P_LL + (P_LH + P_HL + P_HH) * (N * exp(-w * t))

        Finally, we reconstruct the sample via the inverse wavelet transform.
        """
        img = x_start
        device = x_start.device

        pbar = tqdm(list(range(self.num_timesteps))[::-1])
        for idx in pbar:
            time = torch.tensor([idx] * img.shape[0], device=device)
            # print(img.shape)
            
            img = img.requires_grad_()
            out = self.p_sample(x=img, t=time, model=model)

            # ---- Wavelet-domain fusion integration ----
            # current_t = time[0].item()  # assume same timestep for the batch
            # scaling = wavelet_N * np.exp(-wavelet_w * current_t)
            # img_np = img.detach().cpu().numpy()
            # new_img = []
            # for i in range(img_np.shape[0]):
            #     wt_pred = wavelet_transform(img_np[i])
            #     A_ll, P_ll = wt_pred['LL']
            #     A_lh, P_lh = wt_pred['LH']
            #     A_hl, P_hl = wt_pred['HL']
            #     A_hh, P_hh = wt_pred['HH']
            #     A_combined = A_ll + (A_lh + A_hl + A_hh) * scaling
            #     P_combined = P_ll + (P_lh + P_hl + P_hh) * scaling
            #     pred_mod = inverse_wavelet_transform_combined(A_combined, P_combined)
            #     new_img.append(pred_mod)
            # new_img = abs(np.stack(new_img, axis=0))
            # img = torch.from_numpy(new_img).to(device)
            # ---- End wavelet fusion integration ----
            
            # Give condition.
            noisy_measurement = self.q_sample(measurement, t=time)
            img, distance = measurement_cond_fn(x_t=out['sample'],
                                                measurement=measurement,
                                                noisy_measurement=noisy_measurement,
                                                x_prev=img,
                                                x_0_hat=out['pred_xstart'])
            
            # ---- Wavelet-domain fusion integration ----
            if idx>998:
                reconstructed_img = process_image_with_wavelet(img).detach().cpu().numpy()
                img = torch.from_numpy(reconstructed_img).to(device)
            # print(reconstructed_img.shape)
            # ---- End wavelet fusion integration ----
            
            img = img.detach_()
            pbar.set_postfix({'distance': distance.item()}, refresh=False)
            if record:
                if idx % 10 == 0:
                    file_path = os.path.join(save_root, f"progress/x_{str(idx).zfill(4)}.jpg")
                    plt.imsave(file_path, clear_color(img))
        return img

    def p_sample(self, model, x, t):
        raise NotImplementedError

    def p_mean_variance(self, model, x, t):
        model_output = model(x, self._scale_timesteps(t))
        if model_output.shape[1] == 2 * x.shape[1]:
            model_output, model_var_values = torch.split(model_output, x.shape[1], dim=1)
        else:
            model_var_values = model_output

        model_mean, pred_xstart = self.mean_processor.get_mean_and_xstart(x, t, model_output)
        model_variance, model_log_variance = self.var_processor.get_variance(model_var_values, t)
        assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape

        return {'mean': model_mean,
                'variance': model_variance,
                'log_variance': model_log_variance,
                'pred_xstart': pred_xstart}

    def _scale_timesteps(self, t):
        if self.rescale_timesteps:
            return t.float() * (1000.0 / self.num_timesteps)
        return t


def space_timesteps(num_timesteps, section_counts):
    """
    Create a set of timesteps from the original diffusion process.
    """
    if isinstance(section_counts, str):
        if section_counts.startswith("ddim"):
            desired_count = int(section_counts[len("ddim"):])
            for i in range(1, num_timesteps):
                if len(range(0, num_timesteps, i)) == desired_count:
                    return set(range(0, num_timesteps, i))
            raise ValueError(
                f"cannot create exactly {num_timesteps} steps with an integer stride"
            )
        section_counts = [int(x) for x in section_counts.split(",")]
    elif isinstance(section_counts, int):
        section_counts = [section_counts]

    size_per = num_timesteps // len(section_counts)
    extra = num_timesteps % len(section_counts)
    start_idx = 0
    all_steps = []
    for i, section_count in enumerate(section_counts):
        size = size_per + (1 if i < extra else 0)
        if size < section_count:
            raise ValueError(
                f"cannot divide section of {size} steps into {section_count}"
            )
        frac_stride = 1 if section_count <= 1 else (size - 1) / (section_count - 1)
        cur_idx = 0.0
        taken_steps = []
        for _ in range(section_count):
            taken_steps.append(start_idx + round(cur_idx))
            cur_idx += frac_stride
        all_steps += taken_steps
        start_idx += size
    return set(all_steps)


class SpacedDiffusion(GaussianDiffusion):
    """
    A diffusion process which can skip steps in a base diffusion process.
    """

    def __init__(self, use_timesteps, **kwargs):
        self.use_timesteps = set(use_timesteps)
        self.timestep_map = []
        self.original_num_steps = len(kwargs["betas"])

        base_diffusion = GaussianDiffusion(**kwargs)
        last_alpha_cumprod = 1.0
        new_betas = []
        for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
            if i in self.use_timesteps:
                new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
                last_alpha_cumprod = alpha_cumprod
                self.timestep_map.append(i)
        kwargs["betas"] = np.array(new_betas)
        super().__init__(**kwargs)

    def p_mean_variance(self, model, *args, **kwargs):
        return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)

    def training_losses(self, model, *args, **kwargs):
        return super().training_losses(self._wrap_model(model), *args, **kwargs)

    def condition_mean(self, cond_fn, *args, **kwargs):
        return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)

    def condition_score(self, cond_fn, *args, **kwargs):
        return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)

    def _wrap_model(self, model):
        if isinstance(model, _WrappedModel):
            return model
        return _WrappedModel(
            model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
        )

    def _scale_timesteps(self, t):
        return t


class _WrappedModel:
    def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
        self.model = model
        self.timestep_map = timestep_map
        self.rescale_timesteps = rescale_timesteps
        self.original_num_steps = original_num_steps

    def __call__(self, x, ts, **kwargs):
        map_tensor = torch.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
        new_ts = map_tensor[ts]
        if self.rescale_timesteps:
            new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
        return self.model(x, new_ts, **kwargs)


@register_sampler(name='ddpm')
class DDPM(SpacedDiffusion):
    def p_sample(self, model, x, t):
        out = self.p_mean_variance(model, x, t)
        sample = out['mean']
        noise = torch.randn_like(x)
        if t != 0:  # no noise when t == 0
            sample += torch.exp(0.5 * out['log_variance']) * noise
        return {'sample': sample, 'pred_xstart': out['pred_xstart']}


@register_sampler(name='ddim')
class DDIM(SpacedDiffusion):
    def p_sample(self, model, x, t, eta=0.0):
        out = self.p_mean_variance(model, x, t)
        eps = self.predict_eps_from_x_start(x, t, out['pred_xstart'])
        alpha_bar = extract_and_expand(self.alphas_cumprod, t, x)
        alpha_bar_prev = extract_and_expand(self.alphas_cumprod_prev, t, x)
        sigma = (
            eta
            * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
            * torch.sqrt(1 - alpha_bar / alpha_bar_prev)
        )
        noise = torch.randn_like(x)
        mean_pred = (
            out["pred_xstart"] * torch.sqrt(alpha_bar_prev)
            + torch.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
        )
        sample = mean_pred
        if t != 0:
            sample += sigma * noise
        return {"sample": sample, "pred_xstart": out["pred_xstart"]}

    def predict_eps_from_x_start(self, x_t, t, pred_xstart):
        coef1 = extract_and_expand(self.sqrt_recip_alphas_cumprod, t, x_t)
        coef2 = extract_and_expand(self.sqrt_recipm1_alphas_cumprod, t, x_t)
        return (coef1 * x_t - pred_xstart) / coef2


def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
    """
    Get a pre-defined beta schedule for the given name.
    """
    if schedule_name == "linear":
        scale = 1000 / num_diffusion_timesteps
        beta_start = scale * 0.0001
        beta_end = scale * 0.02
        return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
    elif schedule_name == "cosine":
        return betas_for_alpha_bar(
            num_diffusion_timesteps,
            lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
        )
    else:
        raise NotImplementedError(f"unknown beta schedule: {schedule_name}")


def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
    """
    Create a beta schedule that discretizes the given alpha_bar function.
    """
    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
    return np.array(betas)


def extract_and_expand(array, time, target):
    array = torch.from_numpy(array).to(target.device)[time].float()
    while array.ndim < target.ndim:
        array = array.unsqueeze(-1)
    return array.expand_as(target)


def expand_as(array, target):
    if isinstance(array, np.ndarray):
        array = torch.from_numpy(array)
    elif isinstance(array, np.float):
        array = torch.tensor([array])
    while array.ndim < target.ndim:
        array = array.unsqueeze(-1)
    return array.expand_as(target).to(target.device)


def _extract_into_tensor(arr, timesteps, broadcast_shape):
    res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
    while len(res.shape) < len(broadcast_shape):
        res = res[..., None]
    return res.expand(broadcast_shape)
