from common_imports import np, copy
from common_use_functions import softmax, class_index_dict_build
from OOD_score_utils import pruning_mask, log_sum_exponential_score, weight_contribution
from deep_KNN import evaluation_with_ood_bools

def evaluate_probs_with_selected_neurons(params, data, selected_neuron_indices=None, softmax_eval=False):
    """
    This function evaluates the output probabilities with the selected neurons.

    params: The parameters between the output layer and the last hidden layer, it should have the following form:
    {
    'weight' : ...,
    'bias' : ...,
    }
    where both weight and bias are numpy arrays

    data: The data to be evaluated, i.e., the flattened feature vectors.

    selected_neuron_indices: The indices of the selected neurons, if it is None, it will use all neurons.
    otherwise, it will use the given neurons, which are provided in a list-form object.

    softmax_eval: Boolean indicating if we evaluate the pre-activated output (logits) or the probability.
    """
    logits = None
    if selected_neuron_indices is None:
        logits = np.dot(data, np.transpose(params['weight'])) + params['bias']
    else:
        # This part doesn't represent the idea from LINe, it is a simplified version for evaluate the logits using only 
        selected_data = data[:, selected_neuron_indices]
        selected_weight = params['weight'][:, selected_neuron_indices]
        selected_bias = params['bias'] # No selection on the bias
        logits = np.dot(selected_data, np.transpose(selected_weight)) + selected_bias

    if softmax_eval:
        return softmax(logits)
    else:
        return logits

def LINe_score_evaluation(feature_vecs, predictions, params, feature_pruning_masks, weight_pruning_masks):
    """
    This function calculates the OMS score from the LINe paper. (not the whole pipeline)

    feature_vecs: The penultimate layer's activation levels, i.e., the feature vectors. 
    predictions: The predictions for the data contained in the feature vectors.
    params: The final linear parameters (weight and bias).
    feature_pruning_masks: The feature pruning masks for different classes.
    weight_pruning_masks: The weight pruning masks for different classes.
    """
    # Separate the data into different classes
    class_index_dict = class_index_dict_build(predictions)
    # Evaluate the scores for the data from different classes
    LINe_score = np.zeros(feature_vecs.shape[0])
    for classId in class_index_dict:
        current_feature_vecs = feature_vecs[class_index_dict[classId], :]
        current_feature_pruning_mask = feature_pruning_masks[classId]
        current_masked_features = np.multiply(current_feature_pruning_mask, current_feature_vecs)
        current_weight_pruning_mask = weight_pruning_masks[classId]
        current_masked_weight = np.multiply(current_weight_pruning_mask, params['weight'])
        current_masked_params = copy.deepcopy(params)
        current_masked_params['weight'] = current_masked_weight
        current_logits = evaluate_probs_with_selected_neurons(current_masked_params, current_masked_features,
                                                               selected_neuron_indices=None, softmax_eval=False)
        current_LINe_score = log_sum_exponential_score(current_logits, sum_axis=1).reshape(-1) # Here the reshape may be unnecessary
        LINe_score[class_index_dict[classId]] = current_LINe_score
    
    return LINe_score, class_index_dict

def LINe_score_evaluation_pipeline(feature_vecs, predictions, params, sensitivity_index_dict, p_n=20, p_w=20):
    """
    This function executes the evaluation of the OMS score from the LINe paper.

    feature_vecs: The penultimate layer's activation levels, i.e., the feature vectors. 
    predictions: The predictions for the data contained in the feature vectors.
    params: The final linear parameters (weight and bias).
    sensitivity_index_dict: The dictionary containing the sensitivity indices for different classes.
    p_n: The percentage used for the significant neurons.
    p_w: The percentage used for the significant weights.
    """
    # Get the list of all classes
    class_list = sorted(list(sensitivity_index_dict.keys()))
    # Evaluate the weight contribution and pruning masks for different classes
    weight_contributions = {}
    feature_pruning_masks = {}
    weight_pruning_masks = {}
    for classId in class_list:
        weight_contributions[classId] = weight_contribution(params['weight'], sensitivity_index_dict[classId])
    for classId in class_list:
        feature_pruning_masks[classId] = pruning_mask(sensitivity_index_dict[classId], p_n)
        weight_pruning_masks[classId] = pruning_mask(weight_contributions[classId], p_w)
    # Separate the data into different classes
    class_index_dict = class_index_dict_build(predictions)
    # Evaluate the scores for the data from different classes
    LINe_score = np.zeros(feature_vecs.shape[0])
    for classId in class_index_dict:
        current_feature_vecs = feature_vecs[class_index_dict[classId], :]
        current_feature_pruning_mask = feature_pruning_masks[classId]
        current_masked_features = np.multiply(current_feature_pruning_mask, current_feature_vecs)
        current_weight_pruning_mask = weight_pruning_masks[classId]
        current_masked_weight = np.multiply(current_weight_pruning_mask, params['weight'])
        current_masked_params = copy.deepcopy(params)
        current_masked_params['weight'] = current_masked_weight
        current_logits = evaluate_probs_with_selected_neurons(current_masked_params, current_masked_features,
                                                               selected_neuron_indices=None, softmax_eval=False)
        current_LINe_score = log_sum_exponential_score(current_logits, sum_axis=1).reshape(-1) # Here the reshape may be unnecessary
        LINe_score[class_index_dict[classId]] = current_LINe_score
    
    return LINe_score, class_index_dict

def LINe_ood_evaluation(scores, class_index_dict, thresholds):
    """
    This function applies the ood detection based on the scores and the thresholds.

    scores: The evaluated ood scores.
    class_index_dict: The dictionary indicating the examples for different classes.
    thresholds: The ood thresholds for different classes.
    """
    ood_results = np.zeros(scores.shape[0]).astype(int)
    for classId in class_index_dict:
        ood_results[class_index_dict[classId]] = np.less(scores[class_index_dict[classId]], thresholds[classId]).astype(int)

    return ood_results

def experim_ood_detection_LINe(actLevels, feature_vecs, predictions, params,
                                feature_pruning_masks, weight_pruning_masks, thresholds, set_name, display=True):
    """
    This function executes the entire experiment for the OOD detection with the LINe score.

    actLevels: the activation level information that contains also the original and predicted class.
    feature_vecs: The penultimate layer's activation levels, i.e., the feature vectors. 
    predictions: The predictions for the data contained in the feature vectors.
    params: The final linear parameters (weight and bias).
    feature_pruning_masks: The feature pruning masks for different classes.
    weight_pruning_masks: The weight pruning masks for different classes.
    thresholds: The ood thresholds for different classes.
    set_name: The OOD set name.
    display: Boolean indicating if we want to display the result.
    """
    # Get the ood detection result
    scores, class_index_dict  = LINe_score_evaluation(feature_vecs, predictions, params, feature_pruning_masks, weight_pruning_masks)
    ood_results = LINe_ood_evaluation(scores, class_index_dict, thresholds)
    # Get the evaluation result
    evaluation_result = evaluation_with_ood_bools(actLevels, ood_results, set_name, display=display)

    return evaluation_result


"""
New version for applying LINe ood detection (A simpler version)
"""
def LINe_score_evaluation_logit_ver(logits):
    """
    This function calculates the OMS score from the Line paper.

    logits: The evaluated logits.
    """
    # Evaluate the scores for the data from different classes
    LINe_scores = log_sum_exponential_score(logits, sum_axis=1).reshape(-1) # Here the reshape may be unnecessary

    return LINe_scores