import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from torch.autograd import Variable
from torchsummary import summary
import sys
import numpy as np

import torchvision
import torchvision.transforms as transforms
import models.VGG as VGG
import models.resnet as resnet
from tensorboardX import SummaryWriter
import models.mini_inception as mini_inception
import models.mini_alexnet as mini_alexnet
import tools.attack as attack
import csv
class Transferable_perturbations:
    def __init__(self,dataset,model,beta) -> None:
        self.attacker=attack.Adversarial_Attack()
        # ==========dataset==============
        self.dataset=dataset
        self.batch_size=64
        self.transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        self.transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        self.loss_fn=nn.CrossEntropyLoss()
        if self.dataset=="CIFAR10":
            trainset = torchvision.datasets.ImageFolder('./data/cifar-10-batches-py/train/robust_train', transform=self.transform_train)
            self.trainloader = torch.utils.data.DataLoader(trainset, batch_size=self.batch_size, shuffle=True, num_workers=4,pin_memory=True)

            testset = torchvision.datasets.ImageFolder('./data/cifar-10-batches-py/test', transform=self.transform_train)
            self.testloader = torch.utils.data.DataLoader(testset, batch_size=self.batch_size, shuffle=False, num_workers=4,pin_memory=True)
            self.L2_bound=0.05*78.6472
            self.Linf_bound=8/255*1/(0.2)# 0.2:to be consistent with the transform (0.2023, 0.1994, 0.2010)
            n_class=10

            self.target_model1=resnet.resnet18(num_classes=10).cuda()
            self.target_model1.load_state_dict(torch.load("ResNet18_cifar10_clean_half_transfer.pth"))
            self.target_model1.eval()

            self.target_model2=VGG.VGG('VGG16').cuda()
            self.target_model2.load_state_dict(torch.load("VGG16_cifar10_clean_half_transfer.pth"))
            self.target_model2.eval()
        elif self.dataset=="CIFAR100":
            trainset = torchvision.datasets.ImageFolder('./data/cifar-100-python/image/robust_train', transform=self.transform_train)
            self.trainloader = torch.utils.data.DataLoader(trainset, batch_size=self.batch_size, shuffle=True, num_workers=4,pin_memory=True)

            testset = torchvision.datasets.ImageFolder('./data/cifar-100-python/image/test', transform=self.transform_test)
            self.testloader = torch.utils.data.DataLoader(testset, batch_size=self.batch_size, shuffle=False, num_workers=4,pin_memory=True)

            self.L2_bound=0.05*78.6472
            self.Linf_bound=8/255*1/(0.2)# 0.2:to be consistent with the transform (0.2023, 0.1994, 0.2010)
            n_class=100

            self.target_model1=resnet.resnet18(num_classes=100).cuda()
            self.target_model1.load_state_dict(torch.load("resnet_cifar100.pth"))
            self.target_model1.eval()

            self.target_model2=VGG.VGG100('VGG16').cuda()
            self.target_model2.load_state_dict(torch.load("VGG16_cifar100.pth"))
            self.target_model2.eval()
        elif self.dataset=="SVHN":
            trainset = torchvision.datasets.ImageFolder('./data/SVHN/train/robust_train', transform=self.transform_train)
            self.trainloader = torch.utils.data.DataLoader(trainset, batch_size=self.batch_size, shuffle=True, num_workers=4,pin_memory=True)

            testset = torchvision.datasets.ImageFolder('./data/SVHN/test', transform=self.transform_test)
            self.testloader = torch.utils.data.DataLoader(testset, batch_size=self.batch_size, shuffle=False, num_workers=4,pin_memory=True)

            self.L2_bound=0.05*51.9595
            self.Linf_bound=8/255*1/(0.2)# 0.2:to be consistent with the transform (0.2023, 0.1994, 0.2010)
            n_class=10
            self.target_model1=resnet.resnet18(num_classes=10).cuda()
            self.target_model1.load_state_dict(torch.load("resnet_svhn.pth"))
            self.target_model1.eval()

            self.target_model2=VGG.VGG('VGG16').cuda()
            self.target_model2.load_state_dict(torch.load("VGG_SHVN.pth"))
            self.target_model2.eval()
        # ==========model==============
        if model=="Inception":
            if beta==np.inf:
                self.base_model=mini_inception.inception(n_class).cuda()
            else:
                self.base_model=mini_inception.inception_sn(n_class,beta).cuda()

        if model=="Alexnet":
            if beta==np.inf:
                self.base_model=mini_alexnet.alexnet(n_class).cuda()#mini_inception.inception(10)#mini_alexnet.alexnet(num_class=10)
            else:
                self.base_model=mini_alexnet.alexnet_sn(n_class,beta).cuda()
        
        

    def test(self,attack_mode,name):
        self.base_model.load_state_dict(torch.load("model_weights/"+name+".pth"))
        if attack_mode=="ERM":
            generating_perturbation=self.attacker.ERM
            eps=0
            n_iteration=0
            step_size=0
        elif attack_mode=="FGM_L2":
            generating_perturbation=self.attacker.FGM_L2
            eps=self.L2_bound
            n_iteration=1
            step_size=eps
        elif attack_mode=="PGM_L2":
            generating_perturbation=self.attacker.PGM_L2
            eps=self.L2_bound
            n_iteration=15
            step_size=eps/n_iteration*1.5
        elif attack_mode=="FGM_Linf":
            generating_perturbation=self.attacker.FGM_Linf
            eps=self.Linf_bound
            n_iteration=1
            step_size=eps
        elif attack_mode=="PGM_Linf":
            generating_perturbation=self.attacker.PGM_Linf
            eps=self.Linf_bound
            n_iteration=15
            step_size=eps/n_iteration*1.5

        # optimizer=torch.optim.Adam(self.base_model.parameters(), lr = 3e-4)
        # schduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
        for _ in range(10):
            train_clean_acc=0
                
            # loss_train=0
            n_train_samples=0
            # if test_when:
            train_transfer_V16=0
            train_transfer_R18=0
            n_train_samples_transfer=0
            train_adv_acc=0
            test_transfer_V16=0
            test_transfer_R18=0
            n_test_samples=0
            n_test_samples_transfer=0
            test_adv_acc=0
            test_clean_acc=0


            self.base_model.eval()
        
            for index,(data,y) in enumerate(self.trainloader):
                data,y=data.cuda(),y.cuda()
                prediction=self.base_model(data)
                _, prediction_clean = prediction.max(1)
                select_index=prediction_clean.eq(y)
                train_clean_acc += select_index.sum().item()
                n_train_samples+=data.shape[0]
                adv_data= generating_perturbation(self.base_model,data,y,eps,n_iteration,step_size)
                    
                n=select_index.sum().item()
                if n==0:
                    continue
                n_train_samples_transfer+=n
                selected_x=adv_data[select_index]
                selected_y=y[select_index]
                with torch.no_grad():
                    prediction=self.base_model(adv_data)
                    _, prediction = prediction.max(1)
                    train_adv_acc += prediction.eq(y).sum().item()

                    prediction=self.target_model1(selected_x)
                    _, prediction = prediction.max(1)
                    train_transfer_R18 += prediction.eq(selected_y).sum().item()

                    prediction=self.target_model2(selected_x)
                    _, prediction = prediction.max(1)
                    train_transfer_V16 += prediction.eq(selected_y).sum().item()

            for index,(data,y) in enumerate(self.testloader):
                data,y=data.cuda(),y.cuda()
                n_test_samples+=data.shape[0]
                with torch.no_grad():
                    prediction=self.base_model(data)
                    _, prediction_clean = prediction.max(1)
                    test_clean_acc += prediction_clean.eq(y).sum().item() 
                    select_index=prediction_clean.eq(y)
                n=select_index.sum().item()
                if n==0:
                    continue
                n_test_samples_transfer+=n
                adv_data=generating_perturbation(self.base_model,data,y,eps,n_iteration,step_size)
                selected_x=adv_data[select_index]
                selected_y=y[select_index]
                with torch.no_grad():
                    prediction=self.base_model(adv_data)
                    _, prediction = prediction.max(1)
                    test_adv_acc += prediction.eq(y).sum().item()

                    prediction=self.target_model1(selected_x)
                    _, prediction = prediction.max(1)
                    test_transfer_R18 += prediction.eq(selected_y).sum().item()

                    prediction=self.target_model2(selected_x)
                    _, prediction = prediction.max(1)
                    test_transfer_V16 += prediction.eq(selected_y).sum().item()







            with open("FGM_results.csv","a+",newline="") as csvfile: 
                writer = csv.writer(csvfile)
                writer.writerow([name,attack_mode,train_clean_acc/n_train_samples,train_adv_acc/n_train_samples,train_transfer_V16/n_train_samples_transfer,train_transfer_R18/n_train_samples_transfer,test_clean_acc/n_test_samples,test_adv_acc/n_test_samples,test_transfer_V16/n_test_samples_transfer,test_transfer_R18/n_test_samples_transfer])



if __name__ == '__main__':
    with open("FGM_results.csv","a+",newline="") as csvfile: 
        writer = csv.writer(csvfile)

        #先写入columns_name
        writer.writerow(["method","attack_mode","train_clean_acc","train_adv_acc","train_transfer_V16","train_transfer_R18","test_clean_acc","test_adv_acc","test_transfer_V16","test_transfer_R18",])
    model=["Inception","Alexnet",]
    dataset=["CIFAR10","CIFAR100","SVHN"]#"CIFAR10",
    beta=[np.inf,1,1.3,1.6,2,3]
    for m in model:
        for d in dataset:
            for b in beta:
                name=m+"_"+d+"_"+str(b)
                for attack_mode in ["FGM_L2","PGM_L2","FGM_Linf","PGM_Linf"]:
                    task=Transferable_perturbations(dataset=d,model=m,beta=b)
                    task.test(attack_mode,name)

    