import os
import time
from numpy.lib.function_base import gradient 
import torch
import random
import shutil
import numpy as np  
import torch.nn as nn 
from torch.autograd import Variable
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from advertorch.attacks import LinfPGDAttack, L2PGDAttack
from advertorch.context import ctx_noparamgrad
from advertorch.utils import NormalizeByChannelMeanStd
from datasets import *
from models.preactivate_resnet import *
from models.vgg import *
from models.wideresnet import *
import hashlib
import logging
from sparselearning.pruning_utils import apply_mask_to_model
from sparselearning.gmp import gmp_prune_conv_linear
import copy
import pdb
from models.small_dense_resnet import getSmallDenseResNet18
from autoattack import AutoAttack

__all__ = ['save_checkpoint', 'setup_dataset_models', 'setup_dataset_models_standard', 'setup_seed', 'moving_average', 'bn_update', 'print_args',
            'train_epoch', 'train_epoch_adv', 'train_epoch_adv_dual_teacher', 'get_mask', 'get_ite_step', 'set_ite_step',
            'test', 'test_adv', 'get_save_path', 'setup_logger', 'print_and_log', 'apply_static_sparse', 'GradCosineSimilarity', 
            'input_a_sample', 'save_checkpoint_epochs', 'train_epoch_adv_consistency', 'train_epoch_adv_RSLAD', 'add_default_setting',
            'generate_adv','setup_dataset_models_standardzzy']

logger = None

def setup_logger(args):
    global logger
    if logger == None:
        logger = logging.getLogger()
    else:  # wish there was a logger.close()
        for handler in logger.handlers[:]:  # make a copy of the list
            logger.removeHandler(handler)
    
    save_path = get_save_path(args)

    log_path = os.path.join(save_path, 'result.log')

    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(fmt='%(asctime)s: %(message)s', datefmt='%H:%M:%S')

    fh = logging.FileHandler(log_path)
    fh.setFormatter(formatter)
    logger.addHandler(fh)

def print_and_log(msg):
    global logger
    print(msg)
    if logger:
        logger.info(msg)

def save_checkpoint(state, is_SA_best, is_RA_best, is_SA_best_swa, is_RA_best_swa, save_path, filename='checkpoint.pth.tar'):
    
    filepath = os.path.join(save_path, filename)
    torch.save(state, filepath)

    if is_SA_best_swa:
        shutil.copyfile(filepath, os.path.join(save_path, 'model_SWA_SA_best.pth.tar'))
    if is_RA_best_swa:
        shutil.copyfile(filepath, os.path.join(save_path, 'model_SWA_RA_best.pth.tar'))
    if is_SA_best:
        shutil.copyfile(filepath, os.path.join(save_path, 'model_SA_best.pth.tar'))
    if is_RA_best:
        shutil.copyfile(filepath, os.path.join(save_path, 'model_RA_best.pth.tar'))

def save_checkpoint_epochs(checkpoint, epoch, save_path):
    end_epoch = checkpoint['epoch']
    #assert end_epoch == 160
    #all_result: train_acc val_sa val_ra test_sa test_ra
    all_result = checkpoint['result']
    best_sa = checkpoint['best_sa']
    best_ra = checkpoint['best_ra']
    end_epoch = checkpoint['epoch']
    model_state = checkpoint['state_dict']
    checkpoint_state = {
        'best_sa': best_sa,
        'best_ra': best_ra,
        'epoch': end_epoch,
        'result': all_result,
        'state_dict': model_state
    }

    filename='checkpoint_{}.pth.tar'.format(epoch)
    save_path = os.path.join(save_path, 'checkpoint_epochs')
    os.makedirs(save_path, exist_ok=True)
    filepath = os.path.join(save_path, filename)
    torch.save(checkpoint_state, filepath)

#print training configuration
def print_args(args):
    print('*'*50)
    print('Dataset: {}'.format(args.dataset))
    print('Model: {}'.format(args.arch))
    if args.arch == 'wideresnet':
        print('Depth {}'.format(args.depth_factor))
        print('Width {}'.format(args.width_factor))
    print('*'*50)        
    print('Attack Norm {}'.format(args.norm))  
    print('Test Epsilon {}'.format(args.test_eps))
    print('Test Steps {}'.format(args.test_step))
    print('Train Steps Size {}'.format(args.test_gamma))
    print('Test Randinit {}'.format(args.test_randinit))
    if args.eval:
        print('Evaluation')
        print('Loading weight {}'.format(args.pretrained))
    else:
        print('Training')
        print('Train Epsilon {}'.format(args.train_eps))
        print('Train Steps {}'.format(args.train_step))
        print('Train Steps Size {}'.format(args.train_gamma))
        print('Train Randinit {}'.format(args.train_randinit))
        print('SWA={0}, start point={1}, swa_c={2}'.format(args.swa, args.swa_start, args.swa_c_epochs))
        print('LWF={0}, coef_ce={1}, coef_kd1={2}, coef_kd2={3}, start={4}, end={5}'.format(
            args.lwf, args.coef_ce, args.coef_kd1, args.coef_kd2, args.lwf_start, args.lwf_end
        ))

# prepare dataset and models
def setup_dataset_models(args):

    # prepare dataset
    if args.dataset == 'cifar10':
        classes = 10
        dataset_normalization = NormalizeByChannelMeanStd(
            mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])
        train_loader, val_loader, test_loader, data_sample, mask_loader = cifar10_dataloaders(batch_size = args.batch_size, data_dir = args.data, consistency = args.consistency, robust_friendly=args.robust_friendly)
    
    elif args.dataset == 'cifar100':
        classes = 100
        dataset_normalization = NormalizeByChannelMeanStd(
            mean=[0.5071, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2762])
        train_loader, val_loader, test_loader = cifar100_dataloaders(batch_size = args.batch_size, data_dir = args.data)
    
    elif args.dataset == 'tinyimagenet':
        classes = 200
        dataset_normalization = NormalizeByChannelMeanStd(
            mean=[0.4802, 0.4481, 0.3975], std=[0.2302, 0.2265, 0.2262])
        train_loader, val_loader, test_loader = tiny_imagenet_dataloaders(batch_size = args.batch_size, data_dir = args.data)
    
    else:
        raise ValueError("Unknown Dataset")

    #prepare model

    if args.arch == 'resnet18':
        if args.small_dense:
            model = getSmallDenseResNet18(args.small_dense_rate, classes)
        else:
            model = ResNet18(num_classes = classes)
        model.normalize = dataset_normalization

        if args.swa:
            swa_model = ResNet18(num_classes = classes)
            swa_model.normalize = dataset_normalization
        else:
            swa_model = None

        if args.lwf:
            teacher1 = ResNet18(num_classes = classes)
            teacher1.normalize = dataset_normalization
            teacher2 = ResNet18(num_classes = classes)
            teacher2.normalize = dataset_normalization
        else:
            teacher1 = None
            teacher2 = None 

    elif args.arch == 'wideresnet':
        model = WideResNet(args.depth_factor, classes, widen_factor=args.width_factor, dropRate=0.0)
        model.normalize = dataset_normalization

        if args.swa:
            swa_model = WideResNet(args.depth_factor, classes, widen_factor=args.width_factor, dropRate=0.0)
            swa_model.normalize = dataset_normalization
        else:
            swa_model = None

        if args.lwf:
            teacher1 = WideResNet(args.depth_factor, classes, widen_factor=args.width_factor, dropRate=0.0)
            teacher1.normalize = dataset_normalization
            teacher2 = WideResNet(args.depth_factor, classes, widen_factor=args.width_factor, dropRate=0.0)
            teacher2.normalize = dataset_normalization
        else:
            teacher1 = None
            teacher2 = None 

    elif args.arch == 'vgg16':
        model = vgg16_bn(num_classes = classes)


        model.normalize = dataset_normalization

        if args.swa:
            swa_model = vgg16_bn(num_classes = classes)
            swa_model.normalize = dataset_normalization
        else:
            swa_model = None

        if args.lwf:
            teacher1 = vgg16_bn(num_classes = classes)
            teacher1.normalize = dataset_normalization
            teacher2 = vgg16_bn(num_classes = classes)
            teacher2.normalize = dataset_normalization
        else:
            teacher1 = None
            teacher2 = None 

    else:
        raise ValueError("Unknown Model")

    if args.static_sparse and args.sparse_type == 'gmp':
        gmp_prune_conv_linear(model)
        prune_rate = 1 - args.density
        print(f"==> Setting prune rate of network to {prune_rate}")

        def _sparsity(m):
            if hasattr(m, "set_prune_rate"):
                m.set_prune_rate(prune_rate)

        model.apply(_sparsity)
    
    return train_loader, val_loader, test_loader, model, swa_model, teacher1, teacher2

def setup_dataset_models_standard(args):

    # prepare dataset
    if args.dataset == 'cifar10':
        classes = 10
        dataset_normalization = NormalizeByChannelMeanStd(
            mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])
        train_loader, val_loader, test_loader, data_sample, mask_loader = cifar10_dataloaders(batch_size = args.batch_size, data_dir = args.data)
        # train_loader, val_loader, test_loader, data_sample = cifar10_dataloaders(batch_size = args.batch_size, data_dir = args.data)
    
    # elif args.dataset == 'cifar100':
    #     classes = 100
    #     dataset_normalization = NormalizeByChannelMeanStd(
    #         mean=[0.5071, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2762])
    #     train_loader, val_loader, test_loader = cifar100_dataloaders(batch_size = args.batch_size, data_dir = args.data)
    
    # elif args.dataset == 'tinyimagenet':
    #     classes = 200
    #     dataset_normalization = NormalizeByChannelMeanStd(
    #         mean=[0.4802, 0.4481, 0.3975], std=[0.2302, 0.2265, 0.2262])
    #     train_loader, val_loader, test_loader = tiny_imagenet_dataloaders(batch_size = args.batch_size, data_dir = args.data)
    
    else:
        raise ValueError("Unknown Dataset")

    #prepare model

    if args.arch == 'resnet18':
        model = ResNet18(num_classes = classes)
        model.normalize = dataset_normalization

    elif args.arch == 'wideresnet':
        model = WideResNet(args.depth_factor, classes, widen_factor=args.width_factor, dropRate=0.0)
        model.normalize = dataset_normalization

    elif args.arch == 'vgg16':
        model = vgg16_bn(num_classes = 10)
        model.normalize = dataset_normalization

    else:
        raise ValueError("Unknown Model")   
    
    return train_loader, val_loader, test_loader, model, data_sample



def setup_dataset_models_standardzzy(args):

    # prepare dataset
    if args.dataset == 'cifar10':
        classes = 10
        dataset_normalization = NormalizeByChannelMeanStd(
            mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])
        train_loader, val_loader, test_loader, data_sample, mask_loader = cifar10_dataloaders(batch_size = args.batch_size, data_dir = args.data)

    elif args.dataset == 'cifar100':
        classes = 100
        dataset_normalization = NormalizeByChannelMeanStd(
            mean=[0.5071, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2762])
        train_loader, val_loader, test_loader = cifar100_dataloaders(batch_size = args.batch_size, data_dir = args.data)
    
    elif args.dataset == 'tinyimagenet':
        classes = 200
        dataset_normalization = NormalizeByChannelMeanStd(
            mean=[0.4802, 0.4481, 0.3975], std=[0.2302, 0.2265, 0.2262])
        train_loader, val_loader, test_loader = tiny_imagenet_dataloaders(batch_size = args.batch_size, data_dir = args.data)
    
    else:
        raise ValueError("Unknown Dataset")

    #prepare model

    if args.arch == 'resnet18':
        model = ResNet18(num_classes = classes)
        model.normalize = dataset_normalization

    elif args.arch == 'wideresnet':
        model = WideResNet(args.depth_factor, classes, widen_factor=args.width_factor, dropRate=0.0)
        model.normalize = dataset_normalization

    elif args.arch == 'vgg16':
        model = vgg16_bn(num_classes = classes)
        model.normalize = dataset_normalization

    else:
        raise ValueError("Unknown Model")   
    
    return train_loader, val_loader, test_loader, model




class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

def setup_seed(seed): 
    torch.manual_seed(seed) 
    torch.cuda.manual_seed_all(seed) 
    np.random.seed(seed) 
    random.seed(seed) 
    torch.backends.cudnn.deterministic = True 

# knowledge distillation loss function
def loss_fn_kd(scores, target_scores, T=2.):
    """Compute knowledge-distillation (KD) loss given [scores] and [target_scores].

    Both [scores] and [target_scores] should be tensors, although [target_scores] should be repackaged.
    'Hyperparameter': temperature"""

    device = scores.device

    log_scores_norm = F.log_softmax(scores / T, dim=1)
    targets_norm = F.softmax(target_scores / T, dim=1)

    # if [scores] and [target_scores] do not have equal size, append 0's to [targets_norm]
    if not scores.size(1) == target_scores.size(1):
        print('size does not match')

    n = scores.size(1)
    if n>target_scores.size(1):
        n_batch = scores.size(0)
        zeros_to_add = torch.zeros(n_batch, n-target_scores.size(1))
        zeros_to_add = zeros_to_add.to(device)
        targets_norm = torch.cat([targets_norm.detach(), zeros_to_add], dim=1)

    # Calculate distillation loss (see e.g., Li and Hoiem, 2017)
    KD_loss_unnorm = -(targets_norm * log_scores_norm)
    KD_loss_unnorm = KD_loss_unnorm.sum(dim=1)                      #--> sum over classes
    KD_loss_unnorm = KD_loss_unnorm.mean()                          #--> average over batch

    # normalize
    KD_loss = KD_loss_unnorm * T**2

    return KD_loss

def moving_average(net1, net2, alpha=1):
    for param1, param2 in zip(net1.parameters(), net2.parameters()):
        param1.data *= (1.0 - alpha)
        param1.data += param2.data * alpha

def _check_bn(module, flag):
    if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
        flag[0] = True

def check_bn(model):
    flag = [False]
    model.apply(lambda module: _check_bn(module, flag))
    return flag[0]

def reset_bn(module):
    if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
        module.running_mean = torch.zeros_like(module.running_mean)
        module.running_var = torch.ones_like(module.running_var)

def _get_momenta(module, momenta):
    if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
        momenta[module] = module.momentum

def _set_momenta(module, momenta):
    if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
        module.momentum = momenta[module]

def bn_update(loader, model):
    """
        BatchNorm buffers update (if any).
        Performs 1 epochs to estimate buffers average using train dataset.

        :param loader: train dataset loader for buffers average estimation.
        :param model: model being update
        :return: None
    """
    if not check_bn(model):
        return
    model.train()
    momenta = {}
    model.apply(reset_bn)
    model.apply(lambda module: _get_momenta(module, momenta))
    n = 0
    for input, _ in loader:
        input = input.cuda()
        b = input.data.size(0)

        momentum = b / (n + b)
        for module in momenta.keys():
            module.momentum = momentum

        model(input)
        n += b

    model.apply(lambda module: _set_momenta(module, momenta))


# training 
def train_epoch(train_loader, model, criterion, optimizer, epoch, args):
    
    losses = AverageMeter()
    top1 = AverageMeter()

    model.train()
    start = time.time()
    for i, (input, target) in enumerate(train_loader):

        input = input.cuda()
        target = target.cuda()

        # compute output
        output_clean = model(input)
        loss = criterion(output_clean, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        output = output_clean.float()
        loss = loss.float()
        # measure accuracy and record loss
        prec1 = accuracy(output.data, target)[0]

        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

        if i % args.print_freq == 0:
            end = time.time()
            print('Epoch: [{0}][{1}/{2}]\t'
                'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                'Accuracy {top1.val:.3f} ({top1.avg:.3f})\t'
                'Time {3:.2f}'.format(
                    epoch, i, len(train_loader), end-start, loss=losses, top1=top1))
            start = time.time()

    print('train_accuracy {top1.avg:.3f}'.format(top1=top1))

    return top1.avg

ite_step = 0

def get_ite_step():
    global ite_step
    return ite_step

def set_ite_step(step):
    global ite_step
    print_and_log("set ite_step: {}".format(step))
    ite_step = step

def train_epoch_adv(train_loader, model, criterion, optimizer, epoch, args, mask):
    
    losses = AverageMeter()
    top1 = AverageMeter()
    gradient_norm = AverageMeter()
    rl_ratio_list = []

    if args.norm == 'linf':
        adversary = LinfPGDAttack(
            model, loss_fn=criterion, eps=args.train_eps, nb_iter=args.train_step, eps_iter=args.train_gamma,
            rand_init=args.train_randinit, clip_min=0.0, clip_max=1.0, targeted=False
        )
    elif args.norm == 'l2':
        adversary = L2PGDAttack(
            model, loss_fn=criterion, eps=args.train_eps, nb_iter=args.train_step, eps_iter=args.train_gamma,
            rand_init=args.train_randinit, clip_min=0.0, clip_max=1.0, targeted=False
        )

    model.train()
    start = time.time()
    for i, (input, target) in enumerate(train_loader):

        input = input.cuda()
        target = target.cuda()

        #adv samples
        with ctx_noparamgrad(model):
            input_adv = adversary.perturb(input, target)

        # compute output
        output_adv = model(input_adv)
        loss = criterion(output_adv, target)

        global ite_step
        # if ite_step < 1000:
        #     #clean loss
        #     output = model(input)
        #     clean_loss = criterion(output, target)
        #     r_loss = loss - clean_loss

        #     optimizer.zero_grad()
        #     r_loss.backward(retain_graph=True)
        #     r_g_norm = get_grandient_norm(model.parameters()) 

        #     optimizer.zero_grad()
        #     clean_loss.backward(retain_graph=True)    
        #     c_g_norm = get_grandient_norm(model.parameters())

        #     rl_ratio_list.append(r_g_norm/float(c_g_norm))



        optimizer.zero_grad()
        loss.backward()

        # g_norm = get_grandient_norm(model.parameters())
        # gradient_norm.update(g_norm)

        if mask is not None: mask.step()
        else: optimizer.step()

        if args.static_sparse:
            if args.sparse_type != 'gmp':
                apply_static_sparse(model, optimizer, args.sparse_type, args.density, args.seed, args.mask_dir)

        output = output_adv.float()
        loss = loss.float()
        # measure accuracy and record loss
        prec1 = accuracy(output.data, target)[0]

        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

        if i % args.print_freq == 0:
            end = time.time()
            print_and_log('Epoch: [{0}][{1}/{2}]\t'
                'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                'Accuracy {top1.val:.3f} ({top1.avg:.3f})\t'
                'Time {3:.2f}'.format(
                    epoch, i, len(train_loader), end-start, loss=losses, top1=top1))
            start = time.time()
        
        
        # update sparse topology
        # global ite_step
        update_frequency = args.update_frequency
        if args.dynamic_fre and epoch > 100:
            update_frequency = args.second_frequency

        ite_step += 1
        if (args.sparse or args.dynamic_sparse) and ite_step % update_frequency == 0 and not args.fix:
            mask.at_end_of_epoch()

    print_and_log('train_accuracy {top1.avg:.3f}'.format(top1=top1))

    return top1.avg, gradient_norm.avg, rl_ratio_list

def _kl_div(logit1, logit2):
    return F.kl_div(F.log_softmax(logit1, dim=1), F.softmax(logit2, dim=1), reduction='batchmean')


def _jensen_shannon_div(logit1, logit2, T=1.):
    prob1 = F.softmax(logit1/T, dim=1)
    prob2 = F.softmax(logit2/T, dim=1)
    mean_prob = 0.5 * (prob1 + prob2)

    logsoftmax = torch.log(mean_prob.clamp(min=1e-8))
    jsd = F.kl_div(logsoftmax, prob1, reduction='batchmean')
    jsd += F.kl_div(logsoftmax, prob2, reduction='batchmean')
    return jsd * 0.5

def train_epoch_adv_consistency(train_loader, model, criterion, optimizer, epoch, args, mask):
    
    top1 = AverageMeter()
    gradient_norm = AverageMeter()
    rl_ratio_list = []

    losses = dict()
    losses['cls'] = AverageMeter()
    losses['con'] = AverageMeter()

    if args.norm == 'linf':
        adversary = LinfPGDAttack(
            model, loss_fn=criterion, eps=args.train_eps, nb_iter=args.train_step, eps_iter=args.train_gamma,
            rand_init=args.train_randinit, clip_min=0.0, clip_max=1.0, targeted=False
        )
    elif args.norm == 'l2':
        adversary = L2PGDAttack(
            model, loss_fn=criterion, eps=args.train_eps, nb_iter=args.train_step, eps_iter=args.train_gamma,
            rand_init=args.train_randinit, clip_min=0.0, clip_max=1.0, targeted=False
        )  

    model.train()
    start = time.time()
    lam = 1.0
    T = 0.5
    for i, (input, target) in enumerate(train_loader):

        batch_size = input[0].size(0)
        target = target.cuda()

        input_aug1, input_aug2 = input[0].cuda(), input[1].cuda()
        input_pair = torch.cat([input_aug1, input_aug2], dim=0)  # 2B

        #adv samples
        with ctx_noparamgrad(model):
            input_adv = adversary.perturb(input_pair, target.repeat(2))

        # compute output
        output_adv = model(input_adv)
        loss_ce = criterion(output_adv, target.repeat(2))

        ### consistency regularization ###
        outputs_adv1, outputs_adv2 = output_adv.chunk(2)
        loss_con = lam * _jensen_shannon_div(outputs_adv1, outputs_adv2, T)

        ### total loss ###
        loss = loss_ce + loss_con
        global ite_step
        optimizer.zero_grad()
        loss.backward()

        # g_norm = get_grandient_norm(model.parameters())
        # gradient_norm.update(g_norm)

        if mask is not None: mask.step()
        else: optimizer.step()

        if args.static_sparse:
            if args.sparse_type != 'gmp':
                apply_static_sparse(model, optimizer, args.sparse_type, args.density, args.seed, args.mask_dir)

        ### Log losses ###
        losses['cls'].update(loss_ce.item(), batch_size)
        losses['con'].update(loss_con.item(), batch_size)

        output = output_adv.float()
        # measure accuracy and record loss
        prec1 = accuracy(output.data, target.repeat(2))[0]
        top1.update(prec1.item(), batch_size * 2)

        if i % args.print_freq == 0:
            end = time.time()
            print_and_log('Epoch: [{0}][{1}/{2}]\t'
                'Loss_cls {loss1.val:.4f} ({loss1.avg:.4f})\t'
                'Loss_con {loss2.val:.4f} ({loss2.avg:.4f})\t'
                'Accuracy {top1.val:.3f} ({top1.avg:.3f})\t'
                'Time {3:.2f}'.format(
                    epoch, i, len(train_loader), end-start, loss1=losses['cls'], loss2=losses['con'], top1=top1))
            start = time.time()
        
        
        # update sparse topology
        # global ite_step
        update_frequency = args.update_frequency
        if args.dynamic_fre and epoch > 100:
            update_frequency = args.second_frequency

        ite_step += 1
        if (args.sparse or args.dynamic_sparse) and ite_step % update_frequency == 0 and not args.fix:
            mask.at_end_of_epoch()

    print_and_log('train_accuracy {top1.avg:.3f}'.format(top1=top1))

    return top1.avg, gradient_norm.avg, rl_ratio_list

def kl_loss(a,b):
    loss = -a*b + torch.log(b+1e-5)*b
    return loss

def rslad_inner_loss(model,
                teacher_logits,
                x_natural,
                y,
                optimizer,
                step_size=0.003,
                epsilon=0.031,
                perturb_steps=10,
                beta=6.0):
    # define KL-loss
    criterion_kl = nn.KLDivLoss(size_average=False,reduce=False)
    model.eval()
    batch_size = len(x_natural)
    # generate adversarial example
    x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).cuda().detach()

    for _ in range(perturb_steps):
        x_adv.requires_grad_()
        with torch.enable_grad():
            loss_kl = criterion_kl(F.log_softmax(model(x_adv), dim=1),
                                       F.softmax(teacher_logits, dim=1))
            loss_kl = torch.sum(loss_kl)
        grad = torch.autograd.grad(loss_kl, [x_adv])[0]
        x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
        x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon)
        x_adv = torch.clamp(x_adv, 0.0, 1.0)

    model.train()

    x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False)
    # zero gradient
    optimizer.zero_grad()
    logits = model(x_adv)
    return logits

def train_epoch_adv_RSLAD(train_loader, model, criterion, optimizer, epoch, args, mask):
    
    teacher = ResNet18(num_classes = 10)
    checkpoint = torch.load(os.path.join(args.rslad_teacher), map_location = torch.device('cuda:0'))
    teacher.load_state_dict(checkpoint['state_dict'])
    teacher = teacher.cuda()
    teacher.eval()

    losses = AverageMeter()
    top1 = AverageMeter()
    gradient_norm = AverageMeter()
    rl_ratio_list = []

    model.train()
    start = time.time()
    for i, (input, target) in enumerate(train_loader):

        input = input.cuda()
        target = target.cuda()

        optimizer.zero_grad()
        with torch.no_grad():
            teacher_logits = teacher(input)

        adv_logits = rslad_inner_loss(model, teacher_logits, input, target, optimizer, step_size=args.train_gamma, epsilon=args.train_eps, perturb_steps=args.train_step)
        
        nat_logits = model(input)
        kl_Loss1 = kl_loss(F.log_softmax(adv_logits,dim=1),F.softmax(teacher_logits.detach(),dim=1))
        kl_Loss2 = kl_loss(F.log_softmax(nat_logits,dim=1),F.softmax(teacher_logits.detach(),dim=1))
        kl_Loss1 = torch.mean(kl_Loss1)
        kl_Loss2 = torch.mean(kl_Loss2)
        loss = 5/6.0*kl_Loss1 + 1/6.0*kl_Loss2
        loss.backward()

        global ite_step

        if mask is not None: mask.step()
        else: optimizer.step()

        # if args.static_sparse:
        #     if args.sparse_type != 'gmp':
        #         apply_static_sparse(model, optimizer, args.sparse_type, args.density, args.seed)

        output = adv_logits.float()
        loss = loss.float()
        # measure accuracy and record loss
        prec1 = accuracy(output.data, target)[0]

        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

        if i % args.print_freq == 0:
            end = time.time()
            print_and_log('Epoch: [{0}][{1}/{2}]\t'
                'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                'Accuracy {top1.val:.3f} ({top1.avg:.3f})\t'
                'Time {3:.2f}'.format(
                    epoch, i, len(train_loader), end-start, loss=losses, top1=top1))
            start = time.time()
        
        
        # update sparse topology
        # global ite_step
        update_frequency = args.update_frequency
        if args.dynamic_fre and epoch > 100:
            update_frequency = args.second_frequency

        ite_step += 1
        if (args.sparse or args.dynamic_sparse) and ite_step % update_frequency == 0 and not args.fix:
            mask.at_end_of_epoch()

    print_and_log('train_accuracy {top1.avg:.3f}'.format(top1=top1))

    return top1.avg, gradient_norm.avg, rl_ratio_list

def train_epoch_adv_dual_teacher(train_loader, model, teacher1, teacher2, criterion, optimizer, epoch, args, mask):
    
    losses = AverageMeter()
    top1 = AverageMeter()

    if args.norm == 'linf':
        adversary = LinfPGDAttack(
            model, loss_fn=criterion, eps=args.train_eps, nb_iter=args.train_step, eps_iter=args.train_gamma,
            rand_init=args.train_randinit, clip_min=0.0, clip_max=1.0, targeted=False
        )
    elif args.norm == 'l2':
        adversary = L2PGDAttack(
            model, loss_fn=criterion, eps=args.train_eps, nb_iter=args.train_step, eps_iter=args.train_gamma,
            rand_init=args.train_randinit, clip_min=0.0, clip_max=1.0, targeted=False
        )  

    model.train()
    teacher1.eval()
    teacher2.eval()
    start = time.time()
    for i, (input, target) in enumerate(train_loader):

        input = input.cuda()
        target = target.cuda()

        #adv samples
        with ctx_noparamgrad(model):
            input_adv = adversary.perturb(input, target)

        # compute output
        output_adv = model(input_adv)

        with torch.no_grad():
            target_score1 = teacher1(input_adv)
            target_score2 = teacher2(input_adv)

        loss_KD = loss_fn_kd(output_adv, target_score1, T=args.temperature)
        loss_KD2 = loss_fn_kd(output_adv, target_score2, T=args.temperature)

        loss = criterion(output_adv, target)*args.coef_ce + loss_KD*args.coef_kd1 + loss_KD2*args.coef_kd2

        optimizer.zero_grad()
        loss.backward()
        global ite_step

        if mask is not None: mask.step()
        else: optimizer.step()

        output = output_adv.float()
        loss = loss.float()
        # measure accuracy and record loss
        prec1 = accuracy(output.data, target)[0]

        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

        if i % args.print_freq == 0:
            end = time.time()
            print_and_log('Epoch: [{0}][{1}/{2}]\t'
                'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                'Accuracy {top1.val:.3f} ({top1.avg:.3f})\t'
                'Time {3:.2f}'.format(
                    epoch, i, len(train_loader), end-start, loss=losses, top1=top1))
            start = time.time()

        # update sparse topology
        # global ite_step
        update_frequency = args.update_frequency
        if args.dynamic_fre and epoch > 100:
            update_frequency = args.second_frequency

        ite_step += 1
        if (args.sparse or args.dynamic_sparse) and ite_step % update_frequency == 0 and not args.fix:
            mask.at_end_of_epoch()

    print_and_log('train_accuracy {top1.avg:.3f}'.format(top1=top1))

    return top1.avg

#testing
def test(val_loader, model, criterion, args):
    """
    Run evaluation
    """
    losses = AverageMeter()
    top1 = AverageMeter()

    model.eval()
    start = time.time()
    for i, (input, target) in enumerate(val_loader):

        input = input.cuda()
        target = target.cuda()
    
        # compute output
        with torch.no_grad():
            output = model(input)
            loss = criterion(output, target)

        output = output.float()
        loss = loss.float()

        # measure accuracy and record loss
        prec1 = accuracy(output.data, target)[0]
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

        if i % args.print_freq == 0:
            end = time.time()
            print('Test: [{0}/{1}]\t'
                'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                'Accuracy {top1.val:.3f} ({top1.avg:.3f})\t'
                'Time {2:.2f}'.format(
                    i, len(val_loader), end-start, loss=losses, top1=top1))
            start = time.time()

    print('Standard Accuracy {top1.avg:.3f}'.format(top1=top1))

    return top1.avg

def test_adv(val_loader, model, criterion, args):
    """
    Run adversarial evaluation
    """
    losses = AverageMeter()
    top1 = AverageMeter()

    if args.norm == 'linf':
        adversary = LinfPGDAttack(
            model, loss_fn=criterion, eps=args.test_eps, nb_iter=args.test_step, eps_iter=args.test_gamma,
            rand_init=args.test_randinit, clip_min=0.0, clip_max=1.0, targeted=False
        )
    elif args.norm == 'l2':
        adversary = L2PGDAttack(
            model, loss_fn=criterion, eps=args.test_eps, nb_iter=args.test_step, eps_iter=args.test_gamma,
            rand_init=args.test_randinit, clip_min=0.0, clip_max=1.0, targeted=False
        )  

    model.eval()
    start = time.time()
    for i, (input, target) in enumerate(val_loader):

        input = input.cuda()
        target = target.cuda()

        #adv samples
        input_adv = adversary.perturb(input, target)
        # compute output
        with torch.no_grad():
            output = model(input_adv)
            loss = criterion(output, target)

        output = output.float()
        loss = loss.float()

        # measure accuracy and record loss
        prec1 = accuracy(output.data, target)[0]
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

        if i % args.print_freq == 0:
            end = time.time()
            print_and_log('Test: [{0}/{1}]\t'
                'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                'Accuracy {top1.val:.3f} ({top1.avg:.3f})\t'
                'Time {2:.2f}'.format(
                    i, len(val_loader), end-start, loss=losses, top1=top1))
            start = time.time()

    print_and_log('Robust Accuracy {top1.avg:.3f}'.format(top1=top1))

    return top1.avg, losses.avg



def generate_adv(val_loader, model, criterion, args):
    """
    Run adversarial evaluation
    """
    losses = AverageMeter()
    top1 = AverageMeter()

    if args.norm == 'linf':
        adversary = LinfPGDAttack(
            model, loss_fn=criterion, eps=args.test_eps, nb_iter=args.test_step, eps_iter=args.test_gamma,
            rand_init=args.test_randinit, clip_min=0.0, clip_max=1.0, targeted=False
        )
    elif args.norm == 'l2':
        adversary = L2PGDAttack(
            model, loss_fn=criterion, eps=args.test_eps, nb_iter=args.test_step, eps_iter=args.test_gamma,
            rand_init=args.test_randinit, clip_min=0.0, clip_max=1.0, targeted=False
        )  

    model.eval()
    start = time.time()

    all_adv_image = []
    all_target = []

    for i, (input, target) in enumerate(val_loader):

        input = input.cuda()
        target = target.cuda()

        #adv samples
        input_adv = adversary.perturb(input, target)
        # compute output
        with torch.no_grad():
            output = model(input_adv)
            loss = criterion(output, target)

        all_adv_image.append(input_adv.cpu().detach())
        all_target.append(target.cpu())

        output = output.float()
        loss = loss.float()

        # measure accuracy and record loss
        prec1 = accuracy(output.data, target)[0]
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

        if i % args.print_freq == 0:
            end = time.time()
            print_and_log('Test: [{0}/{1}]\t'
                'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                'Accuracy {top1.val:.3f} ({top1.avg:.3f})\t'
                'Time {2:.2f}'.format(
                    i, len(val_loader), end-start, loss=losses, top1=top1))
            start = time.time()


    all_adv_image = torch.cat(all_adv_image, dim=0)
    all_target = torch.cat(all_target, dim=0)
    print('Image shape = {}, Target shape = {}'.format(all_adv_image.shape, all_target.shape))

    print_and_log('Robust Accuracy {top1.avg:.3f}'.format(top1=top1))

    return top1.avg, losses.avg, all_adv_image, all_target

def get_save_path(args):
    dir = ""
    if(args.sparse):
        dir_format = 'dst_{args.arch}_{args.dataset}_d{args.density}_{args.growth}_T{args.update_frequency}_b{args.batch_size}_r{args.death_rate}_{flag}'
    elif(args.static_sparse):
        dir_format = 'static_sparse_{args.arch}_{args.dataset}_{args.sparse_type}_d{args.density}_b{args.batch_size}_seed{args.seed}{flag}'
    elif(args.dynamic_sparse):
        dir_format = 'flying_{args.arch}_{args.dataset}_{args.sparse_init}_T{args.update_frequency}_d{args.density}_dr{args.death_rate}_{args.growth}_p{args.prune_rate}_g{args.growth_rate}_b{args.batch_size}_e{args.epoch_range}_r{args.ratio_threshold}_seed{args.seed}{flag}'
    elif(args.small_dense):
        dir_format = '{args.arch}_{args.dataset}_smalldense{args.small_dense_rate}_b{args.batch_size}_{flag}'
    else:
        dir_format = 'dense_{args.arch}_{args.dataset}_b{args.batch_size}_{flag}'

    #combine test
    if args.consistency:
        dir_format = 'consistency_' + dir_format
    
    if args.lwf:
        dir_format = 'lwf_' + dir_format
        
    # rslad
    if args.rslad:
        dir_format = 'rslad_' + dir_format

    if args.robust_friendly:
        dir_format = 'robust_friendly_' + dir_format


    dir = dir_format.format(args = args, flag = hashlib.md5(str(args).encode('utf-8')).hexdigest()[:4])
    save_path = os.path.join(args.save_dir, dir)
    return save_path

def apply_static_sparse(model, optimizer, sparse_type, density, seed, mask_dir = None):
    if mask_dir:
        mask_path = mask_dir
    else:
        masks_name = '{0}-{1}-seed{2}-mask.pt'.format(sparse_type, density, seed)
        mask_path = os.path.join('./masks_seed', masks_name)
    mask = torch.load(mask_path, map_location='cuda:0')
    apply_mask_to_model(model, optimizer, mask)

def get_mask(sparse_type, density, seed, mask_dir = None):
    if mask_dir:
        mask_path = mask_dir
    else:
        masks_name = '{0}-{1}-seed{2}-mask.pt'.format(sparse_type, density, seed)
        mask_path = os.path.join('./masks_seed', masks_name)
    print(mask_path)        
    mask = torch.load(mask_path)
    return mask

def get_grandient_norm(parameters):
    total_norm = 0
    for p in parameters: 
        param_norm = p.grad.detach().data.norm(2)
        total_norm += param_norm.item() ** 2 

    total_norm = total_norm ** (1. / 2)   
    
    return total_norm

class GradCosineSimilarity(object):
    """Computes and stores the gradient consine similarity between preValue and current value"""
    def __init__(self):
        self.res_list = []
        self.pregrads = None

    def update(self, params):
        grad_list = [p.grad.detach().clone() for p in params]

        if self.pregrads != None:
            n = 0
            res = 0
            for g1, g2 in zip(self.pregrads, grad_list): 
                res = res + F.cosine_similarity(g1, g2)
                n = n + 1
            
            self.res_list.append(res / float(n))

        self.pregrads = grad_list
            

def input_a_sample(model, criterion, optimizer, args, data_sample):

    if args.norm == 'linf':
        adversary = LinfPGDAttack(
            model, loss_fn=criterion, eps=args.train_eps, nb_iter=args.train_step, eps_iter=args.train_gamma,
            rand_init=args.train_randinit, clip_min=0.0, clip_max=1.0, targeted=False
        )
    elif args.norm == 'l2':
        adversary = L2PGDAttack(
            model, loss_fn=criterion, eps=args.train_eps, nb_iter=args.train_step, eps_iter=args.train_gamma,
            rand_init=args.train_randinit, clip_min=0.0, clip_max=1.0, targeted=False
        )  

    model.eval()
    input, target = data_sample

    input = input.unsqueeze(dim = 0)
    target = torch.Tensor([target]).long()

    input = input.cuda()
    target = target.cuda()

    #adv samples
    with ctx_noparamgrad(model):
        input_adv = adversary.perturb(input, target)
    # compute output
    output_adv = model(input_adv)
    loss = criterion(output_adv, target)

    optimizer.zero_grad()
    loss.backward()

def add_default_setting(parser):
    ########################## data setting ##########################
    parser.add_argument('--data', type=str, default='data/cifar10', help='location of the data corpus', required=True)
    parser.add_argument('--dataset', type=str, default='cifar10', help='dataset [cifar10, cifar100, tinyimagenet]', required=True)

    ########################## model setting ##########################
    parser.add_argument('--arch', type=str, default='resnet18', help='model architecture [resnet18, wideresnet, vgg16]', required=True)
    parser.add_argument('--depth_factor', default=34, type=int, help='depth-factor of wideresnet')
    parser.add_argument('--width_factor', default=10, type=int, help='width-factor of wideresnet')

    ########################## basic setting ##########################
    parser.add_argument('--seed', default=1, type=int, help='random seed')
    parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
    parser.add_argument('--resume', action="store_true", help="resume from checkpoint")
    parser.add_argument('--resume_dir', help='The directory resume the trained models', default=None, type=str)
    parser.add_argument('--pretrained', default=None, type=str, help='pretrained model')
    parser.add_argument('--eval', action="store_true", help="evaluation pretrained model")
    parser.add_argument('--print_freq', default=50, type=int, help='logging frequency during training')
    parser.add_argument('--save_dir', help='The parent directory used to save the trained models', default=None, type=str)

    ########################## training setting ##########################
    parser.add_argument('--batch_size', type=int, default=128, help='batch size')
    parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate')
    parser.add_argument('--decreasing_lr', default='100,150', help='decreasing strategy')
    parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
    parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')
    parser.add_argument('--epochs', default=200, type=int, help='number of total epochs to run')

    ########################## attack setting ##########################
    parser.add_argument('--norm', default='linf', type=str, help='linf or l2')
    parser.add_argument('--train_eps', default=8, type=float, help='epsilon of attack during training')
    parser.add_argument('--train_step', default=10, type=int, help='itertion number of attack during training')
    parser.add_argument('--train_gamma', default=2, type=float, help='step size of attack during training')
    parser.add_argument('--train_randinit', action='store_false', help='randinit usage flag (default: on)')
    parser.add_argument('--test_eps', default=8, type=float, help='epsilon of attack during testing')
    parser.add_argument('--test_step', default=20, type=int, help='itertion number of attack during testing')
    parser.add_argument('--test_gamma', default=2, type=float, help='step size of attack during testing')
    parser.add_argument('--test_randinit', action='store_false', help='randinit usage flag (default: on)')

    ########################## SWA setting ##########################
    parser.add_argument('--swa', action='store_true', help='swa usage flag (default: off)')
    parser.add_argument('--swa_start', type=float, default=55, metavar='N', help='SWA start epoch number (default: 55)')
    parser.add_argument('--swa_c_epochs', type=int, default=1, metavar='N', help='SWA model collection frequency/cycle length in epochs (default: 1)')

    ########################## KD setting ##########################
    parser.add_argument('--lwf', action='store_true', help='lwf usage flag (default: off)')
    parser.add_argument('--t_weight1', type=str, default=None, required=False, help='pretrained weight for teacher1')
    parser.add_argument('--t_weight2', type=str, default=None, required=False, help='pretrained weight for teacher2')
    parser.add_argument('--coef_ce', type=float, default=0.3, help='coef for CE')
    parser.add_argument('--coef_kd1', type=float, default=0.1, help='coef for KD1')
    parser.add_argument('--coef_kd2', type=float, default=0.6, help='coef for KD2')
    parser.add_argument('--temperature', type=float, default=2.0, help='temperature of knowledge distillation loss')
    parser.add_argument('--lwf_start', type=int, default=0, metavar='N', help='start point of lwf (default: 200)')
    parser.add_argument('--lwf_end', type=int, default=200, metavar='N', help='end point of lwf (default: 200)')

    ########################## sparse setting ##########################
    # parser.add_argument('--no_exploration', action='store_true', default=False, help='if ture, only do explore for the typical training time')
    # parser.add_argument('--multiplier', type=int, default=1, metavar='N', help='extend training time by multiplier times')
    parser.add_argument('--decay-schedule', type=str, default='cosine', help='The decay schedule for the pruning rate. Default: cosine. Choose from: cosine, linear.')
    parser.add_argument('--update_frequency', type=int, default=100, metavar='N', help='how many iterations to train between mask update')

    parser.add_argument('--growth', type=str, default='random', help='Growth mode. Choose from: momentum, random, and momentum_neuron.')
    parser.add_argument('--death', type=str, default='magnitude', help='Death mode / pruning mode. Choose from: magnitude, SET, threshold, CS_death.')
    parser.add_argument('--redistribution', type=str, default='none', help='Redistribution mode. Choose from: momentum, magnitude, nonzeros, or none.')
    parser.add_argument('--death-rate', type=float, default=0.50, help='The pruning rate / death rate.')
    parser.add_argument('--density', type=float, default=0.05, help='The density of the overall sparse network.')

    # parser.add_argument('--final_density', type=float, default=0.05, help='The density of the overall sparse network.')
    parser.add_argument('--sparse', action='store_true', help='Enable sparse mode. Default: True.')
    parser.add_argument('--snip', action='store_true', help='Enable snip initialization. Default: True.')
    parser.add_argument('--fix', action='store_true', help='Fix topology during training. Default: True.')
    parser.add_argument('--sparse_init', type=str, default='uniform', help='sparse initialization')
    parser.add_argument('--reset', action='store_true', help='Fix topology during training. Default: True.')

    ########################## static sparse setting ##########################
    parser.add_argument('--static_sparse', action='store_true', help='Enable static sparse mode. Default: True.')
    parser.add_argument('--sparse_type', type=str, default='rp', help='static sparse mask initialization. choose from: rp omp gmp tp snip')

    ########################## Dynamic  sparse ##########################
    parser.add_argument('--dynamic_sparse', action='store_true', help='Enable dynamic sparse mode. Default: True.')
    parser.add_argument('--epoch_range', type=int, default=4, help='epoch range to decide sparse action')
    parser.add_argument('--prune_rate', type=float, default=0.4, help='The rate of dst prune ')
    parser.add_argument('--growth_rate', type=float, default=0.05, help='The rate of dst growth ')
    parser.add_argument('--ratio_threshold', type=float, default=0.5, help='The ratio_threshold of dst prune or growth')

    ########################## Small Dense Test #############################
    parser.add_argument('--small_dense', action='store_true', help='Enable small dense mode. Default: True.')
    parser.add_argument('--small_dense_rate', type=float, default=0.8, help='The density of small density, support 0.05 0.1 0.2 0.6 0.8')

    ########################## Dynamic frequency #############################
    parser.add_argument('--dynamic_fre', action='store_true', help='Enable dynamic frequency mode. Default: True.')
    parser.add_argument('--second_frequency', type=int, default=1200, metavar='N', help='how many iterations to train between mask update in second stage')

    ########################## Save epoch #############################
    parser.add_argument('--save_epoch', action='store_true', help='save checkpoint for every epoch. Default: True.')

    ########################## Combine test #############################
    parser.add_argument('--consistency', action='store_true', help='apply consistency regularization')
    parser.add_argument('--rslad', action='store_true', help='apply RSLAD method')
    parser.add_argument('--rslad_teacher', type=str, default=None, required=False, help='rslad teacher path')
    parser.add_argument('--robust_friendly', action='store_true', help='apply robust friendly dataset')
    parser.add_argument('--mask_dir', help='The directory of mask', default=None, type=str)