'''
Adversarial Training 

'''
import os
import sys
from numpy.core.numeric import outer 
import torch
import pickle
import argparse
import torch.optim
import torch.nn as nn
import torch.utils.data
import matplotlib.pyplot as plt
import torchvision.models as models
from utils import *
from sparselearning.core import Masking, CosineDecay
from sparselearning.pruning_utils import check_sparsity


parser = argparse.ArgumentParser(description='PyTorch Adversarial Sparse Training')

########################## data setting ##########################
parser.add_argument('--data', type=str, default='data/cifar10', help='location of the data corpus', required=True)
parser.add_argument('--dataset', type=str, default='cifar10', help='dataset [cifar10, cifar100, tinyimagenet]', required=True)

########################## model setting ##########################
parser.add_argument('--arch', type=str, default='resnet18', help='model architecture [resnet18, wideresnet, vgg16]', required=True)
parser.add_argument('--depth_factor', default=34, type=int, help='depth-factor of wideresnet')
parser.add_argument('--width_factor', default=10, type=int, help='width-factor of wideresnet')

########################## basic setting ##########################
parser.add_argument('--seed', default=1, type=int, help='random seed')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--resume', action="store_true", help="resume from checkpoint")
parser.add_argument('--resume_dir', help='The directory resume the trained models', default=None, type=str)
parser.add_argument('--pretrained', default=None, type=str, help='pretrained model')
parser.add_argument('--eval', action="store_true", help="evaluation pretrained model")
parser.add_argument('--print_freq', default=50, type=int, help='logging frequency during training')
parser.add_argument('--save_dir', help='The parent directory used to save the trained models', default=None, type=str)

########################## training setting ##########################
parser.add_argument('--batch_size', type=int, default=128, help='batch size')
parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate')
parser.add_argument('--decreasing_lr', default='100,150', help='decreasing strategy')
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')
parser.add_argument('--epochs', default=200, type=int, help='number of total epochs to run')

########################## attack setting ##########################
parser.add_argument('--norm', default='linf', type=str, help='linf or l2')
parser.add_argument('--train_eps', default=8, type=float, help='epsilon of attack during training')
parser.add_argument('--train_step', default=10, type=int, help='itertion number of attack during training')
parser.add_argument('--train_gamma', default=2, type=float, help='step size of attack during training')
parser.add_argument('--train_randinit', action='store_false', help='randinit usage flag (default: on)')
parser.add_argument('--test_eps', default=8, type=float, help='epsilon of attack during testing')
parser.add_argument('--test_step', default=20, type=int, help='itertion number of attack during testing')
parser.add_argument('--test_gamma', default=2, type=float, help='step size of attack during testing')
parser.add_argument('--test_randinit', action='store_false', help='randinit usage flag (default: on)')

########################## SWA setting ##########################
parser.add_argument('--swa', action='store_true', help='swa usage flag (default: off)')
parser.add_argument('--swa_start', type=float, default=55, metavar='N', help='SWA start epoch number (default: 55)')
parser.add_argument('--swa_c_epochs', type=int, default=1, metavar='N', help='SWA model collection frequency/cycle length in epochs (default: 1)')

########################## KD setting ##########################
parser.add_argument('--lwf', action='store_true', help='lwf usage flag (default: off)')
parser.add_argument('--t_weight1', type=str, default=None, required=False, help='pretrained weight for teacher1')
parser.add_argument('--t_weight2', type=str, default=None, required=False, help='pretrained weight for teacher2')
parser.add_argument('--coef_ce', type=float, default=0.3, help='coef for CE')
parser.add_argument('--coef_kd1', type=float, default=0.1, help='coef for KD1')
parser.add_argument('--coef_kd2', type=float, default=0.6, help='coef for KD2')
parser.add_argument('--temperature', type=float, default=2.0, help='temperature of knowledge distillation loss')
parser.add_argument('--lwf_start', type=int, default=0, metavar='N', help='start point of lwf (default: 200)')
parser.add_argument('--lwf_end', type=int, default=200, metavar='N', help='end point of lwf (default: 200)')

########################## sparse setting ##########################
# parser.add_argument('--no_exploration', action='store_true', default=False, help='if ture, only do explore for the typical training time')
# parser.add_argument('--multiplier', type=int, default=1, metavar='N', help='extend training time by multiplier times')
parser.add_argument('--decay-schedule', type=str, default='cosine', help='The decay schedule for the pruning rate. Default: cosine. Choose from: cosine, linear.')
parser.add_argument('--update_frequency', type=int, default=100, metavar='N', help='how many iterations to train between mask update')

parser.add_argument('--growth', type=str, default='random', help='Growth mode. Choose from: momentum, random, and momentum_neuron.')
parser.add_argument('--death', type=str, default='magnitude', help='Death mode / pruning mode. Choose from: magnitude, SET, threshold, CS_death.')
parser.add_argument('--redistribution', type=str, default='none', help='Redistribution mode. Choose from: momentum, magnitude, nonzeros, or none.')
parser.add_argument('--death-rate', type=float, default=0.50, help='The pruning rate / death rate.')
parser.add_argument('--density', type=float, default=0.05, help='The density of the overall sparse network.')

# parser.add_argument('--final_density', type=float, default=0.05, help='The density of the overall sparse network.')
parser.add_argument('--sparse', action='store_true', help='Enable sparse mode. Default: True.')
parser.add_argument('--snip', action='store_true', help='Enable snip initialization. Default: True.')
parser.add_argument('--fix', action='store_true', help='Fix topology during training. Default: True.')
parser.add_argument('--sparse_init', type=str, default='uniform', help='sparse initialization')
parser.add_argument('--reset', action='store_true', help='Fix topology during training. Default: True.')

########################## static sparse setting ##########################
parser.add_argument('--static_sparse', action='store_true', help='Enable static sparse mode. Default: True.')
parser.add_argument('--sparse_type', type=str, default='rp', help='static sparse mask initialization. choose from: rp omp gmp tp snip')
parser.add_argument('--custom_mask', type=str, default=None)

########################## Dynamic  sparse ##########################
parser.add_argument('--dynamic_sparse', action='store_true', help='Enable dynamic sparse mode. Default: True.')
parser.add_argument('--epoch_range', type=int, default=4, help='epoch range to decide sparse action')
parser.add_argument('--prune_rate', type=float, default=0.4, help='The rate of dst prune ')
parser.add_argument('--growth_rate', type=float, default=0.05, help='The rate of dst growth ')
parser.add_argument('--ratio_threshold', type=float, default=0.5, help='The ratio_threshold of dst prune or growth')
parser.add_argument('--dynamic_epoch', default=100, type=int)

########################## Small Dense Test #############################
parser.add_argument('--small_dense', action='store_true', help='Enable small dense mode. Default: True.')
parser.add_argument('--small_dense_rate', type=float, default=0.8, help='The density of small density, support 0.05 0.1 0.2 0.6 0.8')

########################## Dynamic frequency #############################
parser.add_argument('--dynamic_fre', action='store_true', help='Enable dynamic frequency mode. Default: True.')
parser.add_argument('--second_frequency', type=int, default=1200, metavar='N', help='how many iterations to train between mask update in second stage')

########################## Save epoch #############################
parser.add_argument('--save_epoch', action='store_true', help='save checkpoint for every epoch. Default: True.')

########################## Combine test #############################
parser.add_argument('--consistency', action='store_true', help='apply consistency regularization')
parser.add_argument('--rslad', action='store_true', help='apply RSLAD method')
parser.add_argument('--rslad_teacher', type=str, default=None, required=False, help='rslad teacher path')
parser.add_argument('--robust_friendly', action='store_true', help='apply robust friendly dataset')


parser.add_argument('--mask_dir', help='The directory of mask', default=None, type=str)






def main():

    args = parser.parse_args()
    args.train_eps = args.train_eps / 255
    args.train_gamma = args.train_gamma / 255
    args.test_eps = args.test_eps / 255
    args.test_gamma = args.test_gamma / 255
    
    save_path = get_save_path(args)
    os.makedirs(save_path, exist_ok=True)
    setup_logger(args)
    print_args(args)
    print_and_log(args)


    torch.cuda.set_device(int(args.gpu))

    if args.seed:
        print('set random seed = ', args.seed)
        setup_seed(args.seed)

    train_loader, val_loader, test_loader, model, swa_model, teacher1, teacher2 = setup_dataset_models(args)

    if args.swa:
        swa_model.cuda()
        swa_n = 0        
    if args.lwf:
        teacher1.cuda()
        teacher2.cuda()

    model.cuda()

    ########################## optimizer and scheduler ##########################
    # decreasing_lr = list(map(int, args.decreasing_lr.split(',')))
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=decreasing_lr, gamma=0.1)
    # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[0.5*args.epochs,0.75*args.epochs], last_epoch=-1)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[int(lr) for lr in args.decreasing_lr.split(',')], last_epoch=-1)



    ########################## sparse  mask ####################################
    mask = None
    if args.sparse or args.dynamic_sparse:
        decay = CosineDecay(args.death_rate, len(train_loader) * (args.epochs))
        mask = Masking(optimizer, death_rate=args.death_rate, death_mode=args.death, death_rate_decay=decay, growth_mode=args.growth,
                        redistribution_mode=args.redistribution, args=args)
        if args.sparse_init == 'SynFlow':
            syn_mask = get_mask('synflow', args.density, args.seed, args.mask_dir)
            # print(syn_mask.keys())
            mask.set_init_mask(syn_mask)

        if args.sparse_init == 'snip':
            snip_mask = get_mask('snip', args.density, args.seed, args.mask_dir)
            # print(syn_mask.keys())
            mask.set_init_mask(snip_mask)

        if args.sparse_init == 'grasp':
            print('grasp')
            grasp_mask = get_mask('grasp', args.density, args.seed, args.mask_dir)
            # print(syn_mask.keys())
            mask.set_init_mask(grasp_mask)

        if args.sparse_init == 'custom':
            print('custom mask')
            custom_mask = torch.load(arsg.custom_mask)
            mask.set_init_mask(custom_mask)

        mask.add_module(model, sparse_init=args.sparse_init, density=args.density)
        # mask.set_dst_start_epoch(0.5*args.epochs)
        mask.set_dst_start_epoch(args.dynamic_epoch)

    
    if args.static_sparse:
        if args.sparse_type != 'gmp':
            apply_static_sparse(model, optimizer, args.sparse_type, args.density, args.seed, args.mask_dir)


    ######################### only evaluation ###################################
    if args.eval:
        assert args.pretrained
        pretrained_model = torch.load(args.pretrained, map_location = torch.device('cuda:'+str(args.gpu)))
        if args.swa:
            print_and_log('loading from swa_state_dict')
            pretrained_model = pretrained_model['swa_state_dict']
        else:
            print_and_log('loading from state_dict')
            if 'state_dict' in pretrained_model.keys():
                pretrained_model = pretrained_model['state_dict']
        model.load_state_dict(pretrained_model)
        test(test_loader, model, criterion, args)
        test_adv(test_loader, model, criterion, args)
        return 

    ########################## loading teacher model weight ##########################
    if args.lwf:
        print_and_log('loading teacher model')
        t1_checkpoint = torch.load(args.t_weight1, map_location = torch.device('cuda:'+str(args.gpu)))
        if 'state_dict' in t1_checkpoint.keys():
            t1_checkpoint = t1_checkpoint['state_dict']
        teacher1.load_state_dict(t1_checkpoint)
        t2_checkpoint = torch.load(args.t_weight2, map_location = torch.device('cuda:'+str(args.gpu)))
        if 'state_dict' in t2_checkpoint.keys():
            t2_checkpoint = t2_checkpoint['state_dict']
        teacher2.load_state_dict(t2_checkpoint)

        print_and_log('test for teacher1')
        test(test_loader, teacher1, criterion, args)
        test_adv(test_loader, teacher1, criterion, args)
        print_and_log('test for teacher2')                
        test(test_loader, teacher2, criterion, args)
        test_adv(test_loader, teacher2, criterion, args)

    ########################## resume ##########################
    start_epoch = 0
    if args.resume:
        print_and_log('resume from checkpoint.pth.tar')
        checkpoint = torch.load(os.path.join(args.resume_dir, 'checkpoint.pth.tar'), map_location = torch.device('cuda:'+str(args.gpu)))
        best_sa = checkpoint['best_sa']
        best_ra = checkpoint['best_ra']
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        all_result = checkpoint['result']
        if mask:
            mask.load_info_resume(checkpoint['mask_info'])
        
        set_ite_step(checkpoint['iter_step'])

        if args.swa:
            best_sa_swa = checkpoint['best_sa_swa']
            best_ra_swa = checkpoint['best_ra_swa']
            swa_model.load_state_dict(checkpoint['swa_state_dict'])
            swa_n = checkpoint['swa_n']
    else:
        all_result = {}
        all_result['train_acc'] = []
        all_result['val_sa'] = []
        all_result['val_ra'] = []
        all_result['test_sa'] = []
        all_result['test_ra'] = []
        all_result['g_norm'] = []
        all_result['rc_ratio'] = []
        all_result['grad_cs'] = []
        all_result['sparsity'] = 1.0
        all_result['total_fired_weights'] = 0
        all_result['best_ra_epoch'] = 0
        best_sa = 0
        best_ra = 0

        if args.swa:
            all_result['val_sa_swa'] = []
            all_result['val_ra_swa'] = []
            all_result['test_sa_swa'] = []
            all_result['test_ra_swa'] = []
            swa_n = 0  
            best_sa_swa = 0
            best_ra_swa = 0

    is_sa_best = False
    is_ra_best = False
    is_sa_best_swa = False
    is_ra_best_swa = False

    init_prune_epoch = 20
    final_prune_epoch = 79

    ########################## training process ##########################
    # consine_similarity = GradCosineSimilarity()   
    for epoch in range(start_epoch, args.epochs):

        current_sparsity = check_sparsity(model)

        print_and_log(optimizer.state_dict()['param_groups'][0]['lr'])

        if mask:
            mask.set_dst_current_epoch(epoch)

        # Gradual pruning in GMP
        if args.sparse_type == 'gmp' and epoch >= init_prune_epoch and epoch <= final_prune_epoch:
            total_prune_epochs = final_prune_epoch - init_prune_epoch + 1
            for n, m in model.named_modules():
                if hasattr(m, 'set_curr_prune_rate'):
                    prune_decay = (1 - ((epoch - init_prune_epoch)/total_prune_epochs))**3
                    curr_prune_rate = m.prune_rate - (m.prune_rate*prune_decay)
                    m.set_curr_prune_rate(curr_prune_rate)

        if args.lwf and epoch >= args.lwf_start and epoch < args.lwf_end:
            print_and_log('adversarial training with LWF')
            train_acc = train_epoch_adv_dual_teacher(train_loader, model, teacher1, teacher2, criterion, optimizer, epoch, args, mask)
        elif args.consistency:
            print_and_log('adversarial training with consistency regularization')
            train_acc, gradient_norm_agv, rl_ratio_list = train_epoch_adv_consistency(train_loader, model, criterion, optimizer, epoch, args, mask)
        elif args.rslad:
            print_and_log('adversarial training with consistency regularization')
            train_acc, gradient_norm_agv, rl_ratio_list = train_epoch_adv_RSLAD(train_loader, model, criterion, optimizer, epoch, args, mask)
        else:
            print_and_log('baseline adversarial training')
            train_acc, gradient_norm_agv, rl_ratio_list = train_epoch_adv(train_loader, model, criterion, optimizer, epoch, args, mask)

        all_result['train_acc'].append(train_acc)
        # all_result['g_norm'].append(gradient_norm_agv)
        # all_result['rc_ratio'].extend(rl_ratio_list)
        scheduler.step()

        all_result['sparsity'] = current_sparsity

        #record consine similarity
        # input_a_sample(model, criterion, optimizer, args, data_sample)
        # consine_similarity.update(model.parameters())
        # print("consine_similarity :", consine_similarity.res_list)
        # all_result['grad_cs'] = consine_similarity.res_list


        ###validation###
        val_sa = test(val_loader, model, criterion, args)
        val_ra, val_loss = test_adv(val_loader, model, criterion, args)   
        test_sa = test(test_loader, model, criterion, args)
        test_ra, _= test_adv(test_loader, model, criterion, args)  

        if args.dynamic_sparse:
            mask.update_loss_info(val_loss)
            mask.update_train_val_diff(train_acc, val_ra)

        all_result['val_sa'].append(val_sa)
        all_result['val_ra'].append(val_ra)
        all_result['test_sa'].append(test_sa)
        all_result['test_ra'].append(test_ra)

        is_sa_best = val_sa  > best_sa
        best_sa = max(val_sa, best_sa)

        is_ra_best = val_ra  > best_ra
        best_ra = max(val_ra, best_ra)

        if is_ra_best:
            all_result['best_ra_epoch'] = epoch

        checkpoint_state = {
            'best_sa': best_sa,
            'best_ra': best_ra,
            'epoch': epoch+1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'iter_step': get_ite_step(),
            'result': all_result
        }

        if mask:
            checkpoint_state.update({
                'mask_info': mask.get_info_resume()
            })

        if args.swa and epoch >= args.swa_start and (epoch - args.swa_start) % args.swa_c_epochs == 0:

            # SWA
            moving_average(swa_model, model, 1.0 / (swa_n + 1))
            swa_n += 1
            bn_update(train_loader, swa_model)

            val_sa_swa = test(val_loader, swa_model, criterion, args)
            val_ra_swa, _ = test_adv(val_loader, swa_model, criterion, args)   
            test_sa_swa = test(test_loader, swa_model, criterion, args)
            test_ra_swa, _ = test_adv(test_loader, swa_model, criterion, args)  

            all_result['val_sa_swa'].append(val_sa_swa)
            all_result['val_ra_swa'].append(val_ra_swa)
            all_result['test_sa_swa'].append(test_sa_swa)
            all_result['test_ra_swa'].append(test_ra_swa)

            is_sa_best_swa = val_sa_swa  > best_sa_swa
            best_sa_swa = max(val_sa_swa, best_sa_swa)

            is_ra_best_swa = val_ra_swa  > best_ra_swa
            best_ra_swa = max(val_ra_swa, best_ra_swa)

            checkpoint_state.update({
                'swa_state_dict': swa_model.state_dict(),
                'swa_n': swa_n,
                'best_sa_swa': best_sa_swa,
                'best_ra_swa': best_ra_swa
            })

        elif args.swa:

            all_result['val_sa_swa'].append(val_sa)
            all_result['val_ra_swa'].append(val_ra)
            all_result['test_sa_swa'].append(test_sa)
            all_result['test_ra_swa'].append(test_ra)
        
        elif args.dynamic_sparse:
            prune_epochs, growth_epochs = mask.get_dst_epochs()
            all_result['prune_epochs'] = prune_epochs
            all_result['growth_epochs'] = growth_epochs

        if mask:
            all_result['total_fired_weights']  = mask.total_fired_weights

        checkpoint_state.update({
            'result': all_result
        })
        save_checkpoint(checkpoint_state, is_sa_best, is_ra_best, is_sa_best_swa, is_ra_best_swa, save_path)

        # if epoch % 10 == 0 or epoch == (args.epochs - 1):
        if args.save_epoch:
            save_checkpoint_epochs(checkpoint_state, epoch, save_path)

        plt.plot(all_result['train_acc'], label='train_acc')
        plt.plot(all_result['test_sa'], label='SA')
        plt.plot(all_result['test_ra'], label='RA')

        if args.swa:
            plt.plot(all_result['test_sa_swa'], label='SWA_SA')
            plt.plot(all_result['test_ra_swa'], label='SWA_RA')
        
        if args.dynamic_sparse :
            prune_epochs, growth_epochs = mask.get_dst_epochs()
            #设置delta 防止重叠无法显示
            delta = 0.4
            for epoch in prune_epochs:
                plt.axvline(epoch - delta, linewidth = 0.8, color = 'black', linestyle='--')
            
            for epoch in growth_epochs:
                plt.axvline(epoch + delta, linewidth = 0.8, color = 'red', linestyle='--')


        plt.legend()
        plt.savefig(os.path.join(save_path, 'net_train.png'))
        plt.close()


if __name__ == '__main__':
    main()


