from common_imports import copy, np
from common_use_functions import softmax

"""
This file contains the processing of the generated activation levels
"""
def probability_evaluation(actLevels):
    """
    This function transforms the probability output in the activation levels to real probabilities.
    (Generally for models that have only a linear in the end.)

    actLevels: The activation levels.
    """
    actLevels['prob'] = softmax(actLevels['prob'])

    return actLevels

def build_selected_actLevels(actLevels, select_indices):
    """
    This function selects the desired examples in the activation levels.
    
    actLevels: The original actLevels.
    select_indices: The desired examples indices
    """
    for info_key in actLevels:
        if info_key == 'actLevel':
            for layerId in actLevels[info_key]:
                actLevels[info_key][layerId] = actLevels[info_key][layerId][select_indices]
        else:
            actLevels[info_key] = actLevels[info_key][select_indices]
    
    return actLevels

def build_correct_actLevels(actLevels):
    """
    This function builds the activation levels only on the correctly predicted examples.
    
    actLevels: The original actLevels.
    """
    correct_actLevels = copy.deepcopy(actLevels)
    correct_example_bools = (correct_actLevels['class'] == correct_actLevels['predict_class']).reshape(-1)
    for info_key in correct_actLevels:
        if info_key == 'actLevel':
            for layerId in correct_actLevels[info_key]:
                correct_actLevels[info_key][layerId] = correct_actLevels[info_key][layerId][correct_example_bools]
        else:
            correct_actLevels[info_key] = correct_actLevels[info_key][correct_example_bools]
    
    return correct_actLevels

def build_correct_example_bools(actLevels):
    """
    This function evaluates the boolean indicating the correctly predicted examples.
    
    actLevels: The actLevels.
    """
    correct_example_bools = (actLevels['class'] == actLevels['predict_class']).reshape(-1)
    
    return correct_example_bools

def build_trustworthy_correct_actLevels(actLevels, predict_class=False, trust_threshold=0.95):
    """
    This function builds the activation levels only on the correctly predicted and trustworthy examples.
    
    actLevels: The original actLevels.
    predict_class: Boolean determines whether we refer to the predicted or original class for taking the probabilities.
    trust_threshold: The threshold to determine whether one example is trustworthy.

    Note: When using this function, the probabilities should be real probabilities (between 0 and 1).
    """
    valid_actLevels = copy.deepcopy(actLevels)
    ground_truths = valid_actLevels['class'].reshape(-1)
    predicted_classes = valid_actLevels['predict_class'].reshape(-1)
    example_indices = np.arange(ground_truths.shape[0])
    correct_example_bools = (ground_truths == predicted_classes).reshape(-1) # This reshape may be unnecessary
    trustworthy_example_bools = None
    if predict_class:
        trustworthy_example_bools = (valid_actLevels['prob'][example_indices, predicted_classes] > trust_threshold).reshape(-1)
    else:
        trustworthy_example_bools = (valid_actLevels['prob'][example_indices, ground_truths] > trust_threshold).reshape(-1)
    valid_example_bools = np.logical_and(correct_example_bools, trustworthy_example_bools)
    for info_key in valid_actLevels:
        if info_key == 'actLevel':
            for layerId in valid_actLevels[info_key]:
                valid_actLevels[info_key][layerId] = valid_actLevels[info_key][layerId][valid_example_bools]
        else:
            valid_actLevels[info_key] = valid_actLevels[info_key][valid_example_bools]
    
    return valid_actLevels