"""
This file contains the functions for the deep k-nearst neighbors OMS detection.
"""
import faiss
import math
from common_imports import np, copy, tqdm
from common_use_functions import class_index_dict_build

"""
Functions for the original KNN algorithm
"""
def normalize_feature_vecs_knn(actLevels, last_hidden_layerId):
    """
    This function normalizes the vector from the last hidden layer activation levels (i.e., unnormalized feature vectors)

    actLevels: The extracted activation levels.
    last_hidden_layerId: The last hidden layer Id.
    """
    feature_vecs = copy.deepcopy(actLevels['actLevel'][last_hidden_layerId])
    feature_vec_norms = np.linalg.norm(feature_vecs, axis=1, keepdims=True)
    zs = feature_vecs / feature_vec_norms

    return zs

def normalize_feature_vecs_knn_with_centralization(actLevels, last_hidden_layerId, mean_vec):
    """
    This function normalizes the vector from the last hidden layer activation levels (i.e., unnormalized feature vectors)

    actLevels: The extracted activation levels.
    last_hidden_layerId: The last hidden layer Id.
    mean_vec: The mean vector representing the center.
    """
    feature_vecs = copy.deepcopy(actLevels['actLevel'][last_hidden_layerId]) - mean_vec
    feature_vec_norms = np.linalg.norm(feature_vecs, axis=1, keepdims=True) + 1e-12
    zs = feature_vecs / feature_vec_norms

    return zs

def normalize_feature_vecs_knn_with_act_prun(actLevels, last_hidden_layerId, act_threshold):
    """
    This function normalizes the vector from the last hidden layer activation levels (i.e., unnormalized feature vectors)

    actLevels: The extracted activation levels.
    last_hidden_layerId: The last hidden layer Id.
    act_threshold: The threshold to apply the activation pruning.
    """
    feature_vecs = copy.deepcopy(actLevels['actLevel'][last_hidden_layerId])
    feature_vecs = np.clip(feature_vecs, a_min=None, a_max=act_threshold)
    feature_vec_norms = np.linalg.norm(feature_vecs, axis=1, keepdims=True)
    zs = feature_vecs / feature_vec_norms

    return zs

# Deep k-nearst neighbors OOD detection function
def faiss_knn_search(search_index, feature_vecs, k, display=False):
    """
    Apply the faiss knn similarity search by iteration.
    
    search_index: The faiss search index.
    feature_vecs: The vectors to apply the search.
    k: The number of nearst neighbors.
    display: The boolean that indicates if we want to display the progress bar.
    """
    # Intialize the evaluated distances and indices        
    D = []
    I = []
    # Get the number of examples
    nb_examples = feature_vecs.shape[0]
    # Build the iterate index list
    index_progress_bar = None
    if display:
        index_progress_bar = tqdm(list(range(nb_examples)), desc='Processed examples')
    else:
        index_progress_bar = list(range(nb_examples))
    # Iterate over the examples
    for index in index_progress_bar:
        # Get the current vector
        current_vec = feature_vecs[index].reshape(1,-1)
        # Apply the search
        current_D, current_I = search_index.search(current_vec, k)
        # Add the results
        D.append(current_D)
        I.append(current_I)
    # Stack the results
    D = np.vstack(D)
    I = np.vstack(I)
    
    return D, I

# Deep k-nearst neighbors OOD detection function
def faiss_knn_search_batch_ver(search_index, feature_vecs, k, batch_size=100, display=False):
    """
    Apply the faiss knn similarity search by iteration.
    
    search_index: The faiss search index.
    feature_vecs: The vectors to apply the search.
    k: The number of nearst neighbors.
    batch_size: The number of examples in each query.
    display: The boolean that indicates if we want to display the progress bar.
    """
    # Intialize the evaluated distances and indices        
    D = []
    I = []
    # Get the number of examples
    nb_examples = feature_vecs.shape[0]
    nb_batches = math.ceil(nb_examples / batch_size)
    # Build the iterate index list
    batch_progress_bar = None
    if display:
        batch_progress_bar = tqdm(list(range(nb_batches)), desc='Processed batches')
    else:
        batch_progress_bar = list(range(nb_batches))
    # Iterate over the examples
    for batch_index in batch_progress_bar:
        # Get the vectors of the current batch
        current_start_pos = batch_index*batch_size
        current_end_pos = min((batch_index+1)*batch_size, nb_examples)
        current_vecs = feature_vecs[current_start_pos:current_end_pos]
        # Apply the search
        current_D, current_I = search_index.search(current_vecs, k)
        # Add the results
        D.append(current_D)
        I.append(current_I)
    # Stack the results
    D = np.vstack(D)
    I = np.vstack(I)
    
    return D, I

def k_nearst_neighbor_scores(train_index, test_zs, k):
    """
    This function obtains the k-nearst neighbor scores.

    train_index: The "faiss" index for similarity search.
    test_zs: The normalized test set feature vectors that contain potentially OOD examples.
    k: The number of considered nearst neighbors.
    """
    # D represents distances, I represents index
    D, I = faiss_knn_search(train_index, test_zs, k, display=True)
    # Evaluate the scores (S)
    S = -D[:,-1]

    return S

def k_nearst_neighbor_OOD_detection(test_scores, threshold):
    """
    This function uses the k-nearst neighbor scores used for the OOD detection.

    test_scores: The k-nearst neighbor scores of the test set with potential OOD examples.
    threshold: The OOD detection threshold.
    """
    # OOD detection with the given threshold
    ood_result = np.less(test_scores, threshold).astype(int)

    return ood_result

def k_nearst_neighbor_OOD_detection_pipeline(train_index, test_zs, k, threshold):
    """
    This function executes the complete k-nearst neighbor OOD detection pipeline.

    train_index: The "faiss" index for similarity search.
    test_zs: The normalized test set feature vectors that contain potentially OOD examples.
    k: The number of considered nearst neighbors.
    threshold: The OOD detection threshold.
    """
    # D represents distances, I represents index
    D, I = faiss_knn_search(train_index, test_zs, k, display=True)
    # Evaluate the scores (S)
    S = -D[:,-1]
    # OOD detection with the given threshold
    ood_result = np.less(S, threshold).astype(int)

    return ood_result

# def accuracy_eval_with_correct_bools(actLevels, ind_correct_bools, ood_correct_bools):
#     """
#     This function executes the evaluation of the accuray results related to the OOD detection with the booleans indicating the InD and OOD examples.

#     actLevels: The extracted activation levels that contais also example-related information (original class, predicted class etc.)
#     ind_correct_bools: The booleans indicating the ind examples.
#     ood_correct_bools: The booleans indicating the ood examples.
#     """
#     test_acc = (actLevels['class'] == actLevels['predict_class']).sum() / actLevels['class'].shape[0]
#     ind_test_acc = (actLevels['class'][ind_correct_bools] 
#                     == actLevels['predict_class'][ind_correct_bools]).sum() / actLevels['class'][ind_correct_bools].shape[0]
#     ood_test_acc = (actLevels['class'][ood_correct_bools] 
#                     == actLevels['predict_class'][ood_correct_bools]).sum() / actLevels['class'][ood_correct_bools].shape[0]
#     print("The accuracy of the test set:", test_acc)
#     print("The InD accuracy of the test set:", ind_test_acc)
#     print("The OOD accuracy of the test set:", ood_test_acc)
    
#     return test_acc, ind_test_acc, ood_test_acc

def evaluation_with_ood_bools(actLevels, ood_result, set_name, display=True):
    """
    This function executes the evaluation of the accuray results related to the OOD detection with the booleans indicating the InD and OOD examples.

    actLevels: The extracted activation levels that contais also example-related information (original class, predicted class etc.)
    ood_result: The list containing 0 and 1, and indicating the ood examples.
    set_name: The OOD set name.
    display: Boolean indicating if we want to display the result.
    """
    # Get the ground truths and predicted class
    groundtruths = actLevels['class']
    predictions = actLevels['predict_class']
    nb_examples = groundtruths.shape[0]
    # Get the ood result boolean
    ood_result_bool = ood_result.astype(bool)
    # Number of InD and OOD examples and their percentages
    nb_ood = ood_result.sum()
    nb_ind = ood_result.shape[0] - nb_ood
    ood_percent = nb_ood / ood_result.shape[0]
    ind_percent = nb_ind / ood_result.shape[0]
    # Get the total, ind and ood accuracy
    acc_total = np.sum(groundtruths == predictions) / nb_examples
    acc_ood = np.sum(groundtruths[ood_result_bool] == predictions[ood_result_bool]) / nb_ood if nb_ood != 0 else 0
    acc_ind = np.sum(groundtruths[np.invert(ood_result_bool)] == predictions[np.invert(ood_result_bool)]) / nb_ind if nb_ind != 0 else 0
    # Display the results
    if display:
        print('The number of OOD examples in', set_name, 'set:', nb_ood)
        print('The number of InD examples in', set_name, 'set:', nb_ind)
        print('The percentage of OOD examples in', set_name, 'set:', ood_percent)
        print('The percentage of InD examples in', set_name, 'set:', ind_percent)
        print('The total accuracy in', set_name, 'set:', acc_total)
        print('The accuracy on the OOD examples in', set_name, 'set:', acc_ood)
        print('The accuracy on the InD examples in', set_name, 'set:', acc_ind)
    
    return [set_name, nb_ood, nb_ind, ood_percent, ind_percent, acc_total, acc_ood, acc_ind]

def experim_ood_detection_knn(train_index, test_zs, test_actLevels, k, threshold, set_name, display=True):
    """
    This function executes the complete k-nearst neighbor OOD detection experiment.

    train_index: The "faiss" index for similarity search.
    test_zs: The normalized test set feature vectors that contain potentially OOD examples.
    test_actLevels: the activation level information that contains also the original and predicted class.
    k: The number of considered nearst neighbors.
    threshold: The OOD detection threshold.
    set_name: The OOD set name.
    display: Boolean indicating if we want to display the result.
    """
    # Execute the experiment
    ood_result = k_nearst_neighbor_OOD_detection_pipeline(train_index, test_zs, k, threshold)
    # Get the evaluation result
    evaluation_result = evaluation_with_ood_bools(test_actLevels, ood_result, set_name, display=display)
    
    return evaluation_result

"""
Functions for the modified KNN algorithm (using the significant neurons but the used significant neurons are different for distinct examples)

Note: This by class version applies the knn search from one normalized feature vector to the ones in the training set (i.e., all the training set examples)
"""
def knn_scores_sig_ver(train_indices, test_zs_by_class, test_class_index_dict, k):
    """
    This function obtains the k-nearst neighbor scores (by class and significant neuron version).
    This is the same function as the one in the previous by-class version (i.e., which use only the examples from the
    referred class to build the search index and not all examples). We just rename it and build another one to separate the use.

    train_indices: The "faiss" index of all classes for similarity search.
    test_zs_by_class: The normalized test set feature vectors that contain potentially OOD examples for each class.
    test_class_index_dict: The dictionary indicating the positions of entries from different classes.
    k: The number of considered nearst neighbors.
    """
    # Determine the number of test examples according to the normalized feature vectors by class
    nb_examples = np.sum([test_zs_by_class[uniq_class].shape[0] for uniq_class in test_zs_by_class])
    # Intialize the results
    S_total = np.zeros(nb_examples)
    # Evaluate the scores
    for uniq_class in tqdm(list(test_class_index_dict.keys()), desc='Processed classes'):
        # D represents distances, I represents index
        class_D, class_I = faiss_knn_search(train_indices[uniq_class], test_zs_by_class[uniq_class], k, display=False)
        # Evaluate the scores (S)
        class_S = -class_D[:,-1]
        # Assign the scores of the current class
        S_total[test_class_index_dict[uniq_class]] = class_S

    return S_total

def knn_OOD_detection_pipeline_sig_ver(train_indices, test_zs_by_class, test_class_index_dict, k, thresholds):
    """
    This function executes the complete k-nearst neighbor OOD detection pipeline (significant neuron version).

    train_indices: The "faiss" index of all classes for similarity search.
    test_zs_by_class: The normalized test set feature vectors that contain potentially OOD examples for each class.
    test_class_index_dict: The dictionary indicating the positions of entries from different classes.
    k: The number of considered nearst neighbors.
    thresholds: The OOD detection thresholds for different classes.
    
    Note: This version is different than the previous "by-class" version because the search index of each class contains all examples.
    """
    # Determine the number of test examples according to the normalized feature vectors by class
    nb_examples = np.sum([test_zs_by_class[uniq_class].shape[0] for uniq_class in test_zs_by_class])
    # Intialize the results
    S_total = np.zeros(nb_examples)
    # Evaluate the scores
    for uniq_class in tqdm(list(test_class_index_dict.keys()), desc='Processed classes'):       
        # D represents distances, I represents index
        class_D, class_I = faiss_knn_search(train_indices[uniq_class], test_zs_by_class[uniq_class], k, display=False)
        # Evaluate the scores (S)
        class_S = -class_D[:,-1]
        # Assign the scores of the current class
        S_total[test_class_index_dict[uniq_class]] = class_S

    # OOD detection with the given thresholds
    ood_result_total = np.zeros(S_total.shape[0]).astype(int)
    for uniq_class in test_class_index_dict:
        class_ood_result = np.less(S_total[test_class_index_dict[uniq_class]], thresholds[uniq_class]).astype(int)
        ood_result_total[test_class_index_dict[uniq_class]] = class_ood_result

    return ood_result_total

def experim_ood_detection_knn_sig_ver(train_indices, test_zs_by_class, test_class_index_dict,
                                       test_actLevels, k, train_thresholds, set_name, display=True):
    """
    This function executes the complete k-nearst neighbor OOD detection experiment.

    train_indices: The "faiss" index of all classes for similarity search.
    test_zs: The normalized test set feature vectors that contain potentially OOD examples.
    test_class_index_dict: The dictionary indicating the positions of entries from different classes.
    test_actLevels: the activation level information that contains also the original and predicted class.
    k: The number of considered nearst neighbors.
    thresholds: The OOD detection thresholds for different classes.
    set_name: The OOD set name.
    display: Boolean indicating if we want to display the result.
    """
    # Execute the experiment
    ood_result = knn_OOD_detection_pipeline_sig_ver(train_indices, test_zs_by_class, test_class_index_dict, k, train_thresholds)
    # Get the evaluation result
    evaluation_result = evaluation_with_ood_bools(test_actLevels, ood_result, set_name, display=display)
    
    return evaluation_result

def build_correct_actLevels(actLevels):
    """
    This function builds the activation levels only on the correctly examples.
    
    actLevels: The original actLevels.
    """
    correct_actLevels = copy.deepcopy(actLevels)
    correct_example_bools = (correct_actLevels['class'] == correct_actLevels['predict_class']).reshape(-1)
    for info_key in correct_actLevels:
        if info_key == 'actLevel':
            for layerId in correct_actLevels[info_key]:
                correct_actLevels[info_key][layerId] = correct_actLevels[info_key][layerId][correct_example_bools]
        else:
            correct_actLevels[info_key] = correct_actLevels[info_key][correct_example_bools]
    
    return correct_actLevels

def build_sig_zs(actLevels, last_hidden_layerId, sig_neuron_indices):
    """
    This function builds the normalized vectors with the significant neurons and other features as a uniteds feature.
    
    actLevels: The activation levels.
    last_hidden_layerId: The last hidden layer Id.
    sig_neuron_indices: The significant neurons indices for this layer.
    """
    # Get the number of neurons
    last_hidden_actLevels = actLevels['actLevel'][last_hidden_layerId]
    nb_neurons = last_hidden_actLevels.shape[1]
    # Determine the significant neuron indices (The sort here just to ensure that the indices are sorted)
    sorted_sig_neuron_indices = sorted(sig_neuron_indices)
    # Build the zs         
    sig_neurons = last_hidden_actLevels[:, sorted_sig_neuron_indices]
    sig_vec_norms = np.linalg.norm(sig_neurons, axis=1, keepdims=True)
    sig_zs = sig_neurons / sig_vec_norms

    return sig_zs

"""
New ideas: Adjust the knn scores according to the prediction probabilities.
"""
def k_nearst_neighbor_scores_inner_product(train_index, test_zs, k):
    """
    This function obtains the k-nearst neighbor scores with the cosine similarity.

    train_index: The "faiss" index for similarity search. (Should be "IndexFlatIP")
    test_zs: The normalized test set feature vectors that contain potentially OOD examples.
    k: The number of considered nearst neighbors.
    """
    # D represents distances, I represents index
    D, I = faiss_knn_search(train_index, test_zs, k, display=True)
    S = D[:,-1]

    return S

def k_nearst_neighbor_scores_inner_product_batch_ver(train_index, test_zs, k, batch_size=100):
    """
    This function obtains the k-nearst neighbor scores with the cosine similarity.

    train_index: The "faiss" index for similarity search. (Should be "IndexFlatIP")
    test_zs: The normalized test set feature vectors that contain potentially OOD examples.
    k: The number of considered nearst neighbors.
    batch_size: The number of examples for each batch when doing the search.
    """
    # D represents distances, I represents index
    D, I = faiss_knn_search_batch_ver(train_index, test_zs, k, batch_size=batch_size, display=True)
    S = D[:,-1]

    return S

# def k_nearst_neighbor_scores_prob_adjust(train_index, test_actLevels, test_zs, k, predict_class_ver=False):
#     """
#     This function obtains the k-nearst neighbor scores. (Not performant)

#     train_index: The "faiss" index for similarity search.
#     test_actLevels: The activation level information that contains also the original and predicted class.
#     test_zs: The normalized test set feature vectors that contain potentially OOD examples.
#     k: The number of considered nearst neighbors.
#     predict_class_ver: Use the probabilities from the predicted classes or the original classes.
#     """
#     # D represents distances, I represents index
#     D, I = faiss_knn_search(train_index, test_zs, k, display=True)
#     # Get the probabilities
#     P = test_actLevels['prob']
#     # Take the correct probabilities for the adjustment
#     A = None
#     if predict_class_ver:
#         A = P[np.arange(test_actLevels['prob'].shape[0]), test_actLevels['predict_class'].reshape(-1)]
#     else:
#         A = P[np.arange(test_actLevels['prob'].shape[0]), test_actLevels['class'].reshape(-1)]
#     # Evaluate the scores (S)
#     S = -(D[:,-1]*(1/A))

#     return S