#encoding:utf-8
import torch
import torchvision
import os
from torchvision import transforms
import torch.nn as nn
from models.binresnet import *
from utils import *
from tqdm import tqdm
import argparse
from trades import *
#args
parser = argparse.ArgumentParser(description='PyTorch--Adversarial Training')

######### save setting ##############

parser.add_argument('--resumepath',type=str, default="the_robustness_king",
                    help='resumepath')
parser.add_argument('--resume_epoch',type=int, default=1,
                    help='resume epoch')
parser.add_argument('--resume', type=bool, default=False,
                    help='1: resume, 0: not resume')
parser.add_argument('--savepath', type=str, default='the_robustness_king',
                    help='savepath')

######## basic setting ###############

parser.add_argument('--device', type=str, default='cuda',
                    help='device')
parser.add_argument('--norm_type', type=str, default='bn',
                    help='normlization layer type,aviable: bn bin in')
parser.add_argument('--classnum', type=int, default=10,
                    help='class numbers')
parser.add_argument('--eps', type=float, default=0.03125,
                    help='epsilon')
parser.add_argument('--weight_decay',type=float,default=2e-4,help='weight decay')
parser.add_argument('--seed',type=int,default=10,help='random seed')
parser.add_argument('--train_bs',type=int,default=64)
parser.add_argument('--test_bs',type=int,default=200)
parser.add_argument('--dataset',type=str,default='cifar10')
parser.add_argument('--epoch',type=int,default=100)
parser.add_argument('--init_lr',type=float,default=0.1)

########  loss hyperparameter ########

parser.add_argument('--beta', type=float, default=1,
                    help='hyperparameter of the maximum entropy')
parser.add_argument('--alpha', type=float, default=0.1,
                    help='hyperparameter of the label smoothing')
parser.add_argument('--Lambda', type=float, default=6.0,
                    help='hyperparameter of the kl-divegence term in TRADES')

########### training mode #############

parser.add_argument('--mode',type=int, default=0,
                    help="training mode; available:[0:'PGD-AT',1:'PGD-AT+LS',2:'ME-AT',3:'TRADES',4:'TRADES+LS',5:'ME-TRADES']")

args = parser.parse_args()
seed = args.seed
seed_everything(seed)

###### Init #######
trainsize = args.train_bs
testsize = args.test_bs
device = args.device
root = 'checkpoints/'
savepath = root+args.savepath
resumepath = root+args.resumepath
step_size = 0.008
init_lr = args.init_lr
end = args.epoch


transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])

trainset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=trainsize, shuffle=True, num_workers=16,pin_memory=True)

trainset1 = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform_test)
trainloader1 = torch.utils.data.DataLoader(
        trainset1, batch_size=testsize, shuffle=False, num_workers=16,pin_memory=True)

testset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
        testset, batch_size=testsize, shuffle=False, num_workers=16,pin_memory=True)
print("train on CIFAR10")

if args.resume:
    resume_epoch = args.resume_epoch
    start = resume_epoch+1
    resume_model = resumepath +'/model/'+str(resume_epoch)+'.pth'
    resume_tmp =  resumepath + '/log_file/'+'log.pt'
    
    print("resume from",resume_model)
    model = torch.load(resume_model).to(device)
    tmp = torch.load(resume_tmp)
    train_nature_acc_list = tmp[0][0:resume_epoch+1]
    train_adv_acc_list = tmp[1][0:resume_epoch+1]
    test_nature_acc_list = tmp[2][0:resume_epoch+1]
    test_adv_acc_list = tmp[3][0:resume_epoch+1]

else:
    print('new start')
    start = 0
    model = BINResNet18(args.norm_type,num_classes=args.classnum).to(device)
    train_nature_acc_list = []
    train_adv_acc_list = []
    test_nature_acc_list = []
    test_adv_acc_list = []
model_savepath = savepath+'/model'
log_savepath = savepath+'/log_file'
if not os.path.exists(model_savepath):
    os.makedirs(model_savepath)
if not os.path.exists(log_savepath):
    os.makedirs(log_savepath)
if args.norm_type == 'bin':
    optimizer = set_bin_optimizer(model,init_lr=init_lr,momentum=0.9,weight_decay=args.weight_decay)
else:
    optimizer= torch.optim.SGD(model.parameters(), lr=init_lr, momentum=0.9, weight_decay=args.weight_decay)


mode_lis = ['PGD-AT','PGD-AT+LS','ME-AT','TRADES','TRADES+LS','ME-TRADES',"Standard training"]
print("training mode:",mode_lis[args.mode])

############## training #####################

for epoch in range(start,end):
    print("epoch:%d"%epoch)
    model.train()
    adjust_learning_rate(optimizer,epoch,bin=args.norm_type=='bin',end=end)
    for i,(images,labels) in tqdm(enumerate(trainloader)):
            images = images.to(device)
            labels = labels.to(device)
            
            #PGD-AT
            if args.mode == 0:
                model.eval()
                adv_images = inf_pgd(model,images,labels,iter_time=10,step_size=0.008,eps=0.03125)
                model.train()
                adv_logits = model(adv_images)
                loss = torch.nn.CrossEntropyLoss()(adv_logits,labels)
            
            #PGD-AT+LS
            if args.mode == 1:
                model.eval()
                adv_images = inf_pgd(model,images,labels,iter_time=10,step_size=0.008,eps=0.03125)
                model.train()
                adv_logits = model(adv_images)
                loss = LabelSmoothingLoss(smoothing=args.alpha)(adv_logits,labels)
            
            #ME-AT
            if args.mode == 2:
                model.eval()
                adv_images = inf_pgd(model,images,labels,iter_time=10,step_size=0.008,eps=0.03125)
                model.train()
                adv_logits = model(adv_images)
                loss_entropy = - torch.softmax(adv_logits,dim=1) * F.log_softmax(adv_logits, dim=1)
                loss_entropy = loss_entropy.sum(dim=1).mean()
                loss_robust = torch.nn.CrossEntropyLoss()(adv_logits,labels)
                loss = loss_robust - args.beta*loss_entropy
            
            #TRADES
            if args.mode == 3:
                loss = trades_loss(model,images,labels,optimizer,Lambda=args.Lambda)
            
            #TRADES+LS
            if args.mode == 4:
                loss = ls_trades_loss(model,images,labels,optimizer,Lambda=args.Lambda,alpha=args.alpha)
            
            #ME-TRADES
            if args.mode == 5:
                loss = me_trades_loss(model,images,labels,optimizer,Lambda=args.Lambda,beta=args.beta)
                
            #standard training 
            if args.mode == 6:
                logits = model(images)
                loss = torch.nn.CrossEntropyLoss()(logits,labels)
                
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if args.norm_type == 'bin':
                bin_gates = [p for p in model.parameters() if getattr(p, 'bin_gate', False)] 
                for p in bin_gates:
                    p.data.clamp_(min=0, max=1)
    model.eval()
    test_num = 1000
    if epoch > end*0.8:
        test_num = 10000
    train_nature_acc,train_adv_acc = white_box_test(model,trainloader1,device = device,step_size = 0.008,pgd_time = 20,test_pic_num=test_num,slogan="Training ")
    test_nature_acc,test_adv_acc = white_box_test(model,testloader,device = device,step_size = 0.008,pgd_time = 20,test_pic_num=test_num,slogan="Testing  ")
    train_nature_acc_list.append(train_nature_acc)
    train_adv_acc_list.append(train_adv_acc)
    test_nature_acc_list.append(test_nature_acc)
    test_adv_acc_list.append(test_adv_acc)
    torch.save([train_nature_acc_list,train_adv_acc_list,test_nature_acc_list,test_adv_acc_list],os.path.join(log_savepath,"log.pt"))
    torch.save(model,os.path.join(model_savepath,str(epoch)+".pth"))


