from utils.prepare_dataset import get_dataset
from utils.train import normal_train,attack_train,clip,removeSPP_train,remove_attack_train,adversarial_train,\
    RPB_train,mixup_train,L1_train,L1_RPB_train
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from utils.visualization import save_images
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np

class ADGT():
    normal_model=None
    gt_model=None
    improve_model=None
    RPB_model=None

    dataset_name=None
    trainset=None
    testset=None
    trainloader=None
    testloader=None
    signal_estimator=None

    attack=None
    min=None
    max=None
    mu=None
    use_cuda=False
    aug=False
    color=False

    nclass={'MNIST':10,'C10':10,'C100':100,'Flower102':102,'RestrictedImageNet':9,'ImageNet':1000}

    def __init__(self,name='MNIST',nclass=None,use_cuda=False,min=None,max=None,attack=None,normal_model=None,
                 gt_model=None,aug=False):
        self.aug=aug
        self.use_cuda=use_cuda
        self.dataset_name=name
        if nclass is not None:
            self.nclass[name]=nclass
        self.min=min
        self.max=max
        self.attack=attack
        self.normal_model=normal_model
        self.gt_model=gt_model
        return
    def save_gt(self,checkpointdir):
        if not os.path.exists(checkpointdir):  # 如果路径不存在
            os.makedirs(checkpointdir)
        print('save checkpoints to :', checkpointdir)
        torch.save(self.gt_model,os.path.join(checkpointdir,'model.ckpt'))
        np.save(os.path.join(checkpointdir,'min.npy'),self.min.numpy())
        np.save(os.path.join(checkpointdir, 'max.npy'), self.max.numpy())
        np.save(os.path.join(checkpointdir, 'attack.npy'), self.attack.numpy())
    def load_gt(self,checkpointdir):
        model=torch.load(os.path.join(checkpointdir,'model.ckpt'))
        min=np.load(os.path.join(checkpointdir,'min.npy'))
        max=np.load(os.path.join(checkpointdir,'max.npy'))
        attack=np.load(os.path.join(checkpointdir,'attack.npy'))
        self.gt_model=model
        self.min=torch.Tensor(min)
        self.max=torch.Tensor(max)
        self.attack=torch.Tensor(attack)
    def load_normal(self,checkpointdir):
        model = torch.load(os.path.join(checkpointdir, 'model.ckpt'))
        self.normal_model=model
    def load_improve(self,checkpointdir):
        model = torch.load(os.path.join(checkpointdir, 'model.ckpt'))
        self.improve_model=model

    def load_RPB(self,checkpointdir):
        model = torch.load(os.path.join(checkpointdir, 'model.ckpt'))
        self.RPB_model=model
    def prepare_dataset_loader(self,root='../data',transform=transforms.Compose([transforms.ToTensor()]),
                               train=True,batch_size=128,shuffle=True,num_workers=4):
        '''
        Input:

        Output:None
        '''
        name=self.dataset_name
        if train:
            self.trainset=get_dataset(name,root,transform,train)
            self.trainloader=torch.utils.data.DataLoader(self.trainset,batch_size=batch_size,shuffle=shuffle,
                                                         num_workers=num_workers)
        else:
            self.testset=get_dataset(name,root,transform,train)
            self.testloader = torch.utils.data.DataLoader(self.testset, batch_size=batch_size, shuffle=shuffle,
                                                           num_workers=num_workers)

    def normal_train(self,model,logdir,checkpointdir,trainloader=None,testloader=None,optimizer=None,
                     schedule=None,criterion=nn.CrossEntropyLoss(),max_epoch=50,suffix='',img=None,target=None,
                     method=None,save=False,explain=None):
        '''
        Input:

        Output: model
        '''
        if trainloader is None:
            trainloader=self.trainloader
        if testloader is None:
            testloader=self.testloader
        if optimizer is None:
            optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.5, 0.9))
        if self.use_cuda:
            model=model.cuda()
        logdir=os.path.join(logdir,self.dataset_name,'normal'+str(self.aug)+suffix)
        writer = SummaryWriter(log_dir=logdir )
        print('save logs to :', logdir)

        normal_train(model, trainloader, testloader, optimizer,schedule, criterion, max_epoch, writer, self.use_cuda,
                     img,target,method,explain,explain_dir=logdir)
        writer.close()
        if self.normal_model is None:
            self.normal_model=model

        checkpointdir = os.path.join(checkpointdir, self.dataset_name, 'normal'+str(self.aug)+suffix)
        if save:
            if not os.path.exists(checkpointdir):  # 如果路径不存在
                os.makedirs(checkpointdir)
            print('save checkpoints to :', checkpointdir)
            torch.save(model,os.path.join(checkpointdir,'model.ckpt'))
        return model
    def L1_train(self,model,logdir,checkpointdir,trainloader=None,testloader=None,optimizer=None,
                     schedule=None,criterion=nn.CrossEntropyLoss(),max_epoch=50,alpha=0.01):
        '''
        Input:

        Output: model
        '''
        if trainloader is None:
            trainloader=self.trainloader
        if testloader is None:
            testloader=self.testloader
        if optimizer is None:
            optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.5, 0.9))
        if self.use_cuda:
            model=model.cuda()
        logdir=os.path.join(logdir,self.dataset_name,'L1_'+str(alpha)+str(self.aug))
        writer = SummaryWriter(log_dir=logdir )
        print('save logs to :', logdir)

        L1_train(model, trainloader, testloader, optimizer,schedule, criterion, max_epoch, writer, self.use_cuda,alpha)
        writer.close()

        checkpointdir = os.path.join(checkpointdir, self.dataset_name, 'L1_'+str(alpha)+str(self.aug))

        if not os.path.exists(checkpointdir):  # 如果路径不存在
            os.makedirs(checkpointdir)
        print('save checkpoints to :', checkpointdir)
        torch.save(model,os.path.join(checkpointdir,'model.ckpt'))
        return model
    def mixup_train(self,model,logdir,checkpointdir,trainloader=None,testloader=None,optimizer=None,
                     schedule=None,criterion=nn.CrossEntropyLoss(),max_epoch=50,alpha=1):
        '''
        Input:

        Output: model
        '''
        if trainloader is None:
            trainloader=self.trainloader
        if testloader is None:
            testloader=self.testloader
        if optimizer is None:
            optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.5, 0.9))
        if self.use_cuda:
            model=model.cuda()
        logdir=os.path.join(logdir,self.dataset_name, 'mixup_'+str(alpha)+str(self.aug))
        writer = SummaryWriter(log_dir=logdir )
        print('save logs to :', logdir)

        mixup_train(model, trainloader, testloader, optimizer,schedule, criterion, max_epoch, writer, alpha,self.use_cuda)
        writer.close()
        checkpointdir = os.path.join(checkpointdir, self.dataset_name, 'mixup_'+str(alpha)+str(self.aug))

        if not os.path.exists(checkpointdir):  # 如果路径不存在
            os.makedirs(checkpointdir)
        print('save checkpoints to :', checkpointdir)
        torch.save(model,os.path.join(checkpointdir,'model.ckpt'))
        return model
    def attack_train(self,model,logdir,checkpointdir,trainloader=None,testloader=None,optimizer=None,
                     schedule=None,criterion=nn.CrossEntropyLoss(),max_epoch=50,inject_num=1,random=False,alpha=0.0,suffix='',
                     img=None,target=None,method=None,save=False,explain=None):
        '''
        Input:

        Output: model
        '''
        self.inject_num=inject_num
        if trainloader is None:
            trainloader=self.trainloader
        if testloader is None:
            testloader=self.testloader
        if optimizer is None:
            optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.5, 0.9))
        if self.use_cuda:
            model=model.cuda()

        if self.min is None:
            self.obtain_statistics()
        if self.attack is None:
            if not random:
                self.obtain_attack(inject_num=inject_num,alpha=alpha)
            else:
                self.random_attack(inject_num=inject_num)
        if random:
            r='_random'
        else:
            r=''
        logdir=os.path.join(logdir,self.dataset_name,'attack_'+str(inject_num)+'_'+str(alpha)+r+str(self.aug)+suffix)
        writer = SummaryWriter(log_dir=logdir )
        print('save logs to :', logdir)

        attack_train(model, trainloader, testloader, optimizer,schedule, criterion, max_epoch, writer, self.use_cuda,
                     self.attack,self.min,self.max,img,target,method,explain,explain_dir=logdir)
        writer.close()
        if self.gt_model is None:
            self.gt_model=model

        if save:
            checkpointdir = os.path.join(checkpointdir, self.dataset_name, 'attack_'+str(inject_num)+'_'+str(alpha)+r+str(self.aug)+suffix)

            self.save_gt(checkpointdir)
        return model

    def attack_img(self,img,label):
        attack_temp = self.attack[label]
        img = clip(img + attack_temp, self.min, self.max)
        return img

    def obtain_statistics(self):
        K=self.nclass[self.dataset_name]
        mu=None
        X2=None
        num = None  # numbers of samples except class
        mu_in=None
        X2_in=None
        num_in=None
        print('obtain statistics')
        for data, label in self.trainloader:
            C,H,W=data.size(1),data.size(2),data.size(3)
            self.channels,self.heights,self.width=C,H,W
            data_temp = data.permute([1, 0, 2, 3])
            data_temp = data_temp.reshape(C, -1)
            if self.min is None:
                self.min=torch.min(data_temp,1)[0].view(1,-1,1,1)
                self.max=torch.max(data_temp,1)[0].view(1,-1,1,1)
            else:
                m1=torch.cat([data_temp,self.min.view(-1,1)],1)
                m2=torch.cat([data_temp,self.max.view(-1,1)],1)
                self.min = torch.min(m1, 1)[0].view(1, -1, 1, 1)
                self.max = torch.max(m2, 1)[0].view(1, -1, 1, 1)

            if mu is None:
                mu=torch.zeros(K,C,H,W)
                X2=torch.zeros(K,C,H,W)
                num=torch.zeros(K,1,1,1)
                mu_in=torch.zeros(K,C,H,W)
                X2_in = torch.zeros(K, C, H, W)
                num_in = torch.zeros(K, 1, 1, 1)

            for i in range(K):
                temp=data[label!=i]
                mu[i]=mu[i]+torch.sum(temp,0,keepdim=True)
                X2[i]=X2[i]+torch.sum(temp**2,0,keepdim=True)
                num[i]+=temp.size(0)

                temp_in = data[label == i]
                if temp_in.size(0)>0:
                    mu_in[i] = mu[i] + torch.sum(temp_in, 0, keepdim=True)
                    X2_in[i] = X2[i] + torch.sum(temp_in ** 2, 0, keepdim=True)
                    num_in[i] += temp_in.size(0)

        self.mu=mu/num
        X2=X2/num
        self.var=X2-self.mu**2

        self.mu_in = mu_in / num_in
        X2_in = X2_in / num_in
        self.var_in = X2_in - self.mu_in ** 2

        print('min:',self.min,'max:',self.max)
        print('mean:',self.mu)
        print('var:',self.var)

        self.right_prob = torch.zeros(K, C, H, W, 2)
        epsilon=1e-4
        for data, label in self.trainloader:
            for i in range(K):
                temp_in = data[label == i]
                if temp_in.size(0) > 0:
                    temp_min=torch.sign(F.relu(self.min-temp_in+epsilon))
                    temp_max=torch.sign(F.relu(temp_in+epsilon-self.max))
                    self.right_prob[i,:,:,:,0]+=torch.sum(temp_min,0)
                    self.right_prob[i, :, :, :, 1] += torch.sum(temp_max, 0)
        self.right_prob=self.right_prob/num_in.view(K,1,1,1,1)
    def parallel(self):
        if self.normal_model is not None:
            self.normal_model=nn.DataParallel(self.normal_model)
        if self.gt_model is not None:
            self.gt_model=nn.DataParallel(self.gt_model)
        if self.improve_model is not None:
            self.improve_model=nn.DataParallel(self.improve_model)
        if self.RPB_model is not None:
            self.RPB_model=nn.DataParallel(self.RPB_model)
    def random_attack(self,inject_num=1):
        from scipy.stats import norm
        K,C,H,W=self.nclass[self.dataset_name],self.channels,self.heights,self.width
        self.attack = torch.zeros(K, C, H, W)
        jilu = torch.zeros(K)
        pan = torch.zeros(1, C, H, W, 2)
        print('find attack position ...')
        i=now=0
        maxnorm = (self.max - self.min).squeeze()
        while i<inject_num*K:
            index=int(np.random.rand()*K*C*H*W*2)
            now+=1
            n4=index %2
            index=int(index/2)
            n3=index % W
            index=int(index/W)
            n2=index %H
            index=int(index/H)
            n1=index % C
            index=int(index/C)
            n0=index
            if jilu[n0]<inject_num and pan[0,n1,n2,n3,n4]==0:
                jilu[n0]+=1
                pan[0, n1, n2, n3, n4]=1
                if n4==0:
                    self.attack[n0,n1,n2,n3]=-maxnorm[n1]*2
                else:
                    self.attack[n0, n1, n2, n3] = maxnorm[n1] * 2
                i+=1
                print('class',n0,'now',now)
        print(self.attack)


    def obtain_attack(self,inject_num=1,alpha=0.5):
        from scipy.stats import norm
        K,C,H,W=self.nclass[self.dataset_name],self.channels,self.heights,self.width
        eloss=torch.zeros(K,C,H,W,2)
        # 0: min 1:max
        sigma=torch.sqrt(self.var)+1e-8
        mu=self.mu
        T = 1 / (np.sqrt(2 * np.pi))

        t0=-(mu-self.min)/sigma
        phi0=torch.Tensor(norm.cdf(t0.numpy()))
        eloss[:,:,:,:,0]=sigma*(T*torch.exp(-0.5*t0**2)+t0*phi0)

        t1 = -(mu - self.max) / sigma
        phi1 = torch.Tensor(norm.cdf(t1.numpy()))
        eloss[:, :, :, :, 1] = -sigma * (-T * torch.exp(-0.5 * t1 ** 2) + t1 * (1-phi1))

        maxnorm=(self.max-self.min).view(-1)

        self.attack=torch.zeros(K,C,H,W)
        jilu=torch.zeros(K)
        pan=torch.zeros(1,C,H,W,2)
        pan[:,:,:,:,0]=1
        #========================================
        right_prob=self.right_prob
        '''
        right_prob=torch.zeros(K,C,H,W,2)
        sigma_in=torch.sqrt(self.var_in)+1e-8
        mu_in=self.mu_in
        t_in0 = -(mu_in - self.min) / sigma_in
        phi_in0 = torch.Tensor(norm.cdf(t_in0.numpy()))
        right_prob[:,:,:,:,0]=phi_in0

        t_in1 = -(mu_in - self.max) / sigma_in
        phi_in1 = torch.Tensor(norm.cdf(t_in1.numpy()))
        right_prob[:, :, :, :, 1] = 1-phi_in1
        '''
        #====================
        value=eloss*alpha+(1-alpha)*right_prob
        value_temp=value.view(-1)

        i=now=0
        print('sort start')
        sorted, indices = torch.sort(value_temp, descending=False)
        print('find attack position ...')
        while i<inject_num*K:
            index=indices[now]
            now+=1
            n4 = index % 2
            index = index//2
            n3 = index % W
            index = index// W
            n2 = index % H
            index = index// H
            n1 = index % C
            index = index// C
            n0 = index
            if jilu[n0]<inject_num and pan[0,n1,n2,n3,n4]==0:
                jilu[n0]+=1
                #pan[0,n1,n2,n3,n4]=1
                pan[0, :, n2, n3, n4] = 1
                if n4==0:
                    self.attack[n0,n1,n2,n3]=-maxnorm[n1]*2
                else:
                    self.attack[n0, n1, n2, n3] = maxnorm[n1] * 2
                i+=1
                print('class',n0,'now',now)
        print(self.attack)

    def RPB_train(self,model,logdir,checkpointdir,trainloader=None,testloader=None,optimizer=None,
                     schedule=None,criterion=nn.CrossEntropyLoss(),max_epoch=50,prob=0.5,alpha=0.2,point_size=1,suffix='',
                  img=None,target=None,method=None,save=False,explain=None):
        '''
        Input:

        Output: model
        '''
        if trainloader is None:
            trainloader=self.trainloader
        if testloader is None:
            testloader=self.testloader
        if optimizer is None:
            optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.5, 0.9))
        if self.use_cuda:
            model=model.cuda()

        logdir=os.path.join(logdir,self.dataset_name,'RPB_'+str(prob)+'_'+str(alpha)+'_'+str(point_size)+str(self.aug)+suffix)
        writer = SummaryWriter(log_dir=logdir )
        print('save logs to :', logdir)
        from model.resnet_RPB import RPB
        rpb=RPB(prob=alpha,point_size=point_size)
        RPB_train(model,rpb, trainloader, testloader, optimizer,schedule, criterion, max_epoch, writer, self.use_cuda,
                  prob,alpha,img,target,method,explain,explain_dir=logdir)
        writer.close()
        if self.RPB_model is None:
            self.RPB_model=model
        checkpointdir = os.path.join(checkpointdir, self.dataset_name, 'RPB_'+str(prob)+'_'+str(alpha)+'_'
                                     +str(point_size)+str(self.aug)+suffix)

        if save:
            if not os.path.exists(checkpointdir):  # 如果路径不存在
                os.makedirs(checkpointdir)
            print('save checkpoints to :', checkpointdir)
            torch.save(model,os.path.join(checkpointdir,'model.ckpt'))
        return model

    def adversarial_train(self,model,logdir,checkpointdir,trainloader=None,testloader=None,optimizer=None,
                     schedule=None,criterion=nn.CrossEntropyLoss(),max_epoch=50,perturbation_type='l2',eps=0.3):
        '''
        Input:

        Output: model
        '''
        if trainloader is None:
            trainloader=self.trainloader
        if testloader is None:
            testloader=self.testloader
        if optimizer is None:
            optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.5, 0.9))
        if self.use_cuda:
            model=model.cuda()
        logdir=os.path.join(logdir,self.dataset_name,'adversarial_'+perturbation_type+'_'+str(eps)+str(self.aug))
        writer = SummaryWriter(log_dir=logdir )
        print('save logs to :', logdir)
        if self.min is None:
            self.obtain_statistics()
        adversarial_train(model, trainloader, testloader, optimizer,schedule, criterion, max_epoch, writer,
                          self.use_cuda,self.min,self.max,perturbation_type,eps)
        writer.close()

        checkpointdir = os.path.join(checkpointdir, self.dataset_name, 'adversarial_'+perturbation_type+'_'+str(eps)+str(self.aug))

        if not os.path.exists(checkpointdir):  # 如果路径不存在
            os.makedirs(checkpointdir)
        print('save checkpoints to :', checkpointdir)
        torch.save(model,os.path.join(checkpointdir,'model.ckpt'))
        return model

    def pure_explain(self,img,model,method,file_name=None,color=False,suffix='', grad=False,top1=None):
        pred = model(img)
        K=4
        if pred.size(1)>=K:
            _, topklabel = torch.topk(pred, K)
        else:
            _, topklabel = torch.topk(pred, 1)
        #print(topklabel)
        def split(mask,k=K,top1=0):
            temp=mask.view(k,-1,mask.size(1),mask.size(2),mask.size(3))
            R=temp.cuda()
            B=torch.mean(R,0,keepdim=True)
            #B = (torch.sum(R, 0, keepdim=True)-R[top1])/(k-1)
            #print(R[:,top1].size())
            C = R - B
            return B.view(-1,mask.size(1),mask.size(2),mask.size(3)),\
                   C.view(-1,mask.size(1),mask.size(2),mask.size(3))
        def obtain_explain(alg, minus_mean=False):
            obj = alg.Explainer(model,nclass=self.nclass[self.dataset_name])
            result=[]
            if minus_mean:
                if self.nclass[self.dataset_name]<K:
                    for i in range(self.nclass[self.dataset_name]):
                        templabel=torch.ones_like(topklabel[:,0].squeeze())*i
                        mask = obj.get_attribution_map(img.clone(), templabel)
                        if not color:
                            mask = torch.mean(mask, 1, keepdim=True)
                        if mask.requires_grad and not grad:
                            mask=mask.detach()
                        result.append(mask)
                else:
                    for i in range(0,topklabel.size(1)):
                        templabel=topklabel[:,i]
                        mask = obj.get_attribution_map(img.clone(), templabel)
                        if not color:
                            mask = torch.mean(mask, 1, keepdim=True)#.cpu()+torch.zeros_like(img).cpu()
                        if mask.requires_grad and not grad:
                            mask = mask.detach()
                        result.append(mask)

                result=torch.cat(tuple(result), 0)
                B,result=split(result,K)
                if top1 is not None:
                    final = obj.get_attribution_map(img.clone(), top1)
                    result=final-B
                    #result=result[top1*img.size(0):(top1+1)*img.size(0)]
                else:
                    result=result[:img.size(0)]
            else:
                if top1 is None:
                    result=obj.get_attribution_map(img.clone(),topklabel[:,0].squeeze())
                else:
                    result = obj.get_attribution_map(img.clone(), top1)
            return result
        if method=='GradientSHAP':
            from attribution_methods import GradientSHAP
            mask=obtain_explain(GradientSHAP)
        elif method=='DeepLIFTSHAP':
            from attribution_methods import DeepLIFTSHAP
            mask=obtain_explain(DeepLIFTSHAP)
        elif method=='Guided_BackProp':
            from attribution_methods import Guided_BackProp
            mask=obtain_explain(Guided_BackProp)
        elif method=='DeepLIFT':
            from attribution_methods import DeepLIFT
            mask=obtain_explain(DeepLIFT)
        elif method=='IntegratedGradients':
            from attribution_methods import IntegratedGradients
            mask=obtain_explain(IntegratedGradients)
        elif method=='InputXGradient':
            from attribution_methods import InputXGradient
            mask=obtain_explain(InputXGradient)
        elif method == 'Occlusion':
            from attribution_methods import Occlusion
            mask = obtain_explain(Occlusion)
        elif method == 'Saliency':
            from attribution_methods import Saliency
            mask = obtain_explain(Saliency)
        elif method=='GradCAM':
            from attribution_methods import Grad_CAM,Grad_CAM_batch
            mask= obtain_explain(Grad_CAM)
            #mask, mask_random = obtain_explain(Grad_CAM_batch, random)
        elif method=='SmoothGrad':
            from attribution_methods import SmoothGrad
            mask = obtain_explain(SmoothGrad)
        elif method=='RectGrad':
            from attribution_methods import RectGrad
            mask = obtain_explain(RectGrad)
        elif method=='FullGrad':
            from attribution_methods import Full_Grad
            mask = obtain_explain(Full_Grad)
        elif method=='FGour':
            from attribution_methods import FGour
            mask = obtain_explain(FGour)
        elif method=='Our_GBP':
            from attribution_methods import Our_GBP
            mask = obtain_explain(Our_GBP)
        elif method=='our_method':
            from attribution_methods import our_method
            mask = obtain_explain(our_method)
        elif method=='our_method_minus_mean':
            from attribution_methods import our_method
            mask = obtain_explain(our_method,minus_mean=True)
        elif method=='our+':
            from attribution_methods import our_plus
            mask = obtain_explain(our_plus,minus_mean=True)
        elif method=='our_no_input':
            from attribution_methods import our_no_input
            mask = obtain_explain(our_no_input, minus_mean=True)
        elif method=='our_fullgrad':
            from attribution_methods import our_fullgrad
            mask = obtain_explain(our_fullgrad)
        elif method=='random':
            return torch.mean(torch.rand_like(img),1,keepdim=True)
        elif method=='CAMERAS':
            from attribution_methods import CAMERAS
            mask = obtain_explain(CAMERAS)
        elif method == 'GIG':
            from attribution_methods import GIG
            mask = obtain_explain(GIG)
        else:
            print(method)
            print('no this method')
        if file_name is not None:
            temp=mask.detach().cpu().numpy()
            if not os.path.exists(file_name):
                os.mkdir(file_name)
            save_images(temp, os.path.join(file_name, method+suffix+'.png'))
        if not grad:
            mask=mask.detach()
        return mask
    def explain_our_full(self,img,model,file_name=None,color=False,suffix=''):
        pred = model(img)
        if pred.size(1)>=5:
            _, topklabel = torch.topk(pred, 5)
        else:
            _, topklabel = torch.topk(pred, 1)
        #print(topklabel)
        def split(mask,k):
            #mask=torch.Tensor(mask)
            temp=mask.view(-1,k,mask.size(1),mask.size(2),mask.size(3))
            #R=temp[1:].cuda()
            R=temp.cuda()
            #B=torch.rand(1,k,mask.size(1),mask.size(2),mask.size(3)).cuda()
            B=torch.mean(R,0,keepdim=True)/0.5
            B=torch.nn.Parameter(B)
            a=torch.nn.Parameter(torch.zeros(R.size(0),k,1,1,1).cuda())
            #a = torch.nn.Parameter(torch.ones(R.size(0), k, 1, 1, 1).cuda())
            #from utils.AdamW import AdamW

            #optimizer = AdamW([a,B], lr=1e-4, betas=(0.5, 0.9), weight_decay=0)
            #schedule=torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 100, eta_min=0, last_epoch=-1)
            '''
            for i in range(10000):
                loss = torch.mean(torch.abs(torch.sigmoid(a) * B - R))
                #loss = torch.mean(torch.abs(a * B - R))
                #loss = torch.mean((torch.sigmoid(a) * B - R)**2)
                #loss=torch.mean(torch.abs(torch.log(1+torch.exp(a))*B-R))
                #loss = torch.mean((torch.log(1 + torch.exp(a)) * B - R)**2)
                loss.backward()
                optimizer.step()
                if i %2000==0:
                    print(i,loss.item())
                #schedule.step()
            '''
            #C=R-torch.log(1+torch.exp(a))*B
            C = R - torch.sigmoid(a) * B
            #C = R - a * B
            #print(temp[0].size(),C.size())
            #C=torch.cat((temp[0],C))
            return B.view(-1,mask.size(1),mask.size(2),mask.size(3)).data,\
                   C.view(-1,mask.size(1),mask.size(2),mask.size(3)).data
        def obtain_explain(alg):
            obj = alg.Explainer(model,nclass=self.nclass[self.dataset_name])
            result=[[]]
            if self.nclass[self.dataset_name]<5:
                for i in range(self.nclass[self.dataset_name]):
                    templabel=torch.ones_like(topklabel[:,0].squeeze())*i
                    masks = obj.get_attribution_map(img.clone(), templabel,no_aggr=True)
                    for j,mask in enumerate(masks):
                        if not color:
                            mask = torch.mean(torch.abs(mask), 1, keepdim=True)#.cpu()+torch.zeros_like(img).cpu()
                        if mask.requires_grad:
                            mask=mask.detach()
                        #mask = mask.cpu().numpy()
                        if len(result)<=j:
                            result.append([])
                        result[i].append(mask)
            else:
                for i in range(topklabel.size(1)):
                    templabel=topklabel[:,i]
                    masks = obj.get_attribution_map(img.clone(), templabel,no_aggr=True)
                    for j, mask in enumerate(masks):
                        if not color:
                            mask = torch.mean(torch.abs(mask), 1, keepdim=True)  # .cpu()+torch.zeros_like(img).cpu()
                        if mask.requires_grad:
                            mask = mask.detach()
                        # mask = mask.cpu().numpy()
                        #print(len(result),j)
                        if len(result) <= j:
                            result.append([])
                        result[j].append(mask)
            #result = np.concatenate(tuple(result), 0)
            masks=[]
            for i,m in enumerate(result):
                temp=torch.cat(tuple(result[i]), 0)
                _,temp=split(temp,img.size(0))
                temp=temp[:img.size(0)]
                masks.append(temp)
            return masks

        from attribution_methods import our_method
        mask = obtain_explain(our_method)

        if file_name is not None:
            if not os.path.exists(file_name):
                os.mkdir(file_name)
            for i in range(len(mask)):
                save_images(mask[i].detach().cpu().numpy(), os.path.join(file_name, str(i)+'_'+suffix+'.png'))
        return mask
    def explain_top5(self,img,model,method,file_name=None,color=False):
        pred = model(img)
        _, topklabel = torch.topk(pred, 5)
        #print(topklabel)
        def split(mask,k):
            #mask=torch.Tensor(mask)
            temp=mask.view(-1,k,mask.size(1),mask.size(2),mask.size(3))
            #R=temp[1:].cuda()
            R=temp.cuda()
            #B=torch.rand(1,k,mask.size(1),mask.size(2),mask.size(3)).cuda()
            B=torch.mean(R,0,keepdim=True)/0.5
            B=torch.nn.Parameter(B)
            a=torch.nn.Parameter(torch.zeros(R.size(0),k,1,1,1).cuda())
            #a = torch.nn.Parameter(torch.ones(R.size(0), k, 1, 1, 1).cuda())
            #from utils.AdamW import AdamW

            #optimizer = AdamW([a,B], lr=1e-4, betas=(0.5, 0.9), weight_decay=0)
            #schedule=torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 100, eta_min=0, last_epoch=-1)
            '''
            for i in range(10000):
                loss = torch.mean(torch.abs(torch.sigmoid(a) * B - R))
                #loss = torch.mean(torch.abs(a * B - R))
                #loss = torch.mean((torch.sigmoid(a) * B - R)**2)
                #loss=torch.mean(torch.abs(torch.log(1+torch.exp(a))*B-R))
                #loss = torch.mean((torch.log(1 + torch.exp(a)) * B - R)**2)
                loss.backward()
                optimizer.step()
                if i %2000==0:
                    print(i,loss.item())
                #schedule.step()
            '''
            #C=R-torch.log(1+torch.exp(a))*B
            C = R - torch.sigmoid(a) * B
            #C = R - a * B
            #print(temp[0].size(),C.size())
            #C=torch.cat((temp[0],C))
            return B.view(-1,mask.size(1),mask.size(2),mask.size(3)).data,\
                   C.view(-1,mask.size(1),mask.size(2),mask.size(3)).data
        def obtain_explain(alg, minus_mean=False):
            obj = alg.Explainer(model,nclass=self.nclass[self.dataset_name])
            result=[]
            if True:
                if self.nclass[self.dataset_name]<5:
                    for i in range(self.nclass[self.dataset_name]):
                        templabel=torch.ones_like(topklabel[:,0].squeeze())*i
                        mask = obj.get_attribution_map(img.clone(), templabel)
                        if not color:
                            mask = torch.mean(torch.abs(mask), 1, keepdim=True)#.cpu()+torch.zeros_like(img).cpu()
                        if mask.requires_grad:
                            mask=mask.detach()
                        #mask = mask.cpu().numpy()
                        result.append(mask)
                else:
                    for i in range(topklabel.size(1)):
                        templabel=topklabel[:,i]
                        mask = obj.get_attribution_map(img.clone(), templabel)
                        if not color:
                            mask = torch.mean(torch.abs(mask), 1, keepdim=True)#.cpu()+torch.zeros_like(img).cpu()
                        if mask.requires_grad:
                            mask = mask.detach()
                        #mask = mask.cpu().numpy()
                        result.append(mask)
                #result = np.concatenate(tuple(result), 0)
                result=torch.cat(tuple(result), 0)
                if minus_mean:
                    _,result=split(result,img.size(0))
            return result
        if method=='GradientSHAP':
            from attribution_methods import GradientSHAP
            mask=obtain_explain(GradientSHAP)
        elif method=='DeepLIFTSHAP':
            from attribution_methods import DeepLIFTSHAP
            mask=obtain_explain(DeepLIFTSHAP)
        elif method=='Guided_BackProp':
            from attribution_methods import Guided_BackProp
            mask=obtain_explain(Guided_BackProp)
        elif method=='DeepLIFT':
            from attribution_methods import DeepLIFT
            mask=obtain_explain(DeepLIFT)
        elif method=='IntegratedGradients':
            from attribution_methods import IntegratedGradients
            mask=obtain_explain(IntegratedGradients)
        elif method=='InputXGradient':
            from attribution_methods import InputXGradient
            mask=obtain_explain(InputXGradient)
        elif method == 'Occlusion':
            from attribution_methods import Occlusion
            mask = obtain_explain(Occlusion)
        elif method == 'Saliency':
            from attribution_methods import Saliency
            mask = obtain_explain(Saliency)
        elif method=='GradCAM':
            from attribution_methods import Grad_CAM,Grad_CAM_batch
            mask= obtain_explain(Grad_CAM)
            #mask, mask_random = obtain_explain(Grad_CAM_batch, random)
        elif method=='SmoothGrad':
            from attribution_methods import SmoothGrad
            mask = obtain_explain(SmoothGrad)
        elif method=='RectGrad':
            from attribution_methods import RectGrad
            mask = obtain_explain(RectGrad)
        elif method=='FullGrad':
            from attribution_methods import Full_Grad
            mask = obtain_explain(Full_Grad)
        elif method=='FGour':
            from attribution_methods import FGour
            mask = obtain_explain(FGour)
        elif method=='Our_GBP':
            from attribution_methods import Our_GBP
            mask = obtain_explain(Our_GBP)
        elif method=='our_method':
            from attribution_methods import our_method
            mask = obtain_explain(our_method)
        elif method=='our_method_minus_mean':
            from attribution_methods import our_method
            mask = obtain_explain(our_method,minus_mean=True)
        elif method=='our_no_input':
            from attribution_methods import our_no_input
            mask = obtain_explain(our_no_input, minus_mean=True)
        elif method=='our_fullgrad':
            from attribution_methods import our_fullgrad
            mask = obtain_explain(our_fullgrad)
        elif method=='random':
            return torch.mean(torch.rand_like(img),1,keepdim=True)
        else:
            print(method)
            print('no this method')
        if file_name is not None:
            temp=mask.detach().cpu().numpy()
            if not os.path.exists(file_name):
                os.mkdir(file_name)
            save_images(temp, os.path.join(file_name, method+'.png'))
        return mask
    def get_mask(self,img,model,method,label,topklabel=None,train_loader=None):
        color=self.color
        random=None
        if topklabel is None:
            pred = model(img)
            _, topklabel = torch.topk(pred, 5)
            print(topklabel)
        def obtain_explain(alg, random=None, train_loader=None):
            if train_loader is None:
                obj = alg.Explainer(model,nclass=self.nclass[self.dataset_name])
            else:
                obj = alg.Explainer(model, train_loader)
            #result=img.clone().cpu().numpy()
            result=[]
            #print('resultshape',result.shape)
            if self.nclass[self.dataset_name]<15:
                for i in range(self.nclass[self.dataset_name]):
                    templabel=torch.ones_like(label)*i
                    mask = obj.get_attribution_map(img.clone(), templabel)
                    if not color:
                        mask = torch.mean(torch.abs(mask), 1, keepdim=True)#.cpu()+torch.zeros_like(img).cpu()
                    if mask.requires_grad:
                        mask=mask.detach()
                    mask = mask.cpu().numpy()
                    result.append(mask)
            else:
                for i in range(topklabel.size(1)):
                    templabel=topklabel[:,i]
                    mask = obj.get_attribution_map(img.clone(), templabel)
                    if not color:
                        mask = torch.mean(torch.abs(mask), 1, keepdim=True)#.cpu()+torch.zeros_like(img).cpu()
                    if mask.requires_grad:
                        mask = mask.detach()
                    mask = mask.cpu().numpy()
                    result.append(mask)
            result = np.concatenate(tuple(result), 0)
            if False:
                condition=obj.get_condition(img.clone())
                condition=condition.cpu().numpy()
                print('obtain condition from all class')
            else:
                condition=None
            #print(result.shape,condition.shape)
            return result,condition
        if method=='GradientSHAP':
            from attribution_methods import GradientSHAP
            mask,mask_random=obtain_explain(GradientSHAP,random)
        elif method=='DeepLIFTSHAP':
            from attribution_methods import DeepLIFTSHAP
            mask,mask_random=obtain_explain(DeepLIFTSHAP,random)
        elif method=='Guided_BackProp':
            from attribution_methods import Guided_BackProp
            mask,mask_random=obtain_explain(Guided_BackProp, random)
        elif method=='DeepLIFT':
            from attribution_methods import DeepLIFT
            mask,mask_random=obtain_explain(DeepLIFT, random)
        elif method=='IntegratedGradients':
            from attribution_methods import IntegratedGradients
            #model=nn.DataParallel(model)
            mask,mask_random=obtain_explain(IntegratedGradients, random)
        elif method=='InputXGradient':
            from attribution_methods import InputXGradient
            mask,mask_random=obtain_explain(InputXGradient, random)
        elif method == 'Occlusion':
            from attribution_methods import Occlusion
            mask, mask_random = obtain_explain(Occlusion, random)
        elif method == 'Saliency':
            from attribution_methods import Saliency
            mask, mask_random = obtain_explain(Saliency, random)
        elif method=='GradCAM':
            from attribution_methods import Grad_CAM,Grad_CAM_batch
            mask, mask_random = obtain_explain(Grad_CAM, random)
            #mask, mask_random = obtain_explain(Grad_CAM_batch, random)
        elif method=='SmoothGrad':
            from attribution_methods import SmoothGrad
            mask, mask_random = obtain_explain(SmoothGrad, random)
        elif method=='RectGrad':
            from attribution_methods import RectGrad
            mask, mask_random = obtain_explain(RectGrad, random)
        elif method=='PatternNet':
            from attribution_methods import PatternNet
            if self.signal_estimator is None:
                self.signal_estimator = PatternNet.SignalEstimator(model)
                if self.trainloader is None:
                    self.trainloader=train_loader
                self.signal_estimator.train_explain(self.trainloader)
            mask, mask_random = obtain_explain(PatternNet, random,self.signal_estimator)
        elif method=='AIR':
            from attribution_methods import AttrInvRec
            mask, mask_random = obtain_explain(AttrInvRec, random)
        elif method=='FullGrad':
            from attribution_methods import Full_Grad
            mask, mask_random = obtain_explain(Full_Grad, random)
        elif method=='FGour':
            from attribution_methods import FGour
            mask, mask_random = obtain_explain(FGour, random)
        elif method=='Our_GBP':
            from attribution_methods import Our_GBP
            mask, mask_random = obtain_explain(Our_GBP, random)
        elif method=='our_method':
            from attribution_methods import our_method
            mask, mask_random = obtain_explain(our_method, random)
        elif method=='our_fullgrad':
            from attribution_methods import our_fullgrad
            mask, mask_random = obtain_explain(our_fullgrad, random)
        else:
            print(method)
            print('no this method')
        return mask,mask_random
    def explain(self,img,label,logdir=None,model=None,method='GradientSHAP',attack=True,random=False,improve=False,suffix='',train_loader=None):
        '''
        input:
        img: batch X channels X height X width [BCHW], torch Tensor

        output:
        attribution_map: batch X height X width,numpy
        '''
        if not attack:
            if model is None:
                if improve:
                    model=self.improve_model
                else:
                    model=self.normal_model
        else:
            if model is None:
                if improve:
                    model = self.improve_model
                else:
                    model = self.gt_model
            img=self.attack_img(img,label)

        def weights_init(m):
            classname = m.__class__.__name__

            # print(classname)
            if classname.find('Conv') != -1:
                nn.init.xavier_normal_(m.weight.data)
            elif classname.find('Linear') != -1:
                nn.init.xavier_normal_(m.weight.data)
                m.bias.data.fill_(0)
        if random:
            import copy
            random_model=copy.deepcopy(model)
            random_model.apply(weights_init)

        if self.use_cuda:
            img=img.cuda()
            label=label.cuda()

        def obtain_explain(alg,random,appendix=None):
            if appendix is None:
                obj = alg.Explainer(model)
            else:
                obj = alg.Explainer(model,appendix)
            mask = obj.get_attribution_map(img, label)
            #mask = torch.mean(mask, 1, keepdim=True)
            if mask.requires_grad:
                mask=mask.detach()
            mask = mask.cpu().numpy()
            mask_random=None
            if random:
                obj = alg.Explainer(random_model)
                mask_random = obj.get_attribution_map(img, label)
                #mask_random = torch.mean(mask_random, 1, keepdim=True)
                if mask_random.requires_grad:
                    mask_random = mask_random.detach()
                mask_random = mask_random.cpu().numpy()
            return mask,mask_random
        model=model.eval()
        if method=='GradientSHAP':
            from attribution_methods import GradientSHAP
            mask,mask_random=obtain_explain(GradientSHAP,random)
        elif method=='DeepLIFTSHAP':
            from attribution_methods import DeepLIFTSHAP
            mask,mask_random=obtain_explain(DeepLIFTSHAP,random)
        elif method=='Guided_BackProb':
            from attribution_methods import Guided_BackProp
            mask,mask_random=obtain_explain(Guided_BackProp, random)
        elif method=='DeepLIFT':
            from attribution_methods import DeepLIFT
            mask,mask_random=obtain_explain(DeepLIFT, random)
        elif method=='IntegratedGradients':
            from attribution_methods import IntegratedGradients
            mask,mask_random=obtain_explain(IntegratedGradients, random)
        elif method=='InputXGradient':
            from attribution_methods import InputXGradient
            mask,mask_random=obtain_explain(InputXGradient, random)
        elif method == 'Occlusion':
            from attribution_methods import Occlusion
            mask, mask_random = obtain_explain(Occlusion, random)
        elif method == 'Saliency':
            from attribution_methods import Saliency
            mask, mask_random = obtain_explain(Saliency, random)
        elif method=='GradCAM':
            from attribution_methods import Grad_CAM
            mask, mask_random = obtain_explain(Grad_CAM, random)
        elif method=='SmoothGrad':
            from attribution_methods import SmoothGrad
            mask, mask_random = obtain_explain(SmoothGrad, random)
        elif method=='RectGrad':
            from attribution_methods import RectGrad
            mask, mask_random = obtain_explain(RectGrad, random)
        elif method=='PatternNet':
            from attribution_methods import PatternNet
            if self.signal_estimator is None:
                self.signal_estimator = PatternNet.SignalEstimator(model)
                if self.trainloader is None:
                    self.trainloader=train_loader
                self.signal_estimator.train_explain(self.trainloader)
            mask, mask_random = obtain_explain(PatternNet, random,self.signal_estimator)
        elif method=='AIR':
            from attribution_methods import AttrInvRec
            mask, mask_random = obtain_explain(AttrInvRec, random)
        else:
            print('no this method')

        if logdir is not None:
            if not os.path.exists(os.path.join(logdir,method+suffix)):  # 如果路径不存在
                os.makedirs(os.path.join(logdir,method+suffix))
            if img.requires_grad:
                img=img.detach()
            img=img.cpu().numpy()

            if attack:
                if self.min is not None:
                    save_images(img, os.path.join(logdir, method+suffix, 'raw_attack.png'), self.min.numpy(), self.max.numpy())
                if improve:
                    save_images(mask, os.path.join(logdir, method + suffix, 'mask_attack_improve.png'))
                    f=open(os.path.join(logdir, method + suffix, 'mask_attack_improve.txt'),'w')
                else:
                    save_images(mask, os.path.join(logdir, method+suffix, 'mask_attack.png'))
                    f = open(os.path.join(logdir, method+suffix, 'mask_attack.txt'), 'w')
                if random:
                    save_images(mask_random, os.path.join(logdir, method+suffix, 'mask_random_attack.png'))
            else:
                if self.min is not None:
                    save_images(img, os.path.join(logdir, method+suffix, 'raw.png'), self.min.numpy(), self.max.numpy())
                if improve:
                    save_images(mask, os.path.join(logdir, method+suffix, 'mask_improve.png'))
                    f = open(os.path.join(logdir, method+suffix, 'mask_improve.txt'), 'w')
                else:
                    save_images(mask, os.path.join(logdir, method+suffix, 'mask.png'))
                    f = open(os.path.join(logdir, method + suffix, 'mask.txt'), 'w')
                if random:
                    save_images(mask_random, os.path.join(logdir, method+suffix, 'mask_random.png'))
            if self.attack is not None:
                gt=torch.sign(torch.mean(torch.abs(self.attack[label]),1,keepdim=True)).numpy()
                Q_value=np.sum(np.abs(mask*gt))/np.sum(np.abs(mask)+1e-8)
                print('quantitative evaluation:', Q_value)
                print('quantitative evaluation:',Q_value,file=f)
            f.close()
            #cam=mask*0.5+img*0.5
            #save_images(cam, os.path.join(logdir, method, 'cam.jpg'))
            #if img.shape[0]==1:
            #    from utils.visualization import show_cam
            #    show_cam(img,mask, os.path.join(logdir, method, 'cam.jpg'))
    def explain_all(self,img,label,logdir=None,model=None,method='GradientSHAP',attack=False,random=False,improve=False,
                    suffix='',train_loader=None,topklabel=None):
        '''
        input:
        img: batch X channels X height X width [BCHW], torch Tensor

        output:
        attribution_map: batch X height X width,numpy
        '''
        if not attack:
            if model is None:
                if improve:
                    model=self.improve_model
                else:
                    model=self.normal_model
        else:
            if model is None:
                if improve:
                    model = self.improve_model
                else:
                    model = self.gt_model
            img=self.attack_img(img,label)

        def weights_init(m):
            classname = m.__class__.__name__

            # print(classname)
            if classname.find('Conv') != -1:
                nn.init.xavier_normal_(m.weight.data)
            elif classname.find('Linear') != -1:
                nn.init.xavier_normal_(m.weight.data)
                m.bias.data.fill_(0)
        if random:
            import copy
            random_model=copy.deepcopy(model)
            random_model.apply(weights_init)

        if self.use_cuda:
            img=img.cuda()
            if label is not None:
                label = label.cuda()

        model=model.eval()
        mask,mask_random=self.get_mask(img,model,method,label,topklabel,train_loader)

        if logdir is not None:
            if not os.path.exists(os.path.join(logdir,method+suffix)):  # 如果路径不存在
                os.makedirs(os.path.join(logdir,method+suffix))
            if img.requires_grad:
                img=img.detach()
            img=img.cpu().numpy()

            if attack:
                if self.min is not None:
                    save_images(img, os.path.join(logdir, method+suffix, 'raw_attack.png'), self.min.numpy(), self.max.numpy())
                if improve:
                    save_images(mask, os.path.join(logdir, method + suffix, 'mask_attack_improve.png'))
                else:
                    save_images(mask, os.path.join(logdir, method+suffix, 'mask_attack.png'))
                if random:
                    save_images(mask_random, os.path.join(logdir, method+suffix, 'mask_random_attack.png'))
            else:
                if self.min is not None:
                    save_images(img, os.path.join(logdir, method+suffix, 'raw.png'), self.min.numpy(), self.max.numpy())
                if improve:
                    save_images(mask, os.path.join(logdir, method+suffix, 'mask_improve.png'))

                else:
                    save_images(mask, os.path.join(logdir, method+suffix, 'mask.png'))

                if random:
                    save_images(mask_random, os.path.join(logdir, method+suffix, 'mask_random.png'))
            #cam=mask*0.5+img*0.5
            #save_images(cam, os.path.join(logdir, method, 'cam.jpg'))
            #if img.shape[0]==1:
            #    from utils.visualization import show_cam
            #    show_cam(img,mask, os.path.join(logdir, method, 'cam.jpg'))
        return mask

    def explain_split(self,img,label,logdir=None,model=None,method='GradientSHAP',attack=False,random=False,improve=False,
                    suffix='',train_loader=None,topklabel=None,color=True):
        '''
        input:
        img: batch X channels X height X width [BCHW], torch Tensor

        output:
        attribution_map: batch X height X width,numpy
        '''
        self.color=color
        if not attack:
            if model is None:
                if improve:
                    model=self.improve_model
                else:
                    model=self.normal_model
        else:
            if model is None:
                if improve:
                    model = self.improve_model
                else:
                    model = self.gt_model
            img=self.attack_img(img,label)

        def weights_init(m):
            classname = m.__class__.__name__

            # print(classname)
            if classname.find('Conv') != -1:
                nn.init.xavier_normal_(m.weight.data)
            elif classname.find('Linear') != -1:
                nn.init.xavier_normal_(m.weight.data)
                m.bias.data.fill_(0)

        import copy
        model = copy.deepcopy(model)
        if random:
            random_model=model
            random_model.apply(weights_init)

        if self.use_cuda:
            img=img.cuda()
            if label is not None:
                label=label.cuda()


        model=model.eval()
        mask,mask_random=self.get_mask(img,model,method,label,topklabel,train_loader)
        def split(mask,k):
            mask=torch.Tensor(mask)
            temp=mask.view(-1,k,mask.size(1),mask.size(2),mask.size(3))
            #R=temp[1:].cuda()
            R=temp.cuda()
            #B=torch.rand(1,k,mask.size(1),mask.size(2),mask.size(3)).cuda()
            B=torch.mean(R,0,keepdim=True)/0.5
            B=torch.nn.Parameter(B)
            a=torch.nn.Parameter(torch.zeros(R.size(0),k,1,1,1).cuda())
            #a = torch.nn.Parameter(torch.ones(R.size(0), k, 1, 1, 1).cuda())
            #from utils.AdamW import AdamW

            #optimizer = AdamW([a,B], lr=1e-4, betas=(0.5, 0.9), weight_decay=0)
            #schedule=torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 100, eta_min=0, last_epoch=-1)
            '''
            for i in range(10000):
                loss = torch.mean(torch.abs(torch.sigmoid(a) * B - R))
                #loss = torch.mean(torch.abs(a * B - R))
                #loss = torch.mean((torch.sigmoid(a) * B - R)**2)
                #loss=torch.mean(torch.abs(torch.log(1+torch.exp(a))*B-R))
                #loss = torch.mean((torch.log(1 + torch.exp(a)) * B - R)**2)
                loss.backward()
                optimizer.step()
                if i %2000==0:
                    print(i,loss.item())
                #schedule.step()
            '''
            #C=R-torch.log(1+torch.exp(a))*B
            C = R - torch.sigmoid(a) * B
            #C = R - a * B
            #print(temp[0].size(),C.size())
            #C=torch.cat((temp[0],C))
            return B.view(-1,mask.size(1),mask.size(2),mask.size(3)).data.cpu().numpy(),\
                   C.view(-1,mask.size(1),mask.size(2),mask.size(3)).data.cpu().numpy()
        mask_raw=mask.copy()
        if mask_random is None:
            condition,mask=split(mask_raw,img.size(0))
        else:
            #print(mask,mask_random)
            condition=torch.Tensor(mask_random).cuda()
            temp = torch.Tensor(mask).cuda()
            temp=temp.view(-1,img.size(0),temp.size(1),temp.size(2),temp.size(3))
            B=condition.view(1,condition.size(0),condition.size(1),condition.size(2),condition.size(3))
            C=temp-B
            mask=C.view(-1,condition.size(1),condition.size(2),condition.size(3)).data.cpu().numpy()
            condition=B.view(-1,condition.size(1),condition.size(2),condition.size(3)).data.cpu().numpy()
        k=img.size(0)

        if logdir is None:
            return condition,mask
        else:
            if not os.path.exists(os.path.join(logdir,method+suffix)):  # 如果路径不存在
                os.makedirs(os.path.join(logdir,method+suffix))
            if img.requires_grad:
                img=img.detach()
            img=img.cpu().numpy()
            save_images(img, os.path.join(logdir, method + suffix, 'raw.png'))
            save_images(condition, os.path.join(logdir, method + suffix, 'condition.png'),img_num=k)
            if attack:
                save_images(mask_raw, os.path.join(logdir, method + suffix, 'mask_raw_attack.png'),img_num=k)
                if improve:
                    save_images(mask, os.path.join(logdir, method + suffix, 'mask_attack_improve.png'),img_num=k)
                else:
                    save_images(mask, os.path.join(logdir, method+suffix, 'mask_attack.png'),img_num=k)
                if random:
                    save_images(mask_random, os.path.join(logdir, method+suffix, 'mask_random_attack.png'),img_num=k)
            else:
                save_images(mask_raw, os.path.join(logdir, method + suffix, 'mask_raw.png'),img_num=k)
                if improve:
                    save_images(mask, os.path.join(logdir, method+suffix, 'mask_improve.png'),img_num=k)

                else:
                    save_images(mask, os.path.join(logdir, method+suffix, 'mask.png'),img_num=k)

                if random:
                    save_images(mask_random, os.path.join(logdir, method+suffix, 'mask_random.png'),img_num=k)
            #cam=mask*0.5+img*0.5
            #save_images(cam, os.path.join(logdir, method, 'cam.jpg'))
            #if img.shape[0]==1:
            #    from utils.visualization import show_cam
            #    show_cam(img,mask, os.path.join(logdir, method, 'cam.jpg'))

    def standardize(self,X):
        minn = np.min(X.reshape([X.shape[0], -1]), axis=1)
        maxx = np.max(X.reshape([X.shape[0], -1]), axis=1)
        if X.ndim == 4:
            minn = minn.reshape([X.shape[0], 1, 1, 1])
            maxx = maxx.reshape([X.shape[0], 1, 1, 1])
        elif X.ndim == 3:
            minn = minn.reshape([X.shape[0], 1, 1])
            maxx = maxx.reshape([X.shape[0], 1, 1])
        else:
            minn = minn.reshape([X.shape[0], 1])
            maxx = maxx.reshape([X.shape[0], 1])

        X = (X - minn) / (maxx - minn + 1e-8)
        return X
    def fft(self,img,logdir=None):
        import scipy.fftpack as fp
        img = self.standardize(img)
        #img = np.mean(img, axis=1)
        ## Functions to go from image to frequency-image and back
        im2freq = lambda data: fp.fft(fp.fft(data, axis=2),
                                       axis=3)
        freq2im = lambda f: fp.ifft(fp.ifft(f, axis=2),
                                     axis=3)
        #f = im2freq(img)

        f=np.fft.fftn(img,axes=(-2,-1))

        f_mag = np.fft.fftshift(f,axes=(-2,-1))
        f_mag=np.abs(f_mag)


        if logdir is not None:
            if not os.path.exists(logdir):  # 如果路径不存在
                os.makedirs(logdir)
            save_images(img, os.path.join(logdir, 'raw.png'))
            save_images(np.log(1+f_mag), os.path.join(logdir, 'f_mag.png'))
        return f


    def explain_split_fft(self,img,label,logdir=None,model=None,method='GradientSHAP',attack=True,random=False,improve=False,
                    suffix='',train_loader=None,topklabel=None):
        '''
        input:
        img: batch X channels X height X width [BCHW], torch Tensor

        output:
        attribution_map: batch X height X width,numpy
        '''
        if not attack:
            if model is None:
                if improve:
                    model=self.improve_model
                else:
                    model=self.normal_model
        else:
            if model is None:
                if improve:
                    model = self.improve_model
                else:
                    model = self.gt_model
            img=self.attack_img(img,label)

        def weights_init(m):
            classname = m.__class__.__name__

            # print(classname)
            if classname.find('Conv') != -1:
                nn.init.xavier_normal_(m.weight.data)
            elif classname.find('Linear') != -1:
                nn.init.xavier_normal_(m.weight.data)
                m.bias.data.fill_(0)

        import copy
        model = copy.deepcopy(model)
        if random:
            random_model=model
            random_model.apply(weights_init)

        if self.use_cuda:
            img=img.cuda()
            label=label.cuda()


        model=model.eval()
        mask_raw,mask_random=self.get_mask(img,model,method,label,topklabel,train_loader)
        def split(mask,k):
            temp=mask.reshape([-1,k,mask.shape[1],mask.shape[2],mask.shape[3]])
            X=temp[0]
            R=temp[1:]
            R_raw=R.copy()
            #tr=np.maximum(np.sign(np.abs(R.real)-np.abs(X.real)),0)
            #ti=np.maximum(np.sign(np.abs(R.imag)-np.abs(X.imag)),0)
            #R.real=R.real*tr+np.abs(X.real)*(1-tr)*np.sign(R.real)
            #R.imag=R.imag*ti+np.abs(X.imag)*(1-ti)*np.sign(R.imag)
            tt=np.maximum(np.sign(np.abs(R)-np.abs(X)),0)
            A=np.abs(X)*tt#+np.abs(R)*(1-tt)
            #R=R/(np.abs(R)+1e-8)*A
            #for i in range(R.shape[0]):
            #    R[i].real=np.minimum(R[i].real,X.real)
            #    R[i].imag = np.minimum(R[i].imag, X.imag)
            #D=R_raw-R
            R = X / (np.abs(X) + 1e-8) * A
            D=X-R
            return R.reshape([-1,mask.shape[1],mask.shape[2],mask.shape[3]]),D.reshape([-1,mask.shape[1],mask.shape[2],mask.shape[3]])
        mask_raw=np.abs(mask_raw)
        mask=mask_raw.copy()
        R,D=split(self.fft(mask),img.size(0))

        if logdir is not None:
            if not os.path.exists(os.path.join(logdir,method+suffix)):  # 如果路径不存在
                os.makedirs(os.path.join(logdir,method+suffix))
            if img.requires_grad:
                img=img.detach()
            img=img.cpu().numpy()
            self.fft(img, os.path.join(logdir, method + suffix, 'img'))
            if attack:
                self.fft(mask_raw, os.path.join(logdir, method + suffix,'mask_attack'))
            else:
                self.fft(mask_raw, os.path.join(logdir, method + suffix,'mask_raw'))
            save_images(np.log(1 + np.abs(np.fft.fftshift(R,axes=(-2,-1)))), os.path.join(logdir,method + suffix, 'R.png'))
            save_images(np.log(1 + np.abs(np.fft.fftshift(D,axes=(-2,-1)))), os.path.join(logdir, method + suffix, 'D.png'))
            IR=np.fft.ifftn(R,axes=(-2,-1)).real
            ID=np.fft.ifftn(D,axes=(-2,-1)).real
            save_images(IR, os.path.join(logdir, method + suffix, 'IR.png'))
            save_images(ID, os.path.join(logdir, method + suffix, 'ID.png'))



