import datetime
import sys
import math
import os
import time
import warnings
import copy

import presets
import torch
import torch.utils.data
import torchvision
import transforms
import utils
from sampler import RASampler
from torch import nn
from torch.utils.data.dataloader import default_collate
from torchvision.transforms.functional import InterpolationMode
from collections import OrderedDict
from caltech_dataset import Caltech
import torch.nn.utils.prune as prune
import numpy as np
import statistics
import logging

import torch
import torch.utils.data

def compute_pruning_amount(args, it, prev_target_ratio, prev_num_pruned):
    assert (args.prune_end_it - args.prune_start_it) % args.prune_freq_it == 0
    '''''''''
    Pruning
    gmp_90: 21669446
    '''''''''
    if args.prune_end_it == 0:
        return args.prune_ratio if args.prune_num == 0 else args.prune_num
    elif args.prune_num != 0:
        target_num = args.prune_num * (1. - (1. - (it - args.prune_start_it) / (args.prune_end_it - args.prune_start_it))**3)
        return int(target_num)
    else:
        return args.prune_ratio + (0. - args.prune_ratio) * (1 - (it - args.prune_start_it) / (args.prune_end_it - args.prune_start_it))**3

def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None, params_to_prune=None, it=None, prev_target_ratio=0, prev_num_pruned=0, n=None, num_total=0):
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
    metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))

    header = f"Epoch: [{epoch}]"
    for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
        if n==0:
            '''''''''
            pruning
            '''''''''
            if (it >= args.prune_start_it) and (it % args.prune_freq_it == 0) and (it <= args.prune_end_it):

                amount = compute_pruning_amount(args, it, prev_target_ratio, prev_num_pruned)

                prune.global_unstructured(
                    params_to_prune,
                    pruning_method=prune.L1Unstructured,
                    amount=amount,)
                num_pruned = 0
                ''' compute layer-wise sparsity '''
                for name, module in model.named_modules():
                    if 'vit' in args.model:
                        if isinstance(module, torch.nn.MultiheadAttention):
                            p = float(torch.sum(module.in_proj_weight == 0))
                            t = float(module.in_proj_weight.nelement())
                            num_pruned += p
                            logging.info("Sparsity in {}.weight: {:.2f}%".format(name, 100. * p / t))
                            if it < args.prune_end_it:
                                prune.remove(module, 'in_proj_weight')
                        elif isinstance(module, torch.nn.Linear) and 'mlp' in name:
                            p = float(torch.sum(module.weight == 0))
                            t = float(module.weight.nelement())
                            num_pruned += p
                            logging.info("Sparsity in {}.weight: {:.2f}%".format(name, 100. * p / t))
                            if it < args.prune_end_it:
                                prune.remove(module, 'weight')
                    else:
                        if isinstance(module, torch.nn.Conv2d):
                            p = float(torch.sum(module.weight == 0))
                            t = float(module.weight.nelement())
                            num_pruned += p
                            logging.info("Sparsity in {}.weight: {:.2f}%".format(name, 100. * p / t))
                            if it < args.prune_end_it:
                                prune.remove(module, 'weight')
                logging.info("Global pruned parameters: {}".format(num_pruned))
                prev_num_pruned = num_pruned
                prev_target_ratio = num_pruned / num_total
                logging.info("Global sparsity: {}%".format(100. * prev_target_ratio))

        start_time = time.time()
        image, target = image.to(device), target.to(device)
        with torch.cuda.amp.autocast(enabled=scaler is not None):
            output = model(image)
            loss = criterion(output, target)

        optimizer.zero_grad()
        if scaler is not None:
            scaler.scale(loss).backward()
            if args.clip_grad_norm is not None:
                # we should unscale the gradients of optimizer's assigned params if do gradient clipping
                scaler.unscale_(optimizer)
                nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            if args.clip_grad_norm is not None:
                nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
            optimizer.step()

        if model_ema and i % args.model_ema_steps == 0:
            model_ema.update_parameters(model)
            if epoch < args.lr_warmup_epochs:
                # Reset ema buffer to keep copying weights during warmup period
                model_ema.n_averaged.fill_(0)

        acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
        batch_size = image.shape[0]
        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
        metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
        metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
        metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))

        it += 1
    return prev_target_ratio, it, prev_num_pruned


def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""):
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = f"Test: {log_suffix}"

    num_processed_samples = 0
    with torch.inference_mode():
        for image, target in metric_logger.log_every(data_loader, print_freq, header):
            image = image.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = model(image)
            loss = criterion(output, target)

            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
            # FIXME need to take into account that the datasets
            # could have been padded in distributed setup
            batch_size = image.shape[0]
            metric_logger.update(loss=loss.item())
            metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
            metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
            num_processed_samples += batch_size
    # gather the stats from all processes

    num_processed_samples = utils.reduce_across_processes(num_processed_samples)
    if (
        hasattr(data_loader.dataset, "__len__")
        and len(data_loader.dataset) != num_processed_samples
        and torch.distributed.get_rank() == 0
    ):
        # See FIXME above
        warnings.warn(
            f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
            "samples were used for the validation, which might bias the results. "
            "Try adjusting the batch size and / or the world size. "
            "Setting the world size to 1 is always a safe bet."
        )

    metric_logger.synchronize_between_processes()

    logging.info(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}")
    return metric_logger.acc1.global_avg


def load_data(args):
    # Data loading code
    logging.info("Loading data")
    val_resize_size, val_crop_size, train_crop_size = args.val_resize_size, args.val_crop_size, args.train_crop_size
    interpolation = InterpolationMode(args.interpolation)

    logging.info("Loading training data")
    st = time.time()
    auto_augment_policy = getattr(args, "auto_augment", None)
    random_erase_prob = getattr(args, "random_erase", 0.0)
    transform_train = presets.ClassificationPresetTrain(
            crop_size=train_crop_size,
            interpolation=interpolation,
            auto_augment_policy=auto_augment_policy,
            random_erase_prob=random_erase_prob,)
    if 'cars' in args.data_path:
        dataset = torchvision.datasets.StanfordCars(args.data_path, split='train', download=True, transform=transform_train)
    elif 'caltech' in args.data_path:
        caltech_dir = os.path.join(args.data_path, 'Caltech101')
        dataset = Caltech(caltech_dir, split='train', transform=transform_train, n=30, seed=0)
    elif 'cifar100' in args.data_path:
        dataset = torchvision.datasets.CIFAR100(args.data_path, train=True, download=True, transform=transform_train)
    elif 'cifar10' in args.data_path:
        dataset = torchvision.datasets.CIFAR10(args.data_path, train=True, download=True, transform=transform_train)
    elif 'pets' in args.data_path:
        dataset = torchvision.datasets.OxfordIIITPet(args.data_path, split='trainval', download=True, transform=transform_train)
    logging.info(f"Took {time.time() - st}")

    logging.info("Loading validation data")
    if args.weights and args.test_only:
        weights = torchvision.models.get_weight(args.weights)
        preprocessing = weights.transforms()
    else:
        preprocessing = presets.ClassificationPresetEval(
            crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
        )

    if 'cars' in args.data_path:
        dataset_test = torchvision.datasets.StanfordCars(args.data_path, split='test', download=True, transform=preprocessing)
    elif 'caltech' in args.data_path:
        dataset_test = Caltech(caltech_dir, split='test', transform=preprocessing, n=30, seed=0)
    elif 'cifar100' in args.data_path:
        dataset_test = torchvision.datasets.CIFAR100(args.data_path, train=False, download=True, transform=preprocessing)
    elif 'cifar10' in args.data_path:
        dataset_test = torchvision.datasets.CIFAR10(args.data_path, train=False, download=True, transform=preprocessing)
    elif 'pets' in args.data_path:
        dataset_test = torchvision.datasets.OxfordIIITPet(args.data_path, split='test', download=True, transform=preprocessing)

    logging.info("Creating data loaders")
    if args.torch_seed is not None:
        torch.manual_seed(args.torch_seed)
    if args.distributed:
        if hasattr(args, "ra_sampler") and args.ra_sampler:
            train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps)
        else:
            train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    return dataset, dataset_test, train_sampler, test_sampler


def main(args):
    if args.output_dir:
        utils.mkdir(args.output_dir)

    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s | %(message)s",
        handlers=[
            logging.FileHandler(os.path.join(args.output_dir, 'training.log')),
            logging.StreamHandler()
        ])
    logger = logging.getLogger()

    utils.init_distributed_mode(args)
    utils.set_lr_and_wd(args)
    utils.set_pretrained_gmp_lr_and_freq(args)
    logging.info(args)

    device = torch.device(args.device)

    if args.use_deterministic_algorithms:
        torch.backends.cudnn.benchmark = False
        torch.use_deterministic_algorithms(True)
    else:
        torch.backends.cudnn.benchmark = True

    dataset, dataset_test, train_sampler, test_sampler = load_data(args)

    ''' change iters to epochs '''
    if args.iterations:
        num_samples = len(dataset)
        iters_per_epoch = math.ceil(num_samples / (args.batch_size*torch.cuda.device_count()))
        args.epochs = math.ceil(args.iterations / iters_per_epoch)
        if args.iterations % args.epochs == 0:
            args.epochs += 1
        logging.info('Change epochs to %d'%args.epochs)

    collate_fn = None
    num_classes = len(dataset.classes)
    logging.info("# of classes: %d"%num_classes)
    mixup_transforms = []
    if args.mixup_alpha > 0.0:
        mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
    if args.cutmix_alpha > 0.0:
        mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
    if mixup_transforms:
        mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
        collate_fn = lambda batch: mixupcutmix(*default_collate(batch))  # noqa: E731
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.workers,
        pin_memory=True,
        collate_fn=collate_fn,
    )
    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
    )

    test_accs = []
    if args.lr_mask == None:
        args.lr_mask =  args.lr
    for exp in range(args.num_exp):
        logging.info("="*10 + " Exp: %d" %exp + "="*10)
        for n in range(2):
            if n == 0:
                learning_rate = args.lr_mask
            else:
                learning_rate = args.lr
            logging.info("Creating model")
            it = 0
            if args.weights:
                if args.model == 'inception_v3':
                    model = torchvision.models.__dict__[args.model](weights=args.weights, aux_logits=False, num_classes=1000)
                else:
                    model = torchvision.models.__dict__[args.model](weights=args.weights, num_classes=1000)
                ''' FC change '''
                if args.model in ['mobilenet_v2', 'mnasnet1_0', 'mnasnet0_5']:
                    num_ftrs = model.classifier[-1].in_features
                    model.classifier[-1] = nn.Linear(num_ftrs, num_classes)
                elif args.model in ['densenet121', 'densenet169', 'densenet201']:
                    num_ftrs = model.classifier.in_features
                    model_classifier = nn.Linear(num_ftrs, num_classes)
                elif args.model in ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'googlenet', 'inception_v3']:
                    num_ftrs = model.fc.in_features
                    model.fc = nn.Linear(num_ftrs, num_classes)
                elif 'vit' in args.model:
                    heads_layers: OrderedDict[str, nn.Module] = OrderedDict()
                    heads_layers["head"] = nn.Linear(model.hidden_dim, num_classes)
                    model.heads = nn.Sequential(heads_layers)
                else:
                    # try your customized model
                    raise NotImplementedError
            else:
                if args.model == 'inception_v3':
                    model = torchvision.models.__dict__[args.model](aux_logits=False, num_classes=num_classes)
                else:
                    model = torchvision.models.__dict__[args.model](num_classes=num_classes)

            params_to_prune = []
            num_total = 0
            for name, module in model.named_modules():
                if 'vit' in args.model:
                    if isinstance(module, torch.nn.MultiheadAttention):
                        params_to_prune.append((module, 'in_proj_weight'))
                        t = float(module.in_proj_weight.nelement())
                        num_total += t
                    elif isinstance(module, torch.nn.Linear) and 'mlp' in name:
                        params_to_prune.append((module, 'weight'))
                        t = float(module.weight.nelement())
                        num_total += t
                else:
                    if isinstance(module, torch.nn.Conv2d):
                        params_to_prune.append((module, 'weight'))
                        t = float(module.weight.nelement())
                        num_total += t

            if n == 0:
                logging.info("="*10 + " Gradual pruning phase " + "="*10)
                checkpoint = {"model": model.state_dict()}
                utils.save_on_master(checkpoint, os.path.join(args.output_dir, "w0.pth"))
            else:
                logging.info("="*10 + " Sparse training phase " + "="*10)
                '''''''''
                load weights from the initial model
                '''''''''
                ckpt = os.path.join(args.output_dir, 'w0.pth')
                if os.path.isfile(ckpt):
                    state_dict = torch.load(os.path.join(ckpt), map_location="cpu")
                    state_dict = state_dict["model"]
                    msg = model.load_state_dict(state_dict, strict=True)
                    logging.info("Load pretrained model with msg: {}".format(msg))
                else:
                    logging.info("=> no checkpoint found at '{}'".format(ckpt))
                '''''''''
                load mask from the previous model
                '''''''''
                prune.global_unstructured(
                    params_to_prune,
                    pruning_method=prune.L1Unstructured,
                    amount=0,)
                ckpt = os.path.join(args.output_dir, f'checkpoint_{exp}.pth')
                if os.path.isfile(ckpt):
                    state_dict = torch.load(os.path.join(ckpt), map_location="cpu")
                    state_dict = state_dict["model"]
                    state_dict_new = {}
                    for k, v in state_dict.items():
                        if 'mask' in k:
                            state_dict_new[k] = v
                    msg = model.load_state_dict(state_dict_new, strict=False)
                    logging.info("Load pretrained mask with msg: {}".format(msg))
                    prune.global_unstructured(
                        params_to_prune,
                        pruning_method=prune.L1Unstructured,
                        amount=0,)
                else:
                    logging.info("=> no checkpoint found at '{}'".format(ckpt))
                ''' compute layer-wise sparsity '''
                num_pruned = 0
                for name, module in model.named_modules():
                    if 'vit' in args.model:
                        if isinstance(module, torch.nn.MultiheadAttention):
                            p = float(torch.sum(module.in_proj_weight == 0))
                            num_pruned += p
                            logging.info("Sparsity in {}.weight: {:.2f}%".format(name, 100. * p / t))
                        elif isinstance(module, torch.nn.Linear) and 'mlp' in name:
                            p = float(torch.sum(module.weight == 0))
                            num_pruned += p
                            logging.info("Sparsity in {}.weight: {:.2f}%".format(name, 100. * p / t))
                    else:
                        if isinstance(module, torch.nn.Conv2d):
                            p = float(torch.sum(module.weight == 0))
                            num_pruned += p
                            logging.info("Sparsity in {}.weight: {:.2f}%".format(name, 100. * p / t))
                logging.info("Global pruned parameters: {}".format(num_pruned))
                logging.info("Global sparsity: {}%".format(100. * num_pruned / num_total))

            model.to(device)

            if args.distributed and args.sync_bn:
                model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

            criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)

            custom_keys_weight_decay = []
            if args.bias_weight_decay is not None:
                custom_keys_weight_decay.append(("bias", args.bias_weight_decay))
            if args.transformer_embedding_decay is not None:
                for key in ["class_token", "position_embedding", "relative_position_bias_table"]:
                    custom_keys_weight_decay.append((key, args.transformer_embedding_decay))
            parameters = utils.set_weight_decay(
                model,
                args.weight_decay,
                norm_weight_decay=args.norm_weight_decay,
                custom_keys_weight_decay=custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None,
            )

            opt_name = args.opt.lower()
            if opt_name.startswith("sgd"):
                optimizer = torch.optim.SGD(
                    parameters,
                    lr=learning_rate,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay,
                    nesterov="nesterov" in opt_name,
                )
            elif opt_name == "rmsprop":
                optimizer = torch.optim.RMSprop(
                    parameters, lr=learning_rate, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9
                )
            elif opt_name == "adamw":
                optimizer = torch.optim.AdamW(parameters, lr=learning_rate, weight_decay=args.weight_decay)
            else:
                raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.")

            scaler = torch.cuda.amp.GradScaler() if args.amp else None

            args.lr_scheduler = args.lr_scheduler.lower()
            if args.lr_scheduler == "steplr":
                main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
            elif args.lr_scheduler == "cosineannealinglr":
                main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                    optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=args.lr_min
                )
            elif args.lr_scheduler == "exponentiallr":
                main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
            else:
                raise RuntimeError(
                    f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
                    "are supported."
                )

            if args.lr_warmup_epochs > 0:
                if args.lr_warmup_method == "linear":
                    warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
                        optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
                    )
                elif args.lr_warmup_method == "constant":
                    warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
                        optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
                    )
                else:
                    raise RuntimeError(
                        f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
                    )
                lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
                    optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs]
                )
            else:
                lr_scheduler = main_lr_scheduler

            model_without_ddp = model
            if args.distributed:
                model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
                model_without_ddp = model.module

            model_ema = None
            if args.model_ema:
                # Decay adjustment that aims to keep the decay independent from other hyper-parameters originally proposed at:
                # https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123
                #
                # total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps)
                # We consider constant = Dataset_size for a given dataset/setup and ommit it. Thus:
                # adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs
                adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs
                alpha = 1.0 - args.model_ema_decay
                alpha = min(1.0, alpha * adjust)
                model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha)

            if args.resume:
                checkpoint = torch.load(args.resume, map_location="cpu")
                model_without_ddp.load_state_dict(checkpoint["model"])
                if not args.test_only:
                    optimizer.load_state_dict(checkpoint["optimizer"])
                    lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
                args.start_epoch = checkpoint["epoch"] + 1
                if model_ema:
                    model_ema.load_state_dict(checkpoint["model_ema"])
                if scaler:
                    scaler.load_state_dict(checkpoint["scaler"])

            if args.test_only:
                # We disable the cudnn benchmarking because it can noticeably affect the accuracy
                torch.backends.cudnn.benchmark = False
                torch.backends.cudnn.deterministic = True
                if model_ema:
                    evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
                else:
                    evaluate(model, criterion, data_loader_test, device=device)
                return

            if n == 1 and args.weights == None:
                ckpt = os.path.join(args.output_dir, 'w0.pth')
                if os.path.isfile(ckpt):
                    checkpoint = torch.load(os.path.join(ckpt), map_location="cpu")
                    optimizer.load_state_dict(checkpoint["optimizer"])
                    lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
                    args.start_epoch = checkpoint["epoch"]
                    it = checkpoint["iteration"]
                else:
                    logging.info("=> no checkpoint found at '{}'".format(ckpt))

            logging.info("Start training")
            start_time = time.time()
            num_pruned = 0
            prev_target_ratio = 0
            for epoch in range(args.start_epoch, args.epochs):
                if n == 0 and epoch == (args.prune_start_it // iters_per_epoch) and args.weights == None:
                    checkpoint = {
                        "model": model_without_ddp.state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "lr_scheduler": lr_scheduler.state_dict(),
                        "epoch": epoch,
                        "iteration": it,
                    }
                    utils.save_on_master(checkpoint, os.path.join(args.output_dir, "w0.pth"))
                if args.distributed:
                    train_sampler.set_epoch(epoch)
                prev_target_ratio, it, num_pruned = train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler, params_to_prune, it, prev_target_ratio, num_pruned, n, num_total)
                lr_scheduler.step()
                if epoch % (args.epochs // 50) == 0:
                    test_acc = evaluate(model, criterion, data_loader_test, device=device)
                if model_ema:
                    evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
                if args.output_dir:
                    checkpoint = {
                        "model": model_without_ddp.state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "lr_scheduler": lr_scheduler.state_dict(),
                        "epoch": epoch,
                        "args": args,
                    }
                    if model_ema:
                        checkpoint["model_ema"] = model_ema.state_dict()
                    if scaler:
                        checkpoint["scaler"] = scaler.state_dict()
                if it > args.prune_end_it and args.get_mask and n == 0:
                    break
            test_acc = evaluate(model, criterion, data_loader_test, device=device)
            if n == 1:
                test_accs.append(test_acc)
            #utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}_{exp}.pth"))
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"checkpoint_{exp}.pth"))
            total_time = time.time() - start_time
            total_time_str = str(datetime.timedelta(seconds=int(total_time)))
            logging.info(f"Training time {total_time_str}")
            ''' pruning sanity check after training '''
            if utils.is_main_process():
                num_pruned = 0
                for name, module in model.named_modules():
                    if 'vit' in args.model:
                        if isinstance(module, torch.nn.MultiheadAttention):
                            p = float(torch.sum(module.in_proj_weight == 0))
                            t = float(module.in_proj_weight.nelement())
                            num_pruned += p
                            logging.info("Sparsity in {}.weight: {:.2f}%".format(name, 100. * p / t))
                        elif isinstance(module, torch.nn.Linear) and 'mlp' in name:
                            p = float(torch.sum(module.weight == 0))
                            t = float(module.weight.nelement())
                            num_pruned += p
                            logging.info("Sparsity in {}.weight: {:.2f}%".format(name, 100. * p / t))
                    else:
                        if isinstance(module, torch.nn.Conv2d):
                            p = float(torch.sum(module.weight == 0))
                            t = float(module.weight.nelement())
                            num_pruned += p
                            logging.info("Sparsity in {}.weight: {:.2f}%".format(name, 100. * p / t))
                logging.info("Global pruned parameters: {}".format(num_pruned))
                logging.info("Global sparsity: {:.2f}%".format(100. * num_pruned / num_total))
    if len(test_accs) == 1:
        logging.info('Mean: %f, Var: %f'%(test_acc, 0))
    else:
        logging.info('Mean: %f, Var: %f'%(statistics.mean(test_accs), statistics.variance(test_accs)))


def get_args_parser(add_help=True):
    import argparse

    parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)

    ''' Add '''
    parser.add_argument("--prune-ratio", type=float, help="Pruning ratio")
    parser.add_argument("--prune-num", type=int, default=0, help="the number of removed parameters")
    parser.add_argument("--num-exp", type=int, default=1)
    parser.add_argument("--prune-start-it", type=int, default=0, help="Pruning start iteration")
    parser.add_argument("--prune-end-it", type=int, default=10000, help="Pruning end iteration")
    parser.add_argument("--prune-freq-it", type=int, default=100, help="Pruning frequency")
    parser.add_argument("--get-mask", action="store_true", default=True)
    parser.add_argument("--torch-seed", type=int)
    parser.add_argument("--lr-mask", type=float, help="learning rate for mask learning phase")
    ''''''

    parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path")
    parser.add_argument("--model", default="resnet18", type=str, help="model name")
    parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
    parser.add_argument(
        "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
    )
    parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
    parser.add_argument("--iterations", type=int)
    parser.add_argument(
        "-j", "--workers", default=8, type=int, metavar="N", help="number of data loading workers (default: 16)"
    )
    parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
    parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
    parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
    parser.add_argument(
        "--wd",
        "--weight-decay",
        default=1e-4,
        type=float,
        metavar="W",
        help="weight decay (default: 1e-4)",
        dest="weight_decay",
    )
    parser.add_argument(
        "--norm-weight-decay",
        default=None,
        type=float,
        help="weight decay for Normalization layers (default: None, same value as --wd)",
    )
    parser.add_argument(
        "--bias-weight-decay",
        default=None,
        type=float,
        help="weight decay for bias parameters of all layers (default: None, same value as --wd)",
    )
    parser.add_argument(
        "--transformer-embedding-decay",
        default=None,
        type=float,
        help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)",
    )
    parser.add_argument(
        "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
    )
    parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)")
    parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)")
    parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)")
    parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
    parser.add_argument(
        "--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)"
    )
    parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
    parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
    parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
    parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)")
    parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
    parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
    parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
    parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
    parser.add_argument(
        "--sync-bn",
        dest="sync_bn",
        help="Use sync batch norm",
        action="store_true",
    )
    parser.add_argument(
        "--test-only",
        dest="test_only",
        help="Only test the model",
        action="store_true",
    )
    parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)")
    parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")

    # Mixed precision training parameters
    parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")

    # distributed training parameters
    parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
    parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
    parser.add_argument(
        "--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters"
    )
    parser.add_argument(
        "--model-ema-steps",
        type=int,
        default=32,
        help="the number of iterations that controls how often to update the EMA model (default: 32)",
    )
    parser.add_argument(
        "--model-ema-decay",
        type=float,
        default=0.99998,
        help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)",
    )
    parser.add_argument(
        "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
    )
    parser.add_argument(
        "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
    )
    parser.add_argument(
        "--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
    )
    parser.add_argument(
        "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
    )
    parser.add_argument(
        "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
    )
    parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
    parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training")
    parser.add_argument(
        "--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)"
    )
    parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
    parser.add_argument("--pretrained", default=None, type=str, help="the weights of self-sl to load")

    return parser


if __name__ == "__main__":
    args = get_args_parser().parse_args()
    main(args)
