from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
import numpy as np



def struc_in_layer(num_layers,net_shape,struc_idx):
    previous_strucs= 0
    struc_layer = 0
    for lay in range(num_layers):
        n_strucs = int(net_shape[lay][0])
        previous_strucs += n_strucs
        if previous_strucs >= struc_idx:
            struc_layer = lay
            break
    return struc_layer


    # implementation of structured pruning for some importance matrix
class struct_importance_pruning(prune.BasePruningMethod):
    """Use an importance vector to prune a NN"""

    def __init__(self, amount, shape, importance_matrix_idx, pruning_chanel_oh,shuffle_mask_layer = False,pruning_layer_cap = False):
        # Check range of validity of pruning amount
        self.amount = amount
        # shape is an array with a list of all the shapes of the nn
        self.num_layers = len(shape)
        self.net_shape = shape
        self.importance_matrix_idx = importance_matrix_idx
        self.pruning_chanel_oh = pruning_chanel_oh
        self.shuffle_mask_layer = shuffle_mask_layer
        self.pruning_layer_cap = pruning_layer_cap

    # we use unstructured even though it is actually structured to use global_unstructured function
    PRUNING_TYPE = 'unstructured'

    def compute_mask(self, t, default_mask):

        mask = default_mask.clone()

        struct_mask = np.ones_like(self.importance_matrix_idx)

        # create structured mask
        if self.pruning_layer_cap==True:
            pruned_structures = 0
            pr_idx = 0

            #delet later
            passes = 0
            strucs_per_layer = []
            for lay in range(self.num_layers):
                n_strucs = int(self.net_shape[lay][0])
                strucs_per_layer.append(n_strucs)
            strucs_per_layer = np.array(strucs_per_layer)
            pruned_strucs_per_layer = np.zeros_like(strucs_per_layer)
            while pruned_structures <= int(len(self.importance_matrix_idx) * self.amount):
                pruning_percentages = pruned_strucs_per_layer/strucs_per_layer
                current_struc_layer = struc_in_layer(self.num_layers,self.net_shape,self.importance_matrix_idx[pr_idx])
                if pruning_percentages[current_struc_layer] >= 0.95:
                    passes +=1
                    pass
                else:
                    pruned_strucs_per_layer[current_struc_layer] += 1
                    struct_mask[self.importance_matrix_idx[pr_idx]] = 0
                    pruned_structures += 1
                pr_idx +=1
            #only for testiing
            #print(pruning_percentages)
        else:
            for idx in range(int(len(self.importance_matrix_idx) * self.amount)):
                struct_mask[self.importance_matrix_idx[idx]] = 0

        # struct_mask shuffle here
        previous_strucs = 0
        for lay in range(self.num_layers):
            n_strucs = int(self.net_shape[lay][0])
            if self.shuffle_mask_layer == True:
                np.random.shuffle(struct_mask[previous_strucs:(previous_strucs+n_strucs)])
            self.pruning_chanel_oh.append(struct_mask[previous_strucs:(previous_strucs+n_strucs)])
            previous_strucs += n_strucs

        # transfare structured mask to unstructured mask
        previous_strucs = 0
        previous_weights = 0
        for lay in range(self.num_layers):

            strucs = int(self.net_shape[lay][0])
            para_per_struc = 1
            for sh in self.net_shape[lay][1:]:
                para_per_struc *= int(sh)

            # add statistics on removed nodes --------------------------

            for struc in range(strucs):
                if struct_mask[previous_strucs + struc] == 0:
                    # set weights of node to zero
                    mask[previous_weights + struc * para_per_struc:previous_weights + (struc + 1) * para_per_struc] = 0

            previous_strucs += strucs
            previous_weights += strucs * para_per_struc
        # print(np.sum(struct_mask))
        # print(len(mask.nonzero()))

        return mask


def importance_structured(module, name):
    struct_importance_pruning.apply(module, name)
    return module


