import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from torch.autograd import Variable
from torchsummary import summary
import sys
import numpy as np

import torchvision
import torchvision.transforms as transforms
import models.VGG as VGG
import models.resnet as resnet
from tensorboardX import SummaryWriter
import models.mini_inception as mini_inception
import models.mini_alexnet as mini_alexnet
import tools.attack as attack
class Transferable_perturbations:
    def __init__(self,dataset,model,beta) -> None:
        self.attacker=attack.Adversarial_Attack()
        # ==========dataset==============
        self.dataset=dataset
        self.batch_size=64
        self.transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        self.transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        self.loss_fn=nn.CrossEntropyLoss()
        if self.dataset=="CIFAR10":
            trainset = torchvision.datasets.ImageFolder('./data/cifar-10-batches-py/train/robust_train', transform=self.transform_train)
            self.trainloader = torch.utils.data.DataLoader(trainset, batch_size=self.batch_size, shuffle=True, num_workers=4,pin_memory=True)

            valset = torchvision.datasets.ImageFolder('./data/cifar-10-batches-py/validation', transform=self.transform_train)
            self.valloader = torch.utils.data.DataLoader(valset, batch_size=self.batch_size, shuffle=False, num_workers=4,pin_memory=True)
            self.L2_bound=0.05*78.6472
            self.Linf_bound=8/255*1/(0.2)# 0.2:to be consistent with the transform (0.2023, 0.1994, 0.2010)
            n_class=10

            self.target_model1=resnet.resnet18(num_classes=10).cuda()
            self.target_model1.load_state_dict(torch.load("ResNet18_cifar10_clean_half_transfer.pth"))
            self.target_model1.eval()

            self.target_model2=VGG.VGG('VGG16').cuda()
            self.target_model2.load_state_dict(torch.load("VGG16_cifar10_clean_half_transfer.pth"))
            self.target_model2.eval()
        elif self.dataset=="CIFAR100":
            trainset = torchvision.datasets.ImageFolder('./data/cifar-100-python/image/robust_train', transform=self.transform_train)
            self.trainloader = torch.utils.data.DataLoader(trainset, batch_size=self.batch_size, shuffle=True, num_workers=4,pin_memory=True)

            valset = torchvision.datasets.ImageFolder('./data/cifar-100-python/image/validation', transform=self.transform_test)
            self.valloader = torch.utils.data.DataLoader(valset, batch_size=self.batch_size, shuffle=False, num_workers=4,pin_memory=True)

            self.L2_bound=0.05*78.6472
            self.Linf_bound=8/255*1/(0.2)# 0.2:to be consistent with the transform (0.2023, 0.1994, 0.2010)
            n_class=100

            self.target_model1=resnet.resnet18(num_classes=100).cuda()
            self.target_model1.load_state_dict(torch.load("resnet_cifar100.pth"))
            self.target_model1.eval()

            self.target_model2=VGG.VGG100('VGG16').cuda()
            self.target_model2.load_state_dict(torch.load("VGG16_cifar100.pth"))
            self.target_model2.eval()
        elif self.dataset=="SVHN":
            trainset = torchvision.datasets.ImageFolder('./data/SVHN/train/robust_train', transform=self.transform_train)
            self.trainloader = torch.utils.data.DataLoader(trainset, batch_size=self.batch_size, shuffle=True, num_workers=4,pin_memory=True)

            valset = torchvision.datasets.ImageFolder('./data/SVHN/validation', transform=self.transform_test)
            self.valloader = torch.utils.data.DataLoader(valset, batch_size=self.batch_size, shuffle=False, num_workers=4,pin_memory=True)

            self.L2_bound=0.05*51.9595
            self.Linf_bound=8/255*1/(0.2)# 0.2:to be consistent with the transform (0.2023, 0.1994, 0.2010)
            n_class=10
            self.target_model1=resnet.resnet18(num_classes=10).cuda()
            self.target_model1.load_state_dict(torch.load("resnet_svhn.pth"))
            self.target_model1.eval()

            self.target_model2=VGG.VGG('VGG16').cuda()
            self.target_model2.load_state_dict(torch.load("VGG_SHVN.pth"))
            self.target_model2.eval()
        # ==========model==============
        if model=="Inception":
            if beta==np.inf:
                self.base_model=mini_inception.inception(n_class).cuda()
            else:
                self.base_model=mini_inception.inception_sn(n_class,beta).cuda()

        if model=="Alexnet":
            if beta==np.inf:
                self.base_model=mini_alexnet.alexnet(n_class).cuda()#mini_inception.inception(10)#mini_alexnet.alexnet(num_class=10)
            else:
                self.base_model=mini_alexnet.alexnet_sn(n_class,beta).cuda()
        
        

    def train(self,attack_mode,name):
        if attack_mode=="ERM":
            generating_perturbation=self.attacker.ERM
            eps=0
            n_iteration=0
            step_size=0
        elif attack_mode=="FGM_L2":
            generating_perturbation=self.attacker.FGM_L2
            eps=self.L2_bound
            n_iteration=1
            step_size=eps
        elif attack_mode=="PGM_L2":
            generating_perturbation=self.attacker.PGM_L2
            eps=self.L2_bound
            n_iteration=15
            step_size=eps/n_iteration*1.5
        elif attack_mode=="FGM_Linf":
            generating_perturbation=self.attacker.FGM_Linf
            eps=self.Linf_bound
            n_iteration=1
            step_size=eps
        elif attack_mode=="PGM_Linf":
            generating_perturbation=self.attacker.PGM_Linf
            eps=self.Linf_bound
            n_iteration=15
            step_size=eps/n_iteration*1.5

        optimizer=torch.optim.Adam(self.base_model.parameters(), lr = 3e-4)
        schduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.97)
        best_test_acc=0
        for epochs in range(200):
            test_when=True
            train_clean_acc=0
            train_adv_acc=0
            loss_train=0
            n_train_samples=0
            if test_when:
                train_transfer_V16=0
                train_transfer_R18=0
                n_train_samples_transfer=0
                train_adv_acc=0
                test_transfer_V16=0
                test_transfer_R18=0
                n_test_samples=0
                n_test_samples_transfer=0
                test_adv_acc=0
                test_clean_acc=0
            self.base_model.train()
            
            
            for index,(data,y) in enumerate(self.trainloader):
                data,y=data.cuda(),y.cuda()
                with torch.no_grad():
                    prediction=self.base_model(data)
                    _, prediction_clean = prediction.max(1)
                    select_index=prediction_clean.eq(y)
                    train_clean_acc += select_index.sum().item()
                adv_data=generating_perturbation(self.base_model,data,prediction_clean,eps,n_iteration,step_size)
                prediction=self.base_model(adv_data)
                _, prediction_clean = prediction.max(1)
                train_adv_acc += prediction_clean.eq(y).sum().item()

                loss=self.loss_fn(prediction,y)
                loss.backward()
                loss_train+=loss.item()
                n_train_samples+=data.shape[0]
                optimizer.step()
                optimizer.zero_grad()

                if test_when:
                    n=select_index.sum().item()
                    if n==0:
                        continue
                    n_train_samples_transfer+=n
            print("epoch=",epochs," loss=",loss_train/(1+index))
            writer.add_scalar('train_loss', loss_train/(1+index), global_step=epochs)
            writer.add_scalar('train_clean_acc', train_clean_acc/n_train_samples, global_step=epochs)
            writer.add_scalar('train_adv_acc', train_adv_acc/n_train_samples, global_step=epochs)
            if test_when:
                writer.add_scalar('train_adv_acc', train_adv_acc/n_train_samples, global_step=epochs)
                writer.add_scalar('train_transfer_V16', train_transfer_V16/n_train_samples_transfer, global_step=epochs)
                writer.add_scalar('train_transfer_R18', train_transfer_R18/n_train_samples_transfer, global_step=epochs)

                for index,(data,y) in enumerate(self.valloader):
                    data,y=data.cuda(),y.cuda()
                    n_test_samples+=data.shape[0]
                    with torch.no_grad():
                        prediction=self.base_model(data)
                        _, prediction_clean = prediction.max(1)
                        test_clean_acc += prediction_clean.eq(y).sum().item() 
                        select_index=prediction_clean.eq(y)
                    n=select_index.sum().item()
                    if n==0:
                        continue
                    n_test_samples_transfer+=n
                    adv_data=generating_perturbation(self.base_model,data,prediction_clean,eps,n_iteration,step_size)
                    with torch.no_grad():
                        prediction=self.base_model(adv_data)
                        _, prediction = prediction.max(1)
                        test_adv_acc += prediction.eq(y).sum().item()
                writer.add_scalar('val_clean_acc', test_clean_acc/n_test_samples, global_step=epochs)
                writer.add_scalar('val_adv_acc', test_adv_acc/n_test_samples, global_step=epochs)
                writer.add_scalar('val_transfer_V16', test_transfer_V16/n_test_samples_transfer, global_step=epochs)
                writer.add_scalar('val_transfer_R18', test_transfer_R18/n_test_samples_transfer, global_step=epochs)
            
            if epochs>=60:
                schduler.step()
            
            if best_test_acc<test_adv_acc/n_test_samples:
                best_test_acc=test_adv_acc/n_test_samples
                torch.save(self.base_model.state_dict(),"early_stop_"+name+".pth")
            torch.save(self.base_model.state_dict(),name+".pth")


if __name__ == '__main__':
    model=["Inception","Alexnet",]
    dataset=["CIFAR10","CIFAR100","SVHN"]
    attack_method=["FGM_L2","PGM_L2"]
    for m in model:
        for d in dataset:
            for a in attack_method:
                name=m+"_"+d+"_"+str(a)+"_early_stopping"
                writer = SummaryWriter(comment="_"+name)
                task=Transferable_perturbations(dataset=d,model=m,beta=np.inf)
                # task.train(name)
                task.train(a,name)

    