import os
import pdb
import sys 
import torch
import pickle
import argparse
import torch.optim
import torch.nn as nn
import torch.utils.data
import matplotlib.pyplot as plt
import torchvision.models as models
from advertorch.utils import NormalizeByChannelMeanStd

from models.resnetv2 import ResNet18
from models.wideresnet import WideResNet
from models.vgg import VGG
from datasets import *
import utils

parser = argparse.ArgumentParser(description='PyTorch Cifar10 Training')

########################## base setting ##########################
parser.add_argument('--data', type=str, default='../data', help='location of the data corpus')
parser.add_argument('--dataset', type=str, default='cifar10', help='dataset')
parser.add_argument('--arch', type=str, default='resnet18', help='dataset')
parser.add_argument('--print_freq', default=50, type=int, help='print frequency')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--save_dir', help='The directory used to save the trained models', default='adv', type=str)
parser.add_argument('--resume', action="store_true", help="resume from checkpoint")
parser.add_argument('--seed', default=None, type=int, help='random seed')
parser.add_argument('--width_factor', default=10, type=int, help='width-factor of wideresnet')
parser.add_argument('--save_all', action='store_true', help='whether to save all checkpoint')
parser.add_argument('--result_file', help='file name for result', default='result.pkl', type=str)

########################## training setting ##########################
parser.add_argument('--batch_size', type=int, default=128, help='batch size')
parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate')
parser.add_argument('--decreasing_lr', default='50,150', help='decreasing strategy')
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')
parser.add_argument('--epochs', default=200, type=int, help='number of total epochs to run')


########################## attack setting ##########################
parser.add_argument('--norm', default='linf', type=str, help='linf or l2')
parser.add_argument('--train_eps', default=8, type=float, help='train_eps')
parser.add_argument('--train_step', default=10, type=int, help='train_steps')
parser.add_argument('--train_gamma', default=2, type=float, help='train_gamma')
parser.add_argument('--train_randinit', action='store_false', help='randinit usage flag (default: on)')

parser.add_argument('--test_eps', default=8, type=float, help='test_eps')
parser.add_argument('--test_step', default=20, type=int, help='test_step')
parser.add_argument('--test_gamma', default=2, type=float, help='test_gamma')
parser.add_argument('--test_randinit', action='store_false', help='randinit usage flag (default: on)')

########################## Lwf setting ##########################
parser.add_argument('--swa', action='store_true', help='swa usage flag (default: off)')
parser.add_argument('--swa_start', type=float, default=55, metavar='N', help='SWA start epoch number (default: 60)')
parser.add_argument('--swa_c_epochs', type=int, default=1, metavar='N', help='SWA model collection frequency/cycle length in epochs (default: 1)')

########################## SWA setting ##########################
parser.add_argument('--lwf', action='store_true', help='lwf usage flag (default: off)')
parser.add_argument('--t_weight1', type=str, default=None, required=False, help='teacher model1')
parser.add_argument('--t_weight2', type=str, default=None, required=False, help='teacher model2')
parser.add_argument('--coef1', type=float, default=0.3, help='coef for XE')
parser.add_argument('--coef2', type=float, default=0.1, help='coef for KD1')
parser.add_argument('--coef3', type=float, default=0.6, help='coef for KD2')
parser.add_argument('--lwf_start', type=int, default=0, metavar='N', help='number of epochs to train lwf (default: 200)')
parser.add_argument('--lwf_end', type=int, default=200, metavar='N', help='number of epochs to train lwf (default: 200)')


best_prec1 = 0
best_ata = 0
best_prec1_swa = 0
best_ata_swa = 0

def main():
    global args, best_prec1, best_ata, best_prec1_swa, best_ata_swa
    args = parser.parse_args()

    args.train_eps = args.train_eps / 255
    args.train_gamma = args.train_gamma / 255
    args.test_eps = args.test_eps / 255
    args.test_gamma = args.test_gamma / 255

    print(args)


    torch.cuda.set_device(int(args.gpu))

    if args.seed:
        print('set random seed = ', args.seed)
        utils.setup_seed(args.seed)


    ########################## prepare dataset ##########################
    if args.dataset == 'cifar10':
        print('training on cifar10 dataset')

        if args.arch == 'resnet18':
            model = ResNet18(num_classes = 10)
        elif args.arch == 'wideresnet':
            model = WideResNet(34, 10, widen_factor=args.width_factor, dropRate=0.0)
        else:
            model = VGG(args.arch, num_classes = 10)
        
        model.normal = NormalizeByChannelMeanStd(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])

        if args.swa:

            if args.arch == 'resnet18':
                swa_model = ResNet18(num_classes = 10)
            elif args.arch == 'wideresnet':
                swa_model = WideResNet(34, 10, widen_factor=args.width_factor, dropRate=0.0)
            else:
                swa_model = VGG(args.arch, num_classes = 10)
        
            swa_model.normal = NormalizeByChannelMeanStd(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
            swa_n = 0

        if args.lwf:
            
            if args.arch == 'resnet18':
                teacher1 = ResNet18(num_classes = 10)
            elif args.arch == 'wideresnet':
                teacher1 = WideResNet(34, 10, widen_factor=args.width_factor, dropRate=0.0)
            else:
                teacher1 = VGG(args.arch, num_classes = 10)
        
            teacher1.normal = NormalizeByChannelMeanStd(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])

            if args.arch == 'resnet18':
                teacher2 = ResNet18(num_classes = 10)
            elif args.arch == 'wideresnet':
                teacher2 = WideResNet(34, 10, widen_factor=args.width_factor, dropRate=0.0)
            else:
                teacher2 = VGG(args.arch, num_classes = 10)
        
            teacher2.normal = NormalizeByChannelMeanStd(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])

        train_loader, val_loader, test_loader = cifar10_dataloaders(train_batch_size= args.batch_size, test_batch_size=args.batch_size, data_dir =args.data)

    elif args.dataset == 'cifar100':
        print('training on cifar100 dataset')

        if args.arch == 'resnet18':
            model = ResNet18(num_classes = 100)
        elif args.arch == 'wideresnet':
            model = WideResNet(34, 100, widen_factor=args.width_factor, dropRate=0.0)
        else:
            model = VGG(args.arch, num_classes = 100)
        
        model.normal = NormalizeByChannelMeanStd(mean=[0.5071, 0.4866, 0.4409], std=[0.2009, 0.1984, 0.2023])

        if args.swa:

            if args.arch == 'resnet18':
                swa_model = ResNet18(num_classes = 100)
            elif args.arch == 'wideresnet':
                swa_model = WideResNet(34, 100, widen_factor=args.width_factor, dropRate=0.0)
            else:
                swa_model = VGG(args.arch, num_classes = 100)
        
            swa_model.normal = NormalizeByChannelMeanStd(mean=[0.5071, 0.4866, 0.4409], std=[0.2009, 0.1984, 0.2023])
            swa_n = 0

        if args.lwf:
            
            if args.arch == 'resnet18':
                teacher1 = ResNet18(num_classes = 100)
            elif args.arch == 'wideresnet':
                teacher1 = WideResNet(34, 100, widen_factor=args.width_factor, dropRate=0.0)
            else:
                teacher1 = VGG(args.arch, num_classes = 100)
        
            teacher1.normal = NormalizeByChannelMeanStd(mean=[0.5071, 0.4866, 0.4409], std=[0.2009, 0.1984, 0.2023])

            if args.arch == 'resnet18':
                teacher2 = ResNet18(num_classes = 100)
            elif args.arch == 'wideresnet':
                teacher2 = WideResNet(34, 100, widen_factor=args.width_factor, dropRate=0.0)
            else:
                teacher2 = VGG(args.arch, num_classes = 100)
        
            teacher2.normal = NormalizeByChannelMeanStd(mean=[0.5071, 0.4866, 0.4409], std=[0.2009, 0.1984, 0.2023])
        
        train_loader, val_loader, test_loader = cifar100_dataloaders(train_batch_size= args.batch_size, test_batch_size=args.batch_size, data_dir =args.data)


    elif args.dataset == 'tinyimagenet':
        print('training on tiny-imagenet dataset')

        if args.arch == 'resnet18':
            model = ResNet18(num_classes = 200)
        elif args.arch == 'wideresnet':
            model = WideResNet(34, 200, widen_factor=args.width_factor, dropRate=0.0)
        else:
            model = VGG(args.arch, num_classes = 200)
        
        model.normal = NormalizeByChannelMeanStd(mean=[0.4802, 0.4481, 0.3975], std=[0.2302, 0.2265, 0.2262])

        if args.swa:

            if args.arch == 'resnet18':
                swa_model = ResNet18(num_classes = 200)
            elif args.arch == 'wideresnet':
                swa_model = WideResNet(34, 200, widen_factor=args.width_factor, dropRate=0.0)
            else:
                swa_model = VGG(args.arch, num_classes = 200)
        
            swa_model.normal = NormalizeByChannelMeanStd(mean=[0.4802, 0.4481, 0.3975], std=[0.2302, 0.2265, 0.2262])
            swa_n = 0

        if args.lwf:
            
            if args.arch == 'resnet18':
                teacher1 = ResNet18(num_classes = 200)
            elif args.arch == 'wideresnet':
                teacher1 = WideResNet(34, 200, widen_factor=args.width_factor, dropRate=0.0)
            else:
                teacher1 = VGG(args.arch, num_classes = 200)
        
            teacher1.normal = NormalizeByChannelMeanStd(mean=[0.4802, 0.4481, 0.3975], std=[0.2302, 0.2265, 0.2262])

            if args.arch == 'resnet18':
                teacher2 = ResNet18(num_classes = 200)
            elif args.arch == 'wideresnet':
                teacher2 = WideResNet(34, 200, widen_factor=args.width_factor, dropRate=0.0)
            else:
                teacher2 = VGG(args.arch, num_classes = 200)
        
            teacher2.normal = NormalizeByChannelMeanStd(mean=[0.4802, 0.4481, 0.3975], std=[0.2302, 0.2265, 0.2262])
        
        train_loader, val_loader, test_loader = tiny_imagenet_dataloaders(train_batch_size= args.batch_size, test_batch_size=args.batch_size, data_dir =args.data)


    else:
        print('dataset not support')

    if args.swa:
        swa_model.cuda()
    if args.lwf:
        teacher1.cuda()
        teacher2.cuda()

    model.cuda()


    ########################## optimizer and scheduler ##########################
    decreasing_lr = list(map(int, args.decreasing_lr.split(',')))
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=decreasing_lr, gamma=0.1)


    ########################## resume ##########################
    start_epoch = 0
    if args.resume:
        print('resume from checkpoint')
        checkpoint = torch.load(os.path.join(args.save_dir, 'checkpoint.pt'), map_location = torch.device('cuda:'+str(args.gpu)))
        best_prec1 = checkpoint['best_prec1']
        best_ata = checkpoint['best_ata']
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])

        if args.swa:
            best_prec1_swa = checkpoint['best_prec1_swa']
            best_ata_swa = checkpoint['best_ata_swa']
            swa_model.load_state_dict(checkpoint['swa_state_dict'])
            swa_n = checkpoint['swa_n']


    ########################## loading teacher model weight ##########################
    if args.lwf:
        print('loading teacher model')
        t1_checkpoint = torch.load(args.t_weight1, map_location = torch.device('cuda:'+str(args.gpu)))
        teacher1.load_state_dict(t1_checkpoint)
        t2_checkpoint = torch.load(args.t_weight2, map_location = torch.device('cuda:'+str(args.gpu)))
        teacher2.load_state_dict(t2_checkpoint)

        print('test for teacher1')
        utils.validate(test_loader, teacher1, criterion)
        utils.validate_adv(test_loader, teacher1, criterion, args)
        print('test for teacher2')                
        utils.validate(test_loader, teacher2, criterion)
        utils.validate_adv(test_loader, teacher2, criterion, args)
                
        
    ########################## record result ##########################
    
    all_result = {}
    all_train_acc = []
    all_val_sa = []
    all_val_ra = []
    all_test_sa = []
    all_test_ra = []

    all_val_sa_swa = []
    all_val_ra_swa = []
    all_test_sa_swa = []
    all_test_ra_swa = []        

    os.makedirs(args.save_dir, exist_ok=True)


    ########################## training process ##########################
    for epoch in range(start_epoch, args.epochs):

        print(optimizer.state_dict()['param_groups'][0]['lr'])

        if args.lwf and epoch >= args.lwf_start and epoch < args.lwf_end:
            print('adversarial training with LWF')
            train_loss, train_acc = utils.train_epoch_adv_teacher2(train_loader, model, teacher1, teacher2, criterion, optimizer, epoch, args)
        else:
            print('baseline adversarial training')
            train_loss, train_acc = utils.train_epoch_adv(train_loader, model, criterion, optimizer, epoch, args)
        
        all_train_acc.append(train_acc)
        scheduler.step()

        ###validation###
        val_loss, val_sa = utils.validate(val_loader, model, criterion)
        val_loss_adv, val_ra = utils.validate_adv(val_loader, model, criterion, args)    
        test_loss, test_sa = utils.validate(test_loader, model, criterion)
        test_loss_adv, test_ra = utils.validate_adv(test_loader, model, criterion, args)

        all_val_sa.append(val_sa)
        all_val_ra.append(val_ra)
        all_test_sa.append(test_sa)
        all_test_ra.append(test_ra)

        if args.swa and epoch >= args.swa_start and (epoch - args.swa_start) % args.swa_c_epochs == 0:

            # SWA
            utils.moving_average(swa_model, model, 1.0 / (swa_n + 1))
            swa_n += 1
            utils.bn_update(train_loader, swa_model)

            val_loss_swa, val_sa_swa = utils.validate(val_loader, swa_model, criterion)
            val_loss_adv_swa, val_ra_swa = utils.validate_adv(val_loader, swa_model, criterion, args)
            test_loss_swa, test_sa_swa = utils.validate(test_loader, swa_model, criterion)
            test_loss_adv_swa, test_ra_swa = utils.validate_adv(test_loader, swa_model, criterion, args)

            all_val_sa_swa.append(val_sa_swa)
            all_val_ra_swa.append(val_ra_swa)
            all_test_sa_swa.append(test_sa_swa)
            all_test_ra_swa.append(test_ra_swa)  

            is_sa_best_swa = val_sa_swa  > best_prec1_swa
            best_prec1_swa = max(val_sa_swa, best_prec1_swa)

            is_ra_best_swa = val_ra_swa  > best_ata_swa
            best_ata_swa = max(val_ra_swa, best_ata_swa)

            if is_sa_best_swa:
            
                utils.save_checkpoint(
                    args.save_dir,
                    epoch + 1,
                    sa_best=True,
                    ra_best=False,
                    swa=True,
                    state_dict=model.state_dict(),
                    swa_state_dict=swa_model.state_dict() if args.swa else None,
                    swa_n=swa_n if args.swa else None,
                    best_prec1 = best_prec1,
                    best_ata = best_ata,
                    best_prec1_swa = best_prec1_swa,
                    best_ata_swa = best_ata_swa,
                    optimizer = optimizer.state_dict(),
                    scheduler =  scheduler.state_dict(),
                )
            
            if is_ra_best_swa:
            
                utils.save_checkpoint(
                    args.save_dir,
                    epoch + 1,
                    sa_best=False,
                    ra_best=True,
                    swa=True,
                    state_dict=model.state_dict(),
                    swa_state_dict=swa_model.state_dict() if args.swa else None,
                    swa_n=swa_n if args.swa else None,
                    best_prec1 = best_prec1,
                    best_ata = best_ata,
                    best_prec1_swa = best_prec1_swa,
                    best_ata_swa = best_ata_swa,
                    optimizer = optimizer.state_dict(),
                    scheduler =  scheduler.state_dict(),
                )

        else:
            all_val_sa_swa.append(val_sa)
            all_val_ra_swa.append(val_ra)
            all_test_sa_swa.append(test_sa)
            all_test_ra_swa.append(test_ra)                             


        is_sa_best = val_sa  > best_prec1
        best_prec1 = max(val_sa, best_prec1)

        is_ra_best = val_ra  > best_ata
        best_ata = max(val_ra, best_ata)

        if is_sa_best:
            
            utils.save_checkpoint(
                args.save_dir,
                epoch + 1,
                sa_best=True,
                ra_best=False,
                swa=False,
                state_dict=model.state_dict(),
                swa_state_dict=swa_model.state_dict() if args.swa else None,
                swa_n=swa_n if args.swa else None,
                best_prec1 = best_prec1,
                best_ata = best_ata,
                best_prec1_swa = best_prec1_swa,
                best_ata_swa = best_ata_swa,
                optimizer = optimizer.state_dict(),
                scheduler =  scheduler.state_dict(),
            )

        if is_ra_best:

            utils.save_checkpoint(
                args.save_dir,
                epoch + 1,
                sa_best=False,
                ra_best=True,
                swa=False,
                state_dict=model.state_dict(),
                swa_state_dict=swa_model.state_dict() if args.swa else None,
                swa_n=swa_n if args.swa else None,
                best_prec1 = best_prec1,
                best_ata = best_ata,
                best_prec1_swa = best_prec1_swa,
                best_ata_swa = best_ata_swa,
                optimizer = optimizer.state_dict(),
                scheduler =  scheduler.state_dict(),
            )
        
        utils.save_checkpoint(
            args.save_dir,
            epoch + 1,
            sa_best=False,
            ra_best=False,
            swa=False,
            state_dict=model.state_dict(),
            swa_state_dict=swa_model.state_dict() if args.swa else None,
            swa_n=swa_n if args.swa else None,
            best_prec1 = best_prec1,
            best_ata = best_ata,
            best_prec1_swa = best_prec1_swa,
            best_ata_swa = best_ata_swa,
            optimizer = optimizer.state_dict(),
            scheduler =  scheduler.state_dict(),
        )

        if args.save_all:

            utils.save_checkpoint(
                args.save_dir,
                epoch + 1,
                sa_best=False,
                ra_best=False,
                swa=False,
                inplace=False,
                state_dict=model.state_dict(),
                swa_state_dict=swa_model.state_dict() if args.swa else None,
                swa_n=swa_n if args.swa else None,
                best_prec1 = best_prec1,
                best_ata = best_ata,
                best_prec1_swa = best_prec1_swa,
                best_ata_swa = best_ata_swa,
                optimizer = optimizer.state_dict(),
                scheduler =  scheduler.state_dict(),
            )
            
        plt.plot(all_train_acc, label='train_acc')
        plt.plot(all_test_sa, label='SA')
        plt.plot(all_test_ra, label='RA')

        if args.swa:
            plt.plot(all_test_sa_swa, label='SWA_SA')
            plt.plot(all_test_ra_swa, label='SWA_RA')

        plt.legend()
        plt.savefig(os.path.join(args.save_dir, 'net_train.png'))
        plt.close()


        all_result['train'] = all_train_acc
        all_result['test_sa'] = all_test_sa
        all_result['test_ra'] = all_test_ra
        all_result['val_sa'] = all_val_sa
        all_result['val_ra'] = all_val_ra
        all_result['test_sa_swa'] = all_test_sa_swa
        all_result['test_ra_swa'] = all_test_ra_swa
        all_result['val_sa_swa'] = all_val_sa_swa
        all_result['val_ra_swa'] = all_val_ra_swa

        pickle.dump(all_result, open(os.path.join(args.save_dir, args.result_file),'wb'))


if __name__ == '__main__':
    main()


