import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import copy
import types


"""
You can test your own saliency criterion by outputting the score for each parameter in the dictionary format.

Example)
saliency[<param's name>] = np.random.random(<param's shape>)
"""


def get_saliency_your_criterion(args, origin_model, batch, device, num_alive, alive_idx, \
                                        alive_mask, data_shape, num_classes):
    saliency = {}

    # Prune all params by prune percent
    for name, p in origin_model.named_parameters():
        weight = p.data.cpu().numpy()
        w_rand = np.random.random(weight.shape)
        saliency[name] = np.abs(w_rand)

    return saliency