import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import copy
import types


"""
Magnitude Pruning
"""


def get_saliency_weight_magnitude(args, origin_model, batch, device, num_alive, alive_idx, \
                                                        alive_mask, data_shape, num_classes):
    saliency = {}

    # Make new mask
    for name, p in origin_model.named_parameters():
        tensor = p.data.cpu().numpy() 
        saliency[name] = np.abs(tensor)

    return saliency