import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import os
import random
import scipy.stats as st
import copy
from utils import ROOT_PATH
from functools import partial
import copy
import pickle as pkl
from torch.autograd import Variable
import torch.nn.functional as F

from dataset import params
from model import get_model
from utils_linbp import linbp_forw_resnet50, linbp_backw_resnet50
from utils_sgm import register_hook_for_densenet, register_hook_for_resnet, register_hook_for_vit
from utils_dct import dct_2d, idct_2d

class BaseAttack(object):
    def __init__(self, attack_name, model_name, target, pre_trained=True, weight_path=None):
        self.attack_name = attack_name
        self.model_name = model_name
        self.target = target
        if self.target:
            self.loss_flag = -1
        else:
            self.loss_flag = 1
        self.used_params = params(self.model_name)

        # loading model
        self.model = get_model(self.model_name,pre_trained=pre_trained,weight_path=weight_path)
        self.model.cuda()
        self.model.eval()

    def forward(self, *input):
        """
        Rewrite
        """
        raise NotImplementedError

    def _mul_std_add_mean(self, inps):
        dtype = inps.dtype
        mean = torch.as_tensor(self.used_params['mean'], dtype=dtype).cuda()
        std = torch.as_tensor(self.used_params['std'], dtype=dtype).cuda()
        inps.mul_(std[:,None, None]).add_(mean[:,None,None])
        return inps

    def _sub_mean_div_std(self, inps):
        dtype = inps.dtype
        mean = torch.as_tensor(self.used_params['mean'], dtype=dtype).cuda()
        std = torch.as_tensor(self.used_params['std'], dtype=dtype).cuda()
        #inps.sub_(mean[:,None,None]).div_(std[:,None,None])
        inps = (inps - mean[:,None,None])/std[:,None,None]
        return inps

    def _save_images(self, inps, filenames, output_dir):
        unnorm_inps = self._mul_std_add_mean(inps)
        for i,filename in enumerate(filenames):
            save_path = os.path.join(output_dir, filename)
            image = unnorm_inps[i].permute([1,2,0]) # c,h,w to h,w,c
            image[image<0] = 0
            image[image>1] = 1
            image = Image.fromarray((image.detach().cpu().numpy()*255).astype(np.uint8))
            # print ('Saving to ', save_path)
            image.save(save_path)

    def _update_inps(self, inps, grad, step_size):
        unnorm_inps = self._mul_std_add_mean(inps.clone().detach())
        unnorm_inps = unnorm_inps + step_size * grad.sign()
        unnorm_inps = torch.clamp(unnorm_inps, min=0, max=1).detach()
        adv_inps = self._sub_mean_div_std(unnorm_inps)
        return adv_inps

    def _update_perts(self, perts, grad, step_size):
        perts = perts + step_size * grad.sign()
        perts = torch.clamp(perts, -self.epsilon, self.epsilon)
        return perts

    def _return_perts(self, clean_inps, inps):
        clean_unnorm = self._mul_std_add_mean(clean_inps.clone().detach())
        adv_unnorm = self._mul_std_add_mean(inps.clone().detach())
        return adv_unnorm - clean_unnorm

    def __call__(self, *input, **kwargs):
        images = self.forward(*input, **kwargs)
        return images

class PGD(BaseAttack):
    def __init__(self, model_name, steps=20, epsilon=4/255, target=False, decay=1.0, pre_trained=True, weight_path=False):
        super(PGD, self).__init__('PGD', model_name, target, pre_trained=pre_trained, weight_path=weight_path)
        self.epsilon = epsilon
        self.steps = steps
        # self.step_size = self.epsilon/self.steps
        self.step_size = 1/255
        self.decay = decay
        self.image_size = 224
        
    def forward(self, inps, labels):
        inps = inps.cuda() #(3,224,224)
        labels = labels.cuda()
        loss = nn.CrossEntropyLoss()

        momentum = torch.zeros_like(inps).cuda()
        # unnorm_inps = self._mul_std_add_mean(inps)
        
        unnorm_inps = inps

        perts = torch.zeros_like(unnorm_inps).cuda()
        perts.requires_grad_()

        for i in range(self.steps):
            outputs = self.model((self._sub_mean_div_std(unnorm_inps + perts)))
            cost = self.loss_flag * loss(outputs, labels).cuda()
            cost.backward()
            grad = perts.grad.data
            grad = grad / torch.mean(torch.abs(grad), dim=[1,2,3], keepdim=True)
            # grad += momentum*self.decay
            # momentum = grad

            perts.data = self._update_perts(perts.data, grad, self.step_size)
            perts.data = torch.clamp(unnorm_inps.data + perts.data, 0.0, 1.0) - unnorm_inps.data
            perts.grad.data.zero_()
        return (self._sub_mean_div_std(unnorm_inps+perts.data)).detach(), None

class MI(BaseAttack):
    def __init__(self, model_name, steps=10, epsilon=16/255, target=False, decay=1.0, pre_trained=True, weight_path=False):
        super(MI, self).__init__('PGD', model_name, target, pre_trained=pre_trained, weight_path=weight_path)
        self.epsilon = epsilon
        self.steps = steps
        self.step_size = self.epsilon/self.steps
        # self.step_size = 1/255
        self.decay = decay
        self.image_size = 224
        
    def forward(self, inps, labels):
        inps = inps.cuda() #(3,224,224)
        labels = labels.cuda()
        loss = nn.CrossEntropyLoss()

        momentum = torch.zeros_like(inps).cuda()
        unnorm_inps = self._mul_std_add_mean(inps)
        
        # unnorm_inps = inps

        perts = torch.zeros_like(unnorm_inps).cuda()
        perts.requires_grad_()

        for i in range(self.steps):
            outputs = self.model((self._sub_mean_div_std(unnorm_inps + perts))) 
            cost = self.loss_flag * loss(outputs, labels).cuda()
            cost.backward()
            grad = perts.grad.data
            grad = grad / torch.mean(torch.abs(grad), dim=[1,2,3], keepdim=True)

            grad += momentum*self.decay
            momentum = grad

            perts.data = self._update_perts(perts.data, grad, self.step_size)
            perts.data = torch.clamp(unnorm_inps.data + perts.data, 0.0, 1.0) - unnorm_inps.data
            perts.grad.data.zero_()
        return (self._sub_mean_div_std(unnorm_inps+perts.data)).detach(), None
    
class NI(BaseAttack):
    def __init__(self, model_name, steps=10, epsilon=16/255, target=False, decay=1.0, pre_trained=True, weight_path=False):
        super(NI, self).__init__('NI', model_name, target, pre_trained=pre_trained, weight_path=weight_path)
        self.epsilon = epsilon
        self.steps = steps
        self.step_size = self.epsilon/self.steps
        # self.step_size = 1/255
        self.decay = decay
        self.image_size = 224
        
    def forward(self, inps, labels):
        inps = inps.cuda() #(3,224,224)
        labels = labels.cuda()
        loss = nn.CrossEntropyLoss()

        momentum = torch.zeros_like(inps).cuda()
        unnorm_inps = self._mul_std_add_mean(inps)
        
        # unnorm_inps = inps

        perts = torch.zeros_like(unnorm_inps).cuda()
        perts.requires_grad_()

        for i in range(self.steps):
            xadv = unnorm_inps + perts
            xnes = xadv + self.step_size * momentum.sign()
            outputs = self.model((self._sub_mean_div_std(xnes)))
            cost = self.loss_flag * loss(outputs, labels).cuda()
            cost.backward()
            grad = perts.grad.data
            grad = grad / torch.mean(torch.abs(grad), dim=[1,2,3], keepdim=True)

            grad += momentum*self.decay
            momentum = grad

            perts.data = self._update_perts(perts.data, grad, self.step_size)
            perts.data = torch.clamp(unnorm_inps.data + perts.data, 0.0, 1.0) - unnorm_inps.data
            perts.grad.data.zero_()
        return (self._sub_mean_div_std(unnorm_inps+perts.data)).detach(), None

class GI(BaseAttack):
    def __init__(self, model_name, steps=10, epsilon=16/255, target=False, decay=1.0, pre_trained=True, weight_path=False, pre_steps=5):
        super(GI, self).__init__('GI', model_name, target, pre_trained=pre_trained, weight_path=weight_path)
        self.epsilon = epsilon
        self.steps = steps
        self.step_size = self.epsilon/self.steps
        self.decay = decay
        self.image_size = 224
        self.pre_steps = pre_steps
        
    def forward(self, inps, labels):
        inps = inps.cuda() #(3,224,224)
        labels = labels.cuda()
        loss = nn.CrossEntropyLoss()

        momentum = torch.zeros_like(inps).cuda()
        unnorm_inps = self._mul_std_add_mean(inps)

        perts = torch.zeros_like(unnorm_inps).cuda()
        perts.requires_grad_()

        # print(self.pre_steps + self.steps)

        for i in range(self.pre_steps + self.steps):
            if i == self.pre_steps:
                perts = torch.zeros_like(unnorm_inps).cuda()
                perts.requires_grad_()

            outputs = self.model((self._sub_mean_div_std(unnorm_inps + perts))) 
            cost = self.loss_flag * loss(outputs, labels).cuda()
            cost.backward()
            grad = perts.grad.data
            grad = grad / torch.mean(torch.abs(grad), dim=[1,2,3], keepdim=True)

            grad += momentum*self.decay
            momentum = grad

            perts.data = self._update_perts(perts.data, grad, self.step_size)
            perts.data = torch.clamp(unnorm_inps.data + perts.data, 0.0, 1.0) - unnorm_inps.data
            perts.grad.data.zero_()
        return (self._sub_mean_div_std(unnorm_inps+perts.data)).detach(), None


class SSA(BaseAttack):
    def __init__(self, model_name, steps=10, epsilon=16/255, target=False, decay=1.0, pre_trained=True, weight_path=False,number = 20):
        super(SSA, self).__init__('SSA', model_name, target, pre_trained=pre_trained, weight_path=weight_path)
        self.epsilon = epsilon
        self.steps = steps
        self.step_size = self.epsilon/self.steps
        # self.step_size = 1/255
        self.decay = decay
        self.image_size = 224
        self.number = number
        
    def forward(self, inps, labels):
        inps = inps.cuda() #(3,224,224)
        labels = labels.cuda()
        loss = nn.CrossEntropyLoss()

        momentum = torch.zeros_like(inps).cuda()
        unnorm_inps = self._mul_std_add_mean(inps)
        

        perts = torch.zeros_like(unnorm_inps).cuda()
        # perts.requires_grad_()

        for i in range(self.steps):
            grad = 0
            img_x = unnorm_inps + perts
            for _ in range(self.number):
                with torch.no_grad():
                    gauss = torch.randn_like(img_x) * self.epsilon
                    gauss = gauss.cuda()
                    x_dct = dct_2d(img_x + gauss)
                    mask = (torch.rand_like(img_x) + 0.5).cuda()
                    x_idct = idct_2d(x_dct * mask).detach()
                
                x_idct.requires_grad_()  # 启用对 x_idct 的梯度计算
                outputs = self.model((self._sub_mean_div_std(x_idct))) 
                cost = self.loss_flag * loss(outputs, labels).cuda()
                cost.backward()
                # grad += perts.grad.data
                grad += x_idct.grad.data

            grad = grad / self.number

            grad = grad / torch.mean(torch.abs(grad), dim=[1,2,3], keepdim=True)

            # grad += momentum*self.decay
            # momentum = grad

            perts.data = self._update_perts(perts.data, grad, self.step_size)
            perts.data = torch.clamp(unnorm_inps.data + perts.data, 0.0, 1.0) - unnorm_inps.data
            # perts.grad.data.zero_()
        return (self._sub_mean_div_std(unnorm_inps+perts.data)).detach(), None
    
class SSA_MI_FGSM(BaseAttack):
    def __init__(self, model_name, steps=10, epsilon=16/255, target=False, decay=1.0, pre_trained=True, weight_path=False,number = 20):
        super(SSA_MI_FGSM, self).__init__('SSA_MI_FGSM', model_name, target, pre_trained=pre_trained, weight_path=weight_path)
        self.epsilon = epsilon
        self.steps = steps
        self.step_size = self.epsilon/self.steps
        # self.step_size = 1/255
        self.decay = decay
        self.image_size = 224
        self.number = number
        
    def forward(self, inps, labels):
        inps = inps.cuda() #(3,224,224)
        labels = labels.cuda()
        loss = nn.CrossEntropyLoss()

        momentum = torch.zeros_like(inps).cuda()
        unnorm_inps = self._mul_std_add_mean(inps)
        

        perts = torch.zeros_like(unnorm_inps).cuda()
        # perts.requires_grad_()

        for i in range(self.steps):
            grad = 0
            img_x = unnorm_inps + perts
            for _ in range(self.number):
                with torch.no_grad():
                    gauss = torch.randn_like(img_x) * self.epsilon
                    gauss = gauss.cuda()
                    x_dct = dct_2d(img_x + gauss)
                    mask = (torch.rand_like(img_x) + 0.5).cuda()
                    x_idct = idct_2d(x_dct * mask).detach()
                
                x_idct.requires_grad_()  # 启用对 x_idct 的梯度计算
                outputs = self.model((self._sub_mean_div_std(x_idct))) 
                cost = self.loss_flag * loss(outputs, labels).cuda()
                cost.backward()
                # grad += perts.grad.data
                grad += x_idct.grad.data

            grad = grad / self.number

            grad = grad / torch.mean(torch.abs(grad), dim=[1,2,3], keepdim=True)

            grad += momentum*self.decay
            momentum = grad

            perts.data = self._update_perts(perts.data, grad, self.step_size)
            perts.data = torch.clamp(unnorm_inps.data + perts.data, 0.0, 1.0) - unnorm_inps.data
            # perts.grad.data.zero_()
        return (self._sub_mean_div_std(unnorm_inps+perts.data)).detach(), None


def DI(X_in,prob=0.7):
    rnd = np.random.randint(299, 330,size=1)[0]
    h_rem = 330 - rnd
    w_rem = 330 - rnd
    pad_top = np.random.randint(0, h_rem,size=1)[0]
    pad_bottom = h_rem - pad_top
    pad_left = np.random.randint(0, w_rem,size=1)[0]
    pad_right = w_rem - pad_left
    c = np.random.rand(1)
    if c <= prob:
        X_out = F.pad(F.interpolate(X_in, size=(rnd,rnd)),(pad_left,pad_top,pad_right,pad_bottom),mode='constant', value=0)
        return  X_out 
    else:
        return  X_in

class DI_FGSM(BaseAttack):
    def __init__(self, model_name, steps=10, epsilon=16/255, target=False, decay=1.0, pre_trained=True, weight_path=False, prob = 0.7):
        super(DI_FGSM, self).__init__('DI-FGSM', model_name, target, pre_trained=pre_trained, weight_path=weight_path)
        self.epsilon = epsilon
        self.steps = steps
        self.step_size = self.epsilon/self.steps
        # self.step_size = 1/255
        self.decay = decay
        self.image_size = 224
        self.prob = prob
        
    def forward(self, inps, labels):
        inps = inps.cuda() #(3,224,224)
        labels = labels.cuda()
        loss = nn.CrossEntropyLoss()

        momentum = torch.zeros_like(inps).cuda()
        unnorm_inps = self._mul_std_add_mean(inps)
        
        # unnorm_inps = inps

        perts = torch.zeros_like(unnorm_inps).cuda()
        perts.requires_grad_()

        for i in range(self.steps):
            outputs = self.model((self._sub_mean_div_std(DI(unnorm_inps + perts))))
            cost = self.loss_flag * loss(outputs, labels).cuda()
            cost.backward()
            grad = perts.grad.data
            grad = grad / torch.mean(torch.abs(grad), dim=[1,2,3], keepdim=True)
            
            # grad += momentum*self.decay
            # momentum = grad

            perts.data = self._update_perts(perts.data, grad, self.step_size)
            perts.data = torch.clamp(unnorm_inps.data + perts.data, 0.0, 1.0) - unnorm_inps.data
            perts.grad.data.zero_()
        return (self._sub_mean_div_std(unnorm_inps+perts.data)).detach(), None

class DI_MI_FGSM(BaseAttack):
    def __init__(self, model_name, steps=10, epsilon=16/255, target=False, decay=1.0, pre_trained=True, weight_path=False, prob = 0.7):
        super(DI_MI_FGSM, self).__init__('DI-FGSM', model_name, target, pre_trained=pre_trained, weight_path=weight_path)
        self.epsilon = epsilon
        self.steps = steps
        self.step_size = self.epsilon/self.steps
        # self.step_size = 1/255
        self.decay = decay
        self.image_size = 224
        self.prob = prob
        
    def forward(self, inps, labels):
        inps = inps.cuda() #(3,224,224)
        labels = labels.cuda()
        loss = nn.CrossEntropyLoss()

        momentum = torch.zeros_like(inps).cuda()
        unnorm_inps = self._mul_std_add_mean(inps)
        
        # unnorm_inps = inps

        perts = torch.zeros_like(unnorm_inps).cuda()
        perts.requires_grad_()

        for i in range(self.steps):
            outputs = self.model(self._sub_mean_div_std(DI(unnorm_inps + perts)))
            cost = self.loss_flag * loss(outputs, labels).cuda()
            cost.backward()
            grad = perts.grad.data
            grad = grad / torch.mean(torch.abs(grad), dim=[1,2,3], keepdim=True)
            
            grad += momentum*self.decay
            momentum = grad

            perts.data = self._update_perts(perts.data, grad, self.step_size)
            perts.data = torch.clamp(unnorm_inps.data + perts.data, 0.0, 1.0) - unnorm_inps.data
            perts.grad.data.zero_()
        return (self._sub_mean_div_std(unnorm_inps+perts.data)).detach(), None
    
def gkern(kernlen=15, nsig=3):
    x = np.linspace(-nsig, nsig, kernlen)
    kern1d = st.norm.pdf(x)
    kernel_raw = np.outer(kern1d, kern1d)
    kernel = kernel_raw / kernel_raw.sum()
    return kernel
    
class TI_FGSM(BaseAttack):
    def __init__(self, model_name, steps=10, epsilon=16/255, target=False, decay=1.0, pre_trained=True, weight_path=False, ti_kernel_size=5):
        super(TI_FGSM, self).__init__('DI-FGSM', model_name, target, pre_trained=pre_trained, weight_path=weight_path)
        self.epsilon = epsilon
        self.steps = steps
        self.step_size = self.epsilon/self.steps
        # self.step_size = 1/255
        self.decay = decay
        self.image_size = 224

        self.ti_kernel_size=ti_kernel_size
        kernel = gkern(ti_kernel_size, 3).astype(np.float32)
        gaussian_kernel = np.stack([kernel, kernel, kernel])
        gaussian_kernel = np.expand_dims(gaussian_kernel, 1)
        self.gaussian_kernel = torch.from_numpy(gaussian_kernel).cuda()
        
    def forward(self, inps, labels):
        inps = inps.cuda() #(3,224,224)
        labels = labels.cuda()
        loss = nn.CrossEntropyLoss()

        momentum = torch.zeros_like(inps).cuda()
        unnorm_inps = self._mul_std_add_mean(inps)
        
        # unnorm_inps = inps

        perts = torch.zeros_like(unnorm_inps).cuda()
        perts.requires_grad_()

        for i in range(self.steps):
            outputs = self.model((self._sub_mean_div_std(unnorm_inps + perts)))
            cost = self.loss_flag * loss(outputs, labels).cuda()
            cost.backward()
            grad = perts.grad.data

            grad = F.conv2d(grad, self.gaussian_kernel, bias=None, stride=1, padding=((self.ti_kernel_size-1)//2,(self.ti_kernel_size-1)//2), groups=3) #TI

            grad = grad / torch.mean(torch.abs(grad), dim=[1,2,3], keepdim=True)
            
            # grad += momentum*self.decay
            # momentum = grad

            perts.data = self._update_perts(perts.data, grad, self.step_size)
            perts.data = torch.clamp(unnorm_inps.data + perts.data, 0.0, 1.0) - unnorm_inps.data
            perts.grad.data.zero_()
        return (self._sub_mean_div_std(unnorm_inps+perts.data)).detach(), None

class TI_MI_FGSM(BaseAttack):
    def __init__(self, model_name, steps=10, epsilon=16/255, target=False, decay=1.0, pre_trained=True, weight_path=False, ti_kernel_size=5):
        super(TI_MI_FGSM, self).__init__('DI-FGSM', model_name, target, pre_trained=pre_trained, weight_path=weight_path)
        self.epsilon = epsilon
        self.steps = steps
        self.step_size = self.epsilon/self.steps
        # self.step_size = 1/255
        self.decay = decay
        self.image_size = 224

        self.ti_kernel_size=ti_kernel_size
        kernel = gkern(ti_kernel_size, 3).astype(np.float32)
        gaussian_kernel = np.stack([kernel, kernel, kernel])
        gaussian_kernel = np.expand_dims(gaussian_kernel, 1)
        self.gaussian_kernel = torch.from_numpy(gaussian_kernel).cuda()
        
    def forward(self, inps, labels):
        inps = inps.cuda() #(3,224,224)
        labels = labels.cuda()
        loss = nn.CrossEntropyLoss()

        momentum = torch.zeros_like(inps).cuda()
        unnorm_inps = self._mul_std_add_mean(inps)
        
        # unnorm_inps = inps

        perts = torch.zeros_like(unnorm_inps).cuda()
        perts.requires_grad_()

        for i in range(self.steps):
            outputs = self.model((self._sub_mean_div_std(unnorm_inps + perts)))
            cost = self.loss_flag * loss(outputs, labels).cuda()
            cost.backward()
            grad = perts.grad.data

            grad = F.conv2d(grad, self.gaussian_kernel, bias=None, stride=1, padding=((self.ti_kernel_size-1)//2,(self.ti_kernel_size-1)//2), groups=3) #TI

            grad = grad / torch.mean(torch.abs(grad), dim=[1,2,3], keepdim=True)
            
            grad += momentum*self.decay
            momentum = grad

            perts.data = self._update_perts(perts.data, grad, self.step_size)
            perts.data = torch.clamp(unnorm_inps.data + perts.data, 0.0, 1.0) - unnorm_inps.data
            perts.grad.data.zero_()
        return (self._sub_mean_div_std(unnorm_inps+perts.data)).detach(), None

class SSA_TI_MI_FGSM(BaseAttack):
    def __init__(self, model_name, steps=10, epsilon=16/255, target=False, decay=1.0, pre_trained=True, weight_path=False,number = 20,ti_kernel_size=5):
        super(SSA_TI_MI_FGSM, self).__init__('SSA_TI_MI_FGSM', model_name, target, pre_trained=pre_trained, weight_path=weight_path)
        self.epsilon = epsilon
        self.steps = steps
        self.step_size = self.epsilon/self.steps
        # self.step_size = 1/255
        self.decay = decay
        self.image_size = 224
        self.number = number

        self.ti_kernel_size=ti_kernel_size
        kernel = gkern(ti_kernel_size, 3).astype(np.float32)
        gaussian_kernel = np.stack([kernel, kernel, kernel])
        gaussian_kernel = np.expand_dims(gaussian_kernel, 1)
        self.gaussian_kernel = torch.from_numpy(gaussian_kernel).cuda()
        
    def forward(self, inps, labels):
        inps = inps.cuda() #(3,224,224)
        labels = labels.cuda()
        loss = nn.CrossEntropyLoss()

        momentum = torch.zeros_like(inps).cuda()
        unnorm_inps = self._mul_std_add_mean(inps)
        

        perts = torch.zeros_like(unnorm_inps).cuda()
        # perts.requires_grad_()

        for i in range(self.steps):
            grad = 0
            img_x = unnorm_inps + perts
            for _ in range(self.number):
                with torch.no_grad():
                    gauss = torch.randn_like(img_x) * self.epsilon
                    gauss = gauss.cuda()
                    x_dct = dct_2d(img_x + gauss)
                    mask = (torch.rand_like(img_x) + 0.5).cuda()
                    x_idct = idct_2d(x_dct * mask).detach()
                
                x_idct.requires_grad_()  # 启用对 x_idct 的梯度计算
                #: DI
                outputs = self.model((self._sub_mean_div_std(x_idct))) 
                cost = self.loss_flag * loss(outputs, labels).cuda()
                cost.backward()
                # grad += perts.grad.data
                grad += x_idct.grad.data

            grad = grad / self.number
            #: TI
            grad = F.conv2d(grad, self.gaussian_kernel, bias=None, stride=1, padding=((self.ti_kernel_size-1)//2,(self.ti_kernel_size-1)//2), groups=3) 

            grad = grad / torch.mean(torch.abs(grad), dim=[1,2,3], keepdim=True)

            #: MI
            grad += momentum*self.decay
            momentum = grad

            perts.data = self._update_perts(perts.data, grad, self.step_size)
            perts.data = torch.clamp(unnorm_inps.data + perts.data, 0.0, 1.0) - unnorm_inps.data
            # perts.grad.data.zero_()
        return (self._sub_mean_div_std(unnorm_inps+perts.data)).detach(), None

class SSA_TI_DI_MI_FGSM(BaseAttack):
    def __init__(self, model_name, steps=10, epsilon=16/255, target=False, decay=1.0, pre_trained=True, weight_path=False,number = 20,ti_kernel_size=5):
        super(SSA_TI_DI_MI_FGSM, self).__init__('SSA_TI_DI_MI_FGSM', model_name, target, pre_trained=pre_trained, weight_path=weight_path)
        self.epsilon = epsilon
        self.steps = steps
        self.step_size = self.epsilon/self.steps
        # self.step_size = 1/255
        self.decay = decay
        self.image_size = 224
        self.number = number

        self.ti_kernel_size=ti_kernel_size
        kernel = gkern(ti_kernel_size, 3).astype(np.float32)
        gaussian_kernel = np.stack([kernel, kernel, kernel])
        gaussian_kernel = np.expand_dims(gaussian_kernel, 1)
        self.gaussian_kernel = torch.from_numpy(gaussian_kernel).cuda()
        
    def forward(self, inps, labels):
        inps = inps.cuda() #(3,224,224)
        labels = labels.cuda()
        loss = nn.CrossEntropyLoss()

        momentum = torch.zeros_like(inps).cuda()
        unnorm_inps = self._mul_std_add_mean(inps)
        

        perts = torch.zeros_like(unnorm_inps).cuda()
        # perts.requires_grad_()

        for i in range(self.steps):
            grad = 0
            img_x = unnorm_inps + perts
            for _ in range(self.number):
                with torch.no_grad():
                    gauss = torch.randn_like(img_x) * self.epsilon
                    gauss = gauss.cuda()
                    x_dct = dct_2d(img_x + gauss)
                    mask = (torch.rand_like(img_x) + 0.5).cuda()
                    x_idct = idct_2d(x_dct * mask).detach()
                
                x_idct.requires_grad_()  # 启用对 x_idct 的梯度计算
                #: DI
                outputs = self.model((self._sub_mean_div_std(DI(x_idct)))) 
                cost = self.loss_flag * loss(outputs, labels).cuda()
                cost.backward()
                # grad += perts.grad.data
                grad += x_idct.grad.data

            grad = grad / self.number
            #: TI
            grad = F.conv2d(grad, self.gaussian_kernel, bias=None, stride=1, padding=((self.ti_kernel_size-1)//2,(self.ti_kernel_size-1)//2), groups=3) 

            grad = grad / torch.mean(torch.abs(grad), dim=[1,2,3], keepdim=True)

            #: MI
            grad += momentum*self.decay
            momentum = grad

            perts.data = self._update_perts(perts.data, grad, self.step_size)
            perts.data = torch.clamp(unnorm_inps.data + perts.data, 0.0, 1.0) - unnorm_inps.data
            # perts.grad.data.zero_()
        return (self._sub_mean_div_std(unnorm_inps+perts.data)).detach(), None

class SSA_TI_DI_MI_FGSM_vit(BaseAttack):
    def __init__(self, model_name, steps=10, epsilon=16/255, target=False, decay=1.0, pre_trained=True, weight_path=False,number = 20,ti_kernel_size=5):
        super(SSA_TI_DI_MI_FGSM_vit, self).__init__('SSA_TI_DI_MI_FGSM_vit', model_name, target, pre_trained=pre_trained, weight_path=weight_path)
        self.epsilon = epsilon
        self.steps = steps
        self.step_size = self.epsilon/self.steps
        # self.step_size = 1/255
        self.decay = decay
        self.image_size = 224
        self.number = number

        self.ti_kernel_size=ti_kernel_size
        kernel = gkern(ti_kernel_size, 3).astype(np.float32)
        gaussian_kernel = np.stack([kernel, kernel, kernel])
        gaussian_kernel = np.expand_dims(gaussian_kernel, 1)
        self.gaussian_kernel = torch.from_numpy(gaussian_kernel).cuda()
        
    def forward(self, inps, labels):
        inps = inps.cuda() #(3,224,224)
        labels = labels.cuda()
        loss = nn.CrossEntropyLoss()

        momentum = torch.zeros_like(inps).cuda()
        unnorm_inps = self._mul_std_add_mean(inps)
        

        perts = torch.zeros_like(unnorm_inps).cuda()
        # perts.requires_grad_()

        for i in range(self.steps):
            grad = 0
            img_x = unnorm_inps + perts
            for _ in range(self.number):
                with torch.no_grad():
                    gauss = torch.randn_like(img_x) * self.epsilon
                    gauss = gauss.cuda()
                    x_dct = dct_2d(img_x + gauss)
                    mask = (torch.rand_like(img_x) + 0.5).cuda()
                    x_idct = idct_2d(x_dct * mask).detach()
                
                x_idct.requires_grad_()  # 启用对 x_idct 的梯度计算
                #: DI
                outputs = self.model(self._sub_mean_div_std(F.interpolate(DI(x_idct), size=(224, 224), mode='bilinear')))
                cost = self.loss_flag * loss(outputs, labels).cuda()
                cost.backward()
                # grad += perts.grad.data
                grad += x_idct.grad.data

            grad = grad / self.number
            #: TI
            grad = F.conv2d(grad, self.gaussian_kernel, bias=None, stride=1, padding=((self.ti_kernel_size-1)//2,(self.ti_kernel_size-1)//2), groups=3) 

            grad = grad / torch.mean(torch.abs(grad), dim=[1,2,3], keepdim=True)

            #: MI
            grad += momentum*self.decay
            momentum = grad

            perts.data = self._update_perts(perts.data, grad, self.step_size)
            perts.data = torch.clamp(unnorm_inps.data + perts.data, 0.0, 1.0) - unnorm_inps.data
            # perts.grad.data.zero_()
        return (self._sub_mean_div_std(unnorm_inps+perts.data)).detach(), None

class SSA_TI_DI_GI_FGSM(BaseAttack):
    def __init__(self, model_name, steps=10, epsilon=16/255, target=False, decay=1.0, pre_trained=True, weight_path=False,number = 20,ti_kernel_size=5,pre_steps=5):
        super(SSA_TI_DI_GI_FGSM, self).__init__('SSA_TI_DI_MI_FGSM', model_name, target, pre_trained=pre_trained, weight_path=weight_path)
        self.epsilon = epsilon
        self.steps = steps
        self.step_size = self.epsilon/self.steps
        # self.step_size = 1/255
        self.decay = decay
        self.image_size = 224
        self.number = number

        self.ti_kernel_size=ti_kernel_size
        kernel = gkern(ti_kernel_size, 3).astype(np.float32)
        gaussian_kernel = np.stack([kernel, kernel, kernel])
        gaussian_kernel = np.expand_dims(gaussian_kernel, 1)
        self.gaussian_kernel = torch.from_numpy(gaussian_kernel).cuda()

        self.pre_steps = pre_steps
        
    def forward(self, inps, labels):
        inps = inps.cuda() #(3,224,224)
        labels = labels.cuda()
        loss = nn.CrossEntropyLoss()

        momentum = torch.zeros_like(inps).cuda()
        unnorm_inps = self._mul_std_add_mean(inps)
        

        perts = torch.zeros_like(unnorm_inps).cuda()
        # perts.requires_grad_()

        # for i in range(self.steps):
        for i in range(self.pre_steps + self.steps):
            if i == self.pre_steps:
                perts = torch.zeros_like(unnorm_inps).cuda()
            grad = 0
            img_x = unnorm_inps + perts
            for _ in range(self.number):
                with torch.no_grad():
                    gauss = torch.randn_like(img_x) * self.epsilon
                    gauss = gauss.cuda()
                    x_dct = dct_2d(img_x + gauss)
                    mask = (torch.rand_like(img_x) + 0.5).cuda()
                    x_idct = idct_2d(x_dct * mask).detach()
                
                x_idct.requires_grad_()  # 启用对 x_idct 的梯度计算
                #: DI
                outputs = self.model((self._sub_mean_div_std(DI(x_idct)))) 
                cost = self.loss_flag * loss(outputs, labels).cuda()
                cost.backward()
                # grad += perts.grad.data
                grad += x_idct.grad.data

            grad = grad / self.number
            #: TI
            grad = F.conv2d(grad, self.gaussian_kernel, bias=None, stride=1, padding=((self.ti_kernel_size-1)//2,(self.ti_kernel_size-1)//2), groups=3) 

            grad = grad / torch.mean(torch.abs(grad), dim=[1,2,3], keepdim=True)

            #: MI
            grad += momentum*self.decay
            momentum = grad

            perts.data = self._update_perts(perts.data, grad, self.step_size)
            perts.data = torch.clamp(unnorm_inps.data + perts.data, 0.0, 1.0) - unnorm_inps.data
            # perts.grad.data.zero_()
        return (self._sub_mean_div_std(unnorm_inps+perts.data)).detach(), None

class FDA(BaseAttack):
    def __init__(self, model_name, steps=10, epsilon=16/255, target=False, decay=1.0, pre_trained=True, weight_path=False):
        super(FDA, self).__init__('FDA', model_name, target, pre_trained=pre_trained, weight_path=weight_path)
        self.epsilon = epsilon
        self.steps = steps
        self.step_size = self.epsilon/self.steps
        # self.step_size = 1/255
        self.decay = decay
        self.image_size = 224
        
        self.get_opt_layers()

    def get_opt_layers(self):
  
        self.opt_operations = []  # 清空之前的记录
        visited_layers = set()  # 避免重复注册 hook

        # 定义 hook 函数，用于捕获层输出
        def hook_fn(module, input, output):
            # 确保输出与设备一致
            self.opt_operations.append(output.to(input[0].device))

        # 遍历模型的所有层
        for name, layer in self.model.named_modules():
            # 避免重复注册
            if layer in visited_layers:
                continue

            # 针对 DenseNet121 和 ResNet50
            if isinstance(layer, (nn.ReLU, nn.AvgPool2d, nn.BatchNorm2d, nn.AdaptiveAvgPool2d)):
                layer.register_forward_hook(hook_fn)
                visited_layers.add(layer)

            # 针对 ViT 和 Swin Transformer
            elif isinstance(layer, (nn.LayerNorm, nn.Linear)):  # LayerNorm 和 MLP 层
                layer.register_forward_hook(hook_fn)
                visited_layers.add(layer)
            elif "attention" in name.lower():  # 捕获 Attention 模块
                layer.register_forward_hook(hook_fn)
                visited_layers.add(layer)

        print(f"Registered hooks for layers of interest. Total: {len(visited_layers)} layers")


    def get_fda_loss(self):

        total_loss = 0.0

        for layer in self.opt_operations:
            batch_size = layer.shape[0] // 2
            tensor = layer[:batch_size] 

            mean_tensor = torch.mean(tensor, dim=-1, keepdim=True)
            mean_tensor = mean_tensor.expand_as(tensor)

            wts_good = (tensor < mean_tensor).float() 
            wts_bad = (tensor >= mean_tensor).float()  

            adv_tensor = layer[batch_size:]
            adv_tensor = adv_tensor / float(torch.numel(adv_tensor)) 

            l2_good = 0.5 * torch.nn.functional.mse_loss(wts_good * adv_tensor, torch.zeros_like(adv_tensor), reduction='sum')
            l2_bad = 0.5 * torch.nn.functional.mse_loss(wts_bad * adv_tensor, torch.zeros_like(adv_tensor), reduction='sum')

            epsilon = 1e-8
            l2_good = torch.clamp(l2_good, min=epsilon)
            l2_bad = torch.clamp(l2_bad, min=epsilon)

            total_loss += torch.log(l2_good) - torch.log(l2_bad)

        avg_loss = total_loss / len(self.opt_operations)
        return avg_loss
        
    def forward(self, inps, labels):
        inps = inps.cuda() #(3,224,224)
        labels = labels.cuda()
        # loss = nn.CrossEntropyLoss()

        # momentum = torch.zeros_like(inps).cuda()
        unnorm_inps = self._mul_std_add_mean(inps)
        
        # unnorm_inps = inps

        perts = torch.zeros_like(unnorm_inps).cuda()
        perts.requires_grad_()

        for i in range(self.steps):
            self.opt_operations.clear()  # 清空特征记录

            input_batch = torch.cat([unnorm_inps, unnorm_inps + perts], dim=0)
            outputs = self.model((self._sub_mean_div_std(input_batch))) 
            cost = self.loss_flag * self.get_fda_loss().cuda()
            cost.backward()

            grad = perts.grad.data
            grad = grad / torch.mean(torch.abs(grad), dim=[1,2,3], keepdim=True)

            #: MI
            # grad += momentum*self.decay
            # momentum = grad

            perts.data = self._update_perts(perts.data, grad, self.step_size)
            perts.data = torch.clamp(unnorm_inps.data + perts.data, 0.0, 1.0) - unnorm_inps.data
            perts.grad.data.zero_()
        return (self._sub_mean_div_std(unnorm_inps+perts.data)).detach(), None
    
class FDA_MI(BaseAttack):
    def __init__(self, model_name, steps=10, epsilon=16/255, target=False, decay=1.0, pre_trained=True, weight_path=False):
        super(FDA_MI, self).__init__('FDA_MI', model_name, target, pre_trained=pre_trained, weight_path=weight_path)
        self.epsilon = epsilon
        self.steps = steps
        self.step_size = self.epsilon/self.steps
        # self.step_size = 1/255
        self.decay = decay
        self.image_size = 224
        
        self.get_opt_layers()

    def get_opt_layers(self):

        self.opt_operations = []  # 清空之前的记录
        visited_layers = set()  # 避免重复注册 hook

        # 定义 hook 函数，用于捕获层输出
        def hook_fn(module, input, output):
            # 确保输出与设备一致
            self.opt_operations.append(output.to(input[0].device))

        # 遍历模型的所有层
        for name, layer in self.model.named_modules():
            # 避免重复注册
            if layer in visited_layers:
                continue

            # 针对 DenseNet121 和 ResNet50
            if isinstance(layer, (nn.ReLU, nn.AvgPool2d, nn.BatchNorm2d, nn.AdaptiveAvgPool2d)):
                layer.register_forward_hook(hook_fn)
                visited_layers.add(layer)

            # 针对 ViT 和 Swin Transformer
            elif isinstance(layer, (nn.LayerNorm, nn.Linear)):  # LayerNorm 和 MLP 层
                layer.register_forward_hook(hook_fn)
                visited_layers.add(layer)
            elif "attention" in name.lower():  # 捕获 Attention 模块
                layer.register_forward_hook(hook_fn)
                visited_layers.add(layer)

        print(f"Registered hooks for layers of interest. Total: {len(visited_layers)} layers")


    def get_fda_loss(self):

        total_loss = 0.0

        for layer in self.opt_operations:
            batch_size = layer.shape[0] // 2
            tensor = layer[:batch_size] 

            mean_tensor = torch.mean(tensor, dim=-1, keepdim=True)
            mean_tensor = mean_tensor.expand_as(tensor)

            wts_good = (tensor < mean_tensor).float()  
            wts_bad = (tensor >= mean_tensor).float() 

            adv_tensor = layer[batch_size:]
            adv_tensor = adv_tensor / float(torch.numel(adv_tensor)) 
            l2_good = 0.5 * torch.nn.functional.mse_loss(wts_good * adv_tensor, torch.zeros_like(adv_tensor), reduction='sum')
            l2_bad = 0.5 * torch.nn.functional.mse_loss(wts_bad * adv_tensor, torch.zeros_like(adv_tensor), reduction='sum')

            epsilon = 1e-8
            l2_good = torch.clamp(l2_good, min=epsilon)
            l2_bad = torch.clamp(l2_bad, min=epsilon)

            total_loss += torch.log(l2_good) - torch.log(l2_bad)

        avg_loss = total_loss / len(self.opt_operations)
        return avg_loss
        
    def forward(self, inps, labels):
        inps = inps.cuda() #(3,224,224)
        labels = labels.cuda()
        # loss = nn.CrossEntropyLoss()

        momentum = torch.zeros_like(inps).cuda()
        unnorm_inps = self._mul_std_add_mean(inps)
        
        # unnorm_inps = inps

        perts = torch.zeros_like(unnorm_inps).cuda()
        perts.requires_grad_()

        for i in range(self.steps):
            self.opt_operations.clear()  # 清空特征记录

            input_batch = torch.cat([unnorm_inps, unnorm_inps + perts], dim=0)
            outputs = self.model((self._sub_mean_div_std(input_batch))) 
            cost = self.loss_flag * self.get_fda_loss().cuda()
            cost.backward()

            grad = perts.grad.data
            grad = grad / torch.mean(torch.abs(grad), dim=[1,2,3], keepdim=True)

            #: MI
            grad += momentum*self.decay
            momentum = grad

            perts.data = self._update_perts(perts.data, grad, self.step_size)
            perts.data = torch.clamp(unnorm_inps.data + perts.data, 0.0, 1.0) - unnorm_inps.data
            perts.grad.data.zero_()
        return (self._sub_mean_div_std(unnorm_inps+perts.data)).detach(), None
    
class FIA(BaseAttack):
    def __init__(self, model_name, steps=10, epsilon=16/255, target=False, decay=1.0, pre_trained=True, weight_path=False, ens=30, probb=0.7):
        super(FIA, self).__init__('FIA', model_name, target, pre_trained=pre_trained, weight_path=weight_path)
        self.epsilon = epsilon
        self.steps = steps
        self.step_size = self.epsilon/self.steps
        # self.step_size = 1/255
        self.decay = decay
        self.image_size = 224
        self.ens = ens
        self.probb = probb

        layer_name = {
            # 'resnet50': 'layer2.1.conv2'
            'resnet50': 'layer3.1.conv2'
        }
        
        self.get_opt_layers(layer_name=layer_name[model_name])

    def get_opt_layers(self,layer_name):
        """
        动态捕获不同架构的特定层输出。
        """
        self.opt_operations = []  # 清空之前的记录
        visited_layers = set()  # 避免重复注册 hook

        # 定义 hook 函数，用于捕获层输出
        def hook_fn(module, input, output):
            # 确保输出与设备一致
            if not output.requires_grad:
                output.requires_grad = True
            self.opt_operations.append(output.to(input[0].device))

        # 遍历模型的所有层
        for name, layer in self.model.named_modules():
            if name == layer_name:
                # 避免重复注册
                if layer in visited_layers:
                    continue
                layer.register_forward_hook(hook_fn)
                visited_layers.add(layer)

        print(f"Registered hooks for layers of interest. Total: {len(visited_layers)} layers")


    def get_fia_loss(self, weights):
        """
        计算 FIA 损失的 PyTorch 实现。

        Args:
            opt_operations (list): 捕获的目标层特征图列表，每个特征图是一个 Tensor。
            weights (torch.Tensor): 权重张量，计算自目标层特征。
            batch_size (int): 批大小。

        Returns:
            torch.Tensor: 平均 FIA 损失。
        """
        loss = 0.0

        for layer in self.opt_operations:
            batch_size = layer.shape[0] // 2
            # 将特征图分为正常样本部分和对抗样本部分
            ori_tensor = layer[:batch_size]  # 正常样本特征
            adv_tensor = layer[batch_size:]  # 对抗样本特征

            # 计算损失 (adv_tensor * weights) / 特征图大小
            layer_loss = torch.sum(adv_tensor * weights) / layer.numel()

            # 如果需要计算绝对差值版本的损失，可以用以下替换
            # layer_loss = torch.sum(weights * torch.abs(adv_tensor - ori_tensor)) / layer.numel()

            loss += layer_loss

        # 平均化损失
        loss = loss / len(self.opt_operations)
        return loss
        
    def forward(self, inps, labels):
        inps = inps.cuda() #(3,224,224)
        labels = labels.cuda()
        batch_size = inps.shape[0]

        # momentum = torch.zeros_like(inps).cuda()
        #: 数据反标准化
        unnorm_inps = self._mul_std_add_mean(inps)
        weight_np = None
        # unnorm_inps = inps

        #: 扰动
        perts = torch.zeros_like(unnorm_inps).cuda()
        perts.requires_grad_()

        for i in range(self.steps):
            # self.opt_operations.clear()  # 清空特征记录

            #: 标准化对抗样本
            input_batch = torch.cat([unnorm_inps, unnorm_inps + perts], dim=0)

            label_ph = torch.cat([labels, labels],dim=0)
            label_ph = F.one_hot(label_ph, num_classes=1000).float()

            if i == 0:
                if self.ens == 0:
                    self.opt_operations.clear()  # 清空特征记录
                    logits = self.model((self._sub_mean_div_std(input_batch))) 
                    
                    weights_tensor = torch.autograd.grad(
                        outputs=(logits * label_ph).sum(),
                        inputs=self.opt_operations[0],  # opt_operations 是目标层的输出
                        create_graph=True,
                        allow_unused=True
                    )[0]
                    weight_np = weights_tensor[:batch_size]
                for l in range(int(self.ens)):
                    self.opt_operations.clear()  # 清空特征记录
                    mask = torch.bernoulli(torch.full(input_batch.shape, self.probb)).cuda()
                    input_tmp = input_batch * mask
                    logits = self.model((self._sub_mean_div_std(input_tmp)))
                    
                    # print("opt_operations[0]_len:", len(self.opt_operations[0]))
                    weights_tensor = torch.autograd.grad(
                        outputs=(logits * label_ph).sum(),
                        inputs=self.opt_operations[0],  # opt_operations 是目标层的输出
                        create_graph=True,
                        allow_unused=True
                    )[0]
                    # print(weights_tensor.shape)
                    weight_np = weights_tensor[:batch_size] if weight_np is None else weight_np + weights_tensor[:batch_size]
            
            weight_np = -F.normalize(weight_np, p=2, dim=-1)

            self.opt_operations.clear()  # 清空特征记录

            outputs = self.model((self._sub_mean_div_std(input_batch)))
            cost = self.loss_flag * self.get_fia_loss(weights=weight_np).cuda()
            cost.backward(retain_graph=True)

            grad = perts.grad.data
            grad = grad / torch.mean(torch.abs(grad), dim=[1,2,3], keepdim=True)

            #: MI
            # grad += momentum*self.decay
            # momentum = grad

            perts.data = self._update_perts(perts.data, grad, self.step_size)
            perts.data = torch.clamp(unnorm_inps.data + perts.data, 0.0, 1.0) - unnorm_inps.data
            perts.grad.data.zero_()
        return (self._sub_mean_div_std(unnorm_inps+perts.data)).detach(), None
    
class FIA_MI(BaseAttack):
    def __init__(self, model_name, steps=10, epsilon=16/255, target=False, decay=1.0, pre_trained=True, weight_path=False, ens=30, probb=0.7):
        super(FIA_MI, self).__init__('FIA_MI', model_name, target, pre_trained=pre_trained, weight_path=weight_path)
        self.epsilon = epsilon
        self.steps = steps
        self.step_size = self.epsilon/self.steps
        # self.step_size = 1/255
        self.decay = decay
        self.image_size = 224
        self.ens = ens
        self.probb = probb

        layer_name = {
            'resnet50': 'layer2.1.conv2'
        }
        
        self.get_opt_layers(layer_name=layer_name[model_name])

    def get_opt_layers(self,layer_name):
        """
        动态捕获不同架构的特定层输出。
        """
        self.opt_operations = []  # 清空之前的记录
        visited_layers = set()  # 避免重复注册 hook

        # 定义 hook 函数，用于捕获层输出
        def hook_fn(module, input, output):
            # 确保输出与设备一致
            if not output.requires_grad:
                output.requires_grad = True
            self.opt_operations.append(output.to(input[0].device))

        # 遍历模型的所有层
        for name, layer in self.model.named_modules():
            if name == layer_name:
                # 避免重复注册
                if layer in visited_layers:
                    continue
                layer.register_forward_hook(hook_fn)
                visited_layers.add(layer)

        print(f"Registered hooks for layers of interest. Total: {len(visited_layers)} layers")


    def get_fia_loss(self, weights):
        """
        计算 FIA 损失的 PyTorch 实现。

        Args:
            opt_operations (list): 捕获的目标层特征图列表，每个特征图是一个 Tensor。
            weights (torch.Tensor): 权重张量，计算自目标层特征。
            batch_size (int): 批大小。

        Returns:
            torch.Tensor: 平均 FIA 损失。
        """
        loss = 0.0

        for layer in self.opt_operations:
            batch_size = layer.shape[0] // 2
            # 将特征图分为正常样本部分和对抗样本部分
            ori_tensor = layer[:batch_size]  # 正常样本特征
            adv_tensor = layer[batch_size:]  # 对抗样本特征

            # 计算损失 (adv_tensor * weights) / 特征图大小
            layer_loss = torch.sum(adv_tensor * weights) / layer.numel()

            # 如果需要计算绝对差值版本的损失，可以用以下替换
            # layer_loss = torch.sum(weights * torch.abs(adv_tensor - ori_tensor)) / layer.numel()

            loss += layer_loss

        # 平均化损失
        loss = loss / len(self.opt_operations)
        return loss
        
    def forward(self, inps, labels):
        inps = inps.cuda() #(3,224,224)
        labels = labels.cuda()
        batch_size = inps.shape[0]

        momentum = torch.zeros_like(inps).cuda()
        #: 数据反标准化
        unnorm_inps = self._mul_std_add_mean(inps)
        weight_np = None
        # unnorm_inps = inps

        #: 扰动
        perts = torch.zeros_like(unnorm_inps).cuda()
        perts.requires_grad_()

        for i in range(self.steps):
            # self.opt_operations.clear()  # 清空特征记录

            #: 标准化对抗样本
            input_batch = torch.cat([unnorm_inps, unnorm_inps + perts], dim=0)

            label_ph = torch.cat([labels, labels],dim=0)
            label_ph = F.one_hot(label_ph, num_classes=1000).float()

            if i == 0:
                if self.ens == 0:
                    self.opt_operations.clear()  # 清空特征记录
                    logits = self.model((self._sub_mean_div_std(input_batch))) 
                    
                    weights_tensor = torch.autograd.grad(
                        outputs=(logits * label_ph).sum(),
                        inputs=self.opt_operations[0],  # opt_operations 是目标层的输出
                        create_graph=True,
                        allow_unused=True
                    )[0]
                    weight_np = weights_tensor[:batch_size]
                for l in range(int(self.ens)):
                    self.opt_operations.clear()  # 清空特征记录
                    mask = torch.bernoulli(torch.full(input_batch.shape, self.probb)).cuda()
                    input_tmp = input_batch * mask
                    logits = self.model((self._sub_mean_div_std(input_tmp)))
                    
                    # print("opt_operations[0]_len:", len(self.opt_operations[0]))
                    weights_tensor = torch.autograd.grad(
                        outputs=(logits * label_ph).sum(),
                        inputs=self.opt_operations[0],  # opt_operations 是目标层的输出
                        create_graph=True,
                        allow_unused=True
                    )[0]
                    # print(weights_tensor.shape)
                    weight_np = weights_tensor[:batch_size] if weight_np is None else weight_np + weights_tensor[:batch_size]
            
            weight_np = -F.normalize(weight_np, p=2, dim=-1)

            self.opt_operations.clear()  # 清空特征记录

            outputs = self.model((self._sub_mean_div_std(input_batch)))
            cost = self.loss_flag * self.get_fia_loss(weights=weight_np).cuda()
            cost.backward(retain_graph=True)

            grad = perts.grad.data
            grad = grad / torch.mean(torch.abs(grad), dim=[1,2,3], keepdim=True)

            #: MI
            grad += momentum*self.decay
            momentum = grad

            perts.data = self._update_perts(perts.data, grad, self.step_size)
            perts.data = torch.clamp(unnorm_inps.data + perts.data, 0.0, 1.0) - unnorm_inps.data
            perts.grad.data.zero_()
        return (self._sub_mean_div_std(unnorm_inps+perts.data)).detach(), None


class TGR(BaseAttack):
    def __init__(self, model_name, steps=10, epsilon=16/255, target=False, decay=1.0, pre_trained=True, weight_path=False, sample_num_batches=130):
        super(TGR, self).__init__('TGR', model_name, target, pre_trained=pre_trained, weight_path=weight_path)

        self.epsilon = epsilon
        self.steps = steps
        self.step_size = self.epsilon/self.steps
        self.decay = decay

        self.image_size = 224
        self.crop_length = 16
        self.sample_num_batches = sample_num_batches
        self.max_num_batches = int((224/16)**2)
        assert self.sample_num_batches <= self.max_num_batches
        self._register_model()

    
    def _register_model(self):   
        def attn_tgr(module, grad_in, grad_out, gamma):
            mask = torch.ones_like(grad_in[0]) * gamma
            out_grad = mask * grad_in[0][:]
            if self.model_name in ['vit_base_patch16_224', 'visformer_small', 'pit_b_224']:
                B,C,H,W = grad_in[0].shape
                out_grad_cpu = out_grad.data.clone().cpu().numpy().reshape(B,C,H*W)
                max_all = np.argmax(out_grad_cpu[0,:,:], axis = 1)
                max_all_H = max_all//H
                max_all_W = max_all%H
                min_all = np.argmin(out_grad_cpu[0,:,:], axis = 1)
                min_all_H = min_all//H
                min_all_W = min_all%H
                out_grad[:,range(C),max_all_H,:] = 0.0
                out_grad[:,range(C),:,max_all_W] = 0.0
                out_grad[:,range(C),min_all_H,:] = 0.0
                out_grad[:,range(C),:,min_all_W] = 0.0
                
            if self.model_name in ['cait_s24_224']:
                B,H,W,C = grad_in[0].shape
                out_grad_cpu = out_grad.data.clone().cpu().numpy().reshape(B, H*W, C)
                max_all = np.argmax(out_grad_cpu[0,:,:], axis = 0)
                max_all_H = max_all//H
                max_all_W = max_all%H
                min_all = np.argmin(out_grad_cpu[0,:,:], axis = 0)
                min_all_H = min_all//H
                min_all_W = min_all%H
                
                out_grad[:,max_all_H,:,range(C)] = 0.0
                out_grad[:,:,max_all_W,range(C)] = 0.0
                out_grad[:,min_all_H,:,range(C)] = 0.0
                out_grad[:,:,min_all_W,range(C)] = 0.0

            return (out_grad, )
        
        def attn_cait_tgr(module, grad_in, grad_out, gamma):
            mask = torch.ones_like(grad_in[0]) * gamma
            out_grad = mask * grad_in[0][:]
            
            B,H,W,C = grad_in[0].shape
            out_grad_cpu = out_grad.data.clone().cpu().numpy()
            max_all = np.argmax(out_grad_cpu[0,:,0,:], axis = 0)
            min_all = np.argmin(out_grad_cpu[0,:,0,:], axis = 0)
                
            out_grad[:,max_all,:,range(C)] = 0.0
            out_grad[:,min_all,:,range(C)] = 0.0
            return (out_grad, )
            
        def q_tgr(module, grad_in, grad_out, gamma):
            # cait Q only uses class token
            mask = torch.ones_like(grad_in[0]) * gamma
            out_grad = mask * grad_in[0][:]
            out_grad[:] = 0.0
            return (out_grad, grad_in[1], grad_in[2])
            
        def v_tgr(module, grad_in, grad_out, gamma):
            mask = torch.ones_like(grad_in[0]) * gamma
            out_grad = mask * grad_in[0][:]

            if self.model_name in ['visformer_small']:
                B,C,H,W = grad_in[0].shape
                out_grad_cpu = out_grad.data.clone().cpu().numpy().reshape(B,C,H*W)
                max_all = np.argmax(out_grad_cpu[0,:,:], axis = 1)
                max_all_H = max_all//H
                max_all_W = max_all%H
                min_all = np.argmin(out_grad_cpu[0,:,:], axis = 1)
                min_all_H = min_all//H
                min_all_W = min_all%H
                out_grad[:,range(C),max_all_H,max_all_W] = 0.0
                out_grad[:,range(C),min_all_H,min_all_W] = 0.0

            if self.model_name in ['vit_base_patch16_224', 'pit_b_224', 'cait_s24_224']:
                c = grad_in[0].shape[2]
                out_grad_cpu = out_grad.data.clone().cpu().numpy()
                max_all = np.argmax(out_grad_cpu[0,:,:], axis = 0)
                min_all = np.argmin(out_grad_cpu[0,:,:], axis = 0)
                    
                out_grad[:,max_all,range(c)] = 0.0
                out_grad[:,min_all,range(c)] = 0.0
            return (out_grad, grad_in[1])
        
        def mlp_tgr(module, grad_in, grad_out, gamma):
            mask = torch.ones_like(grad_in[0]) * gamma
            out_grad = mask * grad_in[0][:]
            if self.model_name in ['visformer_small']:
                B,C,H,W = grad_in[0].shape
                out_grad_cpu = out_grad.data.clone().cpu().numpy().reshape(B,C,H*W)
                max_all = np.argmax(out_grad_cpu[0,:,:], axis = 1)
                max_all_H = max_all//H
                max_all_W = max_all%H
                min_all = np.argmin(out_grad_cpu[0,:,:], axis = 1)
                min_all_H = min_all//H
                min_all_W = min_all%H
                out_grad[:,range(C),max_all_H,max_all_W] = 0.0
                out_grad[:,range(C),min_all_H,min_all_W] = 0.0
            if self.model_name in ['vit_base_patch16_224', 'pit_b_224', 'cait_s24_224', 'resnetv2_101']:
                c = grad_in[0].shape[2]
                out_grad_cpu = out_grad.data.clone().cpu().numpy()
        
                max_all = np.argmax(out_grad_cpu[0,:,:], axis = 0)
                min_all = np.argmin(out_grad_cpu[0,:,:], axis = 0)
                out_grad[:,max_all,range(c)] = 0.0
                out_grad[:,min_all,range(c)] = 0.0
            for i in range(len(grad_in)):
                if i == 0:
                    return_dics = (out_grad,)
                else:
                    return_dics = return_dics + (grad_in[i],)
            return return_dics
                

        attn_tgr_hook = partial(attn_tgr, gamma=0.25)
        attn_cait_tgr_hook = partial(attn_cait_tgr, gamma=0.25)
        v_tgr_hook = partial(v_tgr, gamma=0.75)
        q_tgr_hook = partial(q_tgr, gamma=0.75)
        
        mlp_tgr_hook = partial(mlp_tgr, gamma=0.5)

        if self.model_name in ['vit_base_patch16_224' ,'deit_base_distilled_patch16_224']:
            for i in range(12):
                self.model.blocks[i].attn.attn_drop.register_backward_hook(attn_tgr_hook)
                self.model.blocks[i].attn.qkv.register_backward_hook(v_tgr_hook)
                self.model.blocks[i].mlp.register_backward_hook(mlp_tgr_hook)
        elif self.model_name == 'pit_b_224':
            for block_ind in range(13):
                if block_ind < 3:
                    transformer_ind = 0
                    used_block_ind = block_ind
                elif block_ind < 9 and block_ind >= 3:
                    transformer_ind = 1
                    used_block_ind = block_ind - 3
                elif block_ind < 13 and block_ind >= 9:
                    transformer_ind = 2
                    used_block_ind = block_ind - 9
                self.model.transformers[transformer_ind].blocks[used_block_ind].attn.attn_drop.register_backward_hook(attn_tgr_hook)
                self.model.transformers[transformer_ind].blocks[used_block_ind].attn.qkv.register_backward_hook(v_tgr_hook)
                self.model.transformers[transformer_ind].blocks[used_block_ind].mlp.register_backward_hook(mlp_tgr_hook)
        elif self.model_name == 'cait_s24_224':
            for block_ind in range(26):
                if block_ind < 24:
                    self.model.blocks[block_ind].attn.attn_drop.register_backward_hook(attn_tgr_hook)
                    self.model.blocks[block_ind].attn.qkv.register_backward_hook(v_tgr_hook)
                    self.model.blocks[block_ind].mlp.register_backward_hook(mlp_tgr_hook)
                elif block_ind > 24:
                    self.model.blocks_token_only[block_ind-24].attn.attn_drop.register_backward_hook(attn_cait_tgr_hook)
                    self.model.blocks_token_only[block_ind-24].attn.q.register_backward_hook(q_tgr_hook)
                    self.model.blocks_token_only[block_ind-24].attn.k.register_backward_hook(v_tgr_hook)
                    self.model.blocks_token_only[block_ind-24].attn.v.register_backward_hook(v_tgr_hook)
                    self.model.blocks_token_only[block_ind-24].mlp.register_backward_hook(mlp_tgr_hook)
        elif self.model_name == 'visformer_small':
            for block_ind in range(8):
                if block_ind < 4:
                    self.model.stage2[block_ind].attn.attn_drop.register_backward_hook(attn_tgr_hook)
                    self.model.stage2[block_ind].attn.qkv.register_backward_hook(v_tgr_hook)
                    self.model.stage2[block_ind].mlp.register_backward_hook(mlp_tgr_hook)
                elif block_ind >=4:
                    self.model.stage3[block_ind-4].attn.attn_drop.register_backward_hook(attn_tgr_hook)
                    self.model.stage3[block_ind-4].attn.qkv.register_backward_hook(v_tgr_hook)
                    self.model.stage3[block_ind-4].mlp.register_backward_hook(mlp_tgr_hook)

    def _generate_samples_for_interactions(self, perts, seed):
        add_noise_mask = torch.zeros_like(perts)
        grid_num_axis = int(self.image_size/self.crop_length)

        # Unrepeatable sampling
        ids = [i for i in range(self.max_num_batches)]
        random.seed(seed)
        random.shuffle(ids)
        ids = np.array(ids[:self.sample_num_batches])

        # Repeatable sampling
        # ids = np.random.randint(0, self.max_num_batches, size=self.sample_num_batches)
        rows, cols = ids // grid_num_axis, ids % grid_num_axis
        flag = 0
        for r, c in zip(rows, cols):
            add_noise_mask[:,:,r*self.crop_length:(r+1)*self.crop_length,c*self.crop_length:(c+1)*self.crop_length] = 1
        add_perturbation = perts * add_noise_mask
        return add_perturbation

    def forward(self, inps, labels):
        inps = inps.cuda()
        labels = labels.cuda()
        loss = nn.CrossEntropyLoss()

        momentum = torch.zeros_like(inps).cuda()
        #: 数据反标准化
        unnorm_inps = self._mul_std_add_mean(inps)
        #: 扰动
        perts = torch.zeros_like(unnorm_inps).cuda()
        perts.requires_grad_()

        for i in range(self.steps):
            
            #add_perturbation = self._generate_samples_for_interactions(perts, i)
            #outputs = self.model((self._sub_mean_div_std(unnorm_inps + add_perturbation)))

            ##### If you use patch out, please uncomment the previous two lines and comment the next line.
            #: 标准化对抗样本
            outputs = self.model((self._sub_mean_div_std(unnorm_inps + perts)))
            cost = self.loss_flag * loss(outputs, labels).cuda()
            cost.backward()
            grad = perts.grad.data
            grad = grad / torch.mean(torch.abs(grad), dim=[1,2,3], keepdim=True)
            #: MI
            grad += momentum*self.decay
            momentum = grad

            perts.data = self._update_perts(perts.data, grad, self.step_size)
            perts.data = torch.clamp(unnorm_inps.data + perts.data, 0.0, 1.0) - unnorm_inps.data
            perts.grad.data.zero_()
        return (self._sub_mean_div_std(unnorm_inps+perts.data)).detach(), None
    
class SGM(BaseAttack):
    def __init__(self, model_name, steps=10, epsilon=16/255, target=False, decay=1.0, pre_trained=True, weight_path=False, gamma=0.5):
        super(SGM, self).__init__('SGM', model_name, target, pre_trained=pre_trained, weight_path=weight_path)
        self.epsilon = epsilon
        self.steps = steps
        self.step_size = self.epsilon/self.steps
        # self.step_size = 1/255
        self.decay = decay
        self.image_size = 224
        self.gamma = gamma
        self._register_hook()
    
    def _register_hook(self):
        if self.gamma < 1.0:
            if self.model_name in ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']:
                # print('ok')
                register_hook_for_resnet(self.model, arch=self.model_name, gamma=self.gamma)
            elif self.model_name in ['densenet121', 'densenet169', 'densenet201']:
                register_hook_for_densenet(self.model, arch=self.model_name, gamma=self.gamma)
            elif self.model_name in ['vit_base_patch16_224' ,'deit_base_distilled_patch16_224']:
                print(f'using sgm in {self.model_name}')
                register_hook_for_vit(self.model, arch=self.model_name)
            else:
                raise ValueError('Current code only supports resnet/densenet. '
                                'You can extend this code to other architectures.')
    def forward(self, inps, labels):
        inps = inps.cuda() #(3,224,224)
        labels = labels.cuda()
        loss = nn.CrossEntropyLoss()

        momentum = torch.zeros_like(inps).cuda()
        #: 数据反标准化
        unnorm_inps = self._mul_std_add_mean(inps)
        
        # unnorm_inps = inps

        #: 扰动
        perts = torch.zeros_like(unnorm_inps).cuda()
        perts.requires_grad_()

        for i in range(self.steps):
            #: 标准化对抗样本
            outputs = self.model((self._sub_mean_div_std(unnorm_inps + perts))) 
            cost = self.loss_flag * loss(outputs, labels).cuda()
            cost.backward()
            grad = perts.grad.data
            grad = grad / torch.mean(torch.abs(grad), dim=[1,2,3], keepdim=True)

            #: MI
            grad += momentum*self.decay
            momentum = grad

            perts.data = self._update_perts(perts.data, grad, self.step_size)
            perts.data = torch.clamp(unnorm_inps.data + perts.data, 0.0, 1.0) - unnorm_inps.data
            perts.grad.data.zero_()
        return (self._sub_mean_div_std(unnorm_inps+perts.data)).detach(), None
    

#: PNA + PatchOut
class PNA(BaseAttack):
    def __init__(self, model_name, steps=10, epsilon=16/255, target=False, decay=1.0, pre_trained=True, weight_path=False, ablation_study='0,1,1', sample_num_batches=130, lamb=0.1):
        super(PNA, self).__init__('PNA', model_name, target, pre_trained=pre_trained, weight_path=weight_path)
        self.epsilon = epsilon
        self.steps = steps
        self.step_size = self.epsilon/self.steps

        self.ablation_study = ablation_study.split(',')
        self.lamb = lamb
        self.image_size = 224
        self.crop_length = 16
        self.sample_num_batches = sample_num_batches
        self.max_num_batches = int((224/16)**2)
        assert self.sample_num_batches <= self.max_num_batches

        #: PNA
        if self.ablation_study[2] == '1':
            print ('Using Skip')
            self._register_model()
        else:
            print ('Not Using Skip')
    
    def _register_model(self):   
        def attn_drop_mask_grad(module, grad_in, grad_out, gamma):
            mask = torch.ones_like(grad_in[0]) * gamma
            return (mask * grad_in[0][:], )

        drop_hook_func = partial(attn_drop_mask_grad, gamma=0)

        if self.model_name in ['vit_base_patch16_224' ,'deit_base_distilled_patch16_224']:
                for i in range(12):
                    self.model.blocks[i].attn.attn_drop.register_backward_hook(drop_hook_func)
        elif self.model_name == 'pit_b_224':
            for block_ind in range(13):
                if block_ind < 3:
                    transformer_ind = 0
                    used_block_ind = block_ind
                elif block_ind < 9 and block_ind >= 3:
                    transformer_ind = 1
                    used_block_ind = block_ind - 3
                elif block_ind < 13 and block_ind >= 9:
                    transformer_ind = 2
                    used_block_ind = block_ind - 9
                self.model.transformers[transformer_ind].blocks[used_block_ind].attn.attn_drop.register_backward_hook(drop_hook_func)
        elif self.model_name == 'cait_s24_224':
            for block_ind in range(26):
                if block_ind < 24:
                    self.model.blocks[block_ind].attn.attn_drop.register_backward_hook(drop_hook_func)
                elif block_ind > 24:
                    self.model.blocks_token_only[block_ind-24].attn.attn_drop.register_backward_hook(drop_hook_func)
        elif self.model_name == 'visformer_small':
            for block_ind in range(8):
                if block_ind < 4:
                    self.model.stage2[block_ind].attn.attn_drop.register_backward_hook(drop_hook_func)
                elif block_ind >=4:
                    self.model.stage3[block_ind-4].attn.attn_drop.register_backward_hook(drop_hook_func)

    def _generate_samples_for_interactions(self, perts, seed):
        add_noise_mask = torch.zeros_like(perts)
        grid_num_axis = int(self.image_size/self.crop_length)

        # Unrepeatable sampling
        ids = [i for i in range(self.max_num_batches)]
        random.seed(seed)
        random.shuffle(ids)
        ids = np.array(ids[:self.sample_num_batches])

        # Repeatable sampling
        # ids = np.random.randint(0, self.max_num_batches, size=self.sample_num_batches)
        rows, cols = ids // grid_num_axis, ids % grid_num_axis
        flag = 0
        for r, c in zip(rows, cols):
            add_noise_mask[:,:,r*self.crop_length:(r+1)*self.crop_length,c*self.crop_length:(c+1)*self.crop_length] = 1
        add_perturbation = perts * add_noise_mask
        return add_perturbation

    def forward(self, inps, labels):
        inps = inps.cuda()
        labels = labels.cuda()
        loss = nn.CrossEntropyLoss()

        unnorm_inps = self._mul_std_add_mean(inps)
        perts = torch.zeros_like(unnorm_inps).cuda()
        perts.requires_grad_()

        for i in range(self.steps):
            if self.ablation_study[0] == '1':
                print ('Using Pathes')
                add_perturbation = self._generate_samples_for_interactions(perts, i)
                outputs = self.model((self._sub_mean_div_std(unnorm_inps + add_perturbation)))
            else:
                print ('Not Using Pathes')
                outputs = self.model((self._sub_mean_div_std(unnorm_inps + perts)))

            if self.ablation_study[1] == '1':
                print ('Using L2')
                cost1 = self.loss_flag * loss(outputs, labels).cuda()
                cost2 = torch.norm(perts)
                cost = cost1 + self.lamb * cost2
            else:
                print ('Not Using L2')
                cost = self.loss_flag * loss(outputs, labels).cuda()
            cost.backward()
            grad = perts.grad.data
            perts.data = self._update_perts(perts.data, grad, self.step_size)
            perts.data = torch.clamp(unnorm_inps.data + perts.data, 0.0, 1.0) - unnorm_inps.data
            perts.grad.data.zero_()
        return (self._sub_mean_div_std(unnorm_inps+perts.data)).detach(), None
    
class PatchOut(BaseAttack):
    def __init__(self, model_name, steps=10, epsilon=16/255, target=False, decay=1.0, pre_trained=True, weight_path=False, ablation_study='1,1,0', sample_num_batches=130, lamb=0.1):
        super(PatchOut, self).__init__('PatchOut', model_name, target, pre_trained=pre_trained, weight_path=weight_path)
        self.epsilon = epsilon
        self.steps = steps
        self.step_size = self.epsilon/self.steps

        self.ablation_study = ablation_study.split(',')
        self.lamb = lamb
        self.image_size = 224
        self.crop_length = 16
        self.sample_num_batches = sample_num_batches
        self.max_num_batches = int((224/16)**2)
        assert self.sample_num_batches <= self.max_num_batches

        #: PNA
        if self.ablation_study[2] == '1':
            print ('Using Skip')
            self._register_model()
        else:
            print ('Not Using Skip')
    
    def _register_model(self):   
        def attn_drop_mask_grad(module, grad_in, grad_out, gamma):
            mask = torch.ones_like(grad_in[0]) * gamma
            return (mask * grad_in[0][:], )

        drop_hook_func = partial(attn_drop_mask_grad, gamma=0)

        if self.model_name in ['vit_base_patch16_224' ,'deit_base_distilled_patch16_224']:
                for i in range(12):
                    self.model.blocks[i].attn.attn_drop.register_backward_hook(drop_hook_func)
        elif self.model_name == 'pit_b_224':
            for block_ind in range(13):
                if block_ind < 3:
                    transformer_ind = 0
                    used_block_ind = block_ind
                elif block_ind < 9 and block_ind >= 3:
                    transformer_ind = 1
                    used_block_ind = block_ind - 3
                elif block_ind < 13 and block_ind >= 9:
                    transformer_ind = 2
                    used_block_ind = block_ind - 9
                self.model.transformers[transformer_ind].blocks[used_block_ind].attn.attn_drop.register_backward_hook(drop_hook_func)
        elif self.model_name == 'cait_s24_224':
            for block_ind in range(26):
                if block_ind < 24:
                    self.model.blocks[block_ind].attn.attn_drop.register_backward_hook(drop_hook_func)
                elif block_ind > 24:
                    self.model.blocks_token_only[block_ind-24].attn.attn_drop.register_backward_hook(drop_hook_func)
        elif self.model_name == 'visformer_small':
            for block_ind in range(8):
                if block_ind < 4:
                    self.model.stage2[block_ind].attn.attn_drop.register_backward_hook(drop_hook_func)
                elif block_ind >=4:
                    self.model.stage3[block_ind-4].attn.attn_drop.register_backward_hook(drop_hook_func)

    def _generate_samples_for_interactions(self, perts, seed):
        add_noise_mask = torch.zeros_like(perts)
        grid_num_axis = int(self.image_size/self.crop_length)

        # Unrepeatable sampling
        ids = [i for i in range(self.max_num_batches)]
        random.seed(seed)
        random.shuffle(ids)
        ids = np.array(ids[:self.sample_num_batches])

        # Repeatable sampling
        # ids = np.random.randint(0, self.max_num_batches, size=self.sample_num_batches)
        rows, cols = ids // grid_num_axis, ids % grid_num_axis
        flag = 0
        for r, c in zip(rows, cols):
            add_noise_mask[:,:,r*self.crop_length:(r+1)*self.crop_length,c*self.crop_length:(c+1)*self.crop_length] = 1
        add_perturbation = perts * add_noise_mask
        return add_perturbation

    def forward(self, inps, labels):
        inps = inps.cuda()
        labels = labels.cuda()
        loss = nn.CrossEntropyLoss()

        unnorm_inps = self._mul_std_add_mean(inps)
        perts = torch.zeros_like(unnorm_inps).cuda()
        perts.requires_grad_()

        for i in range(self.steps):
            if self.ablation_study[0] == '1':
                print ('Using Pathes')
                add_perturbation = self._generate_samples_for_interactions(perts, i)
                outputs = self.model((self._sub_mean_div_std(unnorm_inps + add_perturbation)))
            else:
                print ('Not Using Pathes')
                outputs = self.model((self._sub_mean_div_std(unnorm_inps + perts)))

            if self.ablation_study[1] == '1':
                print ('Using L2')
                cost1 = self.loss_flag * loss(outputs, labels).cuda()
                cost2 = torch.norm(perts)
                cost = cost1 + self.lamb * cost2
            else:
                print ('Not Using L2')
                cost = self.loss_flag * loss(outputs, labels).cuda()
            cost.backward()
            grad = perts.grad.data
            perts.data = self._update_perts(perts.data, grad, self.step_size)
            perts.data = torch.clamp(unnorm_inps.data + perts.data, 0.0, 1.0) - unnorm_inps.data
            perts.grad.data.zero_()
        return (self._sub_mean_div_std(unnorm_inps+perts.data)).detach(), None

#~ 这个无法运行，去官方代码
class LinBp(BaseAttack):
    def __init__(self, model_name, steps=10, epsilon=16/255, target=False, decay=1.0, pre_trained=True, weight_path=False):
        super(LinBp, self).__init__('LinBp', model_name, target, pre_trained=pre_trained, weight_path=weight_path)
        self.epsilon = epsilon
        self.steps = steps
        self.step_size = self.epsilon/self.steps
        # self.step_size = 1/255
        self.decay = decay
        self.image_size = 224
        self.sgm_lambda = 1.0
        self.linbp_layer = '3_1'
        
    def forward(self, inps, labels):
        inps = inps.cuda() #(3,224,224)
        labels = labels.cuda()
        loss = nn.CrossEntropyLoss()

        momentum = torch.zeros_like(inps).cuda()
        #: 数据反标准化
        unnorm_inps = self._mul_std_add_mean(inps)
        
        # unnorm_inps = inps

        #: 扰动
        perts = torch.zeros_like(unnorm_inps).cuda()
        perts.requires_grad_()

        for i in range(self.steps):
            #: 标准化对抗样本
            # outputs = self.model((self._sub_mean_div_std(unnorm_inps + perts))) 
            att_out, ori_mask_ls, conv_out_ls, relu_out_ls, conv_input_ls = linbp_forw_resnet50(self.model, self._sub_mean_div_std(unnorm_inps + perts), True, linbp_layer=self.linbp_layer)
            # pred = torch.argmax(att_out, dim=1).view(-1)
            cost = self.loss_flag * loss(att_out, labels).cuda()
            self.model.zero_grad()

            grad = linbp_backw_resnet50(
                self._sub_mean_div_std(unnorm_inps + perts), cost, conv_out_ls, ori_mask_ls, relu_out_ls, conv_input_ls, xp=self.sgm_lambda
            )
            
            # cost.backward()
            # grad = perts.grad.data
            grad = grad / torch.mean(torch.abs(grad), dim=[1,2,3], keepdim=True)

            #: MI
            # grad += momentum*self.decay
            # momentum = grad

            perts.data = self._update_perts(perts.data, grad, self.step_size)
            perts.data = torch.clamp(unnorm_inps.data + perts.data, 0.0, 1.0) - unnorm_inps.data
            perts.grad.data.zero_()
        return (self._sub_mean_div_std(unnorm_inps+perts.data)).detach(), None
    
class Admix(BaseAttack):
    def __init__(self, model_name, steps=10, epsilon=16/255, target=False, decay=1.0, pre_trained=True, weight_path=False,portion=0.2):
        super(Admix, self).__init__('Admix', model_name, target, pre_trained=pre_trained, weight_path=weight_path)
        self.epsilon = epsilon
        self.steps = steps
        self.step_size = self.epsilon/self.steps
        # self.step_size = 1/255
        self.decay = decay
        self.image_size = 224
        self.portion = portion
        
    def forward(self, inps, labels):
        inps = inps.cuda() #(3,224,224)
        labels = labels.cuda()
        loss = nn.CrossEntropyLoss()

        momentum = torch.zeros_like(inps).cuda()
        #: 数据反标准化
        unnorm_inps = self._mul_std_add_mean(inps)
        
        # unnorm_inps = inps

        #: 扰动
        perts = torch.zeros_like(unnorm_inps).cuda()
        # perts.requires_grad_()

        for i in range(self.steps):
            grad = 0
            img_x = unnorm_inps + perts
            for _ in range(3):
                random_indices = list(range(img_x.shape[0]))
                random.shuffle(random_indices)
                for gamma in [1., 1./2, 1./4, 1./8, 1./16]:
                    temp_img = ((img_x + self.portion * img_x[random_indices]) * gamma).detach().requires_grad_()
                    outputs = self.model(self._sub_mean_div_std(temp_img))
                    cost = self.loss_flag * loss(outputs, labels).cuda()
                    cost.backward()
                    grad += temp_img.grad.data
            grad /= 15

            # outputs = self.model((self._sub_mean_div_std(unnorm_inps + perts))) 
            # cost = self.loss_flag * loss(outputs, labels).cuda()
            # cost.backward()
            # grad = perts.grad.data

            grad = grad / torch.mean(torch.abs(grad), dim=[1,2,3], keepdim=True)

            #: MI
            # grad += momentum*self.decay
            # momentum = grad

            perts.data = self._update_perts(perts.data, grad, self.step_size)
            perts.data = torch.clamp(unnorm_inps.data + perts.data, 0.0, 1.0) - unnorm_inps.data
            # perts.grad.data.zero_()
        return (self._sub_mean_div_std(unnorm_inps+perts.data)).detach(), None