#encoding:utf-8
import torch
import torchvision
import os
from torchvision import transforms
import torch.nn as nn
from models.binresnet import *
from utils import *
from tqdm import tqdm
import argparse
seed = 1234
seed_everything(1234)

parser = argparse.ArgumentParser(description='PyTorch CIFAR Adversarial Training')
parser.add_argument('--resume', type=bool, default=False,
                    help='if resume')
parser.add_argument('--savepath', type=str, default='checkpoint',
                    help='where to save')
parser.add_argument('--device', type=str, default='cuda',
                    help='device')
parser.add_argument('--norm_type', type=str, default='bn',
                    help='normlization type,aviable: bn bin bbn')
parser.add_argument('--classnum', type=int, default=10,
                    help='class numbers')
parser.add_argument('--eps', type=float, default=0.03125,
                    help='epsilon')
parser.add_argument('--resume_epoch',type=int, default=1,
                    help='resume epoch')
parser.add_argument('--weight_decay',type=float,default=2e-4,help='weight decay')
parser.add_argument('--target',type=int,default=3,help='target class')


args = parser.parse_args()

trainsize = 128
testsize = 200

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=trainsize, shuffle=True, num_workers=16,pin_memory=True)

trainset1 = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_test)
trainloader1 = torch.utils.data.DataLoader(
    trainset1, batch_size=testsize, shuffle=False, num_workers=16,pin_memory=True)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=testsize, shuffle=False, num_workers=16,pin_memory=True)
device = args.device
root = 'checkpoints/'
savepath = root+args.savepath
resumepath = savepath

step_size = 0.008
init_lr = 0.1
end = 50

if args.resume:
    resume_epoch = args.resume_epoch
    start = resume_epoch+1
    resume_model = resumepath +'/model/'+str(resume_epoch)+'.pth'
    resume_tmp =  resumepath + '/log_file/'+'log.pt'
    
    print("resume from",resume_model)
    model = torch.load(resume_model).to(device)
    tmp = torch.load(resume_tmp)
    train_nature_acc_list = tmp[0][0:resume_epoch+1]
    train_adv_acc_list = tmp[1][0:resume_epoch+1]
    test_nature_acc_list = tmp[2][0:resume_epoch+1]
    test_adv_acc_list = tmp[3][0:resume_epoch+1]

else:
    print('new start')
    start = 0
    model = BINResNet18_2d(args.norm_type,num_classes=10).to(device)
    train_nature_acc_list = []
    train_adv_acc_list = []
    test_nature_acc_list = []
    test_adv_acc_list = []
print("Single adv train the " + str(args.target) + " class")
model_savepath = savepath+'/model'
log_savepath = savepath+'/log_file'
if not os.path.exists(model_savepath):
    os.makedirs(model_savepath)
if not os.path.exists(log_savepath):
    os.makedirs(log_savepath)

optimizer= torch.optim.SGD(model.parameters(), lr=init_lr, momentum=0.9, weight_decay=args.weight_decay)

for epoch in range(start,end):
    print("epoch:%d"%epoch)
    model.train()
    adjust_learning_rate(optimizer,epoch,bin=args.norm_type=='bin',end=end)
    for i,(images,labels) in tqdm(enumerate(trainloader)):
            images = images.to(device)
            labels = labels.to(device)
            pos = labels ==  args.target
            if pos.sum()>0:
                target_images = images[pos]
                target_labels = labels[pos]
                model.eval()
                adv_images = inf_pgd(model,target_images,target_labels,iter_time=10)
                model.train()
                images[pos] = adv_images
            logits = model(images)
            loss = F.cross_entropy(logits,labels)            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    model.eval()
    test_num = 1000
    train_nature_acc,train_adv_acc = white_box_test(model,trainloader1,device = device,step_size = 0.008,pgd_time = 20,test_pic_num=test_num,slogan="Training ")
    test_nature_acc,test_adv_acc = white_box_test(model,testloader,device = device,step_size = 0.008,pgd_time = 20,test_pic_num=test_num,slogan="Testing ")
    train_nature_acc_list.append(train_nature_acc)
    train_adv_acc_list.append(train_adv_acc)
    test_nature_acc_list.append(test_nature_acc)
    test_adv_acc_list.append(test_adv_acc)
    torch.save([train_nature_acc_list,train_adv_acc_list,test_nature_acc_list,test_adv_acc_list],os.path.join(log_savepath,"log.pt"))
    torch.save(model,os.path.join(model_savepath,str(epoch)+".pth"))

