import os
import torch
os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2,3,4,5,6,7"
import argparse
import torchvision
import torch.optim as optim
from torchvision import transforms
from models import *
from tqdm import tqdm
import numpy as np
import copy

from utils import Logger, save_checkpoint, torch_accuracy, AverageMeter
from attacks import *

parser = argparse.ArgumentParser(description='Parameter interpolation Adversarial Training')
parser.add_argument('--epochs', type=int, default=120, metavar='N', help='number of epochs')
parser.add_argument('--arch', type=str, default="resnet18", help="choose from smallcnn, resnet18, WRN")
parser.add_argument('--num_classes', type=int, default=100)
parser.add_argument('--lr', default=0.01, type=float)

parser.add_argument('--loss_fn', type=str, default="cent", help="loss function")
parser.add_argument('--epsilon', type=float, default=0.031, help='perturbation bound')
parser.add_argument('--num-steps', type=int, default=10, help='maximum perturbation step')
parser.add_argument('--step-size', type=float, default=0.007, help='step size')

parser.add_argument('--resume',type=bool, default=False, help='whether to resume training')
parser.add_argument('--out-dir',type=str, default='./logs',help='dir of output')
parser.add_argument('--ablation', type=str, default='', help='ablation study')


args = parser.parse_args()

# Training settings
args.out_dir = os.path.join(args.out_dir, args.ablation)
if not os.path.exists(args.out_dir):
    os.makedirs(args.out_dir)

args.num_classes = 10
weight_decay = 3.5e-3 if args.arch == 'resnet18' else 7e-4
seed = 1
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True

# Setup data loader
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)

class PIAT(object):
    def __init__(self, model, buffer_ema=True):
        self.step = 0
        self.model = copy.deepcopy(model)
        self.alpha = 0.99
        self.buffer_ema = buffer_ema
        self.model_state={
            k: v.clone().detach()
            for k, v in self.model.state_dict().items()
        }
        self.param_keys = [k for k, _ in self.model.named_parameters()]
        self.buffer_keys = [k for k, _ in self.model.named_buffers()]

    def update_params(self, model):
        decay = min(self.alpha, (self.step + 1) / (self.step + 10))
        state = model.state_dict()
        now_state = self.model.state_dict()
        for name in self.param_keys:
            self.model_state[name].copy_((1-decay) * now_state[name] + decay*state[name])
        for name in self.buffer_keys:
            if self.buffer_ema:
                self.model_state[name].copy_((1-decay) * now_state[name] + decay * state[name])
            else:
                self.model_state[name].copy_(now_state[name])
        self.step += 1

    def update_parameter(self):
        self.model.load_state_dict(self.model_state)

if args.arch == 'resnet18':
    adjust_learning_rate = lambda epoch: np.interp([epoch], [0, args.epochs // 2, args.epochs * 3 // 4, args.epochs], [args.lr, args.lr, args.lr/10, args.lr/100])[0]
elif args.arch == 'WRN':
    args.lr = 0.1
    adjust_learning_rate = lambda epoch: np.interp([epoch], [0, args.epochs // 2, args.epochs * 3 // 4, args.epochs], [args.lr, args.lr, args.lr/10, args.lr/20])[0]

def TRADES_loss(adv_logits, natural_logits, target, beta=6.0):
    # Based on the repo TREADES: https://github.com/yaodongyu/TRADES
    batch_size = len(target)
    criterion_kl = nn.KLDivLoss(size_average=False).cuda()
    loss_natural = nn.CrossEntropyLoss(reduction='mean')(natural_logits, target)
    loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(adv_logits, dim=1),
                                                         F.softmax(natural_logits, dim=1))
    loss = loss_natural + beta * loss_robust
    return loss

def NMSE_loss(adv_logits, natural_logits, target,beta=4):
    adv_logits_norm=torch.norm(adv_logits,dim=1,keepdim=True)
    adv_logits_norm_num=adv_logits/adv_logits_norm
    natural_logits_norm=torch.norm(natural_logits,dim=1,keepdim=True)
    natural_logits_norm_num = natural_logits / natural_logits_norm
    predict=F.softmax(natural_logits,dim=1)
    NMSEweight=torch.Tensor(target.size()[0])
    NMSEweight=NMSEweight.to(device=adv_logits_norm.device)
    for i in range(target.size()[0]):
        NMSEweight[i]=(1-predict[i][target[i]])
    loss=(((adv_logits_norm_num-natural_logits_norm_num)*(adv_logits_norm_num-natural_logits_norm_num)).sum(dim=1,keepdim=False)*NMSEweight).mean()
    loss=beta*loss
    return loss

def train(epoch, PIAT_model, Attackers, teacher_optimizer,device, descrip_str):

    adv_losses=AverageMeter()
    adv_losses1=AverageMeter()
    adv_clean_accuracy=AverageMeter()
    adv_adv_accuracy=AverageMeter()

    pbar = tqdm(train_loader)
    pbar.set_description(descrip_str)

    copy_model=copy.deepcopy(PIAT_model.model)

    for batch_idx, (inputs, target) in enumerate(pbar):
        pbar_dic = OrderedDict()

        inputs, target = inputs.to(device), target.to(device)
        PIAT_model.model.train()
        lr = adjust_learning_rate(epoch)
        teacher_optimizer.param_groups[0].update(lr=lr)
        teacher_optimizer.zero_grad()

        x_adv = Attackers.run_specified('PGD_10', PIAT_model.model, inputs, target, return_acc=False)
        x_adv = Variable(x_adv, requires_grad=True)

        PIAT_model.model.train()
        lr = adjust_learning_rate(epoch)
        teacher_optimizer.param_groups[0].update(lr=lr)
        teacher_optimizer.zero_grad()

        logit=PIAT_model.model(x_adv)
        nat_logit=PIAT_model.model(inputs)
        loss1=NMSE_loss(logit,nat_logit,target,beta=4)
        loss=nn.CrossEntropyLoss()(logit,target)+loss1
        loss.backward()
        teacher_optimizer.step()

        adv_losses.update(loss.item())
        adv_losses1.update(loss1.item())
        adv_clean_accuracy.update(torch_accuracy(nat_logit, target, (1,))[0].item())
        adv_adv_accuracy.update(torch_accuracy(logit, target, (1,))[0].item())

        pbar_dic['adv_losses'] = '{:.2f}'.format(adv_losses.mean)
        pbar_dic['adv_losses1'] = '{:.4f}'.format(adv_losses1.mean)
        pbar_dic['adv_clean_accuracy'] = '{:.2f}'.format(adv_clean_accuracy.mean)
        pbar_dic['adv_adv_accuracy'] = '{:.2f}'.format(adv_adv_accuracy.mean)

        pbar.set_postfix(pbar_dic)

    PIAT_model.update_params(copy_model)
    PIAT_model.update_parameter()

def test(PIAT_model, Attackers, device):
    PIAT_model.model.eval()

    piat_clean_accuracy = AverageMeter()
    piat_adv_accuracy = AverageMeter()

    pbar = tqdm(test_loader)
    pbar.set_description('Testing')

    for batch_idx, (inputs, target) in enumerate(pbar):
        pbar_dic = OrderedDict()
        
        inputs, target = inputs.to(device), target.to(device)

        piat_acc = Attackers.run_specified('NAT', PIAT_model.model, inputs, target, return_acc=True)
        piat_adv_acc = Attackers.run_specified('PGD_20', PIAT_model.model, inputs, target, category='Madry', return_acc=True)

        piat_clean_accuracy.update(piat_acc[0].item(), inputs.size(0))
        piat_adv_accuracy.update(piat_adv_acc[0].item(), inputs.size(0))

        pbar_dic['piat_cleanAcc'] = '{:.2f}'.format(piat_clean_accuracy.mean)
        pbar_dic['ema_advAcc'] = '{:.2f}'.format(piat_adv_accuracy.mean)
        pbar.set_postfix(pbar_dic)

    return piat_clean_accuracy.mean, piat_adv_accuracy.mean

def warmup_train(epoch, PIAT_model, PIAT_optimizer,device, descrip_str):
    losses=AverageMeter()
    clean_accuracy=AverageMeter()

    pbar = tqdm(train_loader)
    pbar.set_description(descrip_str)

    for batch_idx, (inputs, target) in enumerate(pbar):
        pbar_dic = OrderedDict()
        inputs, target = inputs.to(device), target.to(device)
        PIAT_model.model.train()
        lr = adjust_learning_rate(epoch)
        PIAT_optimizer.param_groups[0].update(lr=lr)
        PIAT_optimizer.zero_grad()
        nat_logit=PIAT_model.model(inputs)
        loss=nn.CrossEntropyLoss()(nat_logit,target)
        loss.backward()
        PIAT_optimizer.step()

        losses.update(loss.item())
        clean_accuracy.update(torch_accuracy(nat_logit, target, (1,))[0].item())

        pbar_dic['losses'] = '{:.2f}'.format(losses.mean)
        pbar_dic['clean_accuracy'] = '{:.2f}'.format(clean_accuracy.mean)

        pbar.set_postfix(pbar_dic)

def attack(model, Attackers, device):
    model.eval()

    clean_accuracy = AverageMeter()
    pgd20_accuracy = AverageMeter()
    pgd100_accuracy = AverageMeter()
    mim_accuracy = AverageMeter()
    cw_accuracy = AverageMeter()
    APGD_ce_accuracy = AverageMeter()
    APGD_dlr_accuracy = AverageMeter()
    APGD_t_accuracy = AverageMeter()
    FAB_t_accuracy = AverageMeter()
    Square_accuracy = AverageMeter()
    aa_accuracy = AverageMeter()

    pbar = tqdm(test_loader)
    pbar.set_description('Attacking all')

    for batch_idx, (inputs, targets) in enumerate(pbar):
        pbar_dic = OrderedDict()

        inputs, targets = inputs.to(device), targets.to(device)

        acc_dict = Attackers.run_all(model, inputs, targets)

        clean_accuracy.update(acc_dict['NAT'][0].item(), inputs.size(0))
        pgd20_accuracy.update(acc_dict['PGD_20'][0].item(), inputs.size(0))
        pgd100_accuracy.update(acc_dict['PGD_100'][0].item(), inputs.size(0))
        mim_accuracy.update(acc_dict['MIM'][0].item(), inputs.size(0))
        cw_accuracy.update(acc_dict['CW'][0].item(), inputs.size(0))
        APGD_ce_accuracy.update(acc_dict['APGD_ce'][0].item(), inputs.size(0))
        APGD_dlr_accuracy.update(acc_dict['APGD_dlr'][0].item(), inputs.size(0))
        APGD_t_accuracy.update(acc_dict['APGD_t'][0].item(), inputs.size(0))
        FAB_t_accuracy.update(acc_dict['FAB_t'][0].item(), inputs.size(0))
        Square_accuracy.update(acc_dict['Square'][0].item(), inputs.size(0))
        aa_accuracy.update(acc_dict['AA'][0].item(), inputs.size(0))

        pbar_dic['clean'] = '{:.2f}'.format(clean_accuracy.mean)
        pbar_dic['PGD20'] = '{:.2f}'.format(pgd20_accuracy.mean)
        pbar_dic['PGD100'] = '{:.2f}'.format(pgd100_accuracy.mean)
        pbar_dic['MIM'] = '{:.2f}'.format(mim_accuracy.mean)
        pbar_dic['CW'] = '{:.2f}'.format(cw_accuracy.mean)
        pbar_dic['APGD_ce'] = '{:.2f}'.format(APGD_ce_accuracy.mean)
        pbar_dic['APGD_dlr'] = '{:.2f}'.format(APGD_dlr_accuracy.mean)
        pbar_dic['APGD_t'] = '{:.2f}'.format(APGD_t_accuracy.mean)
        pbar_dic['FAB_t'] = '{:.2f}'.format(FAB_t_accuracy.mean)
        pbar_dic['Square'] = '{:.2f}'.format(Square_accuracy.mean)
        pbar_dic['AA'] = '{:.2f}'.format(aa_accuracy.mean)
        pbar.set_postfix(pbar_dic)

    return [clean_accuracy.mean, pgd20_accuracy.mean, pgd100_accuracy.mean, mim_accuracy.mean, cw_accuracy.mean, APGD_ce_accuracy.mean, APGD_dlr_accuracy.mean, APGD_t_accuracy.mean, FAB_t_accuracy.mean, Square_accuracy.mean, aa_accuracy.mean]


def main():
    start_epoch = 1
    best_piat_acc_adv = 0

    if args.arch == "smallcnn":
        model = SmallCNN()
    if args.arch == "resnet18":
        model = ResNet18(num_classes=args.num_classes)
    if args.arch == "preactresnet18":
        model = PreActResNet18(num_classes=args.num_classes)
    if args.arch == "WRN":
        model = Wide_ResNet_Madry(depth=32, num_classes=args.num_classes, widen_factor=10, dropRate=0.0)

    device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    PIAT_model = PIAT(model)
    Attackers = AttackerPolymer(args.epsilon, args.num_steps, args.step_size, args.num_classes, device)

    if not args.resume:
        PIAT_optimizer = optim.SGD(PIAT_model.model.parameters(), lr=args.lr, momentum=0.9, weight_decay=weight_decay)
        
        logger_test = Logger(os.path.join(args.out_dir, 'log_results.txt'), title='reweight')
        logger_test.set_names(['Epoch', 'piat_nat_acc', 'piat_pgd20_acc'])

        for epoch in range(start_epoch,11):
            descrip_str = 'Training epoch:{}/{}'.format(epoch, 10)
            warmup_train(epoch, PIAT_model, PIAT_optimizer, device, descrip_str)

        for epoch in range(start_epoch, args.epochs+1):
            descrip_str = 'Training epoch:{}/{}'.format(epoch, args.epochs)

            train(epoch, PIAT_model, Attackers, PIAT_optimizer, device, descrip_str)
            piat_nat_acc, piat_pgd20_acc = test(PIAT_model, Attackers, device=device)
            logger_test.append([epoch, piat_nat_acc, piat_pgd20_acc])

            if piat_pgd20_acc > best_piat_acc_adv:
                print('==> Updating the teacher model..')
                best_piat_acc_adv = piat_pgd20_acc
                torch.save(PIAT_model.model.state_dict(), os.path.join(args.out_dir, 'piat_bestpoint.pth.tar'))

            # # Save the last checkpoint
            # torch.save(model.state_dict(), os.path.join(args.out_dir, 'lastpoint.pth.tar'))

    PIAT_model.model.load_state_dict(torch.load(os.path.join(args.out_dir, 'piat_bestpoint.pth.tar')))
    res_list = attack(PIAT_model.model, Attackers, device)

    logger_test.set_names(['Epoch', 'clean', 'PGD20', 'PGD100', 'MIM', 'CW', 'APGD_ce', 'APGD_dlr', 'APGD_t', 'FAB_t', 'Square', 'AA'])
    logger_test.append([1000000, res_list[0], res_list[1], res_list[2], res_list[3], res_list[4], res_list[5], res_list[6], res_list[7], res_list[8], res_list[9], res_list[10]])

    logger_test.close()

if __name__ == '__main__':
    main()