"""
This file contains functions useful for the OMS score evaluation
"""
import math
from common_imports import np, torch, tqdm, copy
from common_use_functions import softmax, class_index_dict_build
from scipy.stats import wasserstein_distance

def weight_contribution(weights, sensitivity_indices):
    """
    This function computes the weight contribution matrix.

    weights: The weight matrix. It should be a 2D numpy array.
    sensitivity_indices: The sensitivity indices for each neuron, e.g., the sobol indices. It should be a 1D numpy array.

    Note: The given sensitivity indices could be per-class or unified case (i.e., multiouput evaluation).
    """
    return np.multiply(sensitivity_indices, weights)

def get_sensitivity_pruning_masks(params, sensitivity_contrib_matrix, p_n=20, p_w=20):
    """
    This function evaluates the weight contribution and pruning masks 
    required by the OMS score from the LINe paper. (not the whole pipeline)

    params: The final linear parameters (weight and bias).
    sensitivity_contrib_matrix: The matrix containing the sensitivity indices for different classes.
    p_n: The percentage used for the significant neurons.
    p_w: The percentage used for the significant weights.
    """
    # Get the list of all classes
    class_list = sorted(list(range(sensitivity_contrib_matrix.shape[0])))
    # Evaluate the weight contribution and pruning masks for different classes
    weight_contributions = {}
    feature_pruning_masks = {}
    weight_pruning_masks = {}
    for classId in class_list:
        weight_contributions[classId] = weight_contribution(params['weight'], sensitivity_contrib_matrix[classId])
    for classId in class_list:
        feature_pruning_masks[classId] = pruning_mask(sensitivity_contrib_matrix[classId], p_n)
        weight_pruning_masks[classId] = pruning_mask(weight_contributions[classId], p_w)

    return weight_contributions, feature_pruning_masks, weight_pruning_masks

def get_logits_and_norm_logits_by_preds(feature_vecs, preds, params, feature_pruning_masks, weight_pruning_masks):
    """
    This function return the logits and their normalized version.

    feature_vecs: The penultimate layer's activation levels, i.e., the feature vectors. 
    preds: The predictions for all examples.
    params: The final linear parameters (weight and bias).
    feature_pruning_masks: The feature pruning masks for different classes.
    weight_pruning_masks: The weight pruning masks for different classes.
    """
    # Evaluate the scores for the data from different classes
    weights = params['weight']
    bias = params['bias']
    class_index_dict = class_index_dict_build(preds)
    logits = np.zeros((feature_vecs.shape[0], weights.shape[0]))
    for classId in class_index_dict:
        current_class_vecs = feature_vecs[class_index_dict[classId]]
        current_class_feature_mask = feature_pruning_masks[classId]
        current_class_weight_mask = weight_pruning_masks[classId]
        current_class_logits = np.dot(current_class_vecs*current_class_feature_mask, np.transpose(weights*current_class_weight_mask))+bias
        logits[class_index_dict[classId]] = current_class_logits
    norm_logits = logits / np.linalg.norm(logits, axis=1, keepdims=True)

    return logits, norm_logits

def get_train_similarity_by_preds(feature_vecs, preds, train_mean_vecs):
    """
    This function return the logits and their normalized version.

    feature_vecs: The penultimate layer's activation levels, i.e., the feature vectors. 
    preds: The predictions for all examples.
    train_mean_vecs: The mean vector for each class in the training set.
    """
    # Evaluate the scores for the data from different classes
    class_index_dict = class_index_dict_build(preds)
    similiarities = np.zeros(feature_vecs.shape[0])
    for classId in class_index_dict:
        current_class_vecs = feature_vecs[class_index_dict[classId]]
        current_class_train_mean_vec = train_mean_vecs[classId]
        current_class_sims = np.dot(current_class_vecs, np.transpose(current_class_train_mean_vec.reshape(1,-1))).reshape(-1)
        similiarities[class_index_dict[classId]] = current_class_sims

    return similiarities

# def get_logits_and_norm_logits_by_preds_torch_ver(feature_vecs, preds, params, feature_pruning_masks, weight_pruning_masks,
#                                                   batch_size=500, display=True):
#     """
#     This function return the logits and their normalized version.

#     feature_vecs: The penultimate layer's activation levels, i.e., the feature vectors. 
#     preds: The predictions for all examples.
#     params: The final linear parameters (weight and bias).
#     feature_pruning_masks: The feature pruning masks for different classes.
#     weight_pruning_masks: The weight pruning masks for different classes.
#     batch_size: The size of each batch to be processed.
#     display: Boolean indicating if we would like to display a progress bar.

#     Note: We use the pytorch to accelerate the computation.
#     """
#     # Evaluate the scores for the data from different classes
#     weights = params['weight']
#     bias = params['bias']
#     class_index_dict = class_index_dict_build(preds)
#     logits = np.zeros((feature_vecs.shape[0], weights.shape[0]))
#     # Move the bias and weights to GPU
#     weights = torch.from_numpy(weights).cuda()
#     bias = torch.from_numpy(bias).cuda()
#     # Execute the evaluation
#     class_list = sorted(list(class_index_dict.keys()))
#     progress_bar = tqdm(class_list, desc='Processed classes (logit)') if display else class_list
#     for classId in progress_bar:
#         current_class_vecs = feature_vecs[class_index_dict[classId]]
#         current_class_feature_mask = feature_pruning_masks[classId]
#         current_class_weight_mask = weight_pruning_masks[classId]
#         # Move the masks to GPU
#         current_class_feature_mask = torch.from_numpy(current_class_feature_mask).cuda()
#         current_class_weight_mask = torch.from_numpy(current_class_weight_mask).cuda()
#         # Execution with batches
#         nb_batches = math.ceil(current_class_vecs.shape[0] / batch_size)
#         end_index = current_class_vecs.shape[0]
#         current_class_logits = None
#         for batch_index in range(nb_batches):
#             current_batch_feature_vecs = None
#             if batch_index == nb_batches-1:
#                 current_batch_feature_vecs = current_class_vecs[(batch_index*batch_size):end_index]
#             else:
#                 current_batch_feature_vecs = current_class_vecs[(batch_index*batch_size):((batch_index+1)*batch_size)]
#             current_batch_feature_vecs = torch.from_numpy(current_batch_feature_vecs).cuda()
#             current_batch_logits = torch.matmul(torch.mul(current_batch_feature_vecs, current_class_feature_mask),
#                                                  torch.mul(weights, current_class_weight_mask).T)+bias
#             if batch_index == 0:
#                 current_class_logits = current_batch_logits.cpu().numpy()
#             else:
#                 current_class_logits = np.vstack((current_class_logits,current_batch_logits.cpu().numpy()))
#         logits[class_index_dict[classId]] = current_class_logits
#     # Apply the normalization in float32 precision
#     logits = logits.astype(np.float32)
#     norm_logits = logits / np.linalg.norm(logits, axis=1, keepdims=True)

#     return logits, norm_logits

def get_logits_and_norm_logits_by_preds_torch_ver(feature_vecs, preds, params, feature_pruning_masks, weight_pruning_masks,
                                                  batch_size=500, display=True):
    """
    This function return the logits and their normalized version.

    feature_vecs: The penultimate layer's activation levels, i.e., the feature vectors. 
    preds: The predictions for all examples.
    params: The final linear parameters (weight and bias).
    feature_pruning_masks: The feature pruning masks for different classes.
    weight_pruning_masks: The weight pruning masks for different classes.
    batch_size: The size of each batch to be processed.
    display: Boolean indicating if we would like to display a progress bar.

    Note: We use the pytorch to accelerate the computation.
    """
    # Evaluate the scores for the data from different classes
    weights = params['weight']
    bias = params['bias']
    class_index_dict = class_index_dict_build(preds)
    logits = np.zeros((feature_vecs.shape[0], weights.shape[0]), dtype=np.float32)
    # Move the bias and weights to GPU ("_t" symbolizes the torch version)
    weights_t = torch.from_numpy(weights).cuda()
    bias_t = torch.from_numpy(bias).cuda()
    # Execute the evaluation
    class_list = sorted(list(class_index_dict.keys()))
    progress_bar = tqdm(class_list, desc='Processed classes (logit)') if display else class_list
    for classId in progress_bar:
        current_class_vecs = feature_vecs[class_index_dict[classId]]
        current_class_feature_mask = feature_pruning_masks[classId]
        current_class_weight_mask = weight_pruning_masks[classId]
        # Move the masks to GPU
        current_class_feature_mask = torch.from_numpy(current_class_feature_mask).cuda()
        current_class_weight_mask = torch.from_numpy(current_class_weight_mask).cuda()
        # Apply masks to weights once per class
        masked_weights = torch.mul(weights_t, current_class_weight_mask).T
        # Execution with batches
        nb_batches = math.ceil(current_class_vecs.shape[0] / batch_size)
        batch_logits = [] # The logits of all batches
        with torch.no_grad():
            for batch_index in range(nb_batches):
                batch_start = batch_index * batch_size
                batch_end = min((batch_index + 1) * batch_size, current_class_vecs.shape[0])
                batch_feature_vecs = torch.from_numpy(current_class_vecs[batch_start:batch_end]).cuda()
                masked_features = torch.mul(batch_feature_vecs, current_class_feature_mask)
                batch_logit = torch.matmul(masked_features, masked_weights)+bias_t
                batch_logits.append(batch_logit.cpu().numpy())
        # Combine batch results
        current_class_logits = np.concatenate(batch_logits, axis=0)
        logits[class_index_dict[classId]] = current_class_logits
    # Apply the normalization in float32 precision
    logits = logits.astype(np.float32) # Just to be sure about the type 
    norm_logits = logits / np.linalg.norm(logits, axis=1, keepdims=True)

    return logits, norm_logits

def get_logits_by_preds_torch_ver(feature_vecs, preds, params, feature_pruning_masks, weight_pruning_masks,
                                                  batch_size=500, display=True):
    """
    This function return only the logits.

    feature_vecs: The penultimate layer's activation levels, i.e., the feature vectors. 
    preds: The predictions for all examples.
    params: The final linear parameters (weight and bias).
    feature_pruning_masks: The feature pruning masks for different classes.
    weight_pruning_masks: The weight pruning masks for different classes.
    batch_size: The size of each batch to be processed.
    display: Boolean indicating if we would like to display a progress bar.

    Note: We use the pytorch to accelerate the computation.
    """
    # Evaluate the scores for the data from different classes
    weights = params['weight']
    bias = params['bias']
    class_index_dict = class_index_dict_build(preds)
    logits = np.zeros((feature_vecs.shape[0], weights.shape[0]), dtype=np.float32)
    # Move the bias and weights to GPU ("_t" symbolizes the torch version)
    weights_t = torch.from_numpy(weights).cuda()
    bias_t = torch.from_numpy(bias).cuda()
    # Execute the evaluation
    class_list = sorted(list(class_index_dict.keys()))
    progress_bar = tqdm(class_list, desc='Processed classes (logit)') if display else class_list
    for classId in progress_bar:
        current_class_vecs = feature_vecs[class_index_dict[classId]]
        current_class_feature_mask = feature_pruning_masks[classId]
        current_class_weight_mask = weight_pruning_masks[classId]
        # Move the masks to GPU
        current_class_feature_mask = torch.from_numpy(current_class_feature_mask).cuda()
        current_class_weight_mask = torch.from_numpy(current_class_weight_mask).cuda()
        # Apply masks to weights once per class
        masked_weights = torch.mul(weights_t, current_class_weight_mask).T
        # Execution with batches
        nb_batches = math.ceil(current_class_vecs.shape[0] / batch_size)
        batch_logits = [] # The logits of all batches
        with torch.no_grad():
            for batch_index in range(nb_batches):
                batch_start = batch_index * batch_size
                batch_end = min((batch_index + 1) * batch_size, current_class_vecs.shape[0])
                batch_feature_vecs = torch.from_numpy(current_class_vecs[batch_start:batch_end]).cuda()
                masked_features = torch.mul(batch_feature_vecs, current_class_feature_mask)
                batch_logit = torch.matmul(masked_features, masked_weights)+bias_t
                batch_logits.append(batch_logit.cpu().numpy())
        # Combine batch results
        current_class_logits = np.concatenate(batch_logits, axis=0)
        logits[class_index_dict[classId]] = current_class_logits
    # Apply the normalization in float32 precision
    logits = logits.astype(np.float32) # Just to be sure about the type 

    return logits

def get_logits_by_preds_only_feat_prun_torch_ver(feature_vecs, preds, params, feature_pruning_masks,
                                                  batch_size=500, display=True):
    """
    This function return only the logits.

    feature_vecs: The penultimate layer's activation levels, i.e., the feature vectors. 
    preds: The predictions for all examples.
    params: The final linear parameters (weight and bias).
    feature_pruning_masks: The feature pruning masks for different classes.
    batch_size: The size of each batch to be processed.
    display: Boolean indicating if we would like to display a progress bar.

    Note: We use the pytorch to accelerate the computation.
    """
    # Evaluate the scores for the data from different classes
    weights = params['weight']
    bias = params['bias']
    class_index_dict = class_index_dict_build(preds)
    logits = np.zeros((feature_vecs.shape[0], weights.shape[0]), dtype=np.float32)
    # Move the bias and weights to GPU ("_t" symbolizes the torch version)
    weights_t = torch.from_numpy(weights).cuda()
    bias_t = torch.from_numpy(bias).cuda()
    # Execute the evaluation
    class_list = sorted(list(class_index_dict.keys()))
    progress_bar = tqdm(class_list, desc='Processed classes (logit)') if display else class_list
    for classId in progress_bar:
        current_class_vecs = feature_vecs[class_index_dict[classId]]
        current_class_feature_mask = feature_pruning_masks[classId]
        # Move the masks to GPU
        current_class_feature_mask = torch.from_numpy(current_class_feature_mask).cuda()
        # Execution with batches
        nb_batches = math.ceil(current_class_vecs.shape[0] / batch_size)
        batch_logits = [] # The logits of all batches
        with torch.no_grad():
            for batch_index in range(nb_batches):
                batch_start = batch_index * batch_size
                batch_end = min((batch_index + 1) * batch_size, current_class_vecs.shape[0])
                batch_feature_vecs = torch.from_numpy(current_class_vecs[batch_start:batch_end]).cuda()
                masked_features = torch.mul(batch_feature_vecs, current_class_feature_mask)
                batch_logit = torch.matmul(masked_features, weights_t.T)+bias_t
                batch_logits.append(batch_logit.cpu().numpy())
        # Combine batch results
        current_class_logits = np.concatenate(batch_logits, axis=0)
        logits[class_index_dict[classId]] = current_class_logits
    # Apply the normalization in float32 precision
    logits = logits.astype(np.float32) # Just to be sure about the type 

    return logits

def get_logits_by_preds_unified_weight_mask_torch_ver(feature_vecs, preds, params, feature_pruning_masks, weight_pruning_masks,
                                                  batch_size=500, display=True):
    """
    This function return only the logits.

    feature_vecs: The penultimate layer's activation levels, i.e., the feature vectors. 
    preds: The predictions for all examples.
    params: The final linear parameters (weight and bias).
    feature_pruning_masks: The feature pruning masks for different classes.
    weight_pruning_masks: The unified weight pruning masks.
    batch_size: The size of each batch to be processed.
    display: Boolean indicating if we would like to display a progress bar.

    Note: We use the pytorch to accelerate the computation.
    """
    # Evaluate the scores for the data from different classes
    weights = params['weight']
    bias = params['bias']
    class_index_dict = class_index_dict_build(preds)
    logits = np.zeros((feature_vecs.shape[0], weights.shape[0]), dtype=np.float32)
    # Move the bias and weights to GPU ("_t" symbolizes the torch version)
    weights_t = torch.from_numpy(weights).cuda()
    bias_t = torch.from_numpy(bias).cuda()
    # Execute the evaluation
    class_list = sorted(list(class_index_dict.keys()))
    progress_bar = tqdm(class_list, desc='Processed classes (logit)') if display else class_list
    for classId in progress_bar:
        current_class_vecs = feature_vecs[class_index_dict[classId]]
        current_class_feature_mask = feature_pruning_masks[classId]
        current_class_weight_mask = weight_pruning_masks
        # Move the masks to GPU
        current_class_feature_mask = torch.from_numpy(current_class_feature_mask).cuda()
        current_class_weight_mask = torch.from_numpy(current_class_weight_mask).cuda()
        # Apply masks to weights once per class
        masked_weights = torch.mul(weights_t, current_class_weight_mask).T
        # Execution with batches
        nb_batches = math.ceil(current_class_vecs.shape[0] / batch_size)
        batch_logits = [] # The logits of all batches
        with torch.no_grad():
            for batch_index in range(nb_batches):
                batch_start = batch_index * batch_size
                batch_end = min((batch_index + 1) * batch_size, current_class_vecs.shape[0])
                batch_feature_vecs = torch.from_numpy(current_class_vecs[batch_start:batch_end]).cuda()
                masked_features = torch.mul(batch_feature_vecs, current_class_feature_mask)
                batch_logit = torch.matmul(masked_features, masked_weights)+bias_t
                batch_logits.append(batch_logit.cpu().numpy())
        # Combine batch results
        current_class_logits = np.concatenate(batch_logits, axis=0)
        logits[class_index_dict[classId]] = current_class_logits
    # Apply the normalization in float32 precision
    logits = logits.astype(np.float32) # Just to be sure about the type 

    return logits

def get_logits_torch_ver(feature_vecs, params, feature_pruning_masks, weight_pruning_masks,
                                                  batch_size=500, display=True):
    """
    This function return only the logits.

    feature_vecs: The penultimate layer's activation levels, i.e., the feature vectors. 
    preds: The predictions for all examples.
    params: The final linear parameters (weight and bias).
    feature_pruning_masks: The feature pruning masks for all classes.
    weight_pruning_masks: The weight pruning masks for all classes.
    batch_size: The size of each batch to be processed.
    display: Boolean indicating if we would like to display a progress bar.

    Note: We use the pytorch to accelerate the computation.
    """
    # Evaluate the scores for the data from different classes
    weights = params['weight']
    bias = params['bias']
    logits = np.zeros((feature_vecs.shape[0], weights.shape[0]), dtype=np.float32)
    # Move the bias and weights to GPU ("_t" symbolizes the torch version)
    weights_t = torch.from_numpy(weights).cuda()
    bias_t = torch.from_numpy(bias).cuda()
    # Execute the evaluation
    # Move the masks to GPU
    feature_mask = torch.from_numpy(feature_pruning_masks).cuda()
    weight_mask = torch.from_numpy(weight_pruning_masks).cuda()
    # Apply masks to weights once per class
    masked_weights = torch.mul(weights_t, weight_mask).T
    # Execution with batches
    nb_batches = math.ceil(feature_vecs.shape[0] / batch_size)
    progress_bar = tqdm(list(range(nb_batches)), desc='Processed classes (logit)') if display else list(range(nb_batches))
    with torch.no_grad():
        for batch_index in progress_bar:
            batch_start = batch_index * batch_size
            batch_end = min((batch_index + 1) * batch_size, feature_vecs.shape[0])
            batch_feature_vecs = torch.from_numpy(feature_vecs[batch_start:batch_end]).cuda()
            masked_features = torch.mul(batch_feature_vecs, feature_mask)
            batch_logit = torch.matmul(masked_features, masked_weights)+bias_t
            logits[batch_start:batch_end] = batch_logit.cpu().numpy()
    # Apply the normalization in float32 precision
    logits = logits.astype(np.float32) # Just to be sure about the type 

    return logits

def get_logits_without_masking_torch_ver(feature_vecs, params, batch_size=500, display=True):
    """
    This function return only the logits.

    feature_vecs: The penultimate layer's activation levels, i.e., the feature vectors. 
    params: The final linear parameters (weight and bias).
    batch_size: The size of each batch to be processed.
    display: Boolean indicating if we would like to display a progress bar.

    Note: We use the pytorch to accelerate the computation.
    """
    # Evaluate the scores for the data from different classes
    weights = params['weight']
    bias = params['bias']
    logits = np.zeros((feature_vecs.shape[0], weights.shape[0]), dtype=np.float32)
    # Move the bias and weights to GPU ("_t" symbolizes the torch version)
    weights_t = torch.from_numpy(weights).T.cuda()
    bias_t = torch.from_numpy(bias).cuda()
    # Execution with batches
    nb_batches = math.ceil(feature_vecs.shape[0] / batch_size)
    progress_bar = tqdm(list(range(nb_batches)), desc='Processed classes (logit)') if display else list(range(nb_batches))
    with torch.no_grad():
        for batch_index in progress_bar:
            batch_start = batch_index * batch_size
            batch_end = min((batch_index + 1) * batch_size, feature_vecs.shape[0])
            batch_features = torch.from_numpy(feature_vecs[batch_start:batch_end]).cuda()
            batch_logit = torch.matmul(batch_features, weights_t)+bias_t
            logits[batch_start:batch_end] = batch_logit.cpu().numpy()
    # Apply the normalization in float32 precision
    logits = logits.astype(np.float32) # Just to be sure about the type 

    return logits

def get_logits_unimask_feature_torch_ver(feature_vecs, params, feature_pruning_mask, batch_size=500, display=True):
    """
    This function return only the logits.

    feature_vecs: The penultimate layer's activation levels, i.e., the feature vectors. 
    params: The final linear parameters (weight and bias).
    feature_pruning_mask: The unified mask for features.
    batch_size: The size of each batch to be processed.
    display: Boolean indicating if we would like to display a progress bar.

    Note: We use the pytorch to accelerate the computation.
    """
    # Evaluate the scores for the data from different classes
    weights = params['weight']
    bias = params['bias']
    logits = np.zeros((feature_vecs.shape[0], weights.shape[0]), dtype=np.float32)
    # Move the bias and weights to GPU ("_t" symbolizes the torch version)
    weights_t = torch.from_numpy(weights).T.cuda()
    bias_t = torch.from_numpy(bias).cuda()
    # Move the masks to GPU
    feature_mask = torch.from_numpy(feature_pruning_mask).cuda()
    # Execution with batches
    nb_batches = math.ceil(feature_vecs.shape[0] / batch_size)
    progress_bar = tqdm(list(range(nb_batches)), desc='Processed classes (logit)') if display else list(range(nb_batches))
    with torch.no_grad():
        for batch_index in progress_bar:
            batch_start = batch_index * batch_size
            batch_end = min((batch_index + 1) * batch_size, feature_vecs.shape[0])
            batch_features = torch.from_numpy(feature_vecs[batch_start:batch_end]).cuda()
            masked_features = torch.mul(batch_features, feature_mask)
            batch_logit = torch.matmul(masked_features, weights_t)+bias_t
            logits[batch_start:batch_end] = batch_logit.cpu().numpy()
    # Apply the normalization in float32 precision
    logits = logits.astype(np.float32) # Just to be sure about the type 

    return logits

def OOD_detection(test_scores, threshold):
    """
    This function executes the OOD detection with the scores.

    test_scores: The OOD scores of the test set with potential OOD examples.
    threshold: The OOD detection threshold.
    """
    # OOD detection with the given threshold
    ood_result = np.less(test_scores, threshold).astype(int)

    return ood_result

def pruning_mask(array, p=20):
    """
    This function returns a pruning mask for a 1D or 2D numpy array.

    array: The provided numpy array.
    p: The percentage to determine the threshold for taking the top values. (0-100 represents 0%-100%)
    """
    threshold = np.percentile(array, p)
    mask_array = (array >= threshold) * 1
    
    return mask_array

def significant_indices(array, p=20):
    """
    This function returns the indices of neurons that are above a certain value percentage for a 1D numpy array.

    array: The provided numpy array.
    p: The percentage to determine the threshold for taking the top values. (0-100 represents 0%-100%)
    """
    threshold = np.percentile(array, p)
    indices = np.sort(np.where(array > threshold)[0])
    
    return indices

def log_sum_exponential_score(vecs, sum_axis=None):
    """
    The OMS score evaluated with the probabilities in the output layer.

    vecs: The vector to be evaluated. The given object should be an numpy array.
    sum_axis: The axis to apply the sum, if None, it will sum all elements.

    Note: According to the paper, we don't have the final log, but the code indicated that we should evaluate with the log.
    """
    if sum_axis is None:
        return np.log(np.sum(np.exp(vecs)))
    else:
        return np.log(np.sum(np.exp(vecs), axis=sum_axis))
    
def wasserstein_score(vecs, nb_classes):
    """
    The OMS score evaluated with the probabilities in the output layer.

    vecs: The vector to be evaluated. The given object should be an numpy array. (Generally logits)
    nb_classes: The number of classes.
    """
    flat_prob = np.full(vecs.shape[1], 1/nb_classes)
    vec_probs = softmax(vecs)
    scores = []
    for vec_prob in vec_probs:
        scores.append(wasserstein_distance(np.arange(nb_classes), np.arange(nb_classes), flat_prob, vec_prob))
    
    return np.array(scores)
    
def clip_activations(actLevels, last_hidden_layerId, act_threshold):
    """
    This function clips the activation with the provided activation threshold.

    actLevels: The activationb levels to cut.
    last_hidden_layerId: The ID for the last hidden layer.
    act_threshold: The activation threshold.
    """
    actLevels['actLevel'][last_hidden_layerId] = np.clip(actLevels['actLevel'][last_hidden_layerId], a_min=None, a_max=act_threshold)

    return actLevels
    
def cal_metric(known, novel, method):
    tp, fp, fpr_at_tpr95 = get_curve(known, novel, method)
    results = dict()
    mtypes = ['FPR', 'AUROC', 'DTERR', 'AUIN', 'AUOUT']

    results = dict()

    # FPR
    mtype = 'FPR'
    results[mtype] = fpr_at_tpr95

    # AUROC
    mtype = 'AUROC'
    tpr = np.concatenate([[1.], tp/tp[0], [0.]])
    fpr = np.concatenate([[1.], fp/fp[0], [0.]])
    results[mtype] = -np.trapz(1.-fpr, tpr)

    # DTERR
    mtype = 'DTERR'
    results[mtype] = ((tp[0] - tp + fp) / (tp[0] + fp[0])).min()

    # AUIN
    mtype = 'AUIN'
    denom = tp+fp
    denom[denom == 0.] = -1.
    pin_ind = np.concatenate([[True], denom > 0., [True]])
    pin = np.concatenate([[.5], tp/denom, [0.]])
    results[mtype] = -np.trapz(pin[pin_ind], tpr[pin_ind])

    # AUOUT
    mtype = 'AUOUT'
    denom = tp[0]-tp+fp[0]-fp
    denom[denom == 0.] = -1.
    pout_ind = np.concatenate([[True], denom > 0., [True]])
    pout = np.concatenate([[0.], (fp[0]-fp)/denom, [.5]])
    results[mtype] = np.trapz(pout[pout_ind], 1.-fpr[pout_ind])

    return results

def get_curve(known, novel, method=None):
    tp, fp = dict(), dict()
    fpr_at_tpr95 = dict()

    known.sort()
    novel.sort()

    end = np.max([np.max(known), np.max(novel)])
    start = np.min([np.min(known),np.min(novel)])

    all = np.concatenate((known, novel))
    all.sort()

    num_k = known.shape[0]
    num_n = novel.shape[0]

    if method == 'row':
        threshold = -0.5
    else:
        threshold = known[round(0.05 * num_k)]

    tp = -np.ones([num_k+num_n+1], dtype=int)
    fp = -np.ones([num_k+num_n+1], dtype=int)
    tp[0], fp[0] = num_k, num_n
    k, n = 0, 0
    for l in range(num_k+num_n):
        if k == num_k:
            tp[l+1:] = tp[l]
            fp[l+1:] = np.arange(fp[l]-1, -1, -1)
            break
        elif n == num_n:
            tp[l+1:] = np.arange(tp[l]-1, -1, -1)
            fp[l+1:] = fp[l]
            break
        else:
            if novel[n] < known[k]:
                n += 1
                tp[l+1] = tp[l]
                fp[l+1] = fp[l] - 1
            else:
                k += 1
                tp[l+1] = tp[l] - 1
                fp[l+1] = fp[l]

    j = num_k+num_n-1
    for l in range(num_k+num_n-1):
        if all[j] == all[j-1]:
            tp[j] = tp[j+1]
            fp[j] = fp[j+1]
        j -= 1

    fpr_at_tpr95 = np.sum(novel > threshold) / float(num_n)

    return tp, fp, fpr_at_tpr95

def ood_performance_evaluation(test_scores, novelty_scores, method_name, display=False):
    """
    This function applies the performance evaluation on the OOD scores.

    test_scores: The scores of the test set.
    novelty_scores: The scores for various OOD sets.
    method_name: The name of the OOD detection method.
    display: Boolean indicating if we want to display the average performance.
    """
    # Evaluate according to the metric in the DICE paper
    novelty_metric_results = {}
    for ood_type in novelty_scores:
        # To prevent the order change with deepcopy
        novelty_metric_results[ood_type] = cal_metric(copy.deepcopy(test_scores), copy.deepcopy(novelty_scores[ood_type]), method=None) 
    novelty_avg_FPR = np.mean([novelty_metric_results[ood_type]['FPR'] for ood_type in novelty_metric_results])
    novelty_avg_AUROC = np.mean([novelty_metric_results[ood_type]['AUROC'] for ood_type in novelty_metric_results])
    # Build the final performance result
    novelty_evaluation_results = {}
    novelty_evaluation_results['Detail'] = novelty_metric_results
    novelty_evaluation_results['FPR'] = novelty_avg_FPR
    novelty_evaluation_results['AUROC'] = novelty_avg_AUROC
    if display:
        print('The average performance of', method_name, ': FPR95:', novelty_avg_FPR, "AUROC:", novelty_avg_AUROC)

    return novelty_evaluation_results

def ood_performance_evaluation_imagenet(test_scores, novelty_scores, method_name, display=False):
    """
    This function applies the performance evaluation on the OOD scores.

    test_scores: The scores of the test set.
    novelty_scores: The scores for various OOD sets.
    method_name: The name of the OOD detection method.
    display: Boolean indicating if we want to display the average performance.
    """
    # Evaluate according to the metric in the DICE paper
    novelty_metric_results = {}
    for ood_type in novelty_scores:
        # To prevent the order change with deepcopy
        novelty_metric_results[ood_type] = cal_metric(copy.deepcopy(test_scores), copy.deepcopy(novelty_scores[ood_type]), method=None) 
    novelty_avg_FPR_curated = np.mean([novelty_metric_results[ood_type]['FPR'] for ood_type in novelty_metric_results 
                                            if ood_type != 'openimage'])
    novelty_avg_AUROC_curated = np.mean([novelty_metric_results[ood_type]['AUROC'] for ood_type in novelty_metric_results 
                                        if ood_type != 'openimage'])
    novelty_avg_FPR = np.mean([novelty_metric_results[ood_type]['FPR'] for ood_type in novelty_metric_results])
    novelty_avg_AUROC = np.mean([novelty_metric_results[ood_type]['AUROC'] for ood_type in novelty_metric_results])
    # Build the final performance result
    novelty_evaluation_results = {}
    novelty_evaluation_results['Detail'] = novelty_metric_results
    novelty_evaluation_results['FPR_curated'] = novelty_avg_FPR_curated
    novelty_evaluation_results['AUROC_curated'] = novelty_avg_AUROC_curated
    novelty_evaluation_results['FPR'] = novelty_avg_FPR
    novelty_evaluation_results['AUROC'] = novelty_avg_AUROC

    if display:
        print('The average performance of', method_name, ': FPR95 (curated):', novelty_avg_FPR_curated, "AUROC (curated):", novelty_avg_AUROC_curated)
        print('The average performance of', method_name, ': FPR95:', novelty_avg_FPR, "AUROC:", novelty_avg_AUROC)
        print()

    return novelty_evaluation_results

def percentile_score_threshold(scores, p):
    """
    This function determines the score threshold at the indicated percentile position.

    scores: The scores (generally the scores of the training set).
    p: The percentage indicating the threshold position.
    """
    return np.percentile(scores, p)

def evaluate_best_performance_with_double_guidance_cifar(test_scores, novelty_scores,
                                            test_knn_factors, test_LOG_factors,
                                            novelty_knn_factors, novelty_LOG_factors, 
                                            p_n_list, p_w_list):
    """
    This function evaluates the best performance after the double guidance.
    
    test_scores: The base score for the test set.
    novelty_scores: The base score for the novelty sets.
    test_knn_factors: The feature knn guidance for the test set.
    test_LOG_factors: The feature knn guidance for the novelty sets.
    novelty_knn_factors: The logit knn guidance for the test set.
    novelty_LOG_factors: The logit knn guidance for the novelty sets.
    p_n_list: The evaluated feature pruning rates.
    p_w_list: The evaluated weight pruning rates.
    """
    # Evaluates the performance
    best_avg_FPR = 1
    best_p_n = None
    best_p_w = None
    registered_evaluation_results = {}
    for test_p_n in p_n_list:
        registered_evaluation_results[test_p_n] = {}
        for test_p_w in p_w_list:
            # Take the current scores
            current_test_scores = test_scores[test_p_n][test_p_w]
            current_novelty_scores = novelty_scores[test_p_n][test_p_w]
            ## Test set
            # Compute the scores
            test_mult_factors = test_LOG_factors * test_knn_factors
    #         test_mult_factors = test_LOG_factors
    #         test_mult_factors = test_knn_factors
            # test_mult_factors = 1
            test_mult_scores = test_mult_factors * current_test_scores
            ## Evaluate the scores
            novelty_mult_scores = {}
            novelty_mult_factors = {}
            for ood_type in current_novelty_scores:
                # Compute the scores
                novelty_mult_factors[ood_type] = novelty_LOG_factors[ood_type] * novelty_knn_factors[ood_type]
    #             novelty_mult_factors[ood_type] = novelty_LOG_factors[ood_type]
    #             novelty_mult_factors[ood_type] = novelty_knn_factors[ood_type]
                # novelty_mult_factors[ood_type] = 1
                novelty_mult_scores[ood_type] = novelty_mult_factors[ood_type] * current_novelty_scores[ood_type]

            # Evaluate according to the metric in the DICE paper
            novelty_metric_results = {}
            for ood_type in novelty_mult_scores:
                # To prevent the order change with deepcopy
                novelty_metric_results[ood_type] = cal_metric(copy.deepcopy(test_mult_scores), copy.deepcopy(novelty_mult_scores[ood_type]), method=None) 
            novelty_avg_FPR = np.mean([novelty_metric_results[ood_type]['FPR'] for ood_type in novelty_metric_results])
            novelty_avg_AUROC = np.mean([novelty_metric_results[ood_type]['AUROC'] for ood_type in novelty_metric_results])
            # Save the result
            registered_evaluation_results[test_p_n][test_p_w] = {}
            registered_evaluation_results[test_p_n][test_p_w]['Detail'] = novelty_metric_results
            registered_evaluation_results[test_p_n][test_p_w]['FPR'] = novelty_avg_FPR
            registered_evaluation_results[test_p_n][test_p_w]['AUROC'] = novelty_avg_AUROC
            # Determine if it is the best result
            if registered_evaluation_results[test_p_n][test_p_w]['FPR'] < best_avg_FPR:
                best_avg_FPR = registered_evaluation_results[test_p_n][test_p_w]['FPR']
                best_p_n = test_p_n
                best_p_w = test_p_w
    
    return best_p_n, best_p_w, registered_evaluation_results

def evaluate_best_performance_with_single_guidance_cifar(test_scores, novelty_scores,
                                                        test_knn_factors,
                                                        novelty_knn_factors,
                                                        p_n_list, p_w_list):
    """
    This function evaluates the best performance after the double guidance.
    
    test_scores: The base score for the test set.
    novelty_scores: The base score for the novelty sets.
    test_knn_factors: The feature knn guidance for the test set.
    novelty_knn_factors: The logit knn guidance for the test set.
    p_n_list: The evaluated feature pruning rates.
    p_w_list: The evaluated weight pruning rates.
    """
    # Evaluates the performance
    best_avg_FPR = 1
    best_p_n = None
    best_p_w = None
    registered_evaluation_results = {}
    for test_p_n in p_n_list:
        registered_evaluation_results[test_p_n] = {}
        for test_p_w in p_w_list:
            # Take the current scores
            current_test_scores = test_scores[test_p_n][test_p_w]
            current_novelty_scores = novelty_scores[test_p_n][test_p_w]
            ## Test set
            # Compute the scores
            test_mult_factors = test_knn_factors
            test_mult_scores = test_mult_factors * current_test_scores
            ## Evaluate the scores
            novelty_mult_scores = {}
            novelty_mult_factors = {}
            for ood_type in current_novelty_scores:
                # Compute the scores
                novelty_mult_factors[ood_type] = novelty_knn_factors[ood_type]
                novelty_mult_scores[ood_type] = novelty_mult_factors[ood_type] * current_novelty_scores[ood_type]

            # Evaluate according to the metric in the DICE paper
            novelty_metric_results = {}
            for ood_type in novelty_mult_scores:
                # To prevent the order change with deepcopy
                novelty_metric_results[ood_type] = cal_metric(copy.deepcopy(test_mult_scores), copy.deepcopy(novelty_mult_scores[ood_type]), method=None) 
            novelty_avg_FPR = np.mean([novelty_metric_results[ood_type]['FPR'] for ood_type in novelty_metric_results])
            novelty_avg_AUROC = np.mean([novelty_metric_results[ood_type]['AUROC'] for ood_type in novelty_metric_results])
            # Save the result
            registered_evaluation_results[test_p_n][test_p_w] = {}
            registered_evaluation_results[test_p_n][test_p_w]['Detail'] = novelty_metric_results
            registered_evaluation_results[test_p_n][test_p_w]['FPR'] = novelty_avg_FPR
            registered_evaluation_results[test_p_n][test_p_w]['AUROC'] = novelty_avg_AUROC
            # Determine if it is the best result
            if registered_evaluation_results[test_p_n][test_p_w]['FPR'] < best_avg_FPR:
                best_avg_FPR = registered_evaluation_results[test_p_n][test_p_w]['FPR']
                best_p_n = test_p_n
                best_p_w = test_p_w
    
    return best_p_n, best_p_w, registered_evaluation_results

def evaluate_best_performance_with_double_guidance_imagenet(test_scores, novelty_scores,
                                            test_knn_factors, test_LOG_factors,
                                            novelty_knn_factors, novelty_LOG_factors, 
                                            p_n_list, p_w_list):
    """
    This function evaluates the best performance after the double guidance.
    
    test_scores: The base score for the test set.
    novelty_scores: The base score for the novelty sets.
    test_knn_factors: The feature knn guidance for the test set.
    test_LOG_factors: The feature knn guidance for the novelty sets.
    novelty_knn_factors: The logit knn guidance for the test set.
    novelty_LOG_factors: The logit knn guidance for the novelty sets.
    p_n_list: The evaluated feature pruning rates.
    p_w_list: The evaluated weight pruning rates.
    """
    # Evaluates the performance
    best_avg_FPR = 1
    best_p_n = None
    best_p_w = None
    registered_evaluation_results = {}
    for test_p_n in p_n_list:
        registered_evaluation_results[test_p_n] = {}
        for test_p_w in p_w_list:
            # Take the current scores
            current_test_scores = test_scores[test_p_n][test_p_w]
            current_novelty_scores = novelty_scores[test_p_n][test_p_w]
            ## Test set
            # Compute the scores
            test_mult_factors = test_LOG_factors * test_knn_factors
    #         test_mult_factors = test_LOG_factors
    #         test_mult_factors = test_knn_factors
            # test_mult_factors = 1
            test_mult_scores = test_mult_factors * current_test_scores
            ## Evaluate the scores
            novelty_mult_scores = {}
            novelty_mult_factors = {}
            for ood_type in current_novelty_scores:
                # Compute the scores
                novelty_mult_factors[ood_type] = novelty_LOG_factors[ood_type] * novelty_knn_factors[ood_type]
    #             novelty_mult_factors[ood_type] = novelty_LOG_factors[ood_type]
    #             novelty_mult_factors[ood_type] = novelty_knn_factors[ood_type]
                # novelty_mult_factors[ood_type] = 1
                novelty_mult_scores[ood_type] = novelty_mult_factors[ood_type] * current_novelty_scores[ood_type]

            # Evaluate according to the metric in the DICE paper
            novelty_metric_results = {}
            for ood_type in novelty_mult_scores:
                # To prevent the order change with deepcopy
                novelty_metric_results[ood_type] = cal_metric(copy.deepcopy(test_mult_scores), copy.deepcopy(novelty_mult_scores[ood_type]), method=None) 
            novelty_avg_FPR_curated = np.mean([novelty_metric_results[ood_type]['FPR'] for ood_type in novelty_metric_results 
                                            if ood_type != 'openimage'])
            novelty_avg_AUROC_curated = np.mean([novelty_metric_results[ood_type]['AUROC'] for ood_type in novelty_metric_results 
                                                if ood_type != 'openimage'])
            novelty_avg_FPR = np.mean([novelty_metric_results[ood_type]['FPR'] for ood_type in novelty_metric_results])
            novelty_avg_AUROC = np.mean([novelty_metric_results[ood_type]['AUROC'] for ood_type in novelty_metric_results])
            # Save the result
            registered_evaluation_results[test_p_n][test_p_w] = {}
            registered_evaluation_results[test_p_n][test_p_w]['Detail'] = novelty_metric_results
            registered_evaluation_results[test_p_n][test_p_w]['FPR_curated'] = novelty_avg_FPR_curated
            registered_evaluation_results[test_p_n][test_p_w]['AUROC_curated'] = novelty_avg_AUROC_curated
            registered_evaluation_results[test_p_n][test_p_w]['FPR'] = novelty_avg_FPR
            registered_evaluation_results[test_p_n][test_p_w]['AUROC'] = novelty_avg_AUROC
            # Determine if it is the best result
            if registered_evaluation_results[test_p_n][test_p_w]['FPR_curated'] < best_avg_FPR:
                best_avg_FPR = registered_evaluation_results[test_p_n][test_p_w]['FPR_curated']
                best_p_n = test_p_n
                best_p_w = test_p_w
    
    return best_p_n, best_p_w, registered_evaluation_results

def evaluate_best_performance_with_single_guidance_imagenet(test_scores, novelty_scores,
                                                            test_knn_factors,
                                                            novelty_knn_factors,
                                                            p_n_list, p_w_list):
    """
    This function evaluates the best performance after the double guidance.
    
    test_scores: The base score for the test set.
    novelty_scores: The base score for the novelty sets.
    test_knn_factors: The feature knn guidance for the test set.
    novelty_knn_factors: The logit knn guidance for the test set.
    p_n_list: The evaluated feature pruning rates.
    p_w_list: The evaluated weight pruning rates.
    """
    # Evaluates the performance
    best_avg_FPR = 1
    best_p_n = None
    best_p_w = None
    registered_evaluation_results = {}
    for test_p_n in p_n_list:
        registered_evaluation_results[test_p_n] = {}
        for test_p_w in p_w_list:
            # Take the current scores
            current_test_scores = test_scores[test_p_n][test_p_w]
            current_novelty_scores = novelty_scores[test_p_n][test_p_w]
            ## Test set
            # Compute the scores
            test_mult_factors = test_knn_factors
            test_mult_scores = test_mult_factors * current_test_scores
            ## Evaluate the scores
            novelty_mult_scores = {}
            novelty_mult_factors = {}
            for ood_type in current_novelty_scores:
                # Compute the scores
                novelty_mult_factors[ood_type] = novelty_knn_factors[ood_type]
                novelty_mult_scores[ood_type] = novelty_mult_factors[ood_type] * current_novelty_scores[ood_type]

            # Evaluate according to the metric in the DICE paper
            novelty_metric_results = {}
            for ood_type in novelty_mult_scores:
                # To prevent the order change with deepcopy
                novelty_metric_results[ood_type] = cal_metric(copy.deepcopy(test_mult_scores), copy.deepcopy(novelty_mult_scores[ood_type]), method=None) 
            novelty_avg_FPR_curated = np.mean([novelty_metric_results[ood_type]['FPR'] for ood_type in novelty_metric_results 
                                            if ood_type != 'openimage'])
            novelty_avg_AUROC_curated = np.mean([novelty_metric_results[ood_type]['AUROC'] for ood_type in novelty_metric_results 
                                                if ood_type != 'openimage'])
            novelty_avg_FPR = np.mean([novelty_metric_results[ood_type]['FPR'] for ood_type in novelty_metric_results])
            novelty_avg_AUROC = np.mean([novelty_metric_results[ood_type]['AUROC'] for ood_type in novelty_metric_results])
            # Save the result
            registered_evaluation_results[test_p_n][test_p_w] = {}
            registered_evaluation_results[test_p_n][test_p_w]['Detail'] = novelty_metric_results
            registered_evaluation_results[test_p_n][test_p_w]['FPR_curated'] = novelty_avg_FPR_curated
            registered_evaluation_results[test_p_n][test_p_w]['AUROC_curated'] = novelty_avg_AUROC_curated
            registered_evaluation_results[test_p_n][test_p_w]['FPR'] = novelty_avg_FPR
            registered_evaluation_results[test_p_n][test_p_w]['AUROC'] = novelty_avg_AUROC
            # Determine if it is the best result
            if registered_evaluation_results[test_p_n][test_p_w]['FPR_curated'] < best_avg_FPR:
                best_avg_FPR = registered_evaluation_results[test_p_n][test_p_w]['FPR_curated']
                best_p_n = test_p_n
                best_p_w = test_p_w
    
    return best_p_n, best_p_w, registered_evaluation_results

