import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from torch.autograd import Variable
from torchsummary import summary
import sys
import numpy as np

import torchvision
import torchvision.transforms as transforms
import models.VGG as VGG
import models.resnet as resnet
from tensorboardX import SummaryWriter
import models.mini_inception as mini_inception
import models.mini_alexnet as mini_alexnet
import tools.attack as attack
class Transferable_perturbations:
    def __init__(self,dataset,model,beta) -> None:
        self.attacker=attack.Adversarial_Attack()
        # ==========dataset==============
        self.dataset=dataset
        self.batch_size=640
        self.transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        self.transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        self.loss_fn=nn.CrossEntropyLoss()
        if self.dataset=="CIFAR10":
            trainset = torchvision.datasets.ImageFolder('./data/cifar-10-batches-py/train/robust_train', transform=self.transform_train)
            self.trainloader = torch.utils.data.DataLoader(trainset, batch_size=self.batch_size, shuffle=True, num_workers=4,pin_memory=True)

            testset = torchvision.datasets.ImageFolder('./data/cifar-10-batches-py/test', transform=self.transform_train)
            self.testloader = torch.utils.data.DataLoader(testset, batch_size=self.batch_size, shuffle=False, num_workers=4,pin_memory=True)
            self.L2_bound=0.05*78.6472
            self.Linf_bound=8/255*1/(0.2)# 0.2:to be consistent with the transform (0.2023, 0.1994, 0.2010)
            n_class=10

            self.target_model1=resnet.resnet18(num_classes=10).cuda()
            self.target_model1.load_state_dict(torch.load("ResNet18_cifar10_clean_half_transfer.pth"))
            self.target_model1.eval()

            self.target_model2=VGG.VGG('VGG16').cuda()
            self.target_model2.load_state_dict(torch.load("VGG16_cifar10_clean_half_transfer.pth"))
            self.target_model2.eval()
        elif self.dataset=="CIFAR100":
            trainset = torchvision.datasets.ImageFolder('./data/cifar-100-python/image/robust_train', transform=self.transform_train)
            self.trainloader = torch.utils.data.DataLoader(trainset, batch_size=self.batch_size, shuffle=True, num_workers=4,pin_memory=True)

            testset = torchvision.datasets.ImageFolder('./data/cifar-100-python/image/test', transform=self.transform_test)
            self.testloader = torch.utils.data.DataLoader(testset, batch_size=self.batch_size, shuffle=False, num_workers=4,pin_memory=True)

            self.L2_bound=0.05*78.6472
            self.Linf_bound=8/255*1/(0.2)# 0.2:to be consistent with the transform (0.2023, 0.1994, 0.2010)
            n_class=100

            self.target_model1=resnet.resnet18(num_classes=100).cuda()
            self.target_model1.load_state_dict(torch.load("resnet_cifar100.pth"))
            self.target_model1.eval()

            self.target_model2=VGG.VGG100('VGG16').cuda()
            self.target_model2.load_state_dict(torch.load("VGG16_cifar100.pth"))
            self.target_model2.eval()
        elif self.dataset=="SVHN":
            trainset = torchvision.datasets.ImageFolder('./data/SVHN/train/robust_train', transform=self.transform_train)
            self.trainloader = torch.utils.data.DataLoader(trainset, batch_size=self.batch_size, shuffle=True, num_workers=4,pin_memory=True)

            testset = torchvision.datasets.ImageFolder('./data/SVHN/test', transform=self.transform_test)
            self.testloader = torch.utils.data.DataLoader(testset, batch_size=self.batch_size, shuffle=False, num_workers=4,pin_memory=True)

            self.L2_bound=0.05*51.9595
            self.Linf_bound=8/255*1/(0.2)# 0.2:to be consistent with the transform (0.2023, 0.1994, 0.2010)
            n_class=10
            self.target_model1=resnet.resnet18(num_classes=10).cuda()
            self.target_model1.load_state_dict(torch.load("resnet_svhn.pth"))
            self.target_model1.eval()

            self.target_model2=VGG.VGG('VGG16').cuda()
            self.target_model2.load_state_dict(torch.load("VGG_SHVN.pth"))
            self.target_model2.eval()
        # ==========model==============
        if model=="Inception":
            if beta==np.inf:
                self.base_model1=mini_inception.inception(n_class).cuda()
                self.base_model2=mini_inception.inception(n_class).cuda()
            else:
                self.base_model=mini_inception.inception_sn(n_class,beta).cuda()

        if model=="Alexnet":
            if beta==np.inf:
                self.base_model1=mini_alexnet.alexnet(n_class).cuda()#mini_inception.inception(10)#mini_alexnet.alexnet(num_class=10)
                self.base_model2=mini_alexnet.alexnet(n_class).cuda()
            else:
                self.base_model=mini_alexnet.alexnet_sn(n_class,beta).cuda()
        
        

    def train(self,attack_mode,name):
        if attack_mode=="ERM":
            generating_perturbation=self.attacker.ERM
            eps=0
            n_iteration=0
            step_size=0
        elif attack_mode=="FGM_L2":
            generating_perturbation=self.attacker.FGM_L2
            eps=self.L2_bound
            n_iteration=1
            step_size=eps
        elif attack_mode=="PGM_L2":
            generating_perturbation=self.attacker.PGM_L2
            eps=self.L2_bound
            n_iteration=15
            step_size=eps/n_iteration*1.5
        elif attack_mode=="FGM_Linf":
            generating_perturbation=self.attacker.FGM_Linf
            eps=self.Linf_bound
            n_iteration=1
            step_size=eps
        elif attack_mode=="PGM_Linf":
            generating_perturbation=self.attacker.PGM_Linf
            eps=self.Linf_bound
            n_iteration=15
            step_size=eps/n_iteration*1.5

        self.base_model1.load_state_dict(torch.load("early_stop_"+name+".pth"))
        self.base_model2.load_state_dict(torch.load(name+".pth"))

        test_when=True
        train_clean_acc1=0
        train_clean_acc2=0
        train_adv_acc1=0
        train_adv_acc2=0
        n_train_samples=0

        test_clean_acc1=0
        test_clean_acc2=0
        test_adv_acc1=0
        test_adv_acc2=0
        n_test_samples=0

        train_transfer_V161=0
        train_transfer_R181=0
        train_transfer_V162=0
        train_transfer_R182=0

        test_transfer_V161=0
        test_transfer_R181=0
        test_transfer_V162=0
        test_transfer_R182=0

        intersection_train_transfer_V161=0
        intersection_train_transfer_R181=0
        intersection_train_transfer_V162=0
        intersection_train_transfer_R182=0

        intersection_test_transfer_V161=0
        intersection_test_transfer_R181=0
        intersection_test_transfer_V162=0
        intersection_test_transfer_R182=0

        n_train_samples_transfer1=0
        n_train_samples_transfer2=0
        n_train_samples_transfer3=0

        n_test_samples_transfer1=0
        n_test_samples_transfer2=0
        n_test_samples_transfer3=0


        self.base_model1.eval()
        self.base_model2.eval()
        
        
        for index,(data,y) in enumerate(self.trainloader):
            data,y=data.cuda(),y.cuda()
            with torch.no_grad():
                prediction1=self.base_model1(data)
                _, prediction_clean1 = prediction1.max(1)
                select_index1=prediction_clean1.eq(y)
                train_clean_acc1 += select_index1.sum().item()

                prediction2=self.base_model2(data)
                _, prediction_clean2 = prediction2.max(1)
                select_index2=prediction_clean2.eq(y)
                train_clean_acc2 += select_index2.sum().item()

                intersection=select_index1&select_index2

            adv_data1=generating_perturbation(self.base_model1,data,prediction_clean1,eps,n_iteration,step_size)
            prediction1=self.base_model1(adv_data1)
            _, prediction_clean1 = prediction1.max(1)
            train_adv_acc1 += prediction_clean1.eq(y).sum().item()

            adv_data2=generating_perturbation(self.base_model2,data,prediction_clean2,eps,n_iteration,step_size)
            prediction2=self.base_model2(adv_data2)
            _, prediction_clean2 = prediction2.max(1)
            train_adv_acc2 += prediction_clean2.eq(y).sum().item()

            select_index1=prediction_clean1.eq(y)
            select_index2=prediction_clean2.eq(y)
            intersection=select_index1&select_index2

            n_train_samples+=data.shape[0]
            n1=select_index1.sum().item()
            n2=select_index2.sum().item()
            n3=intersection.sum().item()

            n_train_samples_transfer1+=n1
            n_train_samples_transfer2+=n2
            n_train_samples_transfer3+=n3

            selected_x1=adv_data1[select_index1]
            selected_y1=y[select_index1]

            selected_x2=adv_data2[select_index2]
            selected_y2=y[select_index2]

            with torch.no_grad():
                if selected_x1.shape[0]!=0:
                    prediction1=self.target_model1(selected_x1)
                    _, prediction1 = prediction1.max(1)
                    train_transfer_R181 += prediction1.eq(selected_y1).sum().item()

                    prediction1=self.target_model2(selected_x1)
                    _, prediction1 = prediction1.max(1)
                    train_transfer_V161 += prediction1.eq(selected_y1).sum().item()
                if selected_x2.shape[0]!=0:
                    prediction2=self.target_model1(selected_x2)
                    _, prediction2 = prediction2.max(1)
                    train_transfer_R182 += prediction2.eq(selected_y2).sum().item()

                    prediction2=self.target_model2(selected_x2)
                    _, prediction2 = prediction2.max(1)
                    train_transfer_V162 += prediction2.eq(selected_y2).sum().item()

            selected_x1=adv_data1[intersection]
            selected_y1=y[intersection]

            selected_x2=adv_data2[intersection]
            selected_y2=y[intersection]

            with torch.no_grad():
                if selected_x1.shape[0]!=0:
                    prediction1=self.target_model1(selected_x1)
                    _, prediction1 = prediction1.max(1)
                    intersection_train_transfer_R181 += prediction1.eq(selected_y1).sum().item()

                    prediction1=self.target_model2(selected_x1)
                    _, prediction1 = prediction1.max(1)
                    intersection_train_transfer_V161 += prediction1.eq(selected_y1).sum().item()
                if selected_x2.shape[0]!=0:
                    prediction2=self.target_model1(selected_x2)
                    _, prediction2 = prediction2.max(1)
                    intersection_train_transfer_R182 += prediction2.eq(selected_y2).sum().item()

                    prediction2=self.target_model2(selected_x2)
                    _, prediction2 = prediction2.max(1)
                    intersection_train_transfer_V162 += prediction2.eq(selected_y2).sum().item()
        for index,(data,y) in enumerate(self.testloader):
            data,y=data.cuda(),y.cuda()
            with torch.no_grad():
                prediction1=self.base_model1(data)
                _, prediction_clean1 = prediction1.max(1)
                select_index1=prediction_clean1.eq(y)
                test_clean_acc1 += select_index1.sum().item()

                prediction2=self.base_model2(data)
                _, prediction_clean2 = prediction2.max(1)
                select_index2=prediction_clean2.eq(y)
                test_clean_acc2 += select_index2.sum().item()

                intersection=select_index1&select_index2

            adv_data1=generating_perturbation(self.base_model1,data,prediction_clean1,eps,n_iteration,step_size)
            prediction1=self.base_model1(adv_data1)
            _, prediction_clean1 = prediction1.max(1)
            test_adv_acc1 += prediction_clean1.eq(y).sum().item()

            adv_data2=generating_perturbation(self.base_model2,data,prediction_clean2,eps,n_iteration,step_size)
            prediction2=self.base_model2(adv_data2)
            _, prediction_clean2 = prediction2.max(1)
            test_adv_acc2 += prediction_clean2.eq(y).sum().item()

            select_index1=prediction_clean1.eq(y)
            select_index2=prediction_clean2.eq(y)
            intersection=select_index1&select_index2

            n_test_samples+=data.shape[0]
            n1=select_index1.sum().item()
            n2=select_index2.sum().item()
            n3=intersection.sum().item()

            n_test_samples_transfer1+=n1
            n_test_samples_transfer2+=n2
            n_test_samples_transfer3+=n3

            selected_x1=adv_data1[select_index1]
            selected_y1=y[select_index1]

            selected_x2=adv_data2[select_index2]
            selected_y2=y[select_index2]

            with torch.no_grad():
                if selected_x1.shape[0]!=0:
                    prediction1=self.target_model1(selected_x1)
                    _, prediction1 = prediction1.max(1)
                    test_transfer_R181 += prediction1.eq(selected_y1).sum().item()

                    prediction1=self.target_model2(selected_x1)
                    _, prediction1 = prediction1.max(1)
                    test_transfer_V161 += prediction1.eq(selected_y1).sum().item()

                if selected_x2.shape[0]!=0:
                    prediction2=self.target_model1(selected_x2)
                    _, prediction2 = prediction2.max(1)
                    test_transfer_R182 += prediction2.eq(selected_y2).sum().item()

                    prediction2=self.target_model2(selected_x2)
                    _, prediction2 = prediction2.max(1)
                    test_transfer_V162 += prediction2.eq(selected_y2).sum().item()

            selected_x1=adv_data1[intersection]
            selected_y1=y[intersection]

            selected_x2=adv_data2[intersection]
            selected_y2=y[intersection]

            with torch.no_grad():
                if selected_x1.shape[0]!=0:
                    prediction1=self.target_model1(selected_x1)
                    _, prediction1 = prediction1.max(1)
                    intersection_test_transfer_R181 += prediction1.eq(selected_y1).sum().item()

                    prediction1=self.target_model2(selected_x1)
                    _, prediction1 = prediction1.max(1)
                    intersection_test_transfer_V161 += prediction1.eq(selected_y1).sum().item()
                if selected_x2.shape[0]!=0:
                    prediction2=self.target_model1(selected_x2)
                    _, prediction2 = prediction2.max(1)
                    intersection_test_transfer_R182 += prediction2.eq(selected_y2).sum().item()

                    prediction2=self.target_model2(selected_x2)
                    _, prediction2 = prediction2.max(1)
                    intersection_test_transfer_V162 += prediction2.eq(selected_y2).sum().item()

        a=["ES_clean_train","ES_adv_train","ES_train_transfer_V16","ES_train_transfer_R18",
            "clean_train","adv_train","train_transfer_V16","train_transfer_R18",
            "ES_intersection_train_transfer_V16","ES_intersection_train_transfer_R18","intersection_train_transfer_V16","intersection_train_transfer_R18",
            ]
        b=[train_clean_acc1/n_train_samples,train_adv_acc1/n_train_samples,train_transfer_V161/n_train_samples_transfer1,train_transfer_R181/n_train_samples_transfer1,
        train_clean_acc2/n_train_samples,train_adv_acc2/n_train_samples,train_transfer_V162/n_train_samples_transfer2,train_transfer_R182/n_train_samples_transfer2,
        intersection_train_transfer_V161/n_train_samples_transfer3,intersection_train_transfer_R181/n_train_samples_transfer3, intersection_train_transfer_V162/n_train_samples_transfer3,intersection_train_transfer_R182/n_train_samples_transfer3,
            ]

        c=["ES_clean_test","ES_adv_test","ES_test_transfer_V16","ES_test_transfer_R18",
            "clean_test","adv_test","test_transfer_V16","test_transfer_R18",
            "ES_intersection_test_transfer_V16","ES_intersection_test_transfer_R18","intersection_test_transfer_V16","intersection_test_transfer_R18",
            ]
        d=[test_clean_acc1/n_test_samples,test_adv_acc1/n_test_samples,test_transfer_V161/n_test_samples_transfer1,test_transfer_R181/n_test_samples_transfer1,
        test_clean_acc2/n_test_samples,test_adv_acc2/n_test_samples,test_transfer_V162/n_test_samples_transfer2,test_transfer_R182/n_test_samples_transfer2,
        intersection_test_transfer_V161/n_test_samples_transfer3,intersection_test_transfer_R181/n_test_samples_transfer3, intersection_test_transfer_V162/n_test_samples_transfer3,intersection_test_transfer_R182/n_test_samples_transfer3,
            ]
        import csv
        with open("ES_results_only.csv","a+",newline="") as csvfile: 
                writer = csv.writer(csvfile)
                writer.writerow(b+d)

        print("end")

            


if __name__ == '__main__':
    model=["Inception","Alexnet",]
    dataset=["CIFAR10","CIFAR100","SVHN"]#"CIFAR10",
    # beta=[np.inf,1,1.3,1.6,2,3]
    attack_method=["FGM_L2","PGM_L2"]
    a=["ES_clean_train","ES_adv_train","ES_train_transfer_V16","ES_train_transfer_R18",
            "clean_train","adv_train","train_transfer_V16","train_transfer_R18",
            "ES_intersection_train_transfer_V16","ES_intersection_train_transfer_R18","intersection_train_transfer_V16","intersection_train_transfer_R18",
            ]
    c=["ES_clean_test","ES_adv_test","ES_test_transfer_V16","ES_test_transfer_R18",
            "clean_test","adv_test","test_transfer_V16","test_transfer_R18",
            "ES_intersection_test_transfer_V16","ES_intersection_test_transfer_R18","intersection_test_transfer_V16","intersection_test_transfer_R18",
            ]
    import csv
    with open("ES_results_only.csv","a+",newline="") as csvfile: 
            writer = csv.writer(csvfile)

            writer.writerow(a+c)
    for m in model:
        for d in dataset:
            for a in attack_method:
                name=m+"_"+d+"_"+str(a)+"_early_stopping"
                writer = SummaryWriter(comment="_测试_"+name)
                task=Transferable_perturbations(dataset=d,model=m,beta=np.inf)
                # task.train(name)
                task.train(a,name)

    