from common_imports import np, copy, tqdm, torch, math, gc
from common_use_functions import softmax, class_index_dict_build
from sensitivity_analysis import evaluate_y


def shannon_entropy(prob_distrib, base=2):
    """
    This function computes the shannon entropy on the pytorch platform.
    
    prob_distrib: The probability distribution.
    base: The base for the entropy computation.
    """
    # Erase the terms where the probabilities are zero.
    non_zero_probs = prob_distrib[prob_distrib > 0]
    
    # Compute the entropy
    if base == 2:
        entropy = -torch.sum(non_zero_probs * torch.log2(non_zero_probs))
    elif base == 'e' or base is None:
        entropy = -torch.sum(non_zero_probs * torch.log(non_zero_probs))
    else:
        entropy = -torch.sum(non_zero_probs * torch.log(non_zero_probs)) / torch.log(torch.tensor(base))
    
    return entropy

def shannon_entropy_base_2(prob_distrib, dim=None):
    """
    This function computes the shannon entropy of base 2 on the pytorch platform.
    
    prob_distrib: The probability distribution.
    """
    # Mask zero probabilities to avoid log(0)
    masked_probs = torch.where(prob_distrib > 0, prob_distrib, torch.zeros_like(prob_distrib))
    
    # Compute entropy: H = -sum(p * log2(p)) over the specified dimension
    entropy = -torch.sum(masked_probs * torch.log2(masked_probs+1e-10), dim=dim)
    
    return entropy

def shannon_entropy_base_e(prob_distrib, dim=None):
    """
    This function computes the shannon entropy of base e on the pytorch platform.
    
    prob_distrib: The probability distribution.
    """
    # Mask zero probabilities to avoid log(0)
    masked_probs = torch.where(prob_distrib > 0, prob_distrib, torch.zeros_like(prob_distrib))
    
    # Compute entropy: H = -sum(p * log2(p)) over the specified dimension
    entropy = -torch.sum(masked_probs * torch.log(masked_probs+1e-10), dim=dim)
    
    return entropy

def info_base_2(prob_distrib, dim=None):
    """
    This function computes the information of base 2 on the pytorch platform.
    
    prob_distrib: The probability distribution.
    """
    # Mask zero probabilities to avoid log(0)
    masked_probs = torch.where(prob_distrib > 0, prob_distrib, torch.zeros_like(prob_distrib))
    
    # Compute entropy: H = -sum(p * log2(p)) over the specified dimension
    entropy = -torch.sum(torch.log2(masked_probs+1e-10), dim=dim)
    
    return entropy

def shapley_score_evaluation(actLevels, params, last_hidden_layerId, predict_class_ver=False):
    """
    This function evaluates the shapley score according to the taylor formula.

    actLevels: The activation level information that contains also the original and predicted class.
    params: The parameters in the final linear layer.
    last_hidden_layerId: The id of the last hidden layer.
    predict_class_ver: Build the class index dictionary with the predicted class or the original class.
    """
    # Determine the referenced class tye
    ref_class_type = 'class'
    if predict_class_ver:
        ref_class_type = 'predict_class'
    # Get the referred class for all examples
    ref_classes = actLevels[ref_class_type].reshape(-1)
    # Get the label corresponding dictionary
    # With the class_index_dict_build function, the classes (i.e., keys) are already ordered (i.e., from 0 to "number of classes-1").
    class_index_dict = class_index_dict_build(ref_classes) 
    # Get the activation levels of the last hidden layer
    last_hidden_actLevels = actLevels['actLevel'][last_hidden_layerId]
    # The number of neurons
    nb_neurons = last_hidden_actLevels.shape[1]
    # # Evaluates the shapley scores per example for different classes 
    # # (Absolute value position according to the code, we suppose the absolute is applied after the mean calculation
    # # and the absolute operation is applied for both neuron and weight contribution.)
    # shapley_score_matrix = []
    # for classId in tqdm(class_index_dict.keys(), desc='Processed classes'):
    #     current_class_actLevels = last_hidden_actLevels[class_index_dict[classId]]
    #     original_y = evaluate_y(params, current_class_actLevels)[:, classId].reshape(-1)
    #     # The scores for the current class
    #     class_example_scores = []
    #     # Evaluate multiple examples in the same time
    #     for neuronId in range(nb_neurons):
    #         copied_actLevels = copy.deepcopy(current_class_actLevels)
    #         copied_actLevels[:, neuronId] = 0
    #         neuron_example_scores = original_y - evaluate_y(params, copied_actLevels)[:, classId].reshape(-1)
    #         class_example_scores.append(neuron_example_scores)
    #     class_example_scores = np.transpose(np.array(class_example_scores))
    #     # # Evaluate each example individually
    #     # for index, actLevel in enumerate(current_class_actLevels):
    #     #     current_example_scores = []
    #     #     for neuronId in range(nb_neurons):
    #     #         copied_actLevel = copy.deepcopy(actLevel)
    #     #         copied_actLevel[neuronId] = 0
    #     #         current_neuron_score = original_y[index] - evaluate_y(params, copied_actLevel)[classId]
    #     #         current_example_scores.append(current_neuron_score)
    #     #     class_example_scores.append(current_example_scores)
    #     # class_example_scores = np.array(class_example_scores)
        
    #     # Evaluate the shapley score of the current class
    #     shapley_score_matrix.append(np.mean(class_example_scores, axis=0))
    # shapley_score_matrix = np.abs(np.array(shapley_score_matrix)) 

     # Evaluates the shapley scores per example for different classes (Absolute value position according to the paper)
    shapley_score_matrix = []
    for classId in tqdm(class_index_dict.keys(), desc='Processed classes'):
        current_class_actLevels = last_hidden_actLevels[class_index_dict[classId]] # Original suggestion in the paper
        # current_class_actLevels = last_hidden_actLevels
        original_y = evaluate_y(params, current_class_actLevels)[:, classId].reshape(-1)
        # The scores for the current class
        class_example_scores = []
        # Evaluate multiple examples in the same time
        for neuronId in range(nb_neurons):
            copied_actLevels = copy.deepcopy(current_class_actLevels)
            copied_actLevels[:, neuronId] = 0
            neuron_example_scores = np.abs(original_y - evaluate_y(params, copied_actLevels)[:, classId].reshape(-1))
            class_example_scores.append(neuron_example_scores)
        class_example_scores = np.transpose(np.array(class_example_scores))
        
        # Evaluate the shapley score of the current class
        shapley_score_matrix.append(np.mean(class_example_scores, axis=0))
    shapley_score_matrix = np.array(shapley_score_matrix) # Absolute value position according to the code

    return shapley_score_matrix

def shapley_score_evaluation_GPU(actLevels, params, last_hidden_layerId, predict_class_ver=False, block_size=256):
    """
    This function evaluates the shapley score according to the taylor formula.

    actLevels: The activation level information that contains also the original and predicted class.
    params: The parameters in the final linear layer.
    last_hidden_layerId: The id of the last hidden layer.
    predict_class_ver: Build the class index dictionary with the predicted class or the original class.
    block_size: The number of neurons to be evaluated in each block.
    """
    # Determine the referenced class tye
    ref_class_type = 'class'
    if predict_class_ver:
        ref_class_type = 'predict_class'
    # Get the referred class for all examples
    ref_classes = actLevels[ref_class_type].reshape(-1)
    # Get the label corresponding dictionary
    # With the class_index_dict_build function, the classes (i.e., keys) are already ordered (i.e., from 0 to "number of classes-1").
    class_index_dict = class_index_dict_build(ref_classes) 
    # Get the activation levels of the last hidden layer and move it to GPU
    last_hidden_actLevels = torch.from_numpy(actLevels['actLevel'][last_hidden_layerId]).float().cuda()
    # The number of neurons
    nb_neurons = last_hidden_actLevels.shape[1]
    # Take the weight and the bias and move it to GPU
    W = torch.from_numpy(params['weight']).float().cuda()
    b = torch.from_numpy(params['bias']).float().cuda()
    # Pre-allocate result tensor on GPU
    shapley_score_matrix_GPU = torch.zeros((len(class_index_dict), nb_neurons)).cuda()
    # Evaluates the shapley scores per example for different classes (Absolute value position according to the paper)
    with torch.no_grad():
        for classId in tqdm(class_index_dict.keys(), desc='Processed classes'):
            # Take the current class activation levels
            current_class_actLevels = last_hidden_actLevels[class_index_dict[classId]] # Original suggestion in the paper
            # Get the original y values
            original_y = (current_class_actLevels @ W.T + b)[:, classId]
            # Get the number of blocks
            nb_blocks = math.ceil(nb_neurons / block_size)
            # Iterate over the blocks
            for block_id in range(nb_blocks):
                # Determine the current block size
                block_start = block_id * block_size
                block_end = min(block_start + block_size, nb_neurons)
                current_block_size = block_end - block_start
                # Get the current block masks
                masks = torch.ones(current_block_size, nb_neurons).cuda()
                for i in range(current_block_size):
                    masks[i, block_start + i] = 0
                # Get the masked activation levels
                masked_actLevels = current_class_actLevels.unsqueeze(0) * masks.unsqueeze(1)
                # Compute all masked outputs in parallel
                masked_outputs = torch.matmul(masked_actLevels, W.T) + b
                masked_class_outputs = masked_outputs[:, :, classId]
                # Compute scores for all neurons
                scores = torch.abs(original_y.unsqueeze(0) - masked_class_outputs)
                block_scores = torch.mean(scores, dim=1)
                # Save the result
                shapley_score_matrix_GPU[classId, block_start:block_end] = block_scores

    # Final result
    shapley_score_matrix = shapley_score_matrix_GPU.cpu().numpy()

    # Clean the cache
    del shapley_score_matrix_GPU, last_hidden_actLevels, W, b
    torch.cuda.empty_cache()
    gc.collect()

    return shapley_score_matrix

def unified_entropy_score_evaluation_GPU(actLevels, params, last_hidden_layerId, batch_size=1000, block_size=256):
    """
    This function evaluates the entropy scores to evaluate the global importances of the neurons.

    actLevels: The activation level information that contains also the original and predicted class.
    params: The parameters in the final linear layer.
    last_hidden_layerId: The id of the last hidden layer.
    batch_size: The number of examples to be evaluated in each block.
    block_size: The number of neurons to be evaluated in each block.
    """
    # Get the activation levels of the last hidden layer and move it to GPU
    last_hidden_actLevels = torch.from_numpy(actLevels['actLevel'][last_hidden_layerId]).float()
    # The numbers of examples and neurons
    nb_examples = last_hidden_actLevels.shape[0]
    nb_neurons = last_hidden_actLevels.shape[1]
    # Take the weight and the bias and move it to GPU
    W = torch.from_numpy(params['weight']).float().cuda()
    b = torch.from_numpy(params['bias']).float().cuda()
    # Pre-allocate result tensor on GPU
    entropy_score_matrix_GPU = torch.zeros(nb_neurons).cuda()
    # Evaluates the shapley scores per example for different classes (Absolute value position according to the paper)
    with torch.no_grad():
        # Get the number of blocks
        nb_batches = math.ceil(nb_examples / batch_size)
        for batch_id in tqdm(range(nb_batches), desc='Processed batches'):
            # Take the activation levels of the current batch
            batch_start = batch_id * batch_size
            batch_end = min(batch_start + batch_size, nb_examples)
            batch_actLevels = last_hidden_actLevels[batch_start:batch_end].cuda()
            # Get the original y values
            original_logit = batch_actLevels @ W.T + b
            # original_norm_logit = original_logit / torch.norm(original_logit, dim=1, keepdim=True)
            original_probs = torch.softmax(original_logit, dim=1)
            original_entropy = shannon_entropy_base_2(original_probs, dim=1)
            # Get the number of blocks
            nb_blocks = math.ceil(nb_neurons / block_size)
            # Iterate over the blocks
            for block_id in range(nb_blocks):
                # Determine the current block size
                block_start = block_id * block_size
                block_end = min(block_start + block_size, nb_neurons)
                current_block_size = block_end - block_start
                # Get the current block masks
                masks = torch.ones(current_block_size, nb_neurons).cuda()
                masks[torch.arange(current_block_size), block_start + torch.arange(current_block_size)] = 0
                # Get the masked activation levels
                masked_actLevels = batch_actLevels.unsqueeze(0) * masks.unsqueeze(1)
                # Compute all masked outputs in parallel
                masked_logits = torch.matmul(masked_actLevels, W.T) + b
                # masked_norm_logits = masked_logits / torch.norm(masked_logits, dim=2, keepdim=True)
                masked_probs = torch.softmax(masked_logits, dim=2)
                masked_entropy = shannon_entropy_base_2(masked_probs, dim=2)
                # Compute scores for all neurons
                block_scores = torch.sum(torch.abs(masked_entropy - original_entropy.unsqueeze(0)), dim=1)
                # Save the result
                entropy_score_matrix_GPU[block_start:block_end] += block_scores
        # Take the average values in the end
        entropy_score_matrix_GPU = entropy_score_matrix_GPU / nb_examples

    # Final result
    entropy_score_matrix = entropy_score_matrix_GPU.cpu().numpy()

    # Clean the cache
    del entropy_score_matrix_GPU, last_hidden_actLevels, W, b
    torch.cuda.empty_cache()
    gc.collect()

    return entropy_score_matrix

def entropy_score_evaluation_GPU(actLevels, params, last_hidden_layerId, predict_class_ver=False, block_size=256):
    """
    This function evaluates the entropy scores to evaluate the global importances of the neurons.

    actLevels: The activation level information that contains also the original and predicted class.
    params: The parameters in the final linear layer.
    last_hidden_layerId: The id of the last hidden layer.
    predict_class_ver: Build the class index dictionary with the predicted class or the original class.
    block_size: The number of neurons to be evaluated in each block.
    """
    # Determine the referenced class tye
    ref_class_type = 'class'
    if predict_class_ver:
        ref_class_type = 'predict_class'
    # Get the referred class for all examples
    ref_classes = actLevels[ref_class_type].reshape(-1)
    # Get the label corresponding dictionary
    # With the class_index_dict_build function, the classes (i.e., keys) are already ordered (i.e., from 0 to "number of classes-1").
    class_index_dict = class_index_dict_build(ref_classes) 
    # Get the activation levels of the last hidden layer and move it to GPU
    last_hidden_actLevels = torch.from_numpy(actLevels['actLevel'][last_hidden_layerId]).float()
    # The numbers of examples and neurons
    nb_examples = last_hidden_actLevels.shape[0]
    nb_neurons = last_hidden_actLevels.shape[1]
    # Take the weight and the bias and move it to GPU
    W = torch.from_numpy(params['weight']).float().cuda()
    b = torch.from_numpy(params['bias']).float().cuda()
    # Pre-allocate result tensor on GPU
    entropy_score_matrix_GPU = torch.zeros((len(class_index_dict), nb_neurons)).cuda()
    # Evaluates the shapley scores per example for different classes (Absolute value position according to the paper)
    with torch.no_grad():
        # Iterate over the classes
        for classId in tqdm(class_index_dict.keys(), desc='Processed classes'):
            # Take the current class activation levels
            class_actLevels = last_hidden_actLevels[class_index_dict[classId]].cuda() # Original suggestion in the paper
            # Get the original y values
            original_logit = class_actLevels @ W.T + b
            original_logit = original_logit - torch.min(original_logit, dim=1, keepdim=True)[0]
            original_logit = torch.clamp(original_logit, min=0)
            # original_norm_logit = original_logit / torch.sum(original_logit, dim=1, keepdim=True)
            original_norm_logit = original_logit / torch.norm(original_logit, dim=1, keepdim=True)
            # original_norm_logit = torch.softmax(original_logit, dim=1)
            original_entropy = (1 - original_norm_logit[:, classId]) * shannon_entropy_base_2(original_norm_logit, dim=1)
            # Get the number of blocks
            nb_blocks = math.ceil(nb_neurons / block_size)
            # Iterate over the blocks
            for block_id in range(nb_blocks):
                # Determine the current block size
                block_start = block_id * block_size
                block_end = min(block_start + block_size, nb_neurons)
                current_block_size = block_end - block_start
                # Get the current block masks
                masks = torch.ones(current_block_size, nb_neurons).cuda()
                masks[torch.arange(current_block_size), block_start + torch.arange(current_block_size)] = 0
                # Get the masked activation levels
                masked_actLevels = class_actLevels.unsqueeze(0) * masks.unsqueeze(1)
                # Compute all masked outputs in parallel
                masked_logits = torch.matmul(masked_actLevels, W.T) + b
                masked_logits = masked_logits - torch.min(masked_logits, dim=2, keepdim=True)[0]
                masked_logits = torch.clamp(masked_logits, min=0)
                # masked_norm_logits = masked_logits / torch.sum(masked_logits, dim=2, keepdim=True)
                masked_norm_logits = masked_logits / torch.norm(masked_logits, dim=2, keepdim=True)
                # masked_norm_logits = torch.softmax(masked_logits, dim=2)
                masked_entropy = (1 - masked_norm_logits[:, :, classId]) * shannon_entropy_base_2(masked_norm_logits, dim=2)
                # Compute scores for all neurons
                block_scores = torch.mean(torch.abs(masked_entropy - original_entropy.unsqueeze(0)), dim=1)
                # Save the result
                entropy_score_matrix_GPU[classId, block_start:block_end] = block_scores

    # Final result
    entropy_score_matrix = entropy_score_matrix_GPU.cpu().numpy()

    # Clean the cache
    del entropy_score_matrix_GPU, last_hidden_actLevels, W, b
    torch.cuda.empty_cache()
    gc.collect()

    return entropy_score_matrix

def prob_shapley_score_evaluation_GPU(actLevels, params, last_hidden_layerId, predict_class_ver=False, block_size=256):
    """
    This function evaluates the shapley scores adjusted with probabilities to evaluate the global importances of the neurons.

    actLevels: The activation level information that contains also the original and predicted class.
    params: The parameters in the final linear layer.
    last_hidden_layerId: The id of the last hidden layer.
    predict_class_ver: Build the class index dictionary with the predicted class or the original class.
    block_size: The number of neurons to be evaluated in each block.
    """
    # Determine the referenced class tye
    ref_class_type = 'class'
    if predict_class_ver:
        ref_class_type = 'predict_class'
    # Get the referred class for all examples
    ref_classes = actLevels[ref_class_type].reshape(-1)
    # Get the label corresponding dictionary
    # With the class_index_dict_build function, the classes (i.e., keys) are already ordered (i.e., from 0 to "number of classes-1").
    class_index_dict = class_index_dict_build(ref_classes) 
    # Get the activation levels of the last hidden layer and move it to GPU
    last_hidden_actLevels = torch.from_numpy(actLevels['actLevel'][last_hidden_layerId]).float()
    # The numbers of examples and neurons
    nb_examples = last_hidden_actLevels.shape[0]
    nb_neurons = last_hidden_actLevels.shape[1]
    # Take the weight and the bias and move it to GPU
    W = torch.from_numpy(params['weight']).float().cuda()
    b = torch.from_numpy(params['bias']).float().cuda()
    # Pre-allocate result tensor on GPU
    prob_shapley_score_matrix_GPU = torch.zeros((len(class_index_dict), nb_neurons)).cuda()
    # Evaluates the shapley scores per example for different classes (Absolute value position according to the paper)
    with torch.no_grad():
        # Iterate over the classes
        for classId in tqdm(class_index_dict.keys(), desc='Processed classes'):
            # Take the current class activation levels
            class_actLevels = last_hidden_actLevels[class_index_dict[classId]].cuda() # Original suggestion in the paper
            # Get the original y values
            original_logit = class_actLevels @ W.T + b
            original_y = original_logit[:, classId]
            original_prob = torch.softmax(original_logit, dim=1)
            # original_influence = (1 - original_prob[:, classId]) * original_y
            original_influence = original_prob[:, classId] * original_y
            # Get the number of blocks
            nb_blocks = math.ceil(nb_neurons / block_size)
            # Iterate over the blocks
            for block_id in range(nb_blocks):
                # Determine the current block size
                block_start = block_id * block_size
                block_end = min(block_start + block_size, nb_neurons)
                current_block_size = block_end - block_start
                # Get the current block masks
                masks = torch.ones(current_block_size, nb_neurons).cuda()
                masks[torch.arange(current_block_size), block_start + torch.arange(current_block_size)] = 0
                # Get the masked activation levels
                masked_actLevels = class_actLevels.unsqueeze(0) * masks.unsqueeze(1)
                # Compute all masked outputs in parallel
                masked_logits = torch.matmul(masked_actLevels, W.T) + b
                masked_y = masked_logits[:, :, classId]
                masked_probs = torch.softmax(masked_logits, dim=2)
                # masked_influence = (1 - masked_probs[:, :, classId]) * masked_y
                masked_influence = masked_probs[:, :, classId] * masked_y
                # Compute scores for all neurons
                block_scores = torch.mean(torch.abs(masked_influence - original_influence.unsqueeze(0)), dim=1)
                # Save the result
                prob_shapley_score_matrix_GPU[classId, block_start:block_end] = block_scores

    # Final result
    prob_shapley_score_matrix = prob_shapley_score_matrix_GPU.cpu().numpy()

    # Clean the cache
    del prob_shapley_score_matrix_GPU, last_hidden_actLevels, W, b
    torch.cuda.empty_cache()
    gc.collect()

    return prob_shapley_score_matrix

def entropy_shapley_score_evaluation_GPU(actLevels, params, last_hidden_layerId, predict_class_ver=False, block_size=256):
    """
    This function evaluates the shapley scores adjusted with probabilities to evaluate the global importances of the neurons.

    actLevels: The activation level information that contains also the original and predicted class.
    params: The parameters in the final linear layer.
    last_hidden_layerId: The id of the last hidden layer.
    predict_class_ver: Build the class index dictionary with the predicted class or the original class.
    block_size: The number of neurons to be evaluated in each block.
    """
    # Determine the referenced class tye
    ref_class_type = 'class'
    if predict_class_ver:
        ref_class_type = 'predict_class'
    # Get the referred class for all examples
    ref_classes = actLevels[ref_class_type].reshape(-1)
    # Get the label corresponding dictionary
    # With the class_index_dict_build function, the classes (i.e., keys) are already ordered (i.e., from 0 to "number of classes-1").
    class_index_dict = class_index_dict_build(ref_classes) 
    # Get the activation levels of the last hidden layer and move it to GPU
    last_hidden_actLevels = torch.from_numpy(actLevels['actLevel'][last_hidden_layerId]).float()
    # The numbers of examples and neurons
    nb_examples = last_hidden_actLevels.shape[0]
    nb_neurons = last_hidden_actLevels.shape[1]
    # Take the weight and the bias and move it to GPU
    W = torch.from_numpy(params['weight']).float().cuda()
    b = torch.from_numpy(params['bias']).float().cuda()
    # Pre-allocate result tensor on GPU
    entropy_shapley_score_matrix_GPU = torch.zeros((len(class_index_dict), nb_neurons)).cuda()
    # Evaluates the shapley scores per example for different classes (Absolute value position according to the paper)
    with torch.no_grad():
        # Iterate over the classes
        for classId in tqdm(class_index_dict.keys(), desc='Processed classes'):
            # Take the current class activation levels
            class_actLevels = last_hidden_actLevels[class_index_dict[classId]].cuda() # Original suggestion in the paper
            # Get the original y values
            original_logit = class_actLevels @ W.T + b
            original_y = original_logit[:, classId]
            original_prob = torch.softmax(original_logit, dim=1)
            original_entropy = shannon_entropy_base_2(original_prob, dim=1)
            original_influence = original_y
            # Get the number of blocks
            nb_blocks = math.ceil(nb_neurons / block_size)
            # Iterate over the blocks
            for block_id in range(nb_blocks):
                # Determine the current block size
                block_start = block_id * block_size
                block_end = min(block_start + block_size, nb_neurons)
                current_block_size = block_end - block_start
                # Get the current block masks
                masks = torch.ones(current_block_size, nb_neurons).cuda()
                masks[torch.arange(current_block_size), block_start + torch.arange(current_block_size)] = 0
                # Get the masked activation levels
                masked_actLevels = class_actLevels.unsqueeze(0) * masks.unsqueeze(1)
                # Compute all masked outputs in parallel
                masked_logits = torch.matmul(masked_actLevels, W.T) + b
                masked_y = masked_logits[:, :, classId]
                masked_probs = torch.softmax(masked_logits, dim=2)
                masked_entropy = shannon_entropy_base_2(masked_probs, dim=2)
                masked_influence = masked_y
                # Compute scores for all neurons
                block_scores = torch.mean(torch.abs(masked_entropy - original_entropy.unsqueeze(0)) 
                                          * torch.abs(masked_influence - original_influence.unsqueeze(0)), dim=1)
                # Save the result
                entropy_shapley_score_matrix_GPU[classId, block_start:block_end] = block_scores

    # Final result
    entropy_shapley_score_matrix = entropy_shapley_score_matrix_GPU.cpu().numpy()

    # Clean the cache
    del entropy_shapley_score_matrix_GPU, last_hidden_actLevels, W, b
    torch.cuda.empty_cache()
    gc.collect()

    return entropy_shapley_score_matrix

def shapley_score_evaluation_GPU_memory_save(actLevels, params, last_hidden_layerId, predict_class_ver=False, block_size=256):
    """
    This function evaluates the shapley score according to the taylor formula. (memory-saving version)

    actLevels: The activation level information that contains also the original and predicted class.
    params: The parameters in the final linear layer.
    last_hidden_layerId: The id of the last hidden layer.
    predict_class_ver: Build the class index dictionary with the predicted class or the original class.
    block_size: The number of neurons to be evaluated in each block.
    """
    # Determine the referenced class tye
    ref_class_type = 'class'
    if predict_class_ver:
        ref_class_type = 'predict_class'
    # Get the referred class for all examples
    ref_classes = actLevels[ref_class_type].reshape(-1)
    # Get the label corresponding dictionary
    # With the class_index_dict_build function, the classes (i.e., keys) are already ordered (i.e., from 0 to "number of classes-1").
    class_index_dict = class_index_dict_build(ref_classes) 
    # Get the activation levels of the last hidden layer and move it to GPU
    last_hidden_actLevels = torch.from_numpy(actLevels['actLevel'][last_hidden_layerId]).float()
    # The number of neurons
    nb_neurons = last_hidden_actLevels.shape[1]
    # Take the weight and the bias and move it to GPU
    W = torch.from_numpy(params['weight']).float().cuda()
    b = torch.from_numpy(params['bias']).float().cuda()
    # Pre-allocate result tensor on GPU
    shapley_score_matrix_GPU = torch.zeros((len(class_index_dict), nb_neurons)).cuda()
    # Evaluates the shapley scores per example for different classes (Absolute value position according to the paper)
    with torch.no_grad():
        for classId in tqdm(class_index_dict.keys(), desc='Processed classes'):
            # Take the current class activation levels
            current_class_actLevels = last_hidden_actLevels[class_index_dict[classId]].cuda() # Original suggestion in the paper
            # Get the original y values
            original_y = (current_class_actLevels @ W.T + b)[:, classId]
            # Get the number of blocks
            nb_blocks = math.ceil(nb_neurons / block_size)
            # Iterate over the blocks
            for block_id in range(nb_blocks):
                # Determine the current block size
                block_start = block_id * block_size
                block_end = min(block_start + block_size, nb_neurons)
                current_block_size = block_end - block_start
                # Get the current block masks
                masks = torch.ones(current_block_size, nb_neurons).cuda()
                for i in range(current_block_size):
                    masks[i, block_start + i] = 0
                # Get the masked activation levels
                masked_actLevels = current_class_actLevels.unsqueeze(0) * masks.unsqueeze(1)
                # Compute all masked outputs in parallel
                masked_outputs = torch.matmul(masked_actLevels, W.T) + b
                masked_class_outputs = masked_outputs[:, :, classId]
                # Compute scores for all neurons
                scores = torch.abs(original_y.unsqueeze(0) - masked_class_outputs)
                block_scores = torch.mean(scores, dim=1)
                # Save the result
                shapley_score_matrix_GPU[classId, block_start:block_end] = block_scores
            # Clear the memory
            del current_class_actLevels
            torch.cuda.empty_cache()

    # Final result
    shapley_score_matrix = shapley_score_matrix_GPU.cpu().numpy()

    # Clean the cache
    del shapley_score_matrix_GPU, last_hidden_actLevels, W, b
    torch.cuda.empty_cache()
    gc.collect()

    return shapley_score_matrix

def shapley_score_evaluation_batch_GPU(actLevels, params, last_hidden_layerId, predict_class_ver=False, batch_size=500, block_size=256):
    """
    This function evaluates the shapley score according to the taylor formula. (by-batch version)

    actLevels: The activation level information that contains also the original and predicted class.
    params: The parameters in the final linear layer.
    last_hidden_layerId: The id of the last hidden layer.
    predict_class_ver: Build the class index dictionary with the predicted class or the original class.
    batch_size: The number of examples to be evaluated in each batch.
    block_size: The number of neurons to be evaluated in each block.
    """
    # Determine the referenced class tye
    ref_class_type = 'class'
    if predict_class_ver:
        ref_class_type = 'predict_class'
    # Get the referred class for all examples
    ref_classes = actLevels[ref_class_type].reshape(-1)
    # Get the label corresponding dictionary
    # With the class_index_dict_build function, the classes (i.e., keys) are already ordered (i.e., from 0 to "number of classes-1").
    class_index_dict = class_index_dict_build(ref_classes) 
    # Get the activation levels of the last hidden layer and move it to GPU
    last_hidden_actLevels = torch.from_numpy(actLevels['actLevel'][last_hidden_layerId])
    # The number of neurons
    nb_neurons = last_hidden_actLevels.shape[1]
    # Take the weight and the bias and move it to GPU
    W = torch.from_numpy(params['weight']).float().cuda()
    b = torch.from_numpy(params['bias']).float().cuda()
    # Pre-allocate result tensor on GPU
    shapley_score_matrix_GPU = torch.zeros((len(class_index_dict), nb_neurons)).cuda()
    # Evaluates the shapley scores per example for different classes (Absolute value position according to the paper)
    with torch.no_grad():
        for classId in tqdm(class_index_dict.keys(), desc='Processed classes'):
            # Take the current class activation levels
            current_class_actLevels = last_hidden_actLevels[class_index_dict[classId]].float().cuda() # Original suggestion in the paper
            nb_class_examples = len(current_class_actLevels)
            ## Iterate over the examples by batch
            # Get the number of batches
            nb_batches = math.ceil(nb_class_examples / batch_size)
            for batch_id in range(nb_batches):
                # Get the current batch
                batch_start = batch_id * batch_size
                batch_end = min(batch_start + batch_size, nb_class_examples)
                current_batch_size = batch_end - batch_start
                current_batch_actLevels = current_class_actLevels[batch_start:batch_end]
                # Get the original y values
                original_y = (current_batch_actLevels @ W.T + b)[:, classId]
                # Get the number of blocks
                nb_blocks = math.ceil(nb_neurons / block_size)
                # Iterate over the blocks
                for block_id in range(nb_blocks):
                    # Determine the current block size
                    block_start = block_id * block_size
                    block_end = min(block_start + block_size, nb_neurons)
                    current_block_size = block_end - block_start
                    # Get the current block masks
                    masks = torch.ones(current_block_size, nb_neurons).cuda()
                    for i in range(current_block_size):
                        masks[i, block_start + i] = 0
                    # Get the masked activation levels
                    masked_actLevels = current_batch_actLevels.unsqueeze(0) * masks.unsqueeze(1)
                    # Compute all masked outputs in parallel
                    masked_outputs = torch.matmul(masked_actLevels, W.T) + b
                    masked_class_outputs = masked_outputs[:, :, classId]
                    # Compute scores for all neurons
                    scores = torch.abs(original_y.unsqueeze(0) - masked_class_outputs)
                    block_scores = torch.mean(scores, dim=1)
                    # Save the result
                    shapley_score_matrix_GPU[classId, block_start:block_end] += (
                        block_scores * current_batch_size / nb_class_examples
                    )
            # Clear the memory
            del current_class_actLevels
            torch.cuda.empty_cache()

    # Final result
    shapley_score_matrix = shapley_score_matrix_GPU.cpu().numpy()

    # Clean the cache
    del shapley_score_matrix_GPU, last_hidden_actLevels, W, b
    torch.cuda.empty_cache()
    gc.collect()

    return shapley_score_matrix

# def pruning_mask(array, k=10):
#     """
#     This function returns a pruning mask for a 1D or 2D numpy array. (Old version, not coherent with the paper)

#     array: The provided numpy array.
#     k: The index for top-k selection.
#     """
#     mask_array = np.zeros(array.shape)
#     if array.ndim == 1:
#         sorted_indices = np.argsort(-array) # -array because we would like to have the descending order.
#         top_k_indices = sorted_indices[:k]
#         mask_array[top_k_indices] = 1
#     else:
#         sorted_indices = np.argsort(-array, axis=1)
#         k_columns = list(range(k))
#         top_k_indices = sorted_indices[:, k_columns]
#         for row_index, row_top_k_indices in enumerate(top_k_indices):
#             mask_array[row_index, row_top_k_indices] = 1
    
#     return mask_array










    
