{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7e88ded8",
   "metadata": {},
   "source": [
    "### *Module Loading*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0edbfb6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "from subprocess import PIPE, run\n",
    "from IPython.display import display as ip_display"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ccbc82f",
   "metadata": {},
   "source": [
    "### *External Module Loading*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcf4ad4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "external_modules_path = '..\\\\nn_likelihood_modules'\n",
    "sys.path.append(external_modules_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe63d714",
   "metadata": {},
   "outputs": [],
   "source": [
    "from common_imports import *\n",
    "from common_use_functions import *\n",
    "from constant import *\n",
    "from experim_neural_network import *\n",
    "from experim_preparation import *\n",
    "from generate_activation_level import *\n",
    "from pytorch_model_predict import *\n",
    "from cifar_10_data_prep import *\n",
    "from sensitivity_analysis import *\n",
    "from deep_KNN import *\n",
    "from novelty_data_prep import *\n",
    "from activation_level_processing import *\n",
    "from CoNNGuide_sensitivity_indices import *\n",
    "from multscore_utils import *\n",
    "from densenet import *\n",
    "from CoNNGuide_OOD_datasets import *\n",
    "from OOD_score_utils import *\n",
    "from DICE_OOD_score import *\n",
    "from knn_search_GPU import *\n",
    "from sota_ood_scores import *\n",
    "from pytorch_training_preparation import *"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b8dc22ef",
   "metadata": {},
   "source": [
    "### *GPU verification*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "930c998b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the GPU\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "nb_gpu = torch.cuda.device_count()\n",
    "if nb_gpu > 0:\n",
    "    print(torch.cuda.get_device_name(0))\n",
    "else:\n",
    "    print(\"CPU\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0701537f",
   "metadata": {},
   "source": [
    "### *Working directory*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89c6568b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Current path\n",
    "current_path = os.path.abspath(os.getcwd())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f5c8167",
   "metadata": {},
   "source": [
    "### *Load configurations and data*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cd2037f",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "All the parameters in this part should be configured\n",
    "\"\"\"\n",
    "# Experience path\n",
    "experim_path = current_path\n",
    "\n",
    "# File extensions\n",
    "json_ext = '.json'\n",
    "np_ext = '.npy'\n",
    "csv_ext = '.csv'\n",
    "\n",
    "# Network model prefix\n",
    "model_name_prefix = 'cifar10'\n",
    "\n",
    "# Image max pixel value\n",
    "image_max_pix_val = 255\n",
    "\n",
    "# Tested sets name\n",
    "train_set_name = 'train'\n",
    "test_set_name = 'test'\n",
    "valid_set_name = 'valid'\n",
    "input_extension = 'X'\n",
    "label_extension = 'Y'\n",
    "\n",
    "# Save paths\n",
    "model_save_path = path_join(experim_path, 'experim_models_resnet_paper')\n",
    "\n",
    "\"\"\"\n",
    "The following parameters should be configured according to your experiments\n",
    "\"\"\"\n",
    "\n",
    "# Trained model name \n",
    "trained_net_name = 'cifar10_densenet_pretrained' # You can select any model from the \"experim_models_resnet\" folder\n",
    "\n",
    "# Network related params\n",
    "net_model_name = 'densenet'# The model name should be coherent with your chosen model\n",
    "\n",
    "# Dataset general informations\n",
    "data_set_infos = {\n",
    "    'nb_classes' : 10\n",
    "}\n",
    "\n",
    "# The maximum number of considered k-nearst neighbors\n",
    "k_max = 1000\n",
    "\n",
    "# The batch size of the knn search\n",
    "knn_batch_size = 50\n",
    "\n",
    "\"\"\"\n",
    "\"\"\"\n",
    "\n",
    "# The output folder\n",
    "output_path = path_join(experim_path, 'output_CoNNGuide_NNGuide++_CIFAR10')\n",
    "\n",
    "# Build the class list\n",
    "class_list = list(range(data_set_infos['nb_classes']))\n",
    "\n",
    "# Batch size for the dataloader creation\n",
    "torch_batch_size = 128"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab83bc84",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the folders\n",
    "create_directory(output_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a1ae1fb",
   "metadata": {},
   "source": [
    "### *Experiment preparation*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55636cc8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the dataset\n",
    "cifar10_train_dataset, cifar10_test_dataset = get_cifar10_dataset_with_only_normalization()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b9931900",
   "metadata": {},
   "source": [
    "### *Load the trained Network*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "132f3376",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the network\n",
    "trained_net = load_model_by_net_name(model_save_path, trained_net_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "959a79b0",
   "metadata": {},
   "source": [
    "### *Move the model to GPU*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1a9c01d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Move to gpu\n",
    "trained_net.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d849492d",
   "metadata": {},
   "source": [
    "### *Cifar10 dataset preparation*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "493c5931",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dataloader building\n",
    "train_loader = create_loader_from_torch_dataset(cifar10_train_dataset, batch_size=torch_batch_size, shuffle=False, num_workers=0)\n",
    "test_loader = create_loader_from_torch_dataset(cifar10_test_dataset, batch_size=torch_batch_size, shuffle=False, num_workers=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ed01114",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert the training set to numpy array\n",
    "no_divide_into_batch_train_loader = create_loader_from_torch_dataset(cifar10_train_dataset, batch_size=len(cifar10_train_dataset), shuffle=False, num_workers=0)\n",
    "X_train = next(iter(no_divide_into_batch_train_loader))[0].numpy()\n",
    "y_train = next(iter(no_divide_into_batch_train_loader))[1].numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abbab086",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert the test set to numpy array\n",
    "no_divide_into_batch_test_loader = create_loader_from_torch_dataset(cifar10_test_dataset, batch_size=len(cifar10_test_dataset), shuffle=False, num_workers=0)\n",
    "X_test = next(iter(no_divide_into_batch_test_loader))[0].numpy()\n",
    "y_test = next(iter(no_divide_into_batch_test_loader))[1].numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b13de364",
   "metadata": {},
   "source": [
    "### *Evaluate the activation levels for the original training and test sets of CIFAR-10*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0de364c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the training set activation levels \n",
    "train_actLevels = obtain_activation_levels(trained_net,\n",
    "                                           train_loader, 'train', with_predict_class=True, loss_type='cross_entropy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f924888a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the test set activation levels \n",
    "test_actLevels = obtain_activation_levels(trained_net,\n",
    "                                           test_loader, 'test', with_predict_class=True, loss_type='cross_entropy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aaed65f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the probabilities\n",
    "train_actLevels = probability_evaluation(train_actLevels)\n",
    "test_actLevels = probability_evaluation(test_actLevels)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88a0ad10",
   "metadata": {},
   "source": [
    "### *Load the novelty OOD dataset and evaluate the activation levels*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9400c9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the svhn dataset\n",
    "svhn_test_dataset = get_CIFAR_ood_datasets(set_name='svhn', normalize=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "124e4580",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dataloader building\n",
    "svhn_test_loader = create_loader_from_torch_dataset(svhn_test_dataset, batch_size=torch_batch_size, shuffle=False, num_workers=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40cb32ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the dtd dataset\n",
    "dtd_test_dataset = get_CIFAR_ood_datasets(set_name='dtd', normalize=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acd92743",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert the test set to numpy array\n",
    "no_divide_into_batch_dtd_test_loader = create_loader_from_torch_dataset(dtd_test_dataset, batch_size=len(dtd_test_dataset), shuffle=False, num_workers=2)\n",
    "X_test_dtd = next(iter(no_divide_into_batch_dtd_test_loader))[0].numpy()\n",
    "y_test_dtd = next(iter(no_divide_into_batch_dtd_test_loader))[1].numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee06996c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build the dtd loaders\n",
    "dtd_test_loader = create_dataloader(X_test_dtd, np.random.randint(0, data_set_infos['nb_classes'], y_test_dtd.shape[0]), \n",
    "                                     batch_size=torch_batch_size, shuffle=False, type_conversion=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21293730",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the places365 dataset\n",
    "places_test_dataset = get_CIFAR_ood_datasets(set_name='places', normalize=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5fe04ceb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build the places365 loaders (using random original labels (because they are not important))\n",
    "places_test_loader = create_loader_from_torch_dataset(places_test_dataset, batch_size=torch_batch_size, shuffle=False, num_workers=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6865bee5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build the dictionary that contains all the OOD dataset loaders\n",
    "novelty_loaders = {}\n",
    "novelty_loaders['svhn'] = svhn_test_loader\n",
    "novelty_loaders['dtd'] = dtd_test_loader\n",
    "novelty_loaders['places'] = places_test_loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d37841a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the other ood loaders\n",
    "other_ood_types = ['LSUN', 'LSUN_resize', 'iSUN']\n",
    "for ood_type in other_ood_types:\n",
    "    current_ood_datasset = get_CIFAR_ood_datasets(set_name=ood_type, normalize=True)\n",
    "    novelty_loaders[ood_type] = create_loader_from_torch_dataset(current_ood_datasset, batch_size=torch_batch_size, shuffle=False, num_workers=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d300857",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Iterate over the OOD datasets for generating the normalized feature vectors\n",
    "novelty_actLevels = {}\n",
    "for ood_type in novelty_loaders:\n",
    "    novelty_actLevels[ood_type] = obtain_activation_levels(trained_net,\n",
    "                                                           novelty_loaders[ood_type], ood_type + ' test',\n",
    "                                                           with_predict_class=True, loss_type='cross_entropy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f57e977",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the probabilities\n",
    "for ood_type in novelty_actLevels:\n",
    "    novelty_actLevels[ood_type] = probability_evaluation(novelty_actLevels[ood_type])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe065212",
   "metadata": {},
   "source": [
    "### *Identify the activation threshold*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83c1523c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the last hidden layer Id\n",
    "last_hidden_layerId = list(train_actLevels['actLevel'].keys())[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71faeb67",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Determine the activation threshold for the ReAct (May use the correctly predicted examples for more precise evaluation)\n",
    "act_threshold = np.percentile(train_actLevels['actLevel'][last_hidden_layerId], 96)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "69177f60",
   "metadata": {},
   "source": [
    "### *Get the correctly predicted examples*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a848e3bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the correctly predicted training set activation levels\n",
    "correct_train_actLevels = build_correct_actLevels(train_actLevels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef4bf642",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Free the memory for the original training set activation levels\n",
    "del train_actLevels\n",
    "_ = gc.collect()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d25b294b-b53f-48e5-9440-041c46fb561c",
   "metadata": {},
   "source": [
    "### *Get the parameters of the model*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15ba00b4-ccd6-4341-9fb1-6671f95a8b21",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build the correspondance dictionary for the examples of each class\n",
    "class_index_dict = class_index_dict_build(correct_train_actLevels['class'].reshape(-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "384d1462-0fc4-49ea-a82f-6a4f21656dab",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the last layer parameters\n",
    "model_params = get_model_parameters(trained_net, to_numpy=True)\n",
    "final_linear_params = model_params['fc']\n",
    "nb_vars = final_linear_params['weight'].shape[1]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "317cdf0b",
   "metadata": {},
   "source": [
    "### *Evaluate the knn correction factors with ablation study*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95ab4d35",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Register the evaluated knn correction factors\n",
    "# Build the training set normalized feature vectors (only the correctly predicted examples)\n",
    "correct_train_zs = normalize_feature_vecs_knn(correct_train_actLevels, last_hidden_layerId)\n",
    "# Build the test set normalized feature vectors\n",
    "test_zs = normalize_feature_vecs_knn(test_actLevels, last_hidden_layerId)\n",
    "# Get the normalized feature vectors of the novelty ood set\n",
    "novelty_zs = {}\n",
    "for ood_type in novelty_actLevels:\n",
    "    novelty_zs[ood_type] = normalize_feature_vecs_knn(novelty_actLevels[ood_type], last_hidden_layerId)\n",
    "# Calculate or load the knn scores of the correctly predicted training set examples set \n",
    "knn_k = None\n",
    "# Determine the correct k for the evaluation\n",
    "correct_train_knn_dists, _ = knn_search_IP_GPU(correct_train_zs, correct_train_zs,\n",
    "                                               batch_size=knn_batch_size, k=k_max,\n",
    "                                               display=False, half_precision=False)\n",
    "# Estimate the k for the correction factors and example confidence score     \n",
    "knn_k = estimate_dense_k(correct_train_knn_dists, verify_steps=3, variation_threshold=0.05, min_k=5, smooth_sigma=0)\n",
    "# Calculate or load the test set knn scores\n",
    "test_knn_dists, _ = knn_search_IP_GPU(correct_train_zs, test_zs,\n",
    "                                                     batch_size=knn_batch_size, k=knn_k, display=False,\n",
    "                                                     half_precision=False)\n",
    "test_knn_factors = np.mean(test_knn_dists[:, :knn_k], axis=1)\n",
    "# Calculate or load the knn scores for the novelty sets\n",
    "novelty_knn_factors = {}\n",
    "for ood_type in novelty_zs:\n",
    "    current_novelty_knn_dists, _ = knn_search_IP_GPU(correct_train_zs, novelty_zs[ood_type],\n",
    "                                                       batch_size=knn_batch_size, k=knn_k, display=False,\n",
    "                                                       half_precision=False)\n",
    "    novelty_knn_factors[ood_type] = np.mean(current_novelty_knn_dists[:, :knn_k], axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c7e39c2",
   "metadata": {},
   "source": [
    "### *Evaluate the global importance of the neurons*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0d4661d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Evaluate the overall neuron importance\n",
    "entropy_contrib_vec = unified_entropy_score_evaluation_GPU(correct_train_actLevels, final_linear_params, last_hidden_layerId, batch_size=1000, block_size=32)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15e62f8e",
   "metadata": {},
   "source": [
    "### *Clip the activations*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff6521e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Execute the activation clip the probabilities\n",
    "test_actLevels = clip_activations(test_actLevels, last_hidden_layerId, act_threshold)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c5efd5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Execute the activation clip\n",
    "correct_train_actLevels = clip_activations(correct_train_actLevels, last_hidden_layerId, act_threshold)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2cc325f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Clip the activations\n",
    "for ood_type in novelty_actLevels:\n",
    "    novelty_actLevels[ood_type] = clip_activations(novelty_actLevels[ood_type], last_hidden_layerId, act_threshold)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60467384",
   "metadata": {},
   "source": [
    "### *Evaluate the sensitivity indices*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8b52bbd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the shapley matrix\n",
    "shapley_contrib_matrix = shapley_score_evaluation_GPU(correct_train_actLevels, final_linear_params, last_hidden_layerId, predict_class_ver=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63f0a84b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the shapley matrix\n",
    "prob_shapley_contrib_matrix = prob_shapley_score_evaluation_GPU(correct_train_actLevels, final_linear_params, last_hidden_layerId, predict_class_ver=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6f1916a4-5ef6-4a24-9308-82af3da180d8",
   "metadata": {},
   "source": [
    "### *Evaluate the LOG correction factors*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "528228da",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The potential pruning rate list\n",
    "LOG_p_n_list = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90]\n",
    "LOG_p_w_list = [0, 1, 5, 10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3af17f98-5c9e-4722-b33b-c72743c80cb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the correct activation levels for the logit evaluation\n",
    "LOG_correct_train_vecs = correct_train_actLevels['actLevel'][last_hidden_layerId]\n",
    "LOG_test_vecs = test_actLevels['actLevel'][last_hidden_layerId]\n",
    "LOG_novelty_vecs = {}\n",
    "for ood_type in novelty_actLevels:\n",
    "    LOG_novelty_vecs[ood_type] = novelty_actLevels[ood_type]['actLevel'][last_hidden_layerId]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5109ab7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Register the evaluated LOG correction factors\n",
    "test_LOG_factors = {}\n",
    "novelty_LOG_factors = {}\n",
    "for test_LOG_p_n in LOG_p_n_list:\n",
    "    # Initialize the values\n",
    "    test_LOG_factors[test_LOG_p_n] = {}\n",
    "    novelty_LOG_factors[test_LOG_p_n] = {}\n",
    "    for test_LOG_p_w in tqdm(LOG_p_w_list, desc='Processed pruning rates'):\n",
    "        # Initialize the values\n",
    "        novelty_LOG_factors[test_LOG_p_n][test_LOG_p_w] = {}\n",
    "        # Get the weight pruning mask\n",
    "        LOG_feat_prun_mask = pruning_mask(entropy_contrib_vec, p=test_LOG_p_n)\n",
    "        LOG_important_neurons = [index for index, value in enumerate(LOG_feat_prun_mask) if value == 1]\n",
    "        LOG_weight_contrib = DICE_weight_contribution_matrix_sig_and_torch_ver(final_linear_params, LOG_correct_train_vecs,\n",
    "                                                                               LOG_important_neurons, batch_size=1000)\n",
    "        LOG_weight_prun_mask = build_weight_pruning_masks_sig_ver(final_linear_params['weight'].shape,\n",
    "                                                                LOG_weight_contrib, LOG_important_neurons, p=test_LOG_p_w)\n",
    "        # Get the logits\n",
    "        LOG_correct_train_logits = get_logits_torch_ver(LOG_correct_train_vecs, final_linear_params,\n",
    "                                                        LOG_feat_prun_mask, LOG_weight_prun_mask, batch_size=500, display=False)\n",
    "        LOG_test_logits = get_logits_torch_ver(LOG_test_vecs, final_linear_params,\n",
    "                                               LOG_feat_prun_mask, LOG_weight_prun_mask, batch_size=500, display=False)\n",
    "        LOG_novelty_logits = {}\n",
    "        for ood_type in LOG_novelty_vecs:\n",
    "            LOG_novelty_logits[ood_type] = get_logits_torch_ver(LOG_novelty_vecs[ood_type], final_linear_params,\n",
    "                                                                LOG_feat_prun_mask, LOG_weight_prun_mask, batch_size=500, display=False)\n",
    "        # Evaluate the normalized logits\n",
    "        LOG_correct_train_norm_logits = LOG_correct_train_logits / np.linalg.norm(LOG_correct_train_logits, axis=1, keepdims=True)\n",
    "        LOG_test_norm_logits = LOG_test_logits / np.linalg.norm(LOG_test_logits, axis=1, keepdims=True)\n",
    "        LOG_novelty_norm_logits = {}\n",
    "        for ood_type in LOG_novelty_logits:\n",
    "            # Compute the scores\n",
    "            LOG_novelty_norm_logits[ood_type] = LOG_novelty_logits[ood_type] / np.linalg.norm(LOG_novelty_logits[ood_type],\n",
    "                                                                                              axis=1, keepdims=True)\n",
    "        # Free the memory\n",
    "        del LOG_correct_train_logits, LOG_test_logits, LOG_novelty_logits\n",
    "        _ = gc.collect()\n",
    "        ### Evaluate the LOG correction factors\n",
    "        ## Determine the correct k for the evaluation\n",
    "        # Training set        \n",
    "        correct_train_LOG_dists,_ = knn_search_IP_GPU(LOG_correct_train_norm_logits, LOG_correct_train_norm_logits, batch_size=knn_batch_size,\n",
    "                                                       k=k_max, display=False, half_precision=False)\n",
    "        LOG_k = estimate_dense_k(correct_train_LOG_dists, verify_steps=3, variation_threshold=0.1, min_k=5, smooth_sigma=0)\n",
    "        # Test set\n",
    "        test_LOG_dists, _ = knn_search_IP_GPU(LOG_correct_train_norm_logits, LOG_test_norm_logits, batch_size=knn_batch_size,\n",
    "                                                       k=LOG_k, display=False, half_precision=False)\n",
    "        test_LOG_factors[test_LOG_p_n][test_LOG_p_w] = np.mean(test_LOG_dists[:, :LOG_k], axis=1)\n",
    "        # Novelty sets\n",
    "        for ood_type in LOG_novelty_norm_logits:\n",
    "            # Compute the scores\n",
    "            current_novelty_LOG_dists, _ = knn_search_IP_GPU(LOG_correct_train_norm_logits,  LOG_novelty_norm_logits[ood_type],\n",
    "                                                               batch_size=knn_batch_size,\n",
    "                                                               k=LOG_k, display=False, half_precision=False)\n",
    "            novelty_LOG_factors[test_LOG_p_n][test_LOG_p_w][ood_type] = np.mean(current_novelty_LOG_dists[:, :LOG_k], axis=1)\n",
    "        # Free the memory\n",
    "        del LOG_correct_train_norm_logits, LOG_test_norm_logits, LOG_novelty_norm_logits\n",
    "        _ = gc.collect()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43a5f97c",
   "metadata": {},
   "source": [
    "### *Prepare the feature vectors with all neurons*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64f7570c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the correctly predicted training set feature vectors\n",
    "correct_train_vecs = correct_train_actLevels['actLevel'][last_hidden_layerId]\n",
    "correct_train_preds = correct_train_actLevels['predict_class'].reshape(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44f3092a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build the test set feature vectors\n",
    "test_vecs = test_actLevels['actLevel'][last_hidden_layerId]\n",
    "test_preds = test_actLevels['predict_class'].reshape(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0dcd8e12",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the feature vectors of the novelty ood set\n",
    "novelty_vecs = {}\n",
    "novelty_preds = {}\n",
    "for ood_type in novelty_actLevels:\n",
    "    novelty_vecs[ood_type] = novelty_actLevels[ood_type]['actLevel'][last_hidden_layerId]\n",
    "    novelty_preds[ood_type] = novelty_actLevels[ood_type]['predict_class'].reshape(-1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "106d4c17",
   "metadata": {},
   "source": [
    "### *Evaluate the quality for various combination of pruning rates*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69727f84",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The possible values of p_n and p_w\n",
    "p_n_list = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90]\n",
    "p_w_list = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5887dfa8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluates the original energy score with feature and weight pruning\n",
    "test_DICE_scores = {}\n",
    "novelty_DICE_scores = {}\n",
    "for test_p_n in p_n_list:\n",
    "    test_DICE_scores[test_p_n] = {}\n",
    "    novelty_DICE_scores[test_p_n] = {}\n",
    "    for test_p_w in tqdm(p_w_list, desc='Processed weight pruning rates'):\n",
    "        # Build the weight contribution matrix only for the significant neurons\n",
    "        feature_pruning_masks, weight_pruning_masks = get_sensitivity_DICE_pruning_masks_torch_ver(final_linear_params,\n",
    "                                                                                                    correct_train_vecs,\n",
    "                                                                                                    class_index_dict,\n",
    "                                                                                                    prob_shapley_contrib_matrix,\n",
    "                                                                                                    p_n=test_p_n, p_w=test_p_w,\n",
    "                                                                                                    display=False)\n",
    "        # Get the test set logits\n",
    "        test_logits = get_logits_by_preds_torch_ver(test_vecs, test_preds,\n",
    "                                                    final_linear_params,\n",
    "                                                    feature_pruning_masks,\n",
    "                                                    weight_pruning_masks,\n",
    "                                                    display=False)\n",
    "        # Get the logits for the novelty ood sets\n",
    "        novelty_logits = {}\n",
    "        for ood_type in novelty_vecs:\n",
    "            novelty_logits[ood_type]  = get_logits_by_preds_torch_ver(novelty_vecs[ood_type], \n",
    "                                                                       novelty_preds[ood_type],\n",
    "                                                                       final_linear_params,\n",
    "                                                                       feature_pruning_masks,\n",
    "                                                                       weight_pruning_masks,\n",
    "                                                                       display=False)\n",
    "        # Free the memory for the masks\n",
    "        del feature_pruning_masks, weight_pruning_masks\n",
    "        _ = gc.collect()\n",
    "            \n",
    "        ## Evaluate the DICE scores\n",
    "        # Test set         \n",
    "        test_DICE_scores[test_p_n][test_p_w] = DICE_score_evaluation_logit_ver(test_logits)\n",
    "        # Novelty sets\n",
    "        novelty_DICE_scores[test_p_n][test_p_w] = {}\n",
    "        for ood_type in novelty_logits:\n",
    "            # Compute the scores\n",
    "            novelty_DICE_scores[test_p_n][test_p_w][ood_type] = DICE_score_evaluation_logit_ver(novelty_logits[ood_type])\n",
    "            \n",
    "        # Free the memory\n",
    "        del test_logits, novelty_logits\n",
    "        _ = gc.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74c3559f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Evaluates the quality\n",
    "best_avg_FPR = 1\n",
    "best_LOG_p_n = None\n",
    "best_LOG_p_w = None\n",
    "best_p_n = None\n",
    "best_p_w = None\n",
    "best_evaluation_results = None\n",
    "final_evaluation_results = {}\n",
    "# Take the guidance factors\n",
    "feat_test_knn_factors = test_knn_factors\n",
    "feat_novelty_knn_factors = novelty_knn_factors\n",
    "for LOG_p_n in tqdm(LOG_p_n_list, desc='Processed logit guidance pruning rates'):\n",
    "    # Initialize the values\n",
    "    final_evaluation_results[LOG_p_n] = {}\n",
    "    for LOG_p_w in LOG_p_w_list:\n",
    "        # Take the guidance factors\n",
    "        LOG_test_knn_factors = test_LOG_factors[LOG_p_n][LOG_p_w]\n",
    "        LOG_novelty_knn_factors = novelty_LOG_factors[LOG_p_n][LOG_p_w]\n",
    "        # Get the performance evaluation result\n",
    "        current_best_p_n, current_best_p_w, current_evaluation_results = evaluate_best_performance_with_double_guidance_cifar(test_DICE_scores,\n",
    "                                                                                                    novelty_DICE_scores,\n",
    "                                                                                                    feat_test_knn_factors,\n",
    "                                                                                                    LOG_test_knn_factors,\n",
    "                                                                                                    feat_novelty_knn_factors,\n",
    "                                                                                                    LOG_novelty_knn_factors, \n",
    "                                                                                                    p_n_list, p_w_list)\n",
    "        # Register the result\n",
    "        final_evaluation_results[LOG_p_n][LOG_p_w] = {}\n",
    "        final_evaluation_results[LOG_p_n][LOG_p_w]['Detail'] = current_evaluation_results\n",
    "        final_evaluation_results[LOG_p_n][LOG_p_w]['Best_P_N'] = current_best_p_n\n",
    "        final_evaluation_results[LOG_p_n][LOG_p_w]['Best_P_W'] = current_best_p_w\n",
    "        final_evaluation_results[LOG_p_n][LOG_p_w]['Best_FPR'] = current_evaluation_results[current_best_p_n][current_best_p_n]['FPR']\n",
    "        final_evaluation_results[LOG_p_n][LOG_p_w]['Best_AUROC'] = current_evaluation_results[current_best_p_n][current_best_p_n]['AUROC']\n",
    "        # Update the best evaluation result\n",
    "        if current_evaluation_results[current_best_p_n][current_best_p_w]['FPR'] < best_avg_FPR:\n",
    "            best_avg_FPR = current_evaluation_results[current_best_p_n][current_best_p_w]['FPR']\n",
    "            best_LOG_p_n = LOG_p_n\n",
    "            best_LOG_p_w = LOG_p_w\n",
    "            best_p_n = current_best_p_n\n",
    "            best_p_w = current_best_p_w\n",
    "            best_evaluation_results = current_evaluation_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "645644f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display the results\n",
    "FPR_results = []\n",
    "AUROC_results = []\n",
    "for test_p_n in best_evaluation_results:\n",
    "    FPR_results.append([best_evaluation_results[test_p_n][test_p_w]['FPR'] for test_p_w in best_evaluation_results[test_p_n]])\n",
    "    AUROC_results.append([best_evaluation_results[test_p_n][test_p_w]['AUROC'] for test_p_w in best_evaluation_results[test_p_n]])\n",
    "FPR_df = pd.DataFrame(data=np.array(FPR_results),index=p_n_list, columns=p_w_list)\n",
    "AUROC_df = pd.DataFrame(data=np.array(AUROC_results),index=p_n_list, columns=p_w_list)\n",
    "print('FPR:')\n",
    "ip_display(FPR_df)\n",
    "print('AUROC:')\n",
    "ip_display(AUROC_df)\n",
    "print('Best result:')\n",
    "print('The average performance: FPR95:', best_evaluation_results[best_p_n][best_p_w]['FPR'],\n",
    "      \"AUROC:\", best_evaluation_results[best_p_n][best_p_w]['AUROC'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7f79a4e0",
   "metadata": {},
   "source": [
    "### *Evaluate the quality of the strong baseline*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e651d0c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluates the original energy score with feature and weight pruning\n",
    "test_baseline_scores = {}\n",
    "novelty_baseline_scores = {}\n",
    "for test_p_n in p_n_list:\n",
    "    test_baseline_scores[test_p_n] = {}\n",
    "    novelty_baseline_scores[test_p_n] = {}\n",
    "    for test_p_w in tqdm(p_w_list, desc='Processed weight pruning rates'):\n",
    "        # Build the weight contribution matrix only for the significant neurons\n",
    "        feature_pruning_masks, weight_pruning_masks = get_sensitivity_DICE_pruning_masks_torch_ver(final_linear_params,\n",
    "                                                                                                    correct_train_vecs,\n",
    "                                                                                                    class_index_dict,\n",
    "                                                                                                    shapley_contrib_matrix,\n",
    "                                                                                                    p_n=test_p_n, p_w=test_p_w,\n",
    "                                                                                                    display=False)\n",
    "        # Get the test set logits\n",
    "        test_logits = get_logits_by_preds_torch_ver(test_vecs, test_preds,\n",
    "                                                    final_linear_params,\n",
    "                                                    feature_pruning_masks,\n",
    "                                                    weight_pruning_masks,\n",
    "                                                    display=False)\n",
    "        # Get the logits for the novelty ood sets\n",
    "        novelty_logits = {}\n",
    "        for ood_type in novelty_vecs:\n",
    "            novelty_logits[ood_type]  = get_logits_by_preds_torch_ver(novelty_vecs[ood_type], \n",
    "                                                                       novelty_preds[ood_type],\n",
    "                                                                       final_linear_params,\n",
    "                                                                       feature_pruning_masks,\n",
    "                                                                       weight_pruning_masks,\n",
    "                                                                       display=False)\n",
    "        # Free the memory for the masks\n",
    "        del feature_pruning_masks, weight_pruning_masks\n",
    "        _ = gc.collect()\n",
    "            \n",
    "        ## Evaluate the DICE scores\n",
    "        # Test set         \n",
    "        test_baseline_scores[test_p_n][test_p_w] = DICE_score_evaluation_logit_ver(test_logits)\n",
    "        # Novelty sets\n",
    "        novelty_baseline_scores[test_p_n][test_p_w] = {}\n",
    "        for ood_type in novelty_logits:\n",
    "            # Compute the scores\n",
    "            novelty_baseline_scores[test_p_n][test_p_w][ood_type] = DICE_score_evaluation_logit_ver(novelty_logits[ood_type])\n",
    "            \n",
    "        # Free the memory\n",
    "        del test_logits, novelty_logits\n",
    "        _ = gc.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3591224",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluates the quality\n",
    "baseline_best_p_n, baseline_best_p_w, baseline_evaluation_results = evaluate_best_performance_with_single_guidance_cifar(test_baseline_scores,\n",
    "                                                                                            novelty_baseline_scores,\n",
    "                                                                                            feat_test_knn_factors,\n",
    "                                                                                            feat_novelty_knn_factors,\n",
    "                                                                                            p_n_list, p_w_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03bd3682",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display the results\n",
    "FPR_results = []\n",
    "AUROC_results = []\n",
    "for test_p_n in baseline_evaluation_results:\n",
    "    FPR_results.append([baseline_evaluation_results[test_p_n][test_p_w]['FPR'] for test_p_w in baseline_evaluation_results[test_p_n]])\n",
    "    AUROC_results.append([baseline_evaluation_results[test_p_n][test_p_w]['AUROC'] for test_p_w in baseline_evaluation_results[test_p_n]])\n",
    "FPR_df = pd.DataFrame(data=np.array(FPR_results),index=p_n_list, columns=p_w_list)\n",
    "AUROC_df = pd.DataFrame(data=np.array(AUROC_results),index=p_n_list, columns=p_w_list)\n",
    "print('FPR:')\n",
    "ip_display(FPR_df)\n",
    "print('AUROC:')\n",
    "ip_display(AUROC_df)\n",
    "print('Best result:')\n",
    "print('The average performance: FPR95:', baseline_evaluation_results[baseline_best_p_n][baseline_best_p_w]['FPR'],\n",
    "      \"AUROC:\", baseline_evaluation_results[baseline_best_p_n][baseline_best_p_w]['AUROC'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dbae0613",
   "metadata": {},
   "source": [
    "### *Ablation study*"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb529ea3",
   "metadata": {},
   "source": [
    "#### *With P-Shapley*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83940980",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the performance evaluation result\n",
    "p_shap_best_p_n, p_shap_best_p_w, p_shap_evaluation_results = evaluate_best_performance_with_single_guidance_cifar(test_DICE_scores,\n",
    "                                                                                                                        novelty_DICE_scores,\n",
    "                                                                                                                        feat_test_knn_factors,\n",
    "                                                                                                                        feat_novelty_knn_factors,\n",
    "                                                                                                                        p_n_list, p_w_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c3b96c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display the results\n",
    "FPR_results = []\n",
    "AUROC_results = []\n",
    "for test_p_n in p_shap_evaluation_results:\n",
    "    FPR_results.append([p_shap_evaluation_results[test_p_n][test_p_w]['FPR'] for test_p_w in p_shap_evaluation_results[test_p_n]])\n",
    "    AUROC_results.append([p_shap_evaluation_results[test_p_n][test_p_w]['AUROC'] for test_p_w in p_shap_evaluation_results[test_p_n]])\n",
    "FPR_df = pd.DataFrame(data=np.array(FPR_results),index=p_n_list, columns=p_w_list)\n",
    "AUROC_df = pd.DataFrame(data=np.array(AUROC_results),index=p_n_list, columns=p_w_list)\n",
    "print('FPR:')\n",
    "ip_display(FPR_df)\n",
    "print('AUROC:')\n",
    "ip_display(AUROC_df)\n",
    "print('Best result:')\n",
    "print('The average performance: FPR95:', p_shap_evaluation_results[p_shap_best_p_n][p_shap_best_p_w]['FPR'],\n",
    "      \"AUROC:\", p_shap_evaluation_results[p_shap_best_p_n][p_shap_best_p_w]['AUROC'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c671978",
   "metadata": {},
   "source": [
    "#### *With logit guidance*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ea86d27",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluates the quality\n",
    "best_base_dg_avg_FPR = 1\n",
    "best_base_dg_LOG_p_n = None\n",
    "best_base_dg_LOG_p_w = None\n",
    "best_base_dg_p_n = None\n",
    "best_base_dg_p_w = None\n",
    "best_base_dg_evaluation_results = None\n",
    "final_base_dg_evaluation_results = {}\n",
    "# Take the guidance factors\n",
    "feat_test_knn_factors = test_knn_factors\n",
    "feat_novelty_knn_factors = novelty_knn_factors\n",
    "for LOG_p_n in tqdm(LOG_p_n_list, desc='Processed logit guidance pruning rates'):\n",
    "    # Initialize the values\n",
    "    final_base_dg_evaluation_results[LOG_p_n] = {}\n",
    "    for LOG_p_w in LOG_p_w_list:\n",
    "        # Take the guidance factors\n",
    "        LOG_test_knn_factors = test_LOG_factors[LOG_p_n][LOG_p_w]\n",
    "        LOG_novelty_knn_factors = novelty_LOG_factors[LOG_p_n][LOG_p_w]\n",
    "        # Get the performance evaluation result\n",
    "        current_best_p_n, current_best_p_w, current_evaluation_results = evaluate_best_performance_with_double_guidance_cifar(test_baseline_scores,\n",
    "                                                                                                                                novelty_baseline_scores,\n",
    "                                                                                                                                feat_test_knn_factors,\n",
    "                                                                                                                                LOG_test_knn_factors,\n",
    "                                                                                                                                feat_novelty_knn_factors,\n",
    "                                                                                                                                LOG_novelty_knn_factors, \n",
    "                                                                                                                                p_n_list, p_w_list)\n",
    "        # Register the result\n",
    "        final_base_dg_evaluation_results[LOG_p_n][LOG_p_w] = {}\n",
    "        final_base_dg_evaluation_results[LOG_p_n][LOG_p_w]['Detail'] = current_evaluation_results\n",
    "        final_base_dg_evaluation_results[LOG_p_n][LOG_p_w]['Best_P_N'] = current_best_p_n\n",
    "        final_base_dg_evaluation_results[LOG_p_n][LOG_p_w]['Best_P_W'] = current_best_p_w\n",
    "        final_base_dg_evaluation_results[LOG_p_n][LOG_p_w]['Best_FPR'] = current_evaluation_results[current_best_p_n][current_best_p_w]['FPR']\n",
    "        final_base_dg_evaluation_results[LOG_p_n][LOG_p_w]['Best_AUROC'] = current_evaluation_results[current_best_p_n][current_best_p_w]['AUROC']\n",
    "        # Update the best evaluation result\n",
    "        if current_evaluation_results[current_best_p_n][current_best_p_w]['FPR'] < best_base_dg_avg_FPR:\n",
    "            best_base_dg_avg_FPR = current_evaluation_results[current_best_p_n][current_best_p_w]['FPR']\n",
    "            best_base_dg_LOG_p_n = LOG_p_n\n",
    "            best_base_dg_LOG_p_w = LOG_p_w\n",
    "            best_base_dg_p_n = current_best_p_n\n",
    "            best_base_dg_p_w = current_best_p_w\n",
    "            best_base_dg_evaluation_results = current_evaluation_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f6ac6a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display the results\n",
    "FPR_results = []\n",
    "AUROC_results = []\n",
    "for test_p_n in best_base_dg_evaluation_results:\n",
    "    FPR_results.append([best_base_dg_evaluation_results[test_p_n][test_p_w]['FPR'] for test_p_w in best_base_dg_evaluation_results[test_p_n]])\n",
    "    AUROC_results.append([best_base_dg_evaluation_results[test_p_n][test_p_w]['AUROC'] for test_p_w in best_base_dg_evaluation_results[test_p_n]])\n",
    "FPR_df = pd.DataFrame(data=np.array(FPR_results),index=p_n_list, columns=p_w_list)\n",
    "AUROC_df = pd.DataFrame(data=np.array(AUROC_results),index=p_n_list, columns=p_w_list)\n",
    "print('FPR:')\n",
    "ip_display(FPR_df)\n",
    "print('AUROC:')\n",
    "ip_display(AUROC_df)\n",
    "print('Best result:')\n",
    "print('The average performance: FPR95:', best_base_dg_evaluation_results[best_base_dg_p_n][best_base_dg_p_w]['FPR'],\n",
    "      \"AUROC:\", best_base_dg_evaluation_results[best_base_dg_p_n][best_base_dg_p_w]['AUROC'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "634031b6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
