import torch
from torch.autograd import Variable
import numpy as np

def to_var(x, requires_grad = False, volatile = False):
    if torch.cuda.is_available():
        x = x.to(torch.device("cuda"))
    return Variable(x, requires_grad = requires_grad, volatile = volatile)

def generate_prune_masks(network, ratio, dev):
    all_weights = []
    for p in network.parameters():
        if len(p.data.size()) != 1:
            all_weights += list(p.cpu().data.abs().numpy().flatten())
        threshold = np.percentile(np.array(all_weights), ratio)

    weight_masks = []
    for p in network.parameters():
        if len(p.data.size()) != 1:
            pruned_inds = p.data.abs() > threshold
            weight_masks.append(pruned_inds.float())

    bias_masks = []
    for i in range(len(weight_masks)):
        mask = torch.ones(len(weight_masks[i]))
        for j in range(len(weight_masks[i])):
            if torch.sum(weight_masks[i][j]) == 0:
                mask[j] = torch.tensor(0.0)
        mask.to(dev)
        bias_masks.append(mask)
    del network
    return weight_masks, bias_masks

def generate_prune_masks_grasp(network, ratio, dev):
    all_weights = []
    for p in network.parameters():
        if len(p.data.size()) != 1:
            all_weights += list(p.cpu().data.numpy().flatten())
        threshold = np.percentile(np.array(all_weights), ratio)

    weight_masks = []
    for p in network.parameters():
        if len(p.data.size()) != 1:
            pruned_inds = p.data > threshold
            weight_masks.append(pruned_inds.float())

    bias_masks = []
    for i in range(len(weight_masks)):
        mask = torch.ones(len(weight_masks[i]))
        for j in range(len(weight_masks[i])):
            if torch.sum(weight_masks[i][j]) == 0:
                mask[j] = torch.tensor(0.0)
        mask.to(dev)
        bias_masks.append(mask)
    del network
    return weight_masks, bias_masks

def count_weights(net):
    num = 0
    for p in net.parameters():
        if len(p.data.size()) != 1:
            num = num + p.numel()
    return num

def count(weight_masks):
    n = 0
    for i in range(len(weight_masks)):
        n = n + torch.sum(weight_masks[i])
        print(n)

    return n

def dot_product(network, mask):
    n_layer = 0
    for p in network.parameters():
        if len(p.data.size()) != 1:
            p.data = p.data * mask[n_layer]
            n_layer = n_layer + 1
    return network