
from common_imports import np, tqdm
from common_use_functions import class_index_dict_build, calculate_smd
from scipy.stats import wasserstein_distance

def smd_score_evaluation(actLevels, last_hidden_layerId, predict_class_ver=False):
    """
    This function evaluates the smd score for each neuron in the last hidden layer 
    (i.e., comparison between the example from one class and other examples using SMD distance).

    actLevels: The activation level information that contains also the original and predicted class.
    last_hidden_layerId: The id of the last hidden layer.
    predict_class_ver: Build the class index dictionary with the predicted class or the original class.
    """
    # Determine the referenced class tye
    ref_class_type = 'class'
    if predict_class_ver:
        ref_class_type = 'predict_class'
    # Get the referred class for all examples
    ref_classes = actLevels[ref_class_type].reshape(-1)
    # Get the label corresponding dictionary
    # With the class_index_dict_build function, the classes (i.e., keys) are already ordered (i.e., from 0 to "number of classes-1").
    class_index_dict = class_index_dict_build(ref_classes) 
    # Get the activation levels of the last hidden layer
    last_hidden_actLevels = actLevels['actLevel'][last_hidden_layerId]
    # The number of neurons
    nb_neurons = last_hidden_actLevels.shape[1]
     # Evaluates the shapley scores per example for different classes (Absolute value position according to the paper)
    smd_score_matrix = []
    for classId in tqdm(class_index_dict.keys(), desc='Processed classes'):
        # Get the activation levels
        current_class_actLevels = last_hidden_actLevels[class_index_dict[classId]]
        other_example_indices = [index for index in range(last_hidden_actLevels.shape[0]) if index not in class_index_dict[classId]]
        other_example_actLevels = last_hidden_actLevels[other_example_indices]
        # The scores for the current class
        class_neuron_scores = []
        # Evaluate multiple examples in the same time
        for neuronId in range(nb_neurons):
            # Get the neuron levels
            current_class_neuron_levels = current_class_actLevels[:, neuronId].reshape(-1)
            other_example_neuron_levels = other_example_actLevels[:, neuronId].reshape(-1)
            # Get the smd score
            smd_score = calculate_smd(current_class_neuron_levels, other_example_neuron_levels, absolute=True, display=False)
            # Add the smd score
            class_neuron_scores.append(smd_score)
        # Add the smd scores of the current class
        smd_score_matrix.append(class_neuron_scores)
    smd_score_matrix = np.array(smd_score_matrix)

    return smd_score_matrix

def wasserstein_score_evaluation(actLevels, last_hidden_layerId, predict_class_ver=False):
    """
    This function evaluates the waaserstein score for each neuron in the last hidden layer 
    (i.e., comparison between the example from one class and other examples using wasserstein distance).

    actLevels: The activation level information that contains also the original and predicted class.
    last_hidden_layerId: The id of the last hidden layer.
    predict_class_ver: Build the class index dictionary with the predicted class or the original class.
    """
    # Determine the referenced class tye
    ref_class_type = 'class'
    if predict_class_ver:
        ref_class_type = 'predict_class'
    # Get the referred class for all examples
    ref_classes = actLevels[ref_class_type].reshape(-1)
    # Get the label corresponding dictionary
    # With the class_index_dict_build function, the classes (i.e., keys) are already ordered (i.e., from 0 to "number of classes-1").
    class_index_dict = class_index_dict_build(ref_classes) 
    # Get the activation levels of the last hidden layer
    last_hidden_actLevels = actLevels['actLevel'][last_hidden_layerId]
    # The number of neurons
    nb_neurons = last_hidden_actLevels.shape[1]
     # Evaluates the shapley scores per example for different classes (Absolute value position according to the paper)
    wasserstein_score_matrix = []
    for classId in tqdm(class_index_dict.keys(), desc='Processed classes'):
        # Get the activation levels
        current_class_actLevels = last_hidden_actLevels[class_index_dict[classId]]
        other_example_indices = [index for index in range(last_hidden_actLevels.shape[0]) if index not in class_index_dict[classId]]
        other_example_actLevels = last_hidden_actLevels[other_example_indices]
        # The scores for the current class
        class_neuron_scores = []
        # Evaluate multiple examples in the same time
        for neuronId in range(nb_neurons):
            # Get the neuron levels
            current_class_neuron_levels = current_class_actLevels[:, neuronId].reshape(-1)
            other_example_neuron_levels = other_example_actLevels[:, neuronId].reshape(-1)
            # Get the wasserstein score
            wasserstein_score = wasserstein_distance(current_class_neuron_levels, other_example_neuron_levels)
            # Add the wasserstein score
            class_neuron_scores.append(wasserstein_score)
        # Add the smd scores of the current class
        wasserstein_score_matrix.append(class_neuron_scores)
    wasserstein_score_matrix = np.array(wasserstein_score_matrix)

    return wasserstein_score_matrix