import math
from common_imports import np, torch, tqdm
from common_use_functions import softmax
from OOD_score_utils import log_sum_exponential_score, wasserstein_score, pruning_mask
from deep_KNN import faiss_knn_search

def DICE_weight_contribution_matrix(params, actLevels, last_hidden_layerId):
    """
    This function evaluates the weight contribution defined in the DICE paper.

    params: The parameters of the final linear layer.
    actLevels: The activation level information that contains also the original and predicted class.
    (generally should be the one of the correclty predicted training set examples.)
    """
    # Get the last hidden layer activation levels and weights
    feature_vecs = actLevels['actLevel'][last_hidden_layerId]
    weights = params['weight']
    # Evaluate the contribution matrix
    contrib_matrix = np.zeros(weights.shape)
    for classId in range(weights.shape[0]):
        current_class_weight = weights[classId]
        current_class_contrib = np.mean(current_class_weight*feature_vecs, axis=0)
        contrib_matrix[classId] = current_class_contrib

    return contrib_matrix

def DICE_weight_contribution_matrix_torch_ver(params, actLevels, last_hidden_layerId, batch_size=1000):
    """
    This function evaluates the weight contribution defined in the DICE paper.

    params: The parameters of the final linear layer.
    actLevels: The activation level information that contains also the original and predicted class.
    (generally should be the one of the correclty predicted training set examples.)
    batch_size: The number of examples to be processed in each batch.
    """
    # Get the last hidden layer activation levels and weights
    feature_vecs = actLevels['actLevel'][last_hidden_layerId]
    weights = params['weight']
    # Determine the number of batches
    nb_batches = math.ceil(feature_vecs.shape[0] / batch_size)
    nb_examples = feature_vecs.shape[0]
    # Move the data to GPU
    weights_gpu = torch.from_numpy(weights).cuda()
    # Evaluate the contribution matrix
    contrib_matrix_gpu = torch.zeros_like(weights_gpu).cuda()
    with torch.no_grad():
        for batch_index in tqdm(range(nb_batches), desc='Processed batches'):
            batch_start_pos = batch_index*batch_size
            batch_end_pos = min((batch_index+1)*batch_size, nb_examples)
            batch_feature_vecs = torch.from_numpy(feature_vecs[batch_start_pos:batch_end_pos]).cuda()
            # Evaluate the sum of the batch for each neuron
            ones_col = torch.ones(batch_feature_vecs.size(0), 1).cuda()
            batch_sum = torch.matmul(batch_feature_vecs.T, ones_col).squeeze(-1)
            # Evaluate the sum of the contribution
            batch_contrib = batch_sum * weights_gpu
            contrib_matrix_gpu += batch_contrib
        contrib_matrix_gpu = contrib_matrix_gpu / nb_examples

    # Move the result to GPU
    contrib_matrix = contrib_matrix_gpu.cpu().numpy()

    # Clean the memory
    del weights_gpu, batch_feature_vecs, contrib_matrix_gpu
    torch.cuda.empty_cache()

    return contrib_matrix

def get_logits_and_norm_logits(feature_vecs, params, weight_pruning_masks):
    """
    This function return the logits and their normalized version.

    feature_vecs: The penultimate layer's activation levels, i.e., the feature vectors. 
    params: The final linear parameters (weight and bias).
    weight_pruning_masks: The weight pruning masks.
    """
    # Evaluate the scores for the data from different classes
    weights = params['weight']
    bias = params['bias']
    logits = np.dot(feature_vecs, np.transpose(weights*weight_pruning_masks))+bias
    norm_logits = logits / np.linalg.norm(logits, axis=1, keepdims=True)

    return logits, norm_logits

def DICE_score_evaluation(feature_vecs, params, weight_pruning_masks):
    """
    This function calculates the OMS score from the DICE paper. (not the whole pipeline)

    feature_vecs: The penultimate layer's activation levels, i.e., the feature vectors. 
    params: The final linear parameters (weight and bias).
    weight_pruning_masks: The weight pruning masks.
    """
    # Evaluate the scores for the data from different classes
    weights = params['weight']
    bias = params['bias']
    logits_DICE = np.dot(feature_vecs, np.transpose(weights*weight_pruning_masks))+bias
    DICE_scores = log_sum_exponential_score(logits_DICE, sum_axis=1).reshape(-1) # Here the reshape may be unnecessary

    return DICE_scores

def LOG_score_evaluation(norm_logit_search_index, norm_logits, k, display=True):
    """
    This function calculates the wasserstein OMS score.

    norm_logit_search_index: The search index for the logits built with the training set examples. (Should be the one with inner-product)
    norm_logit_search_index: The logits to be searched.
    k: The number of considered nearst neighbors.
    """
    # Evaluate the scores for the data
    D, I = faiss_knn_search(norm_logit_search_index, norm_logits, k, display=display)
    LOG_scores = D[:,-1]

    return LOG_scores

def DICE_LOG_score_evaluation(logits, norm_logits, norm_logit_search_index, k, correct_factor):
    """
    This function calculates the wasserstein OMS score.

    norm_logit_search_index: The search index for the logits built with the training set examples. (Should be the one with inner-product)
    norm_logit_search_index: The normalized logits to be searched.
    logits: The original logits.
    k: The number of considered nearst neighbors.
    correct_factor: The correction factor for the quantity difference.
    """
    # Evaluate the LOG scores for the data
    D, I = faiss_knn_search(norm_logit_search_index, norm_logits, k, display=True)
    LOG_scores = D[:,-1]
     # Evaluate the DICE scores for the data
    DICE_scores = log_sum_exponential_score(logits, sum_axis=1).reshape(-1) # Here the reshape may be unnecessary
    # Combine the two scores
    DICE_LOG_scores = DICE_scores + correct_factor * LOG_scores

    return DICE_LOG_scores

def WS_score_evaluation(feature_vecs, params, weight_pruning_masks, nb_classes):
    """
    This function calculates the wasserstein OMS score. (Not performant)

    feature_vecs: The penultimate layer's activation levels, i.e., the feature vectors. 
    params: The final linear parameters (weight and bias).
    weight_pruning_masks: The weight pruning masks.
    nb_classes: The number of classes.
    """
    # Evaluate the scores for the data from different classes
    weights = params['weight']
    bias = params['bias']
    logits_WS = np.dot(feature_vecs, np.transpose(weights*weight_pruning_masks))+bias
    WS_scores = wasserstein_score(logits_WS, nb_classes).reshape(-1) # Here the reshape may be unnecessary

    return WS_scores

# def DICE_AJ_score_evaluation(feature_vecs, params, weight_pruning_masks, prob_threshold=0.95):
#     """
#     This function calculates the DICE OMS score with probability adjustment. (Not performant)

#     feature_vecs: The penultimate layer's activation levels, i.e., the feature vectors. 
#     params: The final linear parameters (weight and bias).
#     weight_pruning_masks: The weight pruning masks for different classes.
#     nb_classes: The number of classes.
#     prob_threshold: The threshold of probability to apply the probability adjustment.
#     """
#     # Evaluate the scores for the data from different classes
#     weights = params['weight']
#     bias = params['bias']
#     logits_DICE = np.dot(feature_vecs, np.transpose(weights*weight_pruning_masks))+bias
#     probs = softmax(logits_DICE)
#     DICE_scores = log_sum_exponential_score(logits_DICE, sum_axis=1).reshape(-1) # Here the reshape may be unnecessary
#     prob_adjust = np.max(probs, axis=1)
#     prob_adjust[prob_adjust > prob_threshold] = 1
#     DICE_AJ_scores = DICE_scores * prob_adjust 

#     return DICE_AJ_scores

def build_weight_pruning_masks_sig_ver(weight_shape, weight_contrib_matrix_sig, important_neurons, p):
    """
    This function builds the feature pruning masks.

    weight_shape: The original shape of the weight.
    weight_contrib_matrix_sig: The weight contribution matrix only for the significant neurons.
    important_neurons: A list containing the indices of the significant neurons.
    """
    weight_pruning_masks = np.zeros(weight_shape, dtype=int)
    pruning_masks_important_neurons = pruning_mask(weight_contrib_matrix_sig, p=p)
    weight_pruning_masks[:, important_neurons] = pruning_masks_important_neurons

    return weight_pruning_masks

def DICE_weight_contribution_matrix_sig_ver(params, feature_vecs, important_neurons):
    """
    This function evaluates the weight contribution related only to the significant neurons.

    params: The parameters of the final linear layer.
    actLevels: The used feature vectors. 
    (generally should be the one of the correclty predicted training set examples.)
    important_neurons: A list containing the indices of the significant neurons.

    Note: The returned contribution matrix will have the shape of [nb_classes, nb_important_neurons].
    """
    # Get the weights
    weights = params['weight']
    # Take only the significant neurons
    sig_weights = weights[:, important_neurons]
    sig_feature_vecs = feature_vecs[:, important_neurons]
    # Evaluate the contribution matrix
    sig_contrib_matrix = np.zeros(sig_weights.shape)
    for classId in range(sig_weights.shape[0]):
        current_class_weight = sig_weights[classId]
        current_class_contrib = np.mean(current_class_weight*sig_feature_vecs, axis=0)
        sig_contrib_matrix[classId] = current_class_contrib

    return sig_contrib_matrix

def DICE_weight_contribution_matrix_sig_and_torch_ver(params, feature_vecs, important_neurons, batch_size=1000):
    """
    This function evaluates the weight contribution related only to the significant neurons.

    params: The parameters of the final linear layer.
    feature_vecs: The used feature vectors. 
    (generally should be the one of the correclty predicted training set examples.)
    important_neurons: A list containing the indices of the significant neurons.
    batch_size: The size of each batch to be processed.

    Note: The returned contribution matrix will have the shape of [nb_classes, nb_important_neurons].
    And we use the pytorch to accelerate the calculation.
    """
    # Get the weights
    weights = params['weight']
    # Take only the significant neurons
    sig_weights = weights[:, important_neurons]
    sig_feature_vecs = feature_vecs[:, important_neurons]
    # Move the data to GPU
    sig_weights_gpu = torch.from_numpy(sig_weights).cuda()
    sig_feature_vecs_gpu = torch.from_numpy(sig_feature_vecs).cuda()
    # Determine the number of batches
    nb_batches = math.ceil(sig_feature_vecs.shape[0] / batch_size)
    nb_examples = sig_feature_vecs.shape[0]
    # Initialize the final contribution matrix
    sig_contrib_matrix_gpu = torch.zeros_like(sig_weights_gpu).cuda()
    ## Evaluate the contribution matrix
    with torch.no_grad():
        for batch_index in range(nb_batches):
            batch_start_pos = batch_index*batch_size
            batch_end_pos = min((batch_index+1)*batch_size, nb_examples)
            batch_feature_vecs = sig_feature_vecs_gpu[batch_start_pos:batch_end_pos]
            
            # Evaluate the sum of the batch for each neuron
            ones_col = torch.ones(batch_feature_vecs.size(0), 1).cuda()
            batch_sum = torch.matmul(batch_feature_vecs.T, ones_col).squeeze(-1)
            
            # Evaluate the sum of the contribution
            batch_contrib = batch_sum * sig_weights_gpu
            sig_contrib_matrix_gpu += batch_contrib
        sig_contrib_matrix_gpu = sig_contrib_matrix_gpu / nb_examples
    # Per-class version
    # for classId in range(sig_weights.shape[0]):
    #     # Torch acceleration part
    #     current_class_weight = sig_weights_gpu[classId]
    #     current_class_contrib = torch.zeros(sig_weights.shape[1]).cuda()
    #     for batch_index in range(nb_batches):
    #         batch_start_pos = batch_index*batch_size
    #         batch_end_pos = min((batch_index+1)*batch_size, nb_examples)
    #         current_batch_feature_vecs = sig_feature_vecs_gpu[batch_start_pos:batch_end_pos]
    #         current_batch_contrib = torch.sum(torch.mul(current_class_weight, current_batch_feature_vecs), dim=0)
    #         current_class_contrib += current_batch_contrib
    #     current_class_contrib = current_class_contrib / nb_examples
    #     sig_contrib_matrix_gpu[classId] = current_class_contrib
    
    # Move the result to GPU
    sig_contrib_matrix = sig_contrib_matrix_gpu.cpu().numpy()

    # Clean the memory
    del sig_weights_gpu, sig_feature_vecs_gpu, sig_contrib_matrix_gpu
    torch.cuda.empty_cache()

    return sig_contrib_matrix

def DICE_weight_contribution_matrix_sig_and_torch_ver_memory_save(params, feature_vecs, important_neurons, batch_size=1000):
    """
    This function evaluates the weight contribution related only to the significant neurons. (with memory saving)

    params: The parameters of the final linear layer.
    feature_vecs: The used feature vectors. 
    (generally should be the one of the correclty predicted training set examples.)
    important_neurons: A list containing the indices of the significant neurons.
    batch_size: The size of each batch to be processed.

    Note: The returned contribution matrix will have the shape of [nb_classes, nb_important_neurons].
    And we use the pytorch to accelerate the calculation.
    """
    # Get the weights
    weights = params['weight']
    # Take only the significant neurons
    sig_weights = weights[:, important_neurons]
    sig_feature_vecs = feature_vecs[:, important_neurons]
    # Move the data to GPU
    sig_weights_gpu = torch.from_numpy(sig_weights).cuda()
    sig_feature_vecs_gpu = torch.from_numpy(sig_feature_vecs)
    # Determine the number of batches
    nb_batches = math.ceil(sig_feature_vecs.shape[0] / batch_size)
    nb_examples = sig_feature_vecs.shape[0]
    # Initialize the final contribution matrix
    sig_contrib_matrix_gpu = torch.zeros_like(sig_weights_gpu).cuda()
    ## Evaluate the contribution matrix
    with torch.no_grad():
        for batch_index in range(nb_batches):
            batch_start_pos = batch_index*batch_size
            batch_end_pos = min((batch_index+1)*batch_size, nb_examples)
            batch_feature_vecs = sig_feature_vecs_gpu[batch_start_pos:batch_end_pos].cuda()
            
            # Evaluate the sum of the batch for each neuron
            ones_col = torch.ones(batch_feature_vecs.size(0), 1).cuda()
            batch_sum = torch.matmul(batch_feature_vecs.T, ones_col).squeeze(-1)
            
            # Evaluate the sum of the contribution
            batch_contrib = batch_sum * sig_weights_gpu
            sig_contrib_matrix_gpu += batch_contrib
        sig_contrib_matrix_gpu = sig_contrib_matrix_gpu / nb_examples
    
    # Move the result to GPU
    sig_contrib_matrix = sig_contrib_matrix_gpu.cpu().numpy()

    # Clean the memory
    del sig_weights_gpu, sig_feature_vecs_gpu, sig_contrib_matrix_gpu
    torch.cuda.empty_cache()

    return sig_contrib_matrix

def DICE_unified_weight_contribution_matrix_sig_and_torch_ver(params, feature_vecs, classId, important_neurons, batch_size=1000):
    """
    This function evaluates the weight contribution related only to the significant neurons.

    params: The parameters of the final linear layer.
    feature_vecs: The used feature vectors. 
    (generally should be the one of the correclty predicted training set examples.)
    classId: The id of the class to be evalauated.
    important_neurons: A list containing the indices of the significant neurons.
    batch_size: The size of each batch to be processed.

    Note: The returned contribution matrix will have the shape of [nb_classes, nb_important_neurons].
    And we use the pytorch to accelerate the calculation.
    """
    # Get the weights
    weights = params['weight']
    # Take only the significant neurons
    sig_weights = weights[classId, important_neurons]
    sig_feature_vecs = feature_vecs[:, important_neurons]
    # Move the data to GPU
    sig_weights_gpu = torch.from_numpy(sig_weights).cuda()
    sig_feature_vecs_gpu = torch.from_numpy(sig_feature_vecs).cuda()
    # Determine the number of batches
    nb_batches = math.ceil(sig_feature_vecs.shape[0] / batch_size)
    nb_examples = sig_feature_vecs.shape[0]
    # Initialize the final contribution vector
    sig_contrib_vec_gpu = torch.zeros_like(sig_weights_gpu).cuda()
    ## Evaluate the contribution vector
    with torch.no_grad():
        for batch_index in range(nb_batches):
            batch_start_pos = batch_index*batch_size
            batch_end_pos = min((batch_index+1)*batch_size, nb_examples)
            batch_feature_vecs = sig_feature_vecs_gpu[batch_start_pos:batch_end_pos]
            
            # Evaluate the sum of the batch for each neuron (This is the same as: batch_sum = torch.sum(batch_feature_vecs, dim=0))
            ones_col = torch.ones(batch_feature_vecs.size(0), 1).cuda()
            batch_sum = torch.matmul(batch_feature_vecs.T, ones_col).squeeze(-1)
            
            # Evaluate the sum of the contribution
            batch_contrib = batch_sum * sig_weights_gpu
            sig_contrib_vec_gpu += batch_contrib
        sig_contrib_vec_gpu = sig_contrib_vec_gpu / nb_examples
    
    # Move the result to CPU
    sig_contrib_vec = sig_contrib_vec_gpu.cpu().numpy()

    # Clean the memory
    del sig_weights_gpu, sig_feature_vecs_gpu, sig_contrib_vec_gpu
    torch.cuda.empty_cache()

    return sig_contrib_vec

def DICE_unified_weight_contribution_matrix_sig_and_torch_ver_memory_save(params, feature_vecs, classId, important_neurons, batch_size=1000):
    """
    This function evaluates the weight contribution related only to the significant neurons. (with memory saving)

    params: The parameters of the final linear layer.
    feature_vecs: The used feature vectors. 
    (generally should be the one of the correclty predicted training set examples.)
    classId: The id of the class to be evalauated.
    important_neurons: A list containing the indices of the significant neurons.
    batch_size: The size of each batch to be processed.

    Note: The returned contribution matrix will have the shape of [nb_classes, nb_important_neurons].
    And we use the pytorch to accelerate the calculation.
    """
    # Get the weights
    weights = params['weight']
    # Take only the significant neurons
    sig_weights = weights[classId, important_neurons]
    sig_feature_vecs = feature_vecs[:, important_neurons]
    # Move the data to GPU
    sig_weights_gpu = torch.from_numpy(sig_weights).cuda()
    sig_feature_vecs_gpu = torch.from_numpy(sig_feature_vecs)
    # Determine the number of batches
    nb_batches = math.ceil(sig_feature_vecs.shape[0] / batch_size)
    nb_examples = sig_feature_vecs.shape[0]
    # Initialize the final contribution vector
    sig_contrib_vec_gpu = torch.zeros_like(sig_weights_gpu).cuda()
    ## Evaluate the contribution vector
    with torch.no_grad():
        for batch_index in range(nb_batches):
            batch_start_pos = batch_index*batch_size
            batch_end_pos = min((batch_index+1)*batch_size, nb_examples)
            batch_feature_vecs = sig_feature_vecs_gpu[batch_start_pos:batch_end_pos].cuda()
            
            # Evaluate the sum of the batch for each neuron (This is the same as: batch_sum = torch.sum(batch_feature_vecs, dim=0))
            ones_col = torch.ones(batch_feature_vecs.size(0), 1).cuda()
            batch_sum = torch.matmul(batch_feature_vecs.T, ones_col).squeeze(-1)
            
            # Evaluate the sum of the contribution
            batch_contrib = batch_sum * sig_weights_gpu
            sig_contrib_vec_gpu += batch_contrib
        sig_contrib_vec_gpu = sig_contrib_vec_gpu / nb_examples
    
    # Move the result to CPU
    sig_contrib_vec = sig_contrib_vec_gpu.cpu().numpy()

    # Clean the memory
    del sig_weights_gpu, sig_feature_vecs_gpu, sig_contrib_vec_gpu
    torch.cuda.empty_cache()

    return sig_contrib_vec

def get_sensitivity_DICE_pruning_masks(params, feature_vecs, class_index_dict, sensitivity_contrib_matrix, p_n=20, p_w=20):
    """
    This function evaluates the weight contribution and pruning masks 
    required by the OMS score from the LINe paper. (not the whole pipeline)

    params: The final linear parameters (weight and bias).
    feature_vecs: The used feature vectors. 
    (generally should be the one of the correclty predicted training set examples.)
    class_index_dict: The correspondance dictionary for the examples of each class. 
    (It should be coherent with the feature vectors.)
    sensitivity_contrib_matrix: The matrix containing the sensitivity indices for different classes.
    p_n: The percentage used for the significant neurons.
    p_w: The percentage used for the significant weights.
    """
    # Get the list of all classes
    class_list = sorted(list(range(sensitivity_contrib_matrix.shape[0])))
    # Determine the feature pruning masks
    feature_pruning_masks = {}
    for classId in class_list:
        feature_pruning_masks[classId] = pruning_mask(sensitivity_contrib_matrix[classId], p_n)
    # Determine the weight contribution and pruning masks.
    weight_contributions = {}
    weight_pruning_masks = {}
    for classId in class_list:
        current_class_feature_vecs = feature_vecs[class_index_dict[classId]]
        current_class_important_neurons = [index for index, value in enumerate(feature_pruning_masks[classId]) if value == 1]
        weight_contributions[classId] = DICE_weight_contribution_matrix_sig_ver(params, current_class_feature_vecs,
                                                                                current_class_important_neurons)
        weight_pruning_masks[classId] = build_weight_pruning_masks_sig_ver(params['weight'].shape,
                                                                            weight_contributions[classId],
                                                                            current_class_important_neurons, p_w)

    return weight_contributions, feature_pruning_masks, weight_pruning_masks

def get_sensitivity_DICE_pruning_masks_torch_ver(params, feature_vecs, class_index_dict, sensitivity_contrib_matrix, p_n=20, p_w=20, display=True):
    """
    This function evaluates the weight contribution and pruning masks 
    required by the OMS score from the LINe paper. (not the whole pipeline)

    params: The final linear parameters (weight and bias).
    feature_vecs: The used feature vectors. 
    (generally should be the one of the correclty predicted training set examples.)
    class_index_dict: The correspondance dictionary for the examples of each class. 
    (It should be coherent with the feature vectors.)
    sensitivity_contrib_matrix: The matrix containing the sensitivity indices for different classes.
    p_n: The percentage used for the significant neurons.
    p_w: The percentage used for the significant weights.
    display: Boolean indicating if we would like to display a progress bar.

    Note: We use the pytorch to accelerate the calculation.
    """
    # Get the list of all classes
    class_list = sorted(list(range(sensitivity_contrib_matrix.shape[0])))
    # Determine the feature pruning masks
    feature_pruning_masks = {}
    for classId in class_list:
        feature_pruning_masks[classId] = pruning_mask(sensitivity_contrib_matrix[classId], p_n)
    # Determine the weight contribution and pruning masks.
    weight_pruning_masks = {}
    progress_bar = tqdm(class_list, desc='Processed classes (mask)') if display else class_list
    for classId in progress_bar:
        current_class_feature_vecs = feature_vecs[class_index_dict[classId]]
        current_class_important_neurons = [index for index, value in enumerate(feature_pruning_masks[classId]) if value == 1]
        current_class_weight_contributions = DICE_weight_contribution_matrix_sig_and_torch_ver(params, current_class_feature_vecs,
                                                                                current_class_important_neurons, batch_size=1000)
        weight_pruning_masks[classId] = build_weight_pruning_masks_sig_ver(params['weight'].shape,
                                                                            current_class_weight_contributions,
                                                                            current_class_important_neurons, p_w)

    return feature_pruning_masks, weight_pruning_masks

def get_sensitivity_unified_DICE_pruning_masks_torch_ver(params, feature_vecs, class_index_dict, sensitivity_contrib_matrix, p_n=20, p_w=20, display=True):
    """
    This function evaluates the weight contribution and pruning masks 
    required by the OMS score from the LINe paper. (not the whole pipeline)

    params: The final linear parameters (weight and bias).
    feature_vecs: The used feature vectors. 
    (generally should be the one of the correclty predicted training set examples.)
    class_index_dict: The correspondance dictionary for the examples of each class. 
    (It should be coherent with the feature vectors.)
    sensitivity_contrib_matrix: The matrix containing the sensitivity indices for different classes.
    p_n: The percentage used for the significant neurons.
    p_w: The percentage used for the significant weights.
    display: Boolean indicating if we would like to display a progress bar.

    Note: We use the pytorch to accelerate the calculation.
    """
    # Get the list of all classes
    class_list = sorted(list(range(sensitivity_contrib_matrix.shape[0])))
    # Determine the feature pruning masks
    feature_pruning_masks = {}
    for classId in class_list:
        feature_pruning_masks[classId] = pruning_mask(sensitivity_contrib_matrix[classId], p_n)
    # Determine the weight contribution and pruning masks.
    weight_contrib_matrix = np.zeros(params['weight'].shape)
    progress_bar = tqdm(class_list, desc='Processed classes (mask)') if display else class_list
    for classId in progress_bar:
        current_class_feature_vecs = feature_vecs[class_index_dict[classId]]
        current_class_important_neurons = [index for index, value in enumerate(feature_pruning_masks[classId]) if value == 1]
        current_class_contrib_vec = DICE_unified_weight_contribution_matrix_sig_and_torch_ver(params, current_class_feature_vecs, classId,
                                                                                current_class_important_neurons, batch_size=1000)
        weight_contrib_matrix[classId, current_class_important_neurons] = current_class_contrib_vec
    weight_pruning_masks = pruning_mask(weight_contrib_matrix, p=p_w)

    return feature_pruning_masks, weight_pruning_masks

def DICE_weight_contribution_matrix_sensitivity_reweighted(params, feature_vecs, sensitivity):
    """
    This function evaluates the sensitivity-reweighted weight contribution.

    params: The parameters of the final linear layer.
    actLevels: The used feature vectors. 
    (generally should be the one of the correclty predicted training set examples.)
    sensitivity: The sensitivity indices for the neurons.

    Note: The returned contribution matrix will have the shape of [nb_classes, nb_important_neurons].
    """
    # Get the weights
    weights = params['weight']
    # Evaluate the contribution matrix
    contrib_matrix = np.zeros(weights.shape)
    for classId in range(weights.shape[0]):
        current_class_weight = weights[classId]
        current_class_contrib = np.mean(current_class_weight*feature_vecs*sensitivity, axis=0)
        contrib_matrix[classId] = current_class_contrib

    return contrib_matrix

# def get_sensitivity_reweighted_DICE_pruning_masks(params, feature_vecs, class_index_dict, sensitivity_contrib_matrix, p_n=20, p_w=20):
#     """
#     This function evaluates the weight contribution and pruning masks 
#     required by the newly designed DICE OMS score. (Not performant)

#     params: The final linear parameters (weight and bias).
#     feature_vecs: The used feature vectors. 
#     (generally should be the one of the correclty predicted training set examples.)
#     class_index_dict: The correspondance dictionary for the examples of each class. 
#     (It should be coherent with the feature vectors. Generally related to the original class.)
#     sensitivity_contrib_matrix: The matrix containing the sensitivity indices for different classes.
#     p_n: The percentage used for the significant neurons.
#     p_w: The percentage used for the significant weights.
#     """
#     # Get the list of all classes
#     class_list = sorted(list(range(sensitivity_contrib_matrix.shape[0])))
#     # Determine the feature pruning masks
#     feature_pruning_masks = {}
#     for classId in class_list:
#         feature_pruning_masks[classId] = pruning_mask(sensitivity_contrib_matrix[classId], p_n)
#     # Determine the weight contribution and pruning masks.
#     weight_contributions = {}
#     weight_pruning_masks = {}
#     for classId in class_list:
#         current_class_feature_vecs = feature_vecs[class_index_dict[classId]]
#         current_class_sensitivity = sensitivity_contrib_matrix[classId]
#         weight_contributions[classId] = DICE_weight_contribution_matrix_sensitivity_reweighted(params, current_class_feature_vecs,
#                                                                                         current_class_sensitivity)
#         weight_pruning_masks[classId] = pruning_mask(weight_contributions[classId], p_w)

#     return weight_contributions, feature_pruning_masks, weight_pruning_masks

""""
The functions below are not performant while using the significant neurons.
"""
def build_feature_pruning_masks(nb_features, important_neurons):
    """
    This function builds the feature pruning masks. (Generally not used)

    nb_features: The number of features.
    important_neurons: A list containing the indices of the significant neurons.
    """
    feature_pruning_masks = np.zeros(nb_features, dtype=int)
    feature_pruning_masks[important_neurons] = 1

    return feature_pruning_masks

# def LINe_weight_contribution_matrix(params, sensitivity_indices):
#     """
#     This function evaluates the weight contribution related only to the significant neurons.

#     params: The parameters of the final linear layer.
#     actLevels: The activation level information that contains also the original and predicted class.
#     sensitivity_indices: The sensitivity indices for all neurons

#     Note: The returned contribution matrix will have the shape of [nb_classes, nb_important_neurons].
#     """
#     # Get the last hidden layer activation levels and weights
#     weights = params['weight']
#     # Evaluate the contribution matrix
#     contrib_matrix = weights*sensitivity_indices

#     return contrib_matrix

def get_logits_and_norm_logits_sig_ver(feature_vecs, params, feature_pruning_masks, weight_pruning_masks):
    """
    This function return the logits and their normalized version.

    feature_vecs: The penultimate layer's activation levels, i.e., the feature vectors. 
    params: The final linear parameters (weight and bias).
    feature_pruning_masks: The feature pruning masks.
    weight_pruning_masks: The weight pruning masks.
    """
    # Evaluate the scores for the data from different classes
    weights = params['weight']
    bias = params['bias']
    logits = np.dot(feature_vecs*feature_pruning_masks, np.transpose(weights*weight_pruning_masks))+bias
    norm_logits = logits / np.linalg.norm(logits, axis=1, keepdims=True)

    return logits, norm_logits

def DICE_score_evaluation_sig_ver(feature_vecs, params, feature_pruning_masks, weight_pruning_masks):
    """
    This function calculates the OMS score from the DICE paper. (not the whole pipeline)

    feature_vecs: The penultimate layer's activation levels, i.e., the feature vectors. 
    params: The final linear parameters (weight and bias).
    feature_pruning_masks: The feature pruning masks.
    weight_pruning_masks: The weight pruning masks.
    """
    # Evaluate the scores for the data from different classes
    weights = params['weight']
    bias = params['bias']
    logits_DICE = np.dot(feature_vecs*feature_pruning_masks, np.transpose(weights*weight_pruning_masks))+bias
    DICE_scores = log_sum_exponential_score(logits_DICE, sum_axis=1).reshape(-1) # Here the reshape may be unnecessary

    return DICE_scores


def DICE_score_evaluation_logit_ver(logits):
    """
    This function calculates the OMS score from the DICE paper.

    logits: The evaluated logits.
    """
    # Evaluate the scores for the data from different classes
    DICE_scores = log_sum_exponential_score(logits, sum_axis=1).reshape(-1) # Here the reshape may be unnecessary

    return DICE_scores

