import torch
import torch.nn as nn
import numpy as np
import copy


"""
Random Pruning (Baseline)
"""


def get_saliency_random(args, origin_model, batch, device, num_alive, alive_idx, \
                                        alive_mask, data_shape, num_classes):
    saliency = {}

    # Prune all params by prune percent
    for name, p in origin_model.named_parameters():
        weight = p.data.cpu().numpy()
        w_rand = np.random.random(weight.shape)
        saliency[name] = np.abs(w_rand)

    return saliency