import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import copy


def prune_saliency(args, saliency, mask):
    for name in saliency:
        if name in mask:
            saliency[name] *= np.where(mask[name] == 0, 0, 1)
    
    return saliency


def get_mask_by_saliency(args, origin_saliency, batch, device, num_alive, alive_idx, \
                                            alive_mask, data_shape, num_classes, masked=None):
    mask = {}
    global_s = np.array([])

    saliency = copy.deepcopy(origin_saliency)

    # Make masked saliency to zero
    if masked is not None:
        for name in masked:
            if name in saliency:
                saliency[name] *= np.where(masked[name] == 0, 0, 1)
            else:
                global_s = np.append(global_s, np.zeros(masked[name].size - np.count_nonzero(masked[name])))

    for name in saliency:
        global_s = np.append(global_s, np.abs(saliency[name][alive_idx[name]]))

    # Get cut-off
    cutoff_index = np.round(0.01 * args.prune_percent * num_alive).astype(int)
    cutoff = np.partition(global_s, cutoff_index - 1, axis=None)[cutoff_index - 1]

    for name in alive_mask:
        mask[name] = copy.deepcopy(alive_mask[name])

        if name in saliency:
            mask[name] *= np.where(np.abs(saliency[name]) <= cutoff, 0, 1)

        # Add mask in masked
        if masked is not None:
            if name in masked:
                mask[name] *= np.where(masked[name] == 0, 0, 1)

    return mask