{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7e88ded8",
   "metadata": {},
   "source": [
    "### *Module Loading*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0edbfb6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import faiss\n",
    "import pickle\n",
    "from subprocess import PIPE, run\n",
    "from IPython.display import display as ip_display"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ccbc82f",
   "metadata": {},
   "source": [
    "### *External Module Loading*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcf4ad4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "external_modules_path = '..\\\\nn_likelihood_modules'\n",
    "sys.path.append(external_modules_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe63d714",
   "metadata": {},
   "outputs": [],
   "source": [
    "from common_imports import *\n",
    "from common_use_functions import *\n",
    "from constant import *\n",
    "from experim_neural_network import *\n",
    "from experim_preparation import *\n",
    "from generate_activation_level import *\n",
    "from pytorch_model_predict import *\n",
    "from ResNet_OOD import *\n",
    "from imagenet_data_prep import *\n",
    "from sensitivity_analysis import *\n",
    "from deep_KNN import *\n",
    "from novelty_data_prep import *\n",
    "from activation_level_processing import *\n",
    "from CoNNGuide_sensitivity_indices import *\n",
    "from neuron_importance_analysis import *\n",
    "from multscore_utils import *\n",
    "from CoNNGuide_OOD_datasets import *\n",
    "from OOD_score_utils import *\n",
    "from DICE_OOD_score import *\n",
    "from knn_search_GPU import *\n",
    "from RegNet import *\n",
    "from sota_ood_scores import *\n",
    "from pytorch_training_preparation import *"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b8dc22ef",
   "metadata": {},
   "source": [
    "### *GPU verification*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "930c998b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the GPU\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "nb_gpu = torch.cuda.device_count()\n",
    "if nb_gpu > 0:\n",
    "    print(torch.cuda.get_device_name(0))\n",
    "else:\n",
    "    print(\"CPU\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0701537f",
   "metadata": {},
   "source": [
    "### *Working directory*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89c6568b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Current path\n",
    "current_path = os.path.abspath(os.getcwd())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f5c8167",
   "metadata": {},
   "source": [
    "### *Load configurations and data*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cd2037f",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "All the parameters in this part should be configured\n",
    "\"\"\"\n",
    "# Experience path\n",
    "experim_path = current_path\n",
    "\n",
    "# File extensions\n",
    "json_ext = '.json'\n",
    "np_ext = '.npy'\n",
    "pkl_ext = '.pkl'\n",
    "csv_ext = '.csv'\n",
    "\n",
    "# Network model prefix\n",
    "model_name_prefix = 'imagenet'\n",
    "\n",
    "# Image max pixel value\n",
    "image_max_pix_val = 255\n",
    "\n",
    "# Tested sets name\n",
    "train_set_name = 'train'\n",
    "test_set_name = 'test'\n",
    "valid_set_name = 'valid'\n",
    "input_extension = 'X'\n",
    "label_extension = 'Y'\n",
    "\n",
    "# Save paths\n",
    "model_save_path = path_join(experim_path, 'experim_models_imagenet')\n",
    "\n",
    "\"\"\"\n",
    "The following parameters should be configured according to your experiments\n",
    "\"\"\"\n",
    "\n",
    "# Trained model name \n",
    "trained_net_name = 'imagenet_regnet'\n",
    "\n",
    "# Network related params\n",
    "net_model_name = 'regnet'\n",
    "\n",
    "# Dataset general informations\n",
    "data_set_infos = {\n",
    "    'nb_classes' : 1000\n",
    "}\n",
    "\n",
    "# The maximum number of considered k-nearst neighbors\n",
    "k_max = 1000\n",
    "\n",
    "# The path to the original imagenet data\n",
    "data_path = 'Your path here'\n",
    "\n",
    "# The path to the preprocessed feature vectors\n",
    "precompute_data_path = path_join(experim_path, 'precomputed_data_'+trained_net_name)\n",
    "\n",
    "# The number of features\n",
    "nb_vars = 3024\n",
    "\n",
    "# The batch size of the knn search\n",
    "knn_batch_size = 50\n",
    "\n",
    "\"\"\"\n",
    "\"\"\"\n",
    "# The output folder\n",
    "output_path = path_join(experim_path, 'output_CoNNGuide_NNGuide++_RegNet')\n",
    "\n",
    "# Build the class list\n",
    "class_list = list(range(data_set_infos['nb_classes']))\n",
    "\n",
    "# Batch size for the dataloader creation\n",
    "torch_batch_size = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab83bc84",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the folders\n",
    "create_directory(output_path)\n",
    "create_directory(precompute_data_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a1ae1fb",
   "metadata": {},
   "source": [
    "### *Experiment preparation*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55636cc8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the dataset\n",
    "imagenet_train_dataset, imagenet_test_dataset = get_imagenet_dataset_without_transform_for_regnet(data_path, normalize=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b9931900",
   "metadata": {},
   "source": [
    "### *Load the trained Network*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "132f3376",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the network\n",
    "trained_net = load_model_by_net_name(model_save_path, trained_net_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "959a79b0",
   "metadata": {},
   "source": [
    "### *Move the model to GPU*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1a9c01d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Move to gpu\n",
    "trained_net.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b13de364",
   "metadata": {},
   "source": [
    "### *Evaluate the activation levels for the original training and test sets*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a1d3c3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the training set activation levels\n",
    "train_actLevels = None\n",
    "train_actLevels_file_path = path_join(precompute_data_path, 'train'+pkl_ext)\n",
    "if file_or_folder_existence(train_actLevels_file_path):\n",
    "    # Load the data     \n",
    "    with open(train_actLevels_file_path, 'rb') as f:\n",
    "        train_actLevels = pickle.load(f)\n",
    "    # Delete the not used probability data\n",
    "    train_actLevels.pop('prob', None)\n",
    "else:\n",
    "    # Dataloader building\n",
    "    train_loader = create_loader_from_torch_dataset(imagenet_train_dataset, batch_size=torch_batch_size, shuffle=False, num_workers=4)\n",
    "    # Get the training set activation levels \n",
    "    train_actLevels = obtain_activation_levels(trained_net,\n",
    "                                               train_loader, 'train', with_predict_class=True, loss_type='cross_entropy')\n",
    "    # Evaluate the probabilities\n",
    "    train_actLevels = probability_evaluation(train_actLevels)\n",
    "    # Save the activation levels\n",
    "    with open(train_actLevels_file_path, 'wb') as f:\n",
    "        pickle.dump(train_actLevels, f, protocol=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80699289",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Take only the desired number of images for each class\n",
    "# # The desired number of examples per class\n",
    "# select_nb_examples = 700\n",
    "# # Find the class index dict\n",
    "# train_class_index_dict = class_index_dict_build(train_actLevels['class'].reshape(-1))\n",
    "# # Generate the random indices\n",
    "# selected_train_indices = []\n",
    "# for classId in train_class_index_dict:\n",
    "#     selected_train_indices.extend(random.sample(train_class_index_dict[classId], select_nb_examples))\n",
    "# selected_train_indices = sorted(selected_train_indices)\n",
    "# # Get the selected activation levels\n",
    "# train_actLevels = build_selected_actLevels(train_actLevels, selected_train_indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "780bbfe1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the training set activation levels\n",
    "test_actLevels = None\n",
    "test_actLevels_file_path = path_join(precompute_data_path, 'test'+pkl_ext)\n",
    "if file_or_folder_existence(test_actLevels_file_path):\n",
    "    # Load the data\n",
    "    with open(test_actLevels_file_path, 'rb') as f:\n",
    "        test_actLevels = pickle.load(f)\n",
    "    # Delete the not used probability data\n",
    "    test_actLevels.pop('prob', None)\n",
    "else:\n",
    "    # Dataloader building\n",
    "    test_loader = create_loader_from_torch_dataset(imagenet_test_dataset, batch_size=torch_batch_size, shuffle=False, num_workers=4)\n",
    "    # Get the test set activation levels \n",
    "    test_actLevels = obtain_activation_levels(trained_net,\n",
    "                                               test_loader, 'test', with_predict_class=True, loss_type='cross_entropy')\n",
    "    # Evaluate the probabilities\n",
    "    test_actLevels = probability_evaluation(test_actLevels)\n",
    "    # Save the activation levels\n",
    "    with open(test_actLevels_file_path, 'wb') as f:\n",
    "        pickle.dump(test_actLevels, f, protocol=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa2a18f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Reduce the memory use\n",
    "del imagenet_train_dataset, imagenet_test_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f92c1a1",
   "metadata": {},
   "source": [
    "### *Evaluate the activation levels for the novelty sets*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bcf9c2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The considered novelty types\n",
    "ood_types = ['dtd', 'inat', 'places', 'sun', 'openimage']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c546b0b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the novelty datasets\n",
    "novelty_datasets = {}\n",
    "for ood_type in ood_types:\n",
    "    novelty_datasets[ood_type] = get_imagenet_ood_datasets_for_regnet(set_name=ood_type, normalize=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a994f3aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Iterate over the ood sets to generate the activation levels\n",
    "novelty_actLevels = {}\n",
    "for ood_type in ood_types:\n",
    "    current_novelty_actLevels_file_path = path_join(precompute_data_path, ood_type+pkl_ext)\n",
    "    if file_or_folder_existence(current_novelty_actLevels_file_path):\n",
    "        # Load the data\n",
    "        with open(current_novelty_actLevels_file_path, 'rb') as f:\n",
    "            novelty_actLevels[ood_type] = pickle.load(f)\n",
    "        # Delete the not used probability data\n",
    "        novelty_actLevels[ood_type].pop('prob', None)\n",
    "    else:\n",
    "        # Dataloader building\n",
    "        current_novelty_loader = create_loader_from_torch_dataset(novelty_datasets[ood_type], batch_size=torch_batch_size, shuffle=False, num_workers=4)\n",
    "        # Get the test set activation levels \n",
    "        current_novelty_actLevels = obtain_activation_levels(trained_net, current_novelty_loader, ood_type,\n",
    "                                                               with_predict_class=True, loss_type='cross_entropy')\n",
    "        # Evaluate the probabilities\n",
    "        current_novelty_actLevels = probability_evaluation(current_novelty_actLevels)\n",
    "        # Save the activation levels\n",
    "        with open(current_novelty_actLevels_file_path, 'wb') as f:\n",
    "            pickle.dump(current_novelty_actLevels, f, protocol=4)\n",
    "        # Assign the activation levels\n",
    "        novelty_actLevels[ood_type] = current_novelty_actLevels"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe065212",
   "metadata": {},
   "source": [
    "### *Identify the activation threshold*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83c1523c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the last hidden layer Id\n",
    "last_hidden_layerId = list(train_actLevels['actLevel'].keys())[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71faeb67",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Determine the activation threshold for the ReAct (May use the correctly predicted examples for more precise evaluation)\n",
    "act_threshold = np.percentile(train_actLevels['actLevel'][last_hidden_layerId], 85)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7dc4f22e",
   "metadata": {},
   "source": [
    "### *Build the activation levels for the correctly predicted training set examples*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a848e3bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the correctly predicted training set examples\n",
    "correct_train_actLevels = build_correct_actLevels(train_actLevels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32aa1326",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Free the memory for the original training set activation levels\n",
    "del train_actLevels\n",
    "_ = gc.collect()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff7df19d",
   "metadata": {},
   "source": [
    "### *Get the parameters of the model*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c03cac8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build the correspondance dictionary for the examples of each class\n",
    "class_index_dict = class_index_dict_build(correct_train_actLevels['class'].reshape(-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6bb2b81",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the last layer parameters\n",
    "model_params = get_model_parameters(trained_net, to_numpy=True)\n",
    "final_linear_params = model_params['fc']\n",
    "nb_vars = final_linear_params['weight'].shape[1]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee8c36a6",
   "metadata": {},
   "source": [
    "### *Evaluate the knn correction factors with ablation study*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13097ca7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Register the evaluated knn correction factors\n",
    "# Build the training set normalized feature vectors (only the correctly predicted examples)\n",
    "correct_train_zs = normalize_feature_vecs_knn(correct_train_actLevels, last_hidden_layerId)\n",
    "# Build the test set normalized feature vectors\n",
    "test_zs = normalize_feature_vecs_knn(test_actLevels, last_hidden_layerId)\n",
    "# Get the normalized feature vectors of the novelty ood set\n",
    "novelty_zs = {}\n",
    "for ood_type in novelty_actLevels:\n",
    "    novelty_zs[ood_type] = normalize_feature_vecs_knn(novelty_actLevels[ood_type], last_hidden_layerId)\n",
    "# Calculate or load the knn scores of the correctly predicted training set examples set \n",
    "knn_k = None\n",
    "# Determine the correct k for the evaluation\n",
    "correct_train_knn_dists, _ = knn_search_IP_GPU(correct_train_zs, correct_train_zs,\n",
    "                                               batch_size=knn_batch_size, k=k_max,\n",
    "                                               display=True, half_precision=True)\n",
    "# Estimate the k for the correction factors and example confidence score     \n",
    "knn_k = estimate_dense_k(correct_train_knn_dists, verify_steps=3, variation_threshold=0.05, min_k=5, smooth_sigma=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df76c24c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute the factors\n",
    "test_knn_factors = None\n",
    "test_knn_factors_file_path = path_join(precompute_data_path, 'test_knn_factors'+pkl_ext)\n",
    "if file_or_folder_existence(test_knn_factors_file_path):\n",
    "    # Load the data\n",
    "    with open(test_knn_factors_file_path, 'rb') as f:\n",
    "        test_knn_factors = pickle.load(f)\n",
    "else:\n",
    "    # Calculate or load the test set knn scores\n",
    "    test_knn_dists, _ = knn_search_IP_GPU(correct_train_zs, test_zs,\n",
    "                                                         batch_size=knn_batch_size, k=knn_k, display=True,\n",
    "                                                         half_precision=True)\n",
    "    test_knn_factors = np.mean(test_knn_dists[:, :knn_k], axis=1)\n",
    "    # Save the data\n",
    "    with open(test_knn_factors_file_path, 'wb') as f:\n",
    "        pickle.dump(test_knn_factors, f, protocol=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5307b8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute the factors\n",
    "novelty_knn_factors = None\n",
    "novelty_knn_factors_file_path = path_join(precompute_data_path, 'novelty_knn_factors'+pkl_ext)\n",
    "if file_or_folder_existence(novelty_knn_factors_file_path):\n",
    "    # Load the data\n",
    "    with open(novelty_knn_factors_file_path, 'rb') as f:\n",
    "        novelty_knn_factors = pickle.load(f)\n",
    "else:\n",
    "    # Calculate or load the knn scores for the novelty sets\n",
    "    novelty_knn_factors = {}\n",
    "    for ood_type in novelty_zs:\n",
    "        current_novelty_knn_dists, _ = knn_search_IP_GPU(correct_train_zs, novelty_zs[ood_type],\n",
    "                                                           batch_size=knn_batch_size, k=knn_k, display=True,\n",
    "                                                           half_precision=True)\n",
    "        novelty_knn_factors[ood_type] = np.mean(current_novelty_knn_dists[:, :knn_k], axis=1)\n",
    "    # Save the data\n",
    "    with open(novelty_knn_factors_file_path, 'wb') as f:\n",
    "        pickle.dump(novelty_knn_factors, f, protocol=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11929b68",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Free the memory\n",
    "del correct_train_knn_dists\n",
    "del correct_train_zs, test_zs, novelty_zs\n",
    "_ = gc.collect()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "463e047f",
   "metadata": {},
   "source": [
    "### *Evaluate the global importance of the neurons*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "646e94e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the overall significance of neurons\n",
    "entropy_contrib_vec = None\n",
    "entropy_contrib_vec_file_path = path_join(precompute_data_path, 'entropy_contrib_vec'+pkl_ext)\n",
    "if file_or_folder_existence(entropy_contrib_vec_file_path):\n",
    "    # Load the data\n",
    "    with open(entropy_contrib_vec_file_path, 'rb') as f:\n",
    "        entropy_contrib_vec = pickle.load(f)\n",
    "else:\n",
    "    entropy_contrib_vec = unified_entropy_score_evaluation_GPU(correct_train_actLevels, final_linear_params,\n",
    "                                                                   last_hidden_layerId, batch_size=1000, block_size=32)\n",
    "    # Save the data\n",
    "    with open(entropy_contrib_vec_file_path, 'wb') as f:\n",
    "        pickle.dump(entropy_contrib_vec, f, protocol=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "63e73898",
   "metadata": {},
   "source": [
    "### *Clip the activations*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69f5645b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Execute the activation clip\n",
    "correct_train_actLevels = clip_activations(correct_train_actLevels, last_hidden_layerId, act_threshold)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01cb16f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Execute the activation clip\n",
    "test_actLevels = clip_activations(test_actLevels, last_hidden_layerId, act_threshold)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4de15f25",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Clip the activations\n",
    "for ood_type in novelty_actLevels:\n",
    "    novelty_actLevels[ood_type] = clip_activations(novelty_actLevels[ood_type], last_hidden_layerId, act_threshold)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60467384",
   "metadata": {},
   "source": [
    "### *Evaluate the sensitivity indices*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86714bfc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the shapley matrix\n",
    "shapley_contrib_matrix = None\n",
    "shapley_contrib_matrix_file_path = path_join(precompute_data_path, 'shapley_contrib_matrix'+pkl_ext)\n",
    "if file_or_folder_existence(shapley_contrib_matrix_file_path):\n",
    "    # Load the data\n",
    "    with open(shapley_contrib_matrix_file_path, 'rb') as f:\n",
    "        shapley_contrib_matrix = pickle.load(f)\n",
    "else:\n",
    "    shapley_contrib_matrix = shapley_score_evaluation_batch_GPU(correct_train_actLevels, final_linear_params, last_hidden_layerId, \n",
    "                                                                predict_class_ver=False,\n",
    "                                                                batch_size=1000, block_size=32)\n",
    "    # Save the data\n",
    "    with open(shapley_contrib_matrix_file_path, 'wb') as f:\n",
    "        pickle.dump(shapley_contrib_matrix, f, protocol=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c180f47",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the p-shapley matrix\n",
    "prob_shapley_contrib_matrix = None\n",
    "prob_shapley_contrib_matrix_file_path = path_join(precompute_data_path, 'prob_shapley_contrib_matrix'+pkl_ext)\n",
    "if file_or_folder_existence(prob_shapley_contrib_matrix_file_path):\n",
    "    # Load the data\n",
    "    with open(prob_shapley_contrib_matrix_file_path, 'rb') as f:\n",
    "        prob_shapley_contrib_matrix = pickle.load(f)\n",
    "else:\n",
    "    prob_shapley_contrib_matrix = prob_shapley_score_evaluation_GPU(correct_train_actLevels, final_linear_params,\n",
    "                                                          last_hidden_layerId, predict_class_ver=False, block_size=32)\n",
    "    # Save the data\n",
    "    with open(prob_shapley_contrib_matrix_file_path, 'wb') as f:\n",
    "        pickle.dump(prob_shapley_contrib_matrix, f, protocol=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "667b2467",
   "metadata": {},
   "source": [
    "### *Evaluate the normalized logits for the LOG factor evaluation*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2581ed10",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The potential pruning rate list\n",
    "LOG_p_n_list = [0, 1, 5, 10]\n",
    "LOG_p_w_list = [0, 1, 5, 10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d05ef111",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the correct activation levels for the logit evaluation\n",
    "LOG_correct_train_vecs = correct_train_actLevels['actLevel'][last_hidden_layerId]\n",
    "LOG_test_vecs = test_actLevels['actLevel'][last_hidden_layerId]\n",
    "LOG_novelty_vecs = {}\n",
    "for ood_type in novelty_actLevels:\n",
    "    LOG_novelty_vecs[ood_type] = novelty_actLevels[ood_type]['actLevel'][last_hidden_layerId]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "097b8505",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Register the evaluated LOG correction factors\n",
    "test_LOG_factors = {}\n",
    "novelty_LOG_factors = {}\n",
    "test_LOG_factors_file_path = path_join(precompute_data_path, 'test_LOG_factors'+pkl_ext)\n",
    "novelty_LOG_factors_file_path = path_join(precompute_data_path, 'novelty_LOG_factors'+pkl_ext)\n",
    "if file_or_folder_existence(test_LOG_factors_file_path) and file_or_folder_existence(novelty_LOG_factors_file_path):\n",
    "    # Load the data\n",
    "    with open(test_LOG_factors_file_path, 'rb') as f:\n",
    "        test_LOG_factors = pickle.load(f)\n",
    "    with open(novelty_LOG_factors_file_path, 'rb') as f:\n",
    "        novelty_LOG_factors = pickle.load(f)\n",
    "else:\n",
    "    # Compute the factors     \n",
    "    for test_LOG_p_n in LOG_p_n_list:\n",
    "        # Initialize the values\n",
    "        test_LOG_factors[test_LOG_p_n] = {}\n",
    "        novelty_LOG_factors[test_LOG_p_n] = {}\n",
    "        for test_LOG_p_w in tqdm(LOG_p_w_list, desc='Processed pruning rates'):\n",
    "            # Initialize the values\n",
    "            novelty_LOG_factors[test_LOG_p_n][test_LOG_p_w] = {}\n",
    "            # Get the weight pruning mask\n",
    "            LOG_feat_prun_mask = pruning_mask(entropy_contrib_vec, p=test_LOG_p_n)\n",
    "            LOG_important_neurons = [index for index, value in enumerate(LOG_feat_prun_mask) if value == 1]\n",
    "            LOG_weight_contrib = DICE_weight_contribution_matrix_sig_and_torch_ver(final_linear_params, LOG_correct_train_vecs,\n",
    "                                                                                   LOG_important_neurons, batch_size=1000)\n",
    "            LOG_weight_prun_mask = build_weight_pruning_masks_sig_ver(final_linear_params['weight'].shape,\n",
    "                                                                    LOG_weight_contrib, LOG_important_neurons, p=test_LOG_p_w)\n",
    "            # Get the logits\n",
    "            LOG_correct_train_logits = get_logits_torch_ver(LOG_correct_train_vecs, final_linear_params,\n",
    "                                                            LOG_feat_prun_mask, LOG_weight_prun_mask, batch_size=500, display=False)\n",
    "            LOG_test_logits = get_logits_torch_ver(LOG_test_vecs, final_linear_params,\n",
    "                                                   LOG_feat_prun_mask, LOG_weight_prun_mask, batch_size=500, display=False)\n",
    "            LOG_novelty_logits = {}\n",
    "            for ood_type in LOG_novelty_vecs:\n",
    "                LOG_novelty_logits[ood_type] = get_logits_torch_ver(LOG_novelty_vecs[ood_type], final_linear_params,\n",
    "                                                                    LOG_feat_prun_mask, LOG_weight_prun_mask, batch_size=500, display=False)\n",
    "            # Evaluate the normalized logits\n",
    "            LOG_correct_train_norm_logits = LOG_correct_train_logits / np.linalg.norm(LOG_correct_train_logits, axis=1, keepdims=True)\n",
    "            LOG_test_norm_logits = LOG_test_logits / np.linalg.norm(LOG_test_logits, axis=1, keepdims=True)\n",
    "            LOG_novelty_norm_logits = {}\n",
    "            for ood_type in LOG_novelty_logits:\n",
    "                # Compute the scores\n",
    "                LOG_novelty_norm_logits[ood_type] = LOG_novelty_logits[ood_type] / np.linalg.norm(LOG_novelty_logits[ood_type],\n",
    "                                                                                                  axis=1, keepdims=True)\n",
    "            # Free the memory\n",
    "            del LOG_correct_train_logits, LOG_test_logits, LOG_novelty_logits\n",
    "            del LOG_feat_prun_mask, LOG_weight_contrib, LOG_weight_prun_mask\n",
    "            torch.cuda.empty_cache()\n",
    "            _ = gc.collect()\n",
    "            ### Evaluate the LOG correction factors\n",
    "            ## Determine the correct k for the evaluation\n",
    "            # Training set        \n",
    "            correct_train_LOG_dists,_ = knn_search_IP_GPU(LOG_correct_train_norm_logits, LOG_correct_train_norm_logits, batch_size=knn_batch_size,\n",
    "                                                           k=k_max, display=False, half_precision=True)\n",
    "            LOG_k = estimate_dense_k(correct_train_LOG_dists, verify_steps=3, variation_threshold=0.1, min_k=5, smooth_sigma=0)\n",
    "            # Test set\n",
    "            test_LOG_dists, _ = knn_search_IP_GPU(LOG_correct_train_norm_logits, LOG_test_norm_logits, batch_size=knn_batch_size,\n",
    "                                                           k=LOG_k, display=False, half_precision=True)\n",
    "            test_LOG_factors[test_LOG_p_n][test_LOG_p_w] = np.mean(test_LOG_dists[:, :LOG_k], axis=1)\n",
    "            # Novelty sets\n",
    "            for ood_type in LOG_novelty_norm_logits:\n",
    "                # Compute the scores\n",
    "                current_novelty_LOG_dists, _ = knn_search_IP_GPU(LOG_correct_train_norm_logits,  LOG_novelty_norm_logits[ood_type],\n",
    "                                                                   batch_size=knn_batch_size,\n",
    "                                                                   k=LOG_k, display=False, half_precision=True)\n",
    "                novelty_LOG_factors[test_LOG_p_n][test_LOG_p_w][ood_type] = np.mean(current_novelty_LOG_dists[:, :LOG_k], axis=1)\n",
    "            # Free the memory\n",
    "            del LOG_correct_train_norm_logits, LOG_test_norm_logits, LOG_novelty_norm_logits\n",
    "            del correct_train_LOG_dists, test_LOG_dists, current_novelty_LOG_dists\n",
    "            _ = gc.collect()\n",
    "    # Save the data\n",
    "    with open(test_LOG_factors_file_path, 'wb') as f:\n",
    "        pickle.dump(test_LOG_factors, f, protocol=4)\n",
    "    with open(novelty_LOG_factors_file_path, 'wb') as f:\n",
    "        pickle.dump(novelty_LOG_factors, f, protocol=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a9bad1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Clear the GPU memory\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43a5f97c",
   "metadata": {},
   "source": [
    "### *Prepare the feature vectors with all neurons*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64f7570c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the correctly predicted training set feature vectors\n",
    "correct_train_vecs = correct_train_actLevels['actLevel'][last_hidden_layerId]\n",
    "correct_train_preds = correct_train_actLevels['predict_class'].reshape(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44f3092a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build the test set feature vectors\n",
    "test_vecs = test_actLevels['actLevel'][last_hidden_layerId]\n",
    "test_preds = test_actLevels['predict_class'].reshape(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0dcd8e12",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the feature vectors of the novelty ood set\n",
    "novelty_vecs = {}\n",
    "novelty_preds = {}\n",
    "for ood_type in novelty_actLevels:\n",
    "    novelty_vecs[ood_type] = novelty_actLevels[ood_type]['actLevel'][last_hidden_layerId]\n",
    "    novelty_preds[ood_type] = novelty_actLevels[ood_type]['predict_class'].reshape(-1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "106d4c17",
   "metadata": {},
   "source": [
    "### *Evaluate the quality for various combination of pruning rates*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69727f84",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The possible values of p_n and p_w\n",
    "p_n_list = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90]\n",
    "p_w_list = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74c3559f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluates the original energy score with feature and weight pruning\n",
    "test_DICE_scores = {}\n",
    "novelty_DICE_scores = {}\n",
    "test_DICE_scores_file_path = path_join(precompute_data_path, 'test_DICE_scores'+pkl_ext)\n",
    "novelty_DICE_scores_file_path = path_join(precompute_data_path, 'novelty_DICE_scores'+pkl_ext)\n",
    "if file_or_folder_existence(test_DICE_scores_file_path) and file_or_folder_existence(novelty_DICE_scores_file_path):\n",
    "    # Load the data\n",
    "    with open(test_DICE_scores_file_path, 'rb') as f:\n",
    "        test_DICE_scores = pickle.load(f)\n",
    "    with open(novelty_DICE_scores_file_path, 'rb') as f:\n",
    "        novelty_DICE_scores = pickle.load(f)\n",
    "else:\n",
    "    for test_p_n in p_n_list:\n",
    "        test_DICE_scores[test_p_n] = {}\n",
    "        novelty_DICE_scores[test_p_n] = {}\n",
    "        for test_p_w in tqdm(p_w_list, desc='Processed weight pruning rates'):\n",
    "            feature_pruning_masks, weight_pruning_masks = get_sensitivity_DICE_pruning_masks_torch_ver(final_linear_params,\n",
    "                                                                                                        correct_train_vecs,\n",
    "                                                                                                        class_index_dict,\n",
    "                                                                                                        prob_shapley_contrib_matrix,\n",
    "                                                                                                        p_n=test_p_n, p_w=test_p_w,\n",
    "                                                                                                        display=False)\n",
    "            # Get the test set logits\n",
    "            test_logits = get_logits_by_preds_torch_ver(test_vecs, test_preds,\n",
    "                                                        final_linear_params,\n",
    "                                                        feature_pruning_masks,\n",
    "                                                        weight_pruning_masks,\n",
    "                                                        display=False)\n",
    "            # Get the logits for the novelty ood sets\n",
    "            novelty_logits = {}\n",
    "            for ood_type in novelty_vecs:\n",
    "                novelty_logits[ood_type]  = get_logits_by_preds_torch_ver(novelty_vecs[ood_type], \n",
    "                                                                           novelty_preds[ood_type],\n",
    "                                                                           final_linear_params,\n",
    "                                                                           feature_pruning_masks,\n",
    "                                                                           weight_pruning_masks,\n",
    "                                                                           display=False)\n",
    "            # Free the memory for the masks\n",
    "            del feature_pruning_masks, weight_pruning_masks\n",
    "            _ = gc.collect()\n",
    "\n",
    "            ## Evaluate the DICE scores\n",
    "            # Test set         \n",
    "            test_DICE_scores[test_p_n][test_p_w] = DICE_score_evaluation_logit_ver(test_logits)\n",
    "            # Novelty sets\n",
    "            novelty_DICE_scores[test_p_n][test_p_w] = {}\n",
    "            for ood_type in novelty_logits:\n",
    "                # Compute the scores\n",
    "                novelty_DICE_scores[test_p_n][test_p_w][ood_type] = DICE_score_evaluation_logit_ver(novelty_logits[ood_type])\n",
    "\n",
    "            # Free the memory\n",
    "            del test_logits, novelty_logits\n",
    "            _ = gc.collect()\n",
    "    # Save the data\n",
    "    with open(test_DICE_scores_file_path, 'wb') as f:\n",
    "        pickle.dump(test_DICE_scores, f, protocol=4)\n",
    "    with open(novelty_DICE_scores_file_path, 'wb') as f:\n",
    "        pickle.dump(novelty_DICE_scores, f, protocol=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "405dd261",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluates the quality\n",
    "best_avg_FPR = 1\n",
    "best_LOG_p_n = None\n",
    "best_LOG_p_w = None\n",
    "best_p_n = None\n",
    "best_p_w = None\n",
    "best_evaluation_results = None\n",
    "final_evaluation_results = {}\n",
    "# Take the guidance factors\n",
    "feat_test_knn_factors = test_knn_factors\n",
    "feat_novelty_knn_factors = novelty_knn_factors\n",
    "for LOG_p_n in tqdm(LOG_p_n_list, desc='Processed logit guidance pruning rates'):\n",
    "    # Initialize the values\n",
    "    final_evaluation_results[LOG_p_n] = {}\n",
    "    for LOG_p_w in LOG_p_w_list:\n",
    "        # Take the guidance factors\n",
    "        LOG_test_knn_factors = test_LOG_factors[LOG_p_n][LOG_p_w]\n",
    "        LOG_novelty_knn_factors = novelty_LOG_factors[LOG_p_n][LOG_p_w]\n",
    "        # Get the performance evaluation result\n",
    "        current_best_p_n, current_best_p_w, current_evaluation_results = evaluate_best_performance_with_double_guidance_imagenet(test_DICE_scores,\n",
    "                                                                                                    novelty_DICE_scores,\n",
    "                                                                                                    feat_test_knn_factors,\n",
    "                                                                                                    LOG_test_knn_factors,\n",
    "                                                                                                    feat_novelty_knn_factors,\n",
    "                                                                                                    LOG_novelty_knn_factors, \n",
    "                                                                                                    p_n_list, p_w_list)\n",
    "        # Register the result\n",
    "        final_evaluation_results[LOG_p_n][LOG_p_w] = {}\n",
    "        final_evaluation_results[LOG_p_n][LOG_p_w]['Detail'] = current_evaluation_results\n",
    "        final_evaluation_results[LOG_p_n][LOG_p_w]['Best_P_N'] = current_best_p_n\n",
    "        final_evaluation_results[LOG_p_n][LOG_p_w]['Best_P_W'] = current_best_p_w\n",
    "        final_evaluation_results[LOG_p_n][LOG_p_w]['Best_FPR'] = current_evaluation_results[current_best_p_n][current_best_p_w]['FPR']\n",
    "        final_evaluation_results[LOG_p_n][LOG_p_w]['Best_AUROC'] = current_evaluation_results[current_best_p_n][current_best_p_w]['AUROC']\n",
    "        final_evaluation_results[LOG_p_n][LOG_p_w]['Best_FPR_curated'] = current_evaluation_results[current_best_p_n][current_best_p_w]['FPR_curated']\n",
    "        final_evaluation_results[LOG_p_n][LOG_p_w]['Best_AUROC_curated'] = current_evaluation_results[current_best_p_n][current_best_p_w]['AUROC_curated']\n",
    "        # Update the best evaluation result\n",
    "        if current_evaluation_results[current_best_p_n][current_best_p_w]['FPR_curated'] < best_avg_FPR:\n",
    "            best_avg_FPR = current_evaluation_results[current_best_p_n][current_best_p_w]['FPR_curated']\n",
    "            best_LOG_p_n = LOG_p_n\n",
    "            best_LOG_p_w = LOG_p_w\n",
    "            best_p_n = current_best_p_n\n",
    "            best_p_w = current_best_p_w\n",
    "            best_evaluation_results = current_evaluation_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4caa4fc2",
   "metadata": {},
   "outputs": [],
   "source": [
    "best_p_w"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf3926d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display the results\n",
    "FPR_curated_results = []\n",
    "AUROC_curated_results = []\n",
    "for test_p_n in best_evaluation_results:\n",
    "    FPR_curated_results.append([best_evaluation_results[test_p_n][test_p_w]['FPR_curated'] for test_p_w in best_evaluation_results[test_p_n]])\n",
    "    AUROC_curated_results.append([best_evaluation_results[test_p_n][test_p_w]['AUROC_curated'] for test_p_w in best_evaluation_results[test_p_n]])\n",
    "FPR_curated_df = pd.DataFrame(data=np.array(FPR_curated_results),index=p_n_list, columns=p_w_list)\n",
    "AUROC_curated_df = pd.DataFrame(data=np.array(AUROC_curated_results),index=p_n_list, columns=p_w_list)\n",
    "print('FPR_curated:')\n",
    "ip_display(FPR_curated_df)\n",
    "print('AUROC_curated:')\n",
    "ip_display(AUROC_curated_df)\n",
    "print('Best result:')\n",
    "print('The average performance: FPR95 (curated):', best_evaluation_results[best_p_n][best_p_w]['FPR_curated'],\n",
    "      \"AUROC (curated):\", best_evaluation_results[best_p_n][best_p_w]['AUROC_curated'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2dbe2514",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display the results\n",
    "FPR_results = []\n",
    "AUROC_results = []\n",
    "for test_p_n in best_evaluation_results:\n",
    "    FPR_results.append([best_evaluation_results[test_p_n][test_p_w]['FPR'] for test_p_w in best_evaluation_results[test_p_n]])\n",
    "    AUROC_results.append([best_evaluation_results[test_p_n][test_p_w]['AUROC'] for test_p_w in best_evaluation_results[test_p_n]])\n",
    "FPR_df = pd.DataFrame(data=np.array(FPR_results),index=p_n_list, columns=p_w_list)\n",
    "AUROC_df = pd.DataFrame(data=np.array(AUROC_results),index=p_n_list, columns=p_w_list)\n",
    "print('FPR:')\n",
    "ip_display(FPR_df)\n",
    "print('AUROC:')\n",
    "ip_display(AUROC_df)\n",
    "print('Best result:')\n",
    "print('The average performance: FPR95:', best_evaluation_results[best_p_n][best_p_w]['FPR'],\n",
    "      \"AUROC:\", best_evaluation_results[best_p_n][best_p_w]['AUROC'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7e2436f1",
   "metadata": {},
   "source": [
    "### *Evaluate the quality of the baseline method*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c319a1db",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluates the original energy score with feature and weight pruning\n",
    "test_baseline_scores = {}\n",
    "novelty_baseline_scores = {}\n",
    "test_baseline_scores_file_path = path_join(precompute_data_path, 'test_baseline_scores'+pkl_ext)\n",
    "novelty_baseline_scores_file_path = path_join(precompute_data_path, 'novelty_baseline_scores'+pkl_ext)\n",
    "if file_or_folder_existence(test_baseline_scores_file_path) and file_or_folder_existence(novelty_baseline_scores_file_path):\n",
    "    # Load the data\n",
    "    with open(test_baseline_scores_file_path, 'rb') as f:\n",
    "        test_baseline_scores = pickle.load(f)\n",
    "    with open(novelty_baseline_scores_file_path, 'rb') as f:\n",
    "        novelty_baseline_scores = pickle.load(f)\n",
    "else:\n",
    "    for test_p_n in p_n_list:\n",
    "        test_baseline_scores[test_p_n] = {}\n",
    "        novelty_baseline_scores[test_p_n] = {}\n",
    "        for test_p_w in tqdm(p_w_list, desc='Processed weight pruning rates'):\n",
    "            # Build the weight contribution matrix only for the significant neurons\n",
    "            feature_pruning_masks, weight_pruning_masks = get_sensitivity_DICE_pruning_masks_torch_ver(final_linear_params,\n",
    "                                                                                                        correct_train_vecs,\n",
    "                                                                                                        class_index_dict,\n",
    "                                                                                                        shapley_contrib_matrix,\n",
    "                                                                                                        p_n=test_p_n, p_w=test_p_w,\n",
    "                                                                                                        display=False)\n",
    "            # Get the test set logits\n",
    "            test_logits = get_logits_by_preds_torch_ver(test_vecs, test_preds,\n",
    "                                                        final_linear_params,\n",
    "                                                        feature_pruning_masks,\n",
    "                                                        weight_pruning_masks,\n",
    "                                                        display=False)\n",
    "            # Get the logits for the novelty ood sets\n",
    "            novelty_logits = {}\n",
    "            for ood_type in novelty_vecs:\n",
    "                novelty_logits[ood_type]  = get_logits_by_preds_torch_ver(novelty_vecs[ood_type], \n",
    "                                                                           novelty_preds[ood_type],\n",
    "                                                                           final_linear_params,\n",
    "                                                                           feature_pruning_masks,\n",
    "                                                                           weight_pruning_masks,\n",
    "                                                                           display=False)\n",
    "            # Free the memory for the masks\n",
    "            del feature_pruning_masks, weight_pruning_masks\n",
    "            _ = gc.collect()\n",
    "\n",
    "            ## Evaluate the DICE scores\n",
    "            # Test set         \n",
    "            test_baseline_scores[test_p_n][test_p_w] = DICE_score_evaluation_logit_ver(test_logits)\n",
    "            # Novelty sets\n",
    "            novelty_baseline_scores[test_p_n][test_p_w] = {}\n",
    "            for ood_type in novelty_logits:\n",
    "                # Compute the scores\n",
    "                novelty_baseline_scores[test_p_n][test_p_w][ood_type] = DICE_score_evaluation_logit_ver(novelty_logits[ood_type])\n",
    "\n",
    "            # Free the memory\n",
    "            del test_logits, novelty_logits\n",
    "            _ = gc.collect()\n",
    "    # Save the data\n",
    "    with open(test_baseline_scores_file_path, 'wb') as f:\n",
    "        pickle.dump(test_baseline_scores, f, protocol=4)\n",
    "    with open(novelty_baseline_scores_file_path, 'wb') as f:\n",
    "        pickle.dump(novelty_baseline_scores, f, protocol=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afb20464",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluates the baseline performance\n",
    "baseline_best_p_n, baseline_best_p_w, baseline_evaluation_results = evaluate_best_performance_with_single_guidance_imagenet(test_baseline_scores,\n",
    "                                                                                                                            novelty_baseline_scores,\n",
    "                                                                                                                            feat_test_knn_factors,\n",
    "                                                                                                                            feat_novelty_knn_factors,\n",
    "                                                                                                                            p_n_list, p_w_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05f0aa17",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display the results\n",
    "FPR_curated_results = []\n",
    "AUROC_curated_results = []\n",
    "for test_p_n in best_evaluation_results:\n",
    "    FPR_curated_results.append([baseline_evaluation_results[test_p_n][test_p_w]['FPR_curated'] for test_p_w in baseline_evaluation_results[test_p_n]])\n",
    "    AUROC_curated_results.append([baseline_evaluation_results[test_p_n][test_p_w]['AUROC_curated'] for test_p_w in baseline_evaluation_results[test_p_n]])\n",
    "FPR_curated_df = pd.DataFrame(data=np.array(FPR_curated_results),index=p_n_list, columns=p_w_list)\n",
    "AUROC_curated_df = pd.DataFrame(data=np.array(AUROC_curated_results),index=p_n_list, columns=p_w_list)\n",
    "print('FPR_curated:')\n",
    "ip_display(FPR_curated_df)\n",
    "print('AUROC_curated:')\n",
    "ip_display(AUROC_curated_df)\n",
    "print('Best result:')\n",
    "print('The average performance: FPR95 (curated):', baseline_evaluation_results[baseline_best_p_n][baseline_best_p_w]['FPR_curated'],\n",
    "      \"AUROC (curated):\", baseline_evaluation_results[baseline_best_p_n][baseline_best_p_w]['AUROC_curated'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b9c7127",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display the results\n",
    "FPR_results = []\n",
    "AUROC_results = []\n",
    "for test_p_n in best_evaluation_results:\n",
    "    FPR_results.append([baseline_evaluation_results[test_p_n][test_p_w]['FPR'] for test_p_w in baseline_evaluation_results[test_p_n]])\n",
    "    AUROC_results.append([baseline_evaluation_results[test_p_n][test_p_w]['AUROC'] for test_p_w in baseline_evaluation_results[test_p_n]])\n",
    "FPR_df = pd.DataFrame(data=np.array(FPR_results),index=p_n_list, columns=p_w_list)\n",
    "AUROC_df = pd.DataFrame(data=np.array(AUROC_results),index=p_n_list, columns=p_w_list)\n",
    "print('FPR:')\n",
    "ip_display(FPR_df)\n",
    "print('AUROC:')\n",
    "ip_display(AUROC_df)\n",
    "print('Best result:')\n",
    "print('The average performance: FPR95:', baseline_evaluation_results[baseline_best_p_n][baseline_best_p_w]['FPR'],\n",
    "      \"AUROC:\", baseline_evaluation_results[baseline_best_p_n][baseline_best_p_w]['AUROC'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ccc50d1e",
   "metadata": {},
   "source": [
    "### *Ablation study*"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db67832b",
   "metadata": {},
   "source": [
    "#### *With P-Shapley*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6848d99e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the performance evaluation result\n",
    "p_shap_best_p_n, p_shap_best_p_w, p_shap_evaluation_results = evaluate_best_performance_with_single_guidance_imagenet(test_DICE_scores,\n",
    "                                                                                                                        novelty_DICE_scores,\n",
    "                                                                                                                        feat_test_knn_factors,\n",
    "                                                                                                                        feat_novelty_knn_factors,\n",
    "                                                                                                                        p_n_list, p_w_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "892e73f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display the results\n",
    "FPR_curated_results = []\n",
    "AUROC_curated_results = []\n",
    "for test_p_n in p_shap_evaluation_results:\n",
    "    FPR_curated_results.append([p_shap_evaluation_results[test_p_n][test_p_w]['FPR_curated'] for test_p_w in p_shap_evaluation_results[test_p_n]])\n",
    "    AUROC_curated_results.append([p_shap_evaluation_results[test_p_n][test_p_w]['AUROC_curated'] for test_p_w in p_shap_evaluation_results[test_p_n]])\n",
    "FPR_curated_df = pd.DataFrame(data=np.array(FPR_curated_results),index=p_n_list, columns=p_w_list)\n",
    "AUROC_curated_df = pd.DataFrame(data=np.array(AUROC_curated_results),index=p_n_list, columns=p_w_list)\n",
    "print('FPR_curated:')\n",
    "ip_display(FPR_curated_df)\n",
    "print('AUROC_curated:')\n",
    "ip_display(AUROC_curated_df)\n",
    "print('Best result:')\n",
    "print('The average performance: FPR95 (curated):', p_shap_evaluation_results[p_shap_best_p_n][p_shap_best_p_w]['FPR_curated'],\n",
    "      \"AUROC (curated):\", p_shap_evaluation_results[p_shap_best_p_n][p_shap_best_p_w]['AUROC_curated'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d269cf67",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display the results\n",
    "FPR_results = []\n",
    "AUROC_results = []\n",
    "for test_p_n in p_shap_evaluation_results:\n",
    "    FPR_results.append([p_shap_evaluation_results[test_p_n][test_p_w]['FPR'] for test_p_w in p_shap_evaluation_results[test_p_n]])\n",
    "    AUROC_results.append([p_shap_evaluation_results[test_p_n][test_p_w]['AUROC'] for test_p_w in p_shap_evaluation_results[test_p_n]])\n",
    "FPR_df = pd.DataFrame(data=np.array(FPR_results),index=p_n_list, columns=p_w_list)\n",
    "AUROC_df = pd.DataFrame(data=np.array(AUROC_results),index=p_n_list, columns=p_w_list)\n",
    "print('FPR:')\n",
    "ip_display(FPR_df)\n",
    "print('AUROC:')\n",
    "ip_display(AUROC_df)\n",
    "print('Best result:')\n",
    "print('The average performance: FPR95:', p_shap_evaluation_results[p_shap_best_p_n][p_shap_best_p_w]['FPR'],\n",
    "      \"AUROC:\", p_shap_evaluation_results[p_shap_best_p_n][p_shap_best_p_w]['AUROC'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad83aedb",
   "metadata": {},
   "source": [
    "#### *With logit guidance*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0d37e73",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluates the quality\n",
    "best_base_dg_avg_FPR = 1\n",
    "best_base_dg_LOG_p_n = None\n",
    "best_base_dg_LOG_p_w = None\n",
    "best_base_dg_p_n = None\n",
    "best_base_dg_p_w = None\n",
    "best_base_dg_evaluation_results = None\n",
    "final_base_dg_evaluation_results = {}\n",
    "# Take the guidance factors\n",
    "feat_test_knn_factors = test_knn_factors\n",
    "feat_novelty_knn_factors = novelty_knn_factors\n",
    "for LOG_p_n in tqdm(LOG_p_n_list, desc='Processed logit guidance pruning rates'):\n",
    "    # Initialize the values\n",
    "    final_base_dg_evaluation_results[LOG_p_n] = {}\n",
    "    for LOG_p_w in LOG_p_w_list:\n",
    "        # Take the guidance factors\n",
    "        LOG_test_knn_factors = test_LOG_factors[LOG_p_n][LOG_p_w]\n",
    "        LOG_novelty_knn_factors = novelty_LOG_factors[LOG_p_n][LOG_p_w]\n",
    "        # Get the performance evaluation result\n",
    "        current_best_p_n, current_best_p_w, current_evaluation_results = evaluate_best_performance_with_double_guidance_imagenet(test_baseline_scores,\n",
    "                                                                                                                                novelty_baseline_scores,\n",
    "                                                                                                                                feat_test_knn_factors,\n",
    "                                                                                                                                LOG_test_knn_factors,\n",
    "                                                                                                                                feat_novelty_knn_factors,\n",
    "                                                                                                                                LOG_novelty_knn_factors, \n",
    "                                                                                                                                p_n_list, p_w_list)\n",
    "        # Register the result\n",
    "        final_base_dg_evaluation_results[LOG_p_n][LOG_p_w] = {}\n",
    "        final_base_dg_evaluation_results[LOG_p_n][LOG_p_w]['Detail'] = current_evaluation_results\n",
    "        final_base_dg_evaluation_results[LOG_p_n][LOG_p_w]['Best_P_N'] = current_best_p_n\n",
    "        final_base_dg_evaluation_results[LOG_p_n][LOG_p_w]['Best_P_W'] = current_best_p_w\n",
    "        final_base_dg_evaluation_results[LOG_p_n][LOG_p_w]['Best_FPR'] = current_evaluation_results[current_best_p_n][current_best_p_w]['FPR']\n",
    "        final_base_dg_evaluation_results[LOG_p_n][LOG_p_w]['Best_AUROC'] = current_evaluation_results[current_best_p_n][current_best_p_w]['AUROC']\n",
    "        final_base_dg_evaluation_results[LOG_p_n][LOG_p_w]['Best_FPR_curated'] = current_evaluation_results[current_best_p_n][current_best_p_w]['FPR_curated']\n",
    "        final_base_dg_evaluation_results[LOG_p_n][LOG_p_w]['Best_AUROC_curated'] = current_evaluation_results[current_best_p_n][current_best_p_w]['AUROC_curated']\n",
    "        # Update the best evaluation result\n",
    "        if current_evaluation_results[current_best_p_n][current_best_p_w]['FPR_curated'] < best_base_dg_avg_FPR:\n",
    "            best_base_dg_avg_FPR = current_evaluation_results[current_best_p_n][current_best_p_w]['FPR_curated']\n",
    "            best_base_dg_LOG_p_n = LOG_p_n\n",
    "            best_base_dg_LOG_p_w = LOG_p_w\n",
    "            best_base_dg_p_n = current_best_p_n\n",
    "            best_base_dg_p_w = current_best_p_w\n",
    "            best_base_dg_evaluation_results = current_evaluation_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b0e0eab",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display the results\n",
    "FPR_curated_results = []\n",
    "AUROC_curated_results = []\n",
    "for test_p_n in best_base_dg_evaluation_results:\n",
    "    FPR_curated_results.append([best_base_dg_evaluation_results[test_p_n][test_p_w]['FPR_curated'] for test_p_w in best_base_dg_evaluation_results[test_p_n]])\n",
    "    AUROC_curated_results.append([best_base_dg_evaluation_results[test_p_n][test_p_w]['AUROC_curated'] for test_p_w in best_base_dg_evaluation_results[test_p_n]])\n",
    "FPR_curated_df = pd.DataFrame(data=np.array(FPR_curated_results),index=p_n_list, columns=p_w_list)\n",
    "AUROC_curated_df = pd.DataFrame(data=np.array(AUROC_curated_results),index=p_n_list, columns=p_w_list)\n",
    "print('FPR_curated:')\n",
    "ip_display(FPR_curated_df)\n",
    "print('AUROC_curated:')\n",
    "ip_display(AUROC_curated_df)\n",
    "print('Best result:')\n",
    "print('The average performance: FPR95 (curated):', best_base_dg_evaluation_results[best_base_dg_p_n][best_base_dg_p_w]['FPR_curated'],\n",
    "      \"AUROC (curated):\", best_base_dg_evaluation_results[best_base_dg_p_n][best_base_dg_p_w]['AUROC_curated'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "539ee703",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display the results\n",
    "FPR_results = []\n",
    "AUROC_results = []\n",
    "for test_p_n in best_base_dg_evaluation_results:\n",
    "    FPR_results.append([best_base_dg_evaluation_results[test_p_n][test_p_w]['FPR'] for test_p_w in best_base_dg_evaluation_results[test_p_n]])\n",
    "    AUROC_results.append([best_base_dg_evaluation_results[test_p_n][test_p_w]['AUROC'] for test_p_w in best_base_dg_evaluation_results[test_p_n]])\n",
    "FPR_df = pd.DataFrame(data=np.array(FPR_results),index=p_n_list, columns=p_w_list)\n",
    "AUROC_df = pd.DataFrame(data=np.array(AUROC_results),index=p_n_list, columns=p_w_list)\n",
    "print('FPR:')\n",
    "ip_display(FPR_df)\n",
    "print('AUROC:')\n",
    "ip_display(AUROC_df)\n",
    "print('Best result:')\n",
    "print('The average performance: FPR95:', best_base_dg_evaluation_results[best_base_dg_p_n][best_base_dg_p_w]['FPR'],\n",
    "      \"AUROC:\", best_base_dg_evaluation_results[best_base_dg_p_n][best_base_dg_p_w]['AUROC'])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
