import copy
import sys
import datetime
import errno
import hashlib
import os
import time
from collections import defaultdict, deque, OrderedDict
from typing import List, Optional, Tuple
import sys

import numpy as np
import torch
from torch import nn
import torch.distributed as dist
from torch.utils.data import Dataset
from PIL import Image
import torch.nn.utils.prune as prune
import torch.nn.functional as F
import logging


class SmoothedValue:
    """Track a series of values and provide access to smoothed values over a
    window or the global series average.
    """

    def __init__(self, window_size=20, fmt=None):
        if fmt is None:
            fmt = "{median:.4f} ({global_avg:.4f})"
        self.deque = deque(maxlen=window_size)
        self.total = 0.0
        self.count = 0
        self.fmt = fmt

    def update(self, value, n=1):
        self.deque.append(value)
        self.count += n
        self.total += value * n

    def synchronize_between_processes(self):
        """
        Warning: does not synchronize the deque!
        """
        t = reduce_across_processes([self.count, self.total])
        t = t.tolist()
        self.count = int(t[0])
        self.total = t[1]

    @property
    def median(self):
        d = torch.tensor(list(self.deque))
        return d.median().item()

    @property
    def avg(self):
        d = torch.tensor(list(self.deque), dtype=torch.float32)
        return d.mean().item()

    @property
    def global_avg(self):
        return self.total / self.count

    @property
    def max(self):
        return max(self.deque)

    @property
    def value(self):
        return self.deque[-1]

    def __str__(self):
        return self.fmt.format(
            median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
        )


class MetricLogger:
    def __init__(self, delimiter="\t"):
        self.meters = defaultdict(SmoothedValue)
        self.delimiter = delimiter

    def update(self, **kwargs):
        for k, v in kwargs.items():
            if isinstance(v, torch.Tensor):
                v = v.item()
            assert isinstance(v, (float, int))
            self.meters[k].update(v)

    def __getattr__(self, attr):
        if attr in self.meters:
            return self.meters[attr]
        if attr in self.__dict__:
            return self.__dict__[attr]
        raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")

    def __str__(self):
        loss_str = []
        for name, meter in self.meters.items():
            loss_str.append(f"{name}: {str(meter)}")
        return self.delimiter.join(loss_str)

    def synchronize_between_processes(self):
        for meter in self.meters.values():
            meter.synchronize_between_processes()

    def add_meter(self, name, meter):
        self.meters[name] = meter

    def log_every(self, iterable, print_freq, header=None):
        i = 0
        if not header:
            header = ""
        start_time = time.time()
        end = time.time()
        iter_time = SmoothedValue(fmt="{avg:.4f}")
        data_time = SmoothedValue(fmt="{avg:.4f}")
        space_fmt = ":" + str(len(str(len(iterable)))) + "d"
        if torch.cuda.is_available():
            log_msg = self.delimiter.join(
                [
                    header,
                    "[{0" + space_fmt + "}/{1}]",
                    "eta: {eta}",
                    "{meters}",
                    "time: {time}",
                    "data: {data}",
                    "max mem: {memory:.0f}",
                ]
            )
        else:
            log_msg = self.delimiter.join(
                [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
            )
        MB = 1024.0 * 1024.0
        for obj in iterable:
            data_time.update(time.time() - end)
            yield obj
            iter_time.update(time.time() - end)
            if i % print_freq == 0:
                eta_seconds = iter_time.global_avg * (len(iterable) - i)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
                if torch.cuda.is_available():
                    #print(
                    logging.info(
                        log_msg.format(
                            i,
                            len(iterable),
                            eta=eta_string,
                            meters=str(self),
                            time=str(iter_time),
                            data=str(data_time),
                            memory=torch.cuda.max_memory_allocated() / MB,
                        )
                    )
                else:
                    #print(
                    logging.info(
                        log_msg.format(
                            i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
                        )
                    )
            i += 1
            end = time.time()
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        #print(f"{header} Total time: {total_time_str}")
        logging.info(f"{header} Total time: {total_time_str}")


class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
    """Maintains moving averages of model parameters using an exponential decay.
    ``ema_avg = decay * avg_model_param + (1 - decay) * model_param``
    `torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_
    is used to compute the EMA.
    """

    def __init__(self, model, decay, device="cpu"):
        def ema_avg(avg_model_param, model_param, num_averaged):
            return decay * avg_model_param + (1 - decay) * model_param

        super().__init__(model, device, ema_avg, use_buffers=True)


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.inference_mode():
        maxk = max(topk)
        batch_size = target.size(0)
        if target.ndim == 2:
            target = target.max(dim=1)[1]

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target[None])

        res = []
        for k in topk:
            correct_k = correct[:k].flatten().sum(dtype=torch.float32)
            res.append(correct_k * (100.0 / batch_size))
        return res


def mkdir(path):
    try:
        os.makedirs(path)
    except OSError as e:
        if e.errno != errno.EEXIST:
            raise


def setup_for_distributed(is_master):
    """
    This function disables printing when not in master process
    """
    import builtins as __builtin__

    builtin_print = __builtin__.print

    def print(*args, **kwargs):
        force = kwargs.pop("force", False)
        if is_master or force:
            builtin_print(*args, **kwargs)

    __builtin__.print = print


def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True


def get_world_size():
    if not is_dist_avail_and_initialized():
        return 1
    return dist.get_world_size()


def get_rank():
    if not is_dist_avail_and_initialized():
        return 0
    return dist.get_rank()


def is_main_process():
    return get_rank() == 0


def save_on_master(*args, **kwargs):
    if is_main_process():
        torch.save(*args, **kwargs)


def init_distributed_mode(args):
    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
        args.rank = int(os.environ["RANK"])
        args.world_size = int(os.environ["WORLD_SIZE"])
        args.gpu = int(os.environ["LOCAL_RANK"])
    elif "SLURM_PROCID" in os.environ:
        args.rank = int(os.environ["SLURM_PROCID"])
        args.gpu = args.rank % torch.cuda.device_count()
    elif hasattr(args, "rank"):
        pass
    else:
        print("Not using distributed mode")
        args.distributed = False
        return

    args.distributed = True

    torch.cuda.set_device(args.gpu)
    args.dist_backend = "nccl"
    print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
    torch.distributed.init_process_group(
        backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
    )
    torch.distributed.barrier()
    setup_for_distributed(args.rank == 0)
    return


def average_checkpoints(inputs):
    """Loads checkpoints from inputs and returns a model with averaged weights. Original implementation taken from:
    https://github.com/pytorch/fairseq/blob/a48f235636557b8d3bc4922a6fa90f3a0fa57955/scripts/average_checkpoints.py#L16

    Args:
      inputs (List[str]): An iterable of string paths of checkpoints to load from.
    Returns:
      A dict of string keys mapping to various values. The 'model' key
      from the returned dict should correspond to an OrderedDict mapping
      string parameter names to torch Tensors.
    """
    params_dict = OrderedDict()
    params_keys = None
    new_state = None
    num_models = len(inputs)
    for fpath in inputs:
        with open(fpath, "rb") as f:
            state = torch.load(
                f,
                map_location=(lambda s, _: torch.serialization.default_restore_location(s, "cpu")),
            )
        # Copies over the settings from the first checkpoint
        if new_state is None:
            new_state = state
        model_params = state["model"]
        model_params_keys = list(model_params.keys())
        if params_keys is None:
            params_keys = model_params_keys
        elif params_keys != model_params_keys:
            raise KeyError(
                f"For checkpoint {f}, expected list of params: {params_keys}, but found: {model_params_keys}"
            )
        for k in params_keys:
            p = model_params[k]
            if isinstance(p, torch.HalfTensor):
                p = p.float()
            if k not in params_dict:
                params_dict[k] = p.clone()
                # NOTE: clone() is needed in case of p is a shared parameter
            else:
                params_dict[k] += p
    averaged_params = OrderedDict()
    for k, v in params_dict.items():
        averaged_params[k] = v
        if averaged_params[k].is_floating_point():
            averaged_params[k].div_(num_models)
        else:
            averaged_params[k] //= num_models
    new_state["model"] = averaged_params
    return new_state


def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=True):
    """
    This method can be used to prepare weights files for new models. It receives as
    input a model architecture and a checkpoint from the training script and produces
    a file with the weights ready for release.

    Examples:
        from torchvision import models as M

        # Classification
        model = M.mobilenet_v3_large(weights=None)
        print(store_model_weights(model, './class.pth'))

        # Quantized Classification
        model = M.quantization.mobilenet_v3_large(weights=None, quantize=False)
        model.fuse_model(is_qat=True)
        model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack')
        _ = torch.ao.quantization.prepare_qat(model, inplace=True)
        print(store_model_weights(model, './qat.pth'))

        # Object Detection
        model = M.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=None, weights_backbone=None)
        print(store_model_weights(model, './obj.pth'))

        # Segmentation
        model = M.segmentation.deeplabv3_mobilenet_v3_large(weights=None, weights_backbone=None, aux_loss=True)
        print(store_model_weights(model, './segm.pth', strict=False))

    Args:
        model (pytorch.nn.Module): The model on which the weights will be loaded for validation purposes.
        checkpoint_path (str): The path of the checkpoint we will load.
        checkpoint_key (str, optional): The key of the checkpoint where the model weights are stored.
            Default: "model".
        strict (bool): whether to strictly enforce that the keys
            in :attr:`state_dict` match the keys returned by this module's
            :meth:`~torch.nn.Module.state_dict` function. Default: ``True``

    Returns:
        output_path (str): The location where the weights are saved.
    """
    # Store the new model next to the checkpoint_path
    checkpoint_path = os.path.abspath(checkpoint_path)
    output_dir = os.path.dirname(checkpoint_path)

    # Deep copy to avoid side-effects on the model object.
    model = copy.deepcopy(model)
    checkpoint = torch.load(checkpoint_path, map_location="cpu")

    # Load the weights to the model to validate that everything works
    # and remove unnecessary weights (such as auxiliaries, etc)
    if checkpoint_key == "model_ema":
        del checkpoint[checkpoint_key]["n_averaged"]
        torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(checkpoint[checkpoint_key], "module.")
    model.load_state_dict(checkpoint[checkpoint_key], strict=strict)

    tmp_path = os.path.join(output_dir, str(model.__hash__()))
    torch.save(model.state_dict(), tmp_path)

    sha256_hash = hashlib.sha256()
    with open(tmp_path, "rb") as f:
        # Read and update hash string value in blocks of 4K
        for byte_block in iter(lambda: f.read(4096), b""):
            sha256_hash.update(byte_block)
        hh = sha256_hash.hexdigest()

    output_path = os.path.join(output_dir, "weights-" + str(hh[:8]) + ".pth")
    os.replace(tmp_path, output_path)

    return output_path


def reduce_across_processes(val):
    if not is_dist_avail_and_initialized():
        # nothing to sync, but we still convert to tensor for consistency with the distributed case.
        return torch.tensor(val)

    t = torch.tensor(val, device="cuda")
    dist.barrier()
    dist.all_reduce(t)
    return t


def set_weight_decay(
    model: torch.nn.Module,
    weight_decay: float,
    norm_weight_decay: Optional[float] = None,
    norm_classes: Optional[List[type]] = None,
    custom_keys_weight_decay: Optional[List[Tuple[str, float]]] = None,
):
    if not norm_classes:
        norm_classes = [
            torch.nn.modules.batchnorm._BatchNorm,
            torch.nn.LayerNorm,
            torch.nn.GroupNorm,
            torch.nn.modules.instancenorm._InstanceNorm,
            torch.nn.LocalResponseNorm,
        ]
    norm_classes = tuple(norm_classes)

    params = {
        "other": [],
        "norm": [],
    }
    params_weight_decay = {
        "other": weight_decay,
        "norm": norm_weight_decay,
    }
    custom_keys = []
    if custom_keys_weight_decay is not None:
        for key, weight_decay in custom_keys_weight_decay:
            params[key] = []
            params_weight_decay[key] = weight_decay
            custom_keys.append(key)

    def _add_params(module, prefix=""):
        for name, p in module.named_parameters(recurse=False):
            if not p.requires_grad:
                continue
            is_custom_key = False
            for key in custom_keys:
                target_name = f"{prefix}.{name}" if prefix != "" and "." in key else name
                if key == target_name:
                    params[key].append(p)
                    is_custom_key = True
                    break
            if not is_custom_key:
                if norm_weight_decay is not None and isinstance(module, norm_classes):
                    params["norm"].append(p)
                else:
                    params["other"].append(p)

        for child_name, child_module in module.named_children():
            child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name
            _add_params(child_module, prefix=child_prefix)

    _add_params(model)

    param_groups = []
    for key in params:
        if len(params[key]) > 0:
            param_groups.append({"params": params[key], "weight_decay": params_weight_decay[key]})
    return param_groups

def _snip(args, model, params, loader, device):
    criterion = torch.nn.CrossEntropyLoss()
    gw = [0. for i in range(len(params))]
    for i, (data, target) in enumerate(loader):
        data, target = data.to(device), target.to(device)
        output = model(data)
        loss = criterion(output, target)
        gw_i = torch.autograd.grad(loss, params)
        gw = list((gw_i[i].detach() / float(args.num_prune_batch) + accu for i, accu in enumerate(gw)))
        if (i+1) == args.num_prune_batch:
            break
    scores = {}
    k = 0
    ''' computing scores '''
    with torch.no_grad():
        for name, module in model.named_modules():
                if isinstance(module, torch.nn.Conv2d):
                    scores[(module, 'weight')] = torch.abs(module.weight * gw[k])
                    k += 1
    return scores

def _get_batch_cos_dist(f1, f2):
    f1_normalize = F.normalize(f1.view(f1.size(0), -1), dim=1)
    f2_normalize = F.normalize(f2.view(f1.size(0), -1), dim=1)
    return 1.0 - (f1_normalize * f2_normalize).sum(dim=1).mean()


def _grasp(args, model, params, loader, device):
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    temp = 200.
    inputs_one = []
    targets_one = []
    grad_w = None
    grad_f = None
    for i, (data, target) in enumerate(loader):
        N = data.shape[0]
        din = copy.deepcopy(data)
        dtarget = copy.deepcopy(target)
        start = 0
        intv = 20
        while start < N:
            end = min(start+intv, N)
            logging.info('(1):  %d -> %d.' % (start, end))
            inputs_one.append(din[start:end])
            targets_one.append(dtarget[start:end])
            outputs = model(data[start:end].to(device)) / temp
            loss = criterion(outputs, target[start:end].to(device)) / (end - start)
            grad_w_p = torch.autograd.grad(loss, params, create_graph=False)
            if grad_w is None:
                grad_w = list(grad_w_p)
            else:
                for idx in range(len(grad_w)):
                    grad_w[idx] += grad_w_p[idx]
            start = end
        if (i+1) == args.num_prune_batch:
            break
    for it in range(len(inputs_one)):
        logging.info("(2): Iterations %d/%d." % (it, len(inputs_one)))
        inputs = inputs_one.pop(0).to(device)
        targets = targets_one.pop(0).to(device)
        outputs = model(inputs) / temp
        loss = criterion(outputs, targets) / len(targets)
        grad_f = torch.autograd.grad(loss, params, create_graph=True)
        z = 0
        for k in range(len(params)):
            z += (grad_w[k] * grad_f[k]).sum()
        z.backward()
    scores = {}
    k = 0
    for name, module in model.named_modules():
            if isinstance(module, torch.nn.Conv2d):
                scores[(module, 'weight')] = torch.clone(module.weight.data * params[k].grad).detach()
                k += 1
    model.zero_grad()
    return scores

def _synflow(args, model, params, loader, device):
    
    @torch.no_grad()
    def linearize(model):
        model.double()
        signs = {}
        for name, param in model.state_dict().items():
            #print(name)
            signs[name] = torch.sign(param)
            #print(signs[name].dtype)
            param.abs_()
        return signs

    @torch.no_grad()
    def nonlinearize(model, signs):
        model.float()
        for name, param in model.state_dict().items():
            param.mul_(signs[name])
    
    signs = linearize(model)

    (data, _) = next(iter(loader))
    input_dim = list(data[0,:].shape)
    input = torch.ones([1] + input_dim).to(torch.double).to(device)#, dtype=torch.float64).to(device)
    output = model(input)
    torch.sum(output).backward()
    
    scores = {}
    k = 0 
    for name, module in model.named_modules():
            if isinstance(module, torch.nn.Conv2d):
                scores[(module, 'weight')] = torch.clone(params[k].grad * module.weight).detach().abs_()
                k += 1
    model.zero_grad()
    nonlinearize(model, signs)
    return scores

def _prospr(args, model, params, loader, device, steps):
    ''' memorize initial weights '''
    w0 = []
    for name, module in model.named_modules():
            if isinstance(module, torch.nn.Conv2d):
                w0.append(torch.clone(module.weight).detach().cpu())
    ''' optimizer setting '''
    parameters = set_weight_decay(model, args.weight_decay, norm_weight_decay=args.norm_weight_decay, custom_keys_weight_decay=None)
    optimizer = torch.optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=False)
    criterion = torch.nn.CrossEntropyLoss()
    ''' training '''
    model.train()
    i = 0
    while i < steps:
        for image, target in loader:
            image, target = image.to(device), target.to(device)
            output = model(image)
            loss = criterion(output, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            i += 1
            if i == steps:
                break
    ''' scoring '''
    for image, target in loader:
        image, target = image.to(device), target.to(device)
        output = model(image)
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
    scores = {}
    k = 0 
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            scores[(module, 'weight')] = torch.clone(params[k].grad.cpu() * w0[k]).detach().abs_()
            k += 1
    model.zero_grad()
    return scores

def get_importance_scores(args, model_tmp, tmp_parameters, loader, device, model_comp):
    '''''''''
    Calculating importance scores.
    Lower-scoring parameters will be pruned.
    '''''''''
    ''' target modules'''
    if args.prune == 'snip':
        return _snip(args, model_tmp, tmp_parameters, loader, device)
    elif args.prune == 'grasp':
        return _grasp(args, model_tmp, tmp_parameters, loader, device)
    elif args.prune == 'synflow':
        model_tmp.eval()
        return _synflow(args, model_tmp, tmp_parameters, loader, device)
    elif args.prune == 'prospr':
        return _prospr(args, model_tmp, tmp_parameters, loader, device, args.prospr_steps)

def Prune(args, params_to_prune, num_total, model, device, prune_loader=None, imagenet_eval=False):
    if args.prune == 'mag':
        prune.global_unstructured(
            params_to_prune,
            pruning_method=prune.L1Unstructured,
            amount=int(num_total*args.prune_ratio),)
    elif args.prune == 'random':
        prune.global_unstructured(
            params_to_prune,
            pruning_method=prune.RandomUnstructured,
            amount=int(num_total*args.prune_ratio),)
    else:
        print('Prune tmp model')
        model_tmp = copy.deepcopy(model)
        if imagenet_eval:
            num_classes = len(prune_loader.dataset.classes)
            ''' FC change '''
            if args.model in ['mobilenet_v2', 'mnasnet1_0', 'mnasnet0_5']:
                num_ftrs = model_tmp.classifier[-1].in_features
                model_tmp.classifier[-1] = nn.Linear(num_ftrs, num_classes).to(device)
            elif args.model in ['densenet121', 'densenet169', 'densenet201']:
                num_ftrs = model_tmp.classifier.in_features
                model_tmp.classifier = nn.Linear(num_ftrs, num_classes).to(device)
            elif args.model in ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'googlenet', 'inception_v3']:
                num_ftrs = model_tmp.fc.in_features
                model_tmp.fc = nn.Linear(num_ftrs, num_classes).to(device)
            else:
                # try your customized model
                raise NotImplementedError
        model_tmp.zero_grad()
        model_tmp.train()
        params_to_prune_tmp = []
        tmp_params = []
        model_comp = None
        for name, module in model_tmp.named_modules():
            if isinstance(module, torch.nn.Conv2d):
                params_to_prune_tmp.append((module, 'weight'))
                tmp_params.append(module.weight)
        prune_init = 0.
        pnum_prev = 0
        with torch.autograd.set_detect_anomaly(False):
            for it in range(args.prune_iters):
                r = args.prune_ratio + (prune_init - args.prune_ratio) * (1 - (it+1) / args.prune_iters)**3
                pnum = int(r * num_total)
                prune.global_unstructured(
                    params_to_prune_tmp,
                    pruning_method=ScoreUnstructured,
                    amount=pnum - pnum_prev,
                    importance_scores=get_importance_scores(args, model_tmp, tmp_params, prune_loader, device, model_comp)
                    )
                pnum_prev = pnum
        ''' copy masks from tmp model to target model '''
        assert len(list(params_to_prune[0][0].named_buffers())) == 0
        state_dict_tmp = model_tmp.state_dict()
        state_dict_new = {}
        for k, v in state_dict_tmp.items():
            if 'mask' in k:
                print('load %s'%k)
                state_dict_new[k] = v
        prune.global_unstructured(params_to_prune, pruning_method=prune.L1Unstructured, amount=0)
        msg = model.load_state_dict(state_dict_new, strict=False)
        prune.global_unstructured(params_to_prune, pruning_method=prune.L1Unstructured, amount=0)
        print("Load mask from tmp model with msg: {}".format(msg))

class ScoreUnstructured(prune.BasePruningMethod):
    r"""Prune (currently unpruned) units in a tensor by zeroing out the ones
    with the lowest value.

    Args:
        amount (int or float): quantity of parameters to prune.
            If ``float``, should be between 0.0 and 1.0 and represent the
            fraction of parameters to prune. If ``int``, it represents the
            absolute number of parameters to prune.
    """

    PRUNING_TYPE = "unstructured"

    def __init__(self, amount):
        # Check range of validity of pruning amount
        prune._validate_pruning_amount_init(amount)
        self.amount = amount

    def compute_mask(self, t, default_mask):
        # Check that the amount of units to prune is not > than the number of
        # parameters in t
        tensor_size = t.nelement()
        # Compute number of units to prune: amount if int,
        # else amount * tensor_size
        nparams_toprune = prune._compute_nparams_toprune(self.amount, tensor_size)
        # This should raise an error if the number of units to prune is larger
        # than the number of units in the tensor
        prune._validate_pruning_amount(nparams_toprune, tensor_size)

        mask = default_mask.clone(memory_format=torch.contiguous_format)

        if nparams_toprune != 0:  # k=0 not supported by torch.kthvalue
            # largest=True --> top k; largest=False --> bottom k
            # Prune the smallest k
            topk = torch.topk(t.view(-1), k=nparams_toprune, largest=False)
            # topk will have .indices and .values
            mask.view(-1)[topk.indices] = 0

        return mask

    @classmethod
    def apply(cls, module, name, amount, importance_scores=None):
        r"""Adds the forward pre-hook that enables pruning on the fly and
        the reparametrization of a tensor in terms of the original tensor
        and the pruning mask.

        Args:
            module (nn.Module): module containing the tensor to prune
            name (str): parameter name within ``module`` on which pruning
                will act.
            amount (int or float): quantity of parameters to prune.
                If ``float``, should be between 0.0 and 1.0 and represent the
                fraction of parameters to prune. If ``int``, it represents the
                absolute number of parameters to prune.
            importance_scores (torch.Tensor): tensor of importance scores (of same
                shape as module parameter) used to compute mask for pruning.
                The values in this tensor indicate the importance of the corresponding
                elements in the parameter being pruned.
                If unspecified or None, the module parameter will be used in its place.
        """
        return super(ScoreUnstructured, cls).apply(
            module, name, amount=amount, importance_scores=importance_scores
        )

def set_lr_and_wd(args):
    if args.weights == None and args.pretrained == None:
        if args.model == 'resnet18':
            if 'cars' in args.data_path:
                args.lr = 0.1
                args.weight_decay = 0.000316
            elif 'cifar100' in args.data_path:
                args.lr = 0.1
                args.weight_decay = 0.001
            elif 'cifar10' in args.data_path:
                args.lr = 0.1
                args.weight_decay = 0.001
            elif 'caltech' in args.data_path:
                args.lr = 1.0
                args.weight_decay = 0.001
            elif 'pets' in args.data_path:
                args.lr = 0.1
                args.weight_decay = 0.00316
        elif args.model == 'resnet50':
            if 'cars' in args.data_path:
                args.lr = 0.0316
                args.weight_decay = 0.00316
            elif 'cifar100' in args.data_path:
                args.lr = 0.0316
                args.weight_decay = 0.00316
            elif 'cifar10' in args.data_path:
                args.lr = 0.1
                args.weight_decay = 0.001
            elif 'caltech' in args.data_path:
                args.lr = 0.0316
                args.weight_decay = 0.01
            elif 'pets' in args.data_path:
                args.lr = 0.0316
                args.weight_decay = 0.00316
    else:
        if args.model == 'resnet18':
            if 'cars' in args.data_path:
                args.lr = 0.0316
                args.weight_decay = 0.00316
            elif 'cifar100' in args.data_path:
                args.lr = 0.00316
                args.weight_decay = 0.000316
            elif 'cifar10' in args.data_path:
                args.lr = 0.01
                args.weight_decay = 0.0001
            elif 'caltech' in args.data_path:
                args.lr = 0.000316
                args.weight_decay = 0.0000316
            elif 'pets' in args.data_path:
                args.lr = 0.000316
                args.weight_decay = 0.000316
        elif args.model == 'resnet50':
            if 'cars' in args.data_path:
                args.lr = 0.0316
                args.weight_decay = 0.00316
            elif 'cifar100' in args.data_path:
                args.lr = 0.00316
                args.weight_decay = 0.00316
            elif 'cifar10' in args.data_path:
                args.lr = 0.01
                args.weight_decay = 0.0001
            elif 'caltech' in args.data_path:
                args.lr = 0.000316
                args.weight_decay = 0.
            elif 'pets' in args.data_path:
                args.lr = 0.000316
                args.weight_decay = 0.0001
        elif 'vit' in args.model:
            if 'cars' in args.data_path:
                args.lr = 0.00316
                args.weight_decay = 0.001
            elif 'cifar100' in args.data_path:
                args.lr = 0.00316
                args.weight_decay = 0.00001
            elif 'cifar10' in args.data_path:
                args.lr = 0.001
                args.weight_decay = 0.0001
            elif 'caltech' in args.data_path:
                args.lr = 0.000316
                args.weight_decay = 0.001
            elif 'pets' in args.data_path:
                args.lr = 0.000316
                args.weight_decay = 0.001
def set_pretrained_eb_lr(args):
    if args.weights != None or args.pretrained != None:
        if args.model == 'resnet50':
            if 'cars' in args.data_path:
                pass
            elif 'cifar100' in args.data_path:
                args.lr_imp = 0.01
            elif 'cifar10' in args.data_path:
                args.lr_imp = 0.01
            elif 'caltech' in args.data_path:
                args.lr_imp = 0.00316
            elif 'pets' in args.data_path:
                args.lr_imp = 0.01
def set_pretrained_gmp_lr_and_freq(args):
    if args.weights != None or args.pretrained != None:
        if args.prune_freq_it <= 100 and args.prune_end_it == 10000:
            if args.model == 'resnet50':
                if 'cars' in args.data_path:
                    args.prune_freq_it = 50
                elif 'cifar100' in args.data_path:
                    args.lr_mask = 0.01
                    args.prune_freq_it = 2
                elif 'cifar10' in args.data_path:
                    args.lr_mask = 0.01
                    args.prune_freq_it = 50
                elif 'caltech' in args.data_path:
                    args.lr_mask = 0.00316
                    args.prune_freq_it = 2
                elif 'pets' in args.data_path:
                    args.lr_mask = 0.01
                    args.prune_freq_it = 2
            elif 'vit' in args.model:
                if 'cars' in args.data_path:
                    args.lr_mask = 0.00316
                    args.prune_freq_it = 100
                elif 'cifar100' in args.data_path:
                    args.lr_mask = 0.00316
                    args.prune_freq_it = 50
                elif 'cifar10' in args.data_path:
                    args.lr_mask = 0.01
                    args.prune_freq_it = 50
                elif 'caltech' in args.data_path:
                    args.lr_mask = 0.001
                    args.prune_freq_it = 100
                elif 'pets' in args.data_path:
                    args.lr_mask = 0.00316
                    args.prune_freq_it = 100
