import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import yaml
from score_sde.models import utils as mutils
from score_sde.models.ema import ExponentialMovingAverage
from score_sde.losses import get_optimizer
from utils import dict2namespace, restore_checkpoint
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from utils import clf2diff, diff2clf
from dct import dct_2d, idct_2d
from torchvision.utils import save_image
import clip
from diffusers import DDPMPipeline
from Diffpure import RevGuidedDiffusion
import os
import torchvision.utils as vutils
from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults
import torchvision
from torchvision.transforms import ToPILImage


device = 'cuda' if torch.cuda.is_available() else 'cpu'

def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def SDE_DIFFUSION_CIFAR10_Creator():
    with open('score_sde/cifar10.yml', 'r') as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config = dict2namespace(config)
    config.data.image_size = 32
    config.data.num_channels = 3
    diffusion = mutils.create_model(config).to(device)
    optimizer = get_optimizer(config, diffusion.parameters())
    ema = ExponentialMovingAverage(diffusion.parameters(), decay=config.model.ema_rate)
    state = dict(model=diffusion, optimizer=optimizer, ema=ema, step=0)
    restore_checkpoint("./checkpoint/cifar10/sde_diffusion/checkpoint_8.pth", state, device)
    ema.copy_to(diffusion.parameters())
    return diffusion

def Guided_Diffusion_ImageNet_Creator():
    with open('guided_diffusion/imagenet.yml', 'r') as f:
        config = yaml.load(f, Loader=yaml.Loader)
    config = dict2namespace(config)
    model_config = model_and_diffusion_defaults()
    model_config.update(vars(config.model))
    Diffusion_pipe, _ = create_model_and_diffusion(**model_config)
    Diffusion_pipe.load_state_dict(torch.load("./checkpoint//imagenet/guide_diffusion/256x256_diffusion_uncond.pt", map_location='cpu'))
    return Diffusion_pipe
def get_beta_schedule(beta_start, beta_end, num_diffusion_timesteps):
    betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
    return torch.from_numpy(betas).float().to(device)

def normalize_to_0_1(x):
    return ((x / 2) + 0.5).clamp(0, 1)

def to_minus1_1(x):
    return ((x - 0.5) * 2).clamp(-1, 1)

class AP_Forward_MANI(nn.Module):
    def __init__(self, diffusion, attack_steps, denoising_steps, sampling_method, clip_mean, clip_std,  amplitude,phase,Imagenet):
        super().__init__()
        self.diffusion = diffusion
        self.attack_steps = attack_steps
        self.denoising_steps = denoising_steps
        self.betas = get_beta_schedule(1e-4, 2e-2, 1000)
        self.eta = 0 if sampling_method == 'DDIM' else 1
        self.clip_mean = clip_mean
        self.clip_std = clip_std
        self. amplitude= amplitude
        self.phase = phase
        self.delta = 0.3
        self.Imagenet=Imagenet
        self.sequence = [s - 1 for s in range(self.attack_steps // self.denoising_steps, self.attack_steps + 1, self.attack_steps // self.denoising_steps)]

    def compute_alpha(self, t):
        beta = torch.cat([torch.zeros(1).to(t.device), self.betas.to(t.device)], dim=0)
        return (1 - beta).cumprod(0).index_select(0, t + 1).view(-1, 1, 1, 1)
    
    def frequency_mask_to_spatial(self, x, energy_weight):
        spatial_mask = torch.zeros_like(energy_weight)
        for c in range(energy_weight.shape[1]):
            f = torch.fft.fft2(x[:, c:c+1], norm="ortho")
            fshift = torch.fft.fftshift(f)
            ew = energy_weight[:, c:c+1]
            weighted_f = fshift * ew
            f_ishift = torch.fft.ifftshift(weighted_f)
            spatial = torch.fft.ifft2(f_ishift, norm="ortho").real
            spatial = (spatial - spatial.min()) / (spatial.max() - spatial.min() + 1e-8)
            spatial_mask[:, c:c+1] = spatial
        return spatial_mask
    
    def compute_frequency_energy_map(self, x, gamma=1.8):
        B, C, H, W = x.shape
        energy_weight = torch.zeros(B, C, H, W, device=x.device)
        u = torch.arange(-W // 2, W // 2, device=x.device)
        v = torch.arange(-H // 2, H // 2, device=x.device)
        U, V = torch.meshgrid(u, v, indexing='xy')
        radius = torch.sqrt(U ** 2 + V ** 2).unsqueeze(0).unsqueeze(0)
        num_bins =8
        bin_edges = torch.linspace(0, radius.max(), num_bins + 1, device=x.device)
        for c in range(C):
            f = torch.fft.fft2(x[:, c:c+1], norm="ortho")
            fshift = torch.fft.fftshift(f)
            magnitude = torch.abs(fshift)
            ew = torch.zeros_like(magnitude)
            for i in range(num_bins):
                mask = (radius >= bin_edges[i]) & (radius < bin_edges[i + 1])
                region_size = mask.sum(dim=(-2, -1), keepdim=True).clamp(min=1)
                bin_energy = (magnitude * mask).sum(dim=(-2, -1), keepdim=True) / region_size + 1e-6
                ew += mask * ((1.0 / bin_energy) ** gamma)
            ew = ew / (ew.max() + 1e-8)
            ew*=1.0
            energy_weight[:, c:c+1] = ew
        return energy_weight 

    def add_energy_adaptive_noise(self,x, t):
        x = ((x - 0.5) * 2).clamp(-1, 1)
        if self.Imagenet:
            x = F.interpolate(x, 256, mode='bilinear', align_corners=False)
        else:
            x = F.interpolate(x, 32, mode='bilinear', align_corners=False)
        noise = torch.randn_like(x)
        energy_mask =self.compute_frequency_energy_map(x)
        spatial_weight = self.frequency_mask_to_spatial(x, energy_mask)
        spatial_weight = (spatial_weight - spatial_weight.min()) / (spatial_weight.max() - spatial_weight.min() + 1e-8)
        noise = noise * (spatial_weight)
        noise_uniform = torch.randn_like(x)
        noise_weighted = noise_uniform * spatial_weight
        noise =  noise_weighted
        timesteps = torch.full((x.size(0),), t, device=x.device, dtype=torch.long)
        alpha_bar = self.compute_alpha(timesteps)
        x_t = x * alpha_bar.sqrt() + noise * (1 - alpha_bar).sqrt()
        return x_t
    
    def add_noise(self, x, t):
        x = to_minus1_1(x)
        if self.Imagenet:
            x = F.interpolate(x, 256, mode='bilinear', align_corners=False)
        else:
            x = F.interpolate(x, 32, mode='bilinear', align_corners=False)
        noise = torch.randn_like(x)
        timesteps = torch.full((x.size(0),), t, device=x.device, dtype=torch.long)
        a = self.compute_alpha(timesteps)
        return x * a.sqrt() + noise * (1 - a).sqrt()
    
    def freqpure_amplitude_phase_exchange(self, x_adv, x0_t):
        if x_adv.shape[-2:] != x0_t.shape[-2:]:
            x_adv = F.interpolate(x_adv, size=x0_t.shape[-2:], mode='bilinear', align_corners=False)
        B, C, H, W = x0_t.shape
        new_images = []
        for i in range(B):
            new_channels = []
            for ch in range(C):
                f_adv = torch.fft.fftshift(torch.fft.fft2(x_adv[i, ch], norm="ortho"))
                f_x0  = torch.fft.fftshift(torch.fft.fft2(x0_t[i, ch], norm="ortho"))
                amp_adv = torch.abs(f_adv)
                phase_adv = torch.angle(f_adv)
                amp_x0 = torch.abs(f_x0)
                phase_x0 = torch.angle(f_x0)
                u = torch.arange(-W // 2, W // 2, device=x0_t.device)
                v = torch.arange(-H // 2, H // 2, device=x0_t.device)
                U, V = torch.meshgrid(u, v, indexing='xy')
                radius = torch.sqrt(U ** 2 + V ** 2)
                low_amp_mask = (radius <= self.amplitude)
                amp_x0 = torch.where(low_amp_mask, amp_adv, amp_x0)
                low_phase_mask = (radius <= self.phase)
                projected_phase = torch.clip(phase_x0, min=phase_adv - self.delta, max=phase_adv + self.delta)
                phase_x0 = torch.where(low_phase_mask, projected_phase, phase_x0)
                f_recon = amp_x0 * torch.exp(1j * phase_x0)
                f_recon = torch.fft.ifftshift(f_recon)
                recon = torch.fft.ifft2(f_recon, norm="ortho").real.clamp(0, 1)
                new_channels.append(recon)
            new_images.append(torch.stack(new_channels, dim=0))
        return torch.stack(new_images, dim=0)
    

    def denoising_step(self, x, seq, x_adv_clean):
        n = x.size(0)
        xt = x
        seq_next = [-1] + list(seq[:-1])
        for i, j in zip(reversed(seq), reversed(seq_next)):
            t = torch.full((n,), i, device=x.device, dtype=torch.long)
            next_t = torch.full((n,), j, device=x.device, dtype=torch.long)
            alpha = self.compute_alpha(t)
            alpha_next = self.compute_alpha(next_t)
            xt = xt.requires_grad_()
            eps = self.diffusion(xt, t)
            x0 = (xt - eps * (1 - alpha).sqrt()) / alpha.sqrt()
            x0 = normalize_to_0_1(x0)
            x0 = self.freqpure_amplitude_phase_exchange(x_adv_clean, x0)
            x0 = to_minus1_1(x0)
            c1 = self.eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
            c2 = ((1 - alpha_next) - c1 ** 2).sqrt()
            xt = alpha_next.sqrt() * x0 + c1 * torch.randn_like(xt) + c2 * eps
            xt = xt.clamp(-1, 1)
        return xt
    

    def denoising_process_ImageNet(self, x, seq, x_adv_clean):
        n = x.size(0)
        seq_next = [-1] + list(seq[:-1])
        xt = x
        for i, j in zip(reversed(seq), reversed(seq_next)):
            t = (torch.ones(n) * i).to(x.device)
            next_t = (torch.ones(n) * j).to(x.device)
            at = self.compute_alpha(t.long())
            at_next = self.compute_alpha(next_t.long())
            et = self.diffusion(xt, t)
            et, _ = torch.split(et, 3, dim=1)
            x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
            x0_t = normalize_to_0_1(x0_t)
            x0_t = self.freqpure_amplitude_phase_exchange(x_adv_clean, x0_t)
            x0_t = to_minus1_1(x0_t)
            c1 = (
                self.eta * ((1 - at / at_next) *
                            (1 - at_next) / (1 - at)).sqrt()
            )
            c2 = ((1 - at_next) - c1 ** 2).sqrt()
            xt = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et
            xt = xt.clamp(-1, 1)
        return xt

    def forward(self, x, y=None):
        x_clean = (x * self.clip_std.view(1, 3, 1, 1) + self.clip_mean.view(1, 3, 1, 1)).clamp(0, 1)
        noisy_x=self.add_energy_adaptive_noise(x_clean, self.attack_steps)
        sequence = [s - 1 for s in range(self.attack_steps // self.denoising_steps, self.attack_steps + 1, self.attack_steps // self.denoising_steps)]
        if self.Imagenet:
            noisy_x = F.interpolate(noisy_x, 256, mode='bilinear', align_corners=False)
            x_clean = F.interpolate(x_clean, 256, mode='bilinear', align_corners=False)
            purified = self.denoising_process_ImageNet(noisy_x, sequence, x_clean)
            purified = normalize_to_0_1(purified)
            purified = F.interpolate(purified, 224, mode='bilinear', align_corners=False)
            purified = (purified - self.clip_mean.view(1, 3, 1, 1)) / self.clip_std.view(1, 3, 1, 1)
        else:    
            purified = self.denoising_step(noisy_x, sequence, x_clean)
            purified = normalize_to_0_1(purified)
            purified = F.interpolate(purified, 224, mode='bilinear', align_corners=False)
            purified = (purified - self.clip_mean.view(1, 3, 1, 1)) / self.clip_std.view(1, 3, 1, 1)
        return purified
    

