{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7e88ded8",
   "metadata": {},
   "source": [
    "### *Module Loading*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0edbfb6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import faiss\n",
    "from subprocess import PIPE, run\n",
    "from IPython.display import display as ip_display"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ccbc82f",
   "metadata": {},
   "source": [
    "### *External Module Loading*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcf4ad4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "external_modules_path = '..\\\\nn_likelihood_modules'\n",
    "sys.path.append(external_modules_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe63d714",
   "metadata": {},
   "outputs": [],
   "source": [
    "from common_imports import *\n",
    "from common_use_functions import *\n",
    "from constant import *\n",
    "from experim_neural_network import *\n",
    "from experim_preparation import *\n",
    "from generate_activation_level import *\n",
    "from pytorch_model_predict import *\n",
    "from cifar_10_data_prep import *\n",
    "from sensitivity_analysis import *\n",
    "from deep_KNN import *\n",
    "from novelty_data_prep import *\n",
    "from activation_level_processing import *\n",
    "from CoNNGuide_sensitivity_indices import *\n",
    "from LINe_OOD_score import *\n",
    "from multscore_utils import *\n",
    "from densenet import *\n",
    "from CoNNGuide_OOD_datasets import *\n",
    "from OOD_score_utils import *\n",
    "from DICE_OOD_score import *\n",
    "from knn_search_GPU import *\n",
    "from sota_ood_scores import *\n",
    "from pytorch_training_preparation import *"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b8dc22ef",
   "metadata": {},
   "source": [
    "### *GPU verification*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "930c998b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the GPU\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "nb_gpu = torch.cuda.device_count()\n",
    "if nb_gpu > 0:\n",
    "    print(torch.cuda.get_device_name(0))\n",
    "else:\n",
    "    print(\"CPU\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0701537f",
   "metadata": {},
   "source": [
    "### *Working directory*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89c6568b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Current path\n",
    "current_path = os.path.abspath(os.getcwd())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f5c8167",
   "metadata": {},
   "source": [
    "### *Load configurations and data*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cd2037f",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "All the parameters in this part should be configured\n",
    "\"\"\"\n",
    "# Experience path\n",
    "experim_path = current_path\n",
    "\n",
    "# File extensions\n",
    "json_ext = '.json'\n",
    "np_ext = '.npy'\n",
    "csv_ext = '.csv'\n",
    "\n",
    "# Network model prefix\n",
    "model_name_prefix = 'cifar10'\n",
    "\n",
    "# Image max pixel value\n",
    "image_max_pix_val = 255\n",
    "\n",
    "# Tested sets name\n",
    "train_set_name = 'train'\n",
    "test_set_name = 'test'\n",
    "valid_set_name = 'valid'\n",
    "input_extension = 'X'\n",
    "label_extension = 'Y'\n",
    "\n",
    "# Save paths\n",
    "model_save_path = path_join(experim_path, 'experim_models_resnet_paper')\n",
    "\n",
    "\"\"\"\n",
    "The following parameters should be configured according to your experiments\n",
    "\"\"\"\n",
    "\n",
    "# Trained model name \n",
    "trained_net_name = 'cifar10_densenet_pretrained'\n",
    "\n",
    "# Network related params\n",
    "net_model_name = 'densenet'\n",
    "\n",
    "# Dataset general informations\n",
    "data_set_infos = {\n",
    "    'nb_classes' : 10\n",
    "}\n",
    "\n",
    "# The maximum number of considered k-nearst neighbors\n",
    "k_max = 1000\n",
    "\n",
    "# The batch size of the knn search\n",
    "knn_batch_size = 50\n",
    "\n",
    "# The k value for the deep-KNN method\n",
    "k_cifar10 = 50\n",
    "\n",
    "\"\"\"\n",
    "\"\"\"\n",
    "\n",
    "# The output folder\n",
    "output_path = path_join(experim_path, 'output_SOTA_CIFAR10')\n",
    "\n",
    "# Build the class list\n",
    "class_list = list(range(data_set_infos['nb_classes']))\n",
    "\n",
    "# Batch size for the dataloader creation\n",
    "torch_batch_size = 128"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab83bc84",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the folders\n",
    "create_directory(output_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a1ae1fb",
   "metadata": {},
   "source": [
    "### *Experiment preparation*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55636cc8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the dataset\n",
    "cifar10_train_dataset, cifar10_test_dataset = get_cifar10_dataset_with_only_normalization()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b9931900",
   "metadata": {},
   "source": [
    "### *Load the trained Network*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "132f3376",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the network\n",
    "trained_net = load_model_by_net_name(model_save_path, trained_net_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "959a79b0",
   "metadata": {},
   "source": [
    "### *Move the model to GPU*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1a9c01d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Move to gpu\n",
    "trained_net.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d849492d",
   "metadata": {},
   "source": [
    "### *Cifar10 dataset preparation*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "493c5931",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dataloader building\n",
    "train_loader = create_loader_from_torch_dataset(cifar10_train_dataset, batch_size=torch_batch_size, shuffle=False, num_workers=0)\n",
    "test_loader = create_loader_from_torch_dataset(cifar10_test_dataset, batch_size=torch_batch_size, shuffle=False, num_workers=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ed01114",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert the training set to numpy array\n",
    "no_divide_into_batch_train_loader = create_loader_from_torch_dataset(cifar10_train_dataset, batch_size=len(cifar10_train_dataset), shuffle=False, num_workers=0)\n",
    "X_train = next(iter(no_divide_into_batch_train_loader))[0].numpy()\n",
    "y_train = next(iter(no_divide_into_batch_train_loader))[1].numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abbab086",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert the test set to numpy array\n",
    "no_divide_into_batch_test_loader = create_loader_from_torch_dataset(cifar10_test_dataset, batch_size=len(cifar10_test_dataset), shuffle=False, num_workers=0)\n",
    "X_test = next(iter(no_divide_into_batch_test_loader))[0].numpy()\n",
    "y_test = next(iter(no_divide_into_batch_test_loader))[1].numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b13de364",
   "metadata": {},
   "source": [
    "### *Evaluate the activation levels for the original training and test sets of CIFAR-10*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0de364c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the training set activation levels \n",
    "train_actLevels = obtain_activation_levels(trained_net,\n",
    "                                           train_loader, 'train', with_predict_class=True, loss_type='cross_entropy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f924888a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the test set activation levels \n",
    "test_actLevels = obtain_activation_levels(trained_net,\n",
    "                                           test_loader, 'test', with_predict_class=True, loss_type='cross_entropy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aaed65f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the probabilities\n",
    "train_actLevels = probability_evaluation(train_actLevels)\n",
    "test_actLevels = probability_evaluation(test_actLevels)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88a0ad10",
   "metadata": {},
   "source": [
    "### *Load the novelty OOD dataset and evaluate the activation levels*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9400c9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the svhn dataset\n",
    "svhn_test_dataset = get_CIFAR_ood_datasets(set_name='svhn', normalize=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "124e4580",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dataloader building\n",
    "svhn_test_loader = create_loader_from_torch_dataset(svhn_test_dataset, batch_size=torch_batch_size, shuffle=False, num_workers=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40cb32ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the dtd dataset\n",
    "dtd_test_dataset = get_CIFAR_ood_datasets(set_name='dtd', normalize=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acd92743",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert the test set to numpy array\n",
    "no_divide_into_batch_dtd_test_loader = create_loader_from_torch_dataset(dtd_test_dataset, batch_size=len(dtd_test_dataset), shuffle=False, num_workers=0)\n",
    "X_test_dtd = next(iter(no_divide_into_batch_dtd_test_loader))[0].numpy()\n",
    "y_test_dtd = next(iter(no_divide_into_batch_dtd_test_loader))[1].numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee06996c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build the dtd loaders\n",
    "dtd_test_loader = create_dataloader(X_test_dtd, np.random.randint(0, data_set_infos['nb_classes'], y_test_dtd.shape[0]), \n",
    "                                     batch_size=torch_batch_size, shuffle=False, type_conversion=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21293730",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the places365 dataset\n",
    "places_test_dataset = get_CIFAR_ood_datasets(set_name='places', normalize=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5fe04ceb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build the places365 loaders (using random original labels (because they are not important))\n",
    "places_test_loader = create_loader_from_torch_dataset(places_test_dataset, batch_size=torch_batch_size, shuffle=False, num_workers=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6865bee5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build the dictionary that contains all the OOD dataset loaders\n",
    "novelty_loaders = {}\n",
    "novelty_loaders['svhn'] = svhn_test_loader\n",
    "novelty_loaders['dtd'] = dtd_test_loader\n",
    "novelty_loaders['places'] = places_test_loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d37841a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the other ood loaders\n",
    "other_ood_types = ['LSUN', 'LSUN_resize', 'iSUN']\n",
    "for ood_type in other_ood_types:\n",
    "    current_ood_datasset = get_CIFAR_ood_datasets(set_name=ood_type, normalize=True)\n",
    "    novelty_loaders[ood_type] = create_loader_from_torch_dataset(current_ood_datasset, batch_size=torch_batch_size, shuffle=False, num_workers=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d300857",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Iterate over the OOD datasets for generating the normalized feature vectors\n",
    "novelty_actLevels = {}\n",
    "for ood_type in novelty_loaders:\n",
    "    novelty_actLevels[ood_type] = obtain_activation_levels(trained_net,\n",
    "                                                           novelty_loaders[ood_type], ood_type + ' test',\n",
    "                                                           with_predict_class=True, loss_type='cross_entropy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f57e977",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the probabilities\n",
    "for ood_type in novelty_actLevels:\n",
    "    novelty_actLevels[ood_type] = probability_evaluation(novelty_actLevels[ood_type])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe065212",
   "metadata": {},
   "source": [
    "### *Identify the activation threshold*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83c1523c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the last hidden layer Id\n",
    "last_hidden_layerId = list(train_actLevels['actLevel'].keys())[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71faeb67",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Determine the activation threshold for the ReAct (May use the correctly predicted examples for more precise evaluation)\n",
    "act_threshold = np.percentile(train_actLevels['actLevel'][last_hidden_layerId], 96)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "69177f60",
   "metadata": {},
   "source": [
    "### *Get the correctly predicted examples*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a848e3bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the correctly predicted training set activation levels\n",
    "correct_train_actLevels = build_correct_actLevels(train_actLevels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef4bf642",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Free the memory for the original training set activation levels\n",
    "del train_actLevels\n",
    "_ = gc.collect()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c7fbbf0",
   "metadata": {},
   "source": [
    "### *Get the model parameters*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e4e38d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build the correspondance dictionary for the examples of each class\n",
    "class_index_dict = class_index_dict_build(correct_train_actLevels['class'].reshape(-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "891bb225",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the last layer parameters\n",
    "model_params = get_model_parameters(trained_net, to_numpy=True)\n",
    "final_linear_params = model_params['fc']\n",
    "nb_vars = final_linear_params['weight'].shape[1]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4eed07b7",
   "metadata": {},
   "source": [
    "### *Evaluate the OOD scores using only logits and their performance*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95984e80",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the correct train logits\n",
    "correct_train_origin_logits = get_logits_without_masking_torch_ver(correct_train_actLevels['actLevel'][last_hidden_layerId],\n",
    "                                                            final_linear_params, display=False)\n",
    "# Get the test set logits\n",
    "test_origin_logits = get_logits_without_masking_torch_ver(test_actLevels['actLevel'][last_hidden_layerId],\n",
    "                                                   final_linear_params, display=False)\n",
    "# Get the logits for the novelty ood sets\n",
    "novelty_origin_logits = {}\n",
    "for ood_type in novelty_actLevels:\n",
    "    novelty_origin_logits[ood_type]  = get_logits_without_masking_torch_ver(novelty_actLevels[ood_type]['actLevel'][last_hidden_layerId],\n",
    "                                                                     final_linear_params, display=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7334ff0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The scores using logits\n",
    "logit_methods = ['msp', 'maxlogit', 'kl', 'energy']\n",
    "test_logit_scores = {}\n",
    "novelty_logit_scores = {}\n",
    "logit_ood_performances = {}\n",
    "for method in logit_methods:\n",
    "    test_logit_scores[method] = get_score(test_origin_logits, method)\n",
    "    novelty_logit_scores[method] = {}\n",
    "    for ood_type in novelty_origin_logits:\n",
    "        novelty_logit_scores[method][ood_type] = get_score(novelty_origin_logits[ood_type], method)\n",
    "    logit_ood_performances[method] = ood_performance_evaluation(test_logit_scores[method], novelty_logit_scores[method], method, display=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41d7249e",
   "metadata": {},
   "source": [
    "### *Evaluate the mahalanobis scores using only logits and its performance*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50fe1412",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The mahalanobis model\n",
    "mahalanobis_model = get_mahalanobis_model(correct_train_actLevels['actLevel'][last_hidden_layerId],\n",
    "                                          correct_train_actLevels['class'].reshape(-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ec0eaf3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The mahalanobis scores\n",
    "test_mahalanobis_scores = get_mahalanobis_score(test_actLevels['actLevel'][last_hidden_layerId], mahalanobis_model)\n",
    "novelty_mahalanobis_scores = {}\n",
    "for ood_type in novelty_actLevels:\n",
    "    novelty_mahalanobis_scores[ood_type] = get_mahalanobis_score(novelty_actLevels[ood_type]['actLevel'][last_hidden_layerId],\n",
    "                                                                 mahalanobis_model)\n",
    "mahalanobis_ood_performance = ood_performance_evaluation(test_mahalanobis_scores, novelty_mahalanobis_scores, 'mahalanobis', display=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b6721b2",
   "metadata": {},
   "source": [
    "### *The possible pruning rates*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f29378f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The possible values of p_n and p_w\n",
    "p_n_list = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90]\n",
    "p_w_list = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5023f8ef",
   "metadata": {},
   "source": [
    "### *Evaluate the DICE scores without ReAct and its performance*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3cb71cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the contribution matrix for the DICE scores\n",
    "DICE_weight_contrib = DICE_weight_contribution_matrix(final_linear_params, correct_train_actLevels, last_hidden_layerId)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6e3c3df",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the DICE performance\n",
    "DICE_ood_performance = {}\n",
    "DICE_best_avg_FPR = 1\n",
    "DICE_best_p_w = None\n",
    "for p_w in tqdm(p_w_list, desc='Processed weight pruning rate'):\n",
    "    # Evaluate the weight pruning mask     \n",
    "    DICE_weight_pruning_mask = pruning_mask(DICE_weight_contrib, p=p_w)\n",
    "    # Get the test set logits\n",
    "    test_DICE_logits = get_logits_torch_ver(test_actLevels['actLevel'][last_hidden_layerId],\n",
    "                                            final_linear_params, np.ones(nb_vars, dtype=int), DICE_weight_pruning_mask,\n",
    "                                            display=False)\n",
    "    # Get the logits for the novelty ood sets\n",
    "    novelty_DICE_logits = {}\n",
    "    for ood_type in novelty_actLevels:\n",
    "        novelty_DICE_logits[ood_type]  = get_logits_torch_ver(novelty_actLevels[ood_type]['actLevel'][last_hidden_layerId],\n",
    "                                                              final_linear_params, np.ones(nb_vars, dtype=int), DICE_weight_pruning_mask,\n",
    "                                                              display=False)\n",
    "    # Compute the scores     \n",
    "    test_DICE_scores = DICE_score_evaluation_logit_ver(test_DICE_logits)\n",
    "    novelty_DICE_scores = {}\n",
    "    for ood_type in novelty_actLevels:\n",
    "        novelty_DICE_scores[ood_type] = DICE_score_evaluation_logit_ver(novelty_DICE_logits[ood_type])\n",
    "    DICE_ood_performance[p_w] = ood_performance_evaluation(test_DICE_scores, novelty_DICE_scores, 'DICE', display=False)\n",
    "    if DICE_ood_performance[p_w]['FPR'] < DICE_best_avg_FPR:\n",
    "        DICE_best_avg_FPR = DICE_ood_performance[p_w]['FPR']\n",
    "        DICE_best_p_w = p_w"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "787b0007",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display the results\n",
    "DICE_FPR_results = [DICE_ood_performance[test_p_w]['FPR'] for test_p_w in DICE_ood_performance]\n",
    "DICE_AUROC_results = [DICE_ood_performance[test_p_w]['AUROC'] for test_p_w in DICE_ood_performance]\n",
    "DICE_FPR_df = pd.DataFrame(data=np.array(DICE_FPR_results).reshape(1,-1), columns=p_w_list)\n",
    "DICE_AUROC_df = pd.DataFrame(data=np.array(DICE_AUROC_results).reshape(1,-1), columns=p_w_list)\n",
    "print('FPR:')\n",
    "ip_display(DICE_FPR_df)\n",
    "print('AUROC:')\n",
    "ip_display(DICE_AUROC_df)\n",
    "print('Best result:')\n",
    "print('The average performance of DICE: FPR95:', DICE_ood_performance[DICE_best_p_w]['FPR'],\n",
    "      \"AUROC:\", DICE_ood_performance[DICE_best_p_w]['AUROC'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "73896a63-f91b-4499-9ce0-4d46ea28559f",
   "metadata": {},
   "source": [
    "### *Deep-KNN*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39ed4bed-2154-45c0-a3ef-d0b827c110dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the feature vectors\n",
    "KNN_correct_train_zs = normalize_feature_vecs_knn(correct_train_actLevels, last_hidden_layerId)\n",
    "KNN_test_zs = normalize_feature_vecs_knn(test_actLevels, last_hidden_layerId)\n",
    "KNN_novelty_zs = {}\n",
    "for ood_type in novelty_actLevels:\n",
    "    KNN_novelty_zs[ood_type] = normalize_feature_vecs_knn(novelty_actLevels[ood_type], last_hidden_layerId)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "952988d4-b3a1-40e9-a413-481cc672fabe",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build the faiss train set index\n",
    "KNN_train_index = faiss.IndexFlatL2(KNN_correct_train_zs.shape[1])\n",
    "KNN_train_index.add(KNN_correct_train_zs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "780f288b-eeba-4540-a5b2-6b41c7f1525c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the test set k-nearst neighbor scores\n",
    "KNN_test_S = k_nearst_neighbor_scores(KNN_train_index, KNN_test_zs, k_cifar10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5c39d8b-c628-45fb-b7ad-3b33d0bcfa6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the train set k-nearst neighbor scores\n",
    "KNN_novelty_S = {}\n",
    "for ood_type in KNN_novelty_zs:\n",
    "    KNN_novelty_S[ood_type] = k_nearst_neighbor_scores(KNN_train_index, KNN_novelty_zs[ood_type], k_cifar10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64eb7508-d2d4-4707-8310-39569dcd5f4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the performance\n",
    "KNN_ood_performance = ood_performance_evaluation(KNN_test_S, KNN_novelty_S, 'KNN', display=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aeb4ffaa",
   "metadata": {},
   "source": [
    "### *Prepare the normalized feature vectors for NNguide*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfeb4b01",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build the training set normalized feature vectors (only the correctly predicted examples)\n",
    "correct_train_zs = normalize_feature_vecs_knn(correct_train_actLevels, last_hidden_layerId)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "386a96d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build the test set normalized feature vectors\n",
    "test_zs = normalize_feature_vecs_knn(test_actLevels, last_hidden_layerId)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bd7312f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the normalized feature vectors of the novelty ood set\n",
    "novelty_zs = {}\n",
    "for ood_type in novelty_actLevels:\n",
    "    novelty_zs[ood_type] = normalize_feature_vecs_knn(novelty_actLevels[ood_type], last_hidden_layerId)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce0e7819",
   "metadata": {},
   "source": [
    "### *Clip the activations*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "891d55a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Execute the activation clip the probabilities\n",
    "test_actLevels = clip_activations(test_actLevels, last_hidden_layerId, act_threshold)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ddf27b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Execute the activation clip\n",
    "correct_train_actLevels = clip_activations(correct_train_actLevels, last_hidden_layerId, act_threshold)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8b8e169",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Clip the activations\n",
    "for ood_type in novelty_actLevels:\n",
    "    novelty_actLevels[ood_type] = clip_activations(novelty_actLevels[ood_type], last_hidden_layerId, act_threshold)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43a5f97c",
   "metadata": {},
   "source": [
    "### *Prepare the feature vectors with all neurons*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64f7570c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the correctly predicted training set feature vectors\n",
    "correct_train_vecs = correct_train_actLevels['actLevel'][last_hidden_layerId]\n",
    "correct_train_preds = correct_train_actLevels['predict_class'].reshape(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44f3092a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build the test set feature vectors\n",
    "test_vecs = test_actLevels['actLevel'][last_hidden_layerId]\n",
    "test_preds = test_actLevels['predict_class'].reshape(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0dcd8e12",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the feature vectors of the novelty ood set\n",
    "novelty_vecs = {}\n",
    "novelty_preds = {}\n",
    "for ood_type in novelty_actLevels:\n",
    "    novelty_vecs[ood_type] = novelty_actLevels[ood_type]['actLevel'][last_hidden_layerId]\n",
    "    novelty_preds[ood_type] = novelty_actLevels[ood_type]['predict_class'].reshape(-1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "106d4c17",
   "metadata": {},
   "source": [
    "### *Evaluate various OOD scores and their performance*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cfae2177",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The logits after the pruning\n",
    "# Get the correct train logits\n",
    "correct_train_logits = get_logits_without_masking_torch_ver(correct_train_vecs, final_linear_params, display=False)\n",
    "# Get the test set logits\n",
    "test_logits = get_logits_without_masking_torch_ver(test_vecs, final_linear_params, display=False)\n",
    "# Get the logits for the novelty ood sets\n",
    "novelty_logits = {}\n",
    "for ood_type in novelty_vecs:\n",
    "    novelty_logits[ood_type]  = get_logits_without_masking_torch_ver(novelty_vecs[ood_type], final_linear_params, display=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd26679b",
   "metadata": {},
   "source": [
    "#### *Energy with ReAct (Called ReAct)*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d43e2ca7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The scores using logits\n",
    "test_ReAct_scores = get_score(test_logits, 'energy')\n",
    "novelty_ReAct_scores = {}\n",
    "for ood_type in novelty_logits:\n",
    "    novelty_ReAct_scores[ood_type] = get_score(novelty_logits[ood_type], 'energy')\n",
    "ReAct_ood_performances = ood_performance_evaluation(test_ReAct_scores, novelty_ReAct_scores, 'ReAct', display=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16062ece",
   "metadata": {},
   "source": [
    "#### *DICE with ReAct*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b7e307d",
   "metadata": {},
   "outputs": [],
   "source": [
    "## The DICE_ReAct scores\n",
    "# Evaluate the contribution matrix for the DICE scores\n",
    "DICE_ReAct_weight_contrib = DICE_weight_contribution_matrix(final_linear_params, correct_train_actLevels, last_hidden_layerId)\n",
    "# Evaluate the DICE performance\n",
    "DICE_ReAct_ood_performance = {}\n",
    "DICE_ReAct_best_avg_FPR = 1\n",
    "DICE_ReAct_best_p_w = None\n",
    "for p_w in tqdm(p_w_list, desc='Processed weight pruning rate'):\n",
    "    # Evaluate the weight pruning mask\n",
    "    DICE_ReAct_weight_pruning_mask = pruning_mask(DICE_ReAct_weight_contrib, p=p_w)\n",
    "    # Get the test set logits\n",
    "    test_DICE_ReAct_logits = get_logits_torch_ver(test_vecs, final_linear_params,\n",
    "                                                  np.ones(nb_vars, dtype=int), DICE_ReAct_weight_pruning_mask,\n",
    "                                                  display=False)\n",
    "    # Get the logits for the novelty ood sets\n",
    "    novelty_DICE_ReAct_logits = {}\n",
    "    for ood_type in novelty_vecs:\n",
    "        novelty_DICE_ReAct_logits[ood_type]  = get_logits_torch_ver(novelty_vecs[ood_type], final_linear_params,\n",
    "                                                                    np.ones(nb_vars, dtype=int), DICE_ReAct_weight_pruning_mask,\n",
    "                                                                    display=False)\n",
    "    # Compute the scores\n",
    "    test_DICE_ReAct_scores = DICE_score_evaluation_logit_ver(test_DICE_ReAct_logits)\n",
    "    novelty_DICE_ReAct_scores = {}\n",
    "    for ood_type in novelty_vecs:\n",
    "        novelty_DICE_ReAct_scores[ood_type] = DICE_score_evaluation_logit_ver(novelty_DICE_ReAct_logits[ood_type])\n",
    "    DICE_ReAct_ood_performance[p_w] = ood_performance_evaluation(test_DICE_ReAct_scores, novelty_DICE_ReAct_scores,\n",
    "                                                                 'DICE_ReAct', display=False)\n",
    "    if DICE_ReAct_ood_performance[p_w]['FPR'] < DICE_ReAct_best_avg_FPR:\n",
    "        DICE_ReAct_best_avg_FPR = DICE_ReAct_ood_performance[p_w]['FPR']\n",
    "        DICE_ReAct_best_p_w = p_w"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42ddc03f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display the results\n",
    "DICE_ReAct_FPR_results = [DICE_ReAct_ood_performance[test_p_w]['FPR'] for test_p_w in DICE_ReAct_ood_performance]\n",
    "DICE_ReAct_AUROC_results = [DICE_ReAct_ood_performance[test_p_w]['AUROC'] for test_p_w in DICE_ReAct_ood_performance]\n",
    "DICE_ReAct_FPR_df = pd.DataFrame(data=np.array(DICE_ReAct_FPR_results).reshape(1,-1), columns=p_w_list)\n",
    "DICE_ReAct_AUROC_df = pd.DataFrame(data=np.array(DICE_ReAct_AUROC_results).reshape(1,-1), columns=p_w_list)\n",
    "print('FPR:')\n",
    "ip_display(DICE_ReAct_FPR_df)\n",
    "print('AUROC:')\n",
    "ip_display(DICE_ReAct_AUROC_df)\n",
    "print('Best result:')\n",
    "print('The average performance of DICE_ReAct: FPR95:', DICE_ReAct_ood_performance[DICE_ReAct_best_p_w]['FPR'],\n",
    "      \"AUROC:\", DICE_ReAct_ood_performance[DICE_ReAct_best_p_w]['AUROC'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db2e390a",
   "metadata": {},
   "source": [
    "#### *NNguide*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb11bb70",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The nnguide model\n",
    "nnguide_model = get_nnguide_model(correct_train_logits, correct_train_zs, correct_train_actLevels['class'].reshape(-1),\n",
    "                                  k_max, knn_batch_size=50, half_precision=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b68ef34d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The mahalanobis scores\n",
    "test_nnguide_scores = get_nnguide_score(test_zs, test_logits, nnguide_model)\n",
    "novelty_nnguide_scores = {}\n",
    "for ood_type in novelty_logits:\n",
    "    novelty_nnguide_scores[ood_type] = get_nnguide_score(novelty_zs[ood_type], novelty_logits[ood_type], nnguide_model)\n",
    "nnguide_ood_performance = ood_performance_evaluation(test_nnguide_scores, novelty_nnguide_scores, 'nnguide', display=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c59d6a86-bd4c-4ea5-9d0d-7362669d8102",
   "metadata": {},
   "source": [
    "#### *Deep-KNN with ReAct*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "333a25d3-ae89-4906-8466-fa4f91a51594",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the feature vectors\n",
    "KNN_ReAct_correct_train_zs = normalize_feature_vecs_knn(correct_train_actLevels, last_hidden_layerId)\n",
    "KNN_ReAct_test_zs = normalize_feature_vecs_knn(test_actLevels, last_hidden_layerId)\n",
    "KNN_ReAct_novelty_zs = {}\n",
    "for ood_type in novelty_actLevels:\n",
    "    KNN_ReAct_novelty_zs[ood_type] = normalize_feature_vecs_knn(novelty_actLevels[ood_type], last_hidden_layerId)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89cc748b-bbd6-498d-9760-c8803497ee05",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build the faiss train set index\n",
    "KNN_ReAct_train_index = faiss.IndexFlatL2(KNN_ReAct_correct_train_zs.shape[1])\n",
    "KNN_ReAct_train_index.add(KNN_ReAct_correct_train_zs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2494c7e3-c512-4abe-92a7-b05db875ec55",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the test set k-nearst neighbor scores\n",
    "KNN_ReAct_test_S = k_nearst_neighbor_scores(KNN_ReAct_train_index, KNN_ReAct_test_zs, k_cifar10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d25bd3ad-864d-431f-8e16-059fb4ae884c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the train set k-nearst neighbor scores\n",
    "KNN_ReAct_novelty_S = {}\n",
    "for ood_type in KNN_ReAct_novelty_zs:\n",
    "    KNN_ReAct_novelty_S[ood_type] = k_nearst_neighbor_scores(KNN_ReAct_train_index, KNN_ReAct_novelty_zs[ood_type], k_cifar10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcde1848-d6a5-4c79-9b62-b9d43f27e559",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the performance\n",
    "KNN_ReAct_ood_performance = ood_performance_evaluation(KNN_ReAct_test_S, KNN_ReAct_novelty_S, 'KNN_ReAct', display=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c19dba6d",
   "metadata": {},
   "source": [
    "#### *LINe*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c6d0e45-c6d8-48f7-b682-bdb58a7bad47",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the shapley matrix\n",
    "shapley_contrib_matrix = shapley_score_evaluation_GPU(correct_train_actLevels, final_linear_params, last_hidden_layerId, predict_class_ver=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc3f2f28-a178-463f-b741-ffe8f8faba67",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the weight contribution and pruning masks for different classes\n",
    "LINe_best_avg_FPR = 1\n",
    "LINe_best_p_n = None\n",
    "LINe_best_p_w = None\n",
    "LINe_registered_evaluation_results = {}\n",
    "for test_p_n in p_n_list:\n",
    "    LINe_registered_evaluation_results[test_p_n] = {}\n",
    "    for test_p_w in tqdm(p_w_list, desc='Processed weight pruning rate'):\n",
    "        # Evaluate the pruning masks\n",
    "        LINe_weight_contributions = {}\n",
    "        LINe_feature_pruning_masks = {}\n",
    "        LINe_weight_pruning_masks = {}\n",
    "        for classId in class_list:\n",
    "            LINe_weight_contributions[classId] = weight_contribution(final_linear_params['weight'], shapley_contrib_matrix[classId])\n",
    "        for classId in class_list:\n",
    "            LINe_feature_pruning_masks[classId] = pruning_mask(shapley_contrib_matrix[classId], test_p_n)\n",
    "            LINe_weight_pruning_masks[classId] = pruning_mask(LINe_weight_contributions[classId], test_p_w)\n",
    "        # Get the test set logits\n",
    "        test_LINe_logits = get_logits_by_preds_torch_ver(test_vecs, test_preds,\n",
    "                                                    final_linear_params,\n",
    "                                                    LINe_feature_pruning_masks,\n",
    "                                                    LINe_weight_pruning_masks,\n",
    "                                                    display=False)\n",
    "        # Get the logits for the novelty ood sets\n",
    "        novelty_LINe_logits = {}\n",
    "        for ood_type in novelty_vecs:\n",
    "            novelty_LINe_logits[ood_type]  = get_logits_by_preds_torch_ver(novelty_vecs[ood_type], \n",
    "                                                                       novelty_preds[ood_type],\n",
    "                                                                       final_linear_params,\n",
    "                                                                       LINe_feature_pruning_masks,\n",
    "                                                                       LINe_weight_pruning_masks,\n",
    "                                                                       display=False)\n",
    "        ## Get the scores\n",
    "        # Test set         \n",
    "        test_LINe_scores = LINe_score_evaluation_logit_ver(test_LINe_logits)\n",
    "        # Novelty sets\n",
    "        novelty_LINe_scores = {}\n",
    "        for ood_type in novelty_LINe_logits:\n",
    "            # Compute the scores\n",
    "            novelty_LINe_scores[ood_type] = LINe_score_evaluation_logit_ver(novelty_LINe_logits[ood_type])\n",
    "        # Evaluate according to the metric in the DICE paper\n",
    "        novelty_metric_results = {}\n",
    "        for ood_type in novelty_LINe_scores:\n",
    "            # To prevent the order change with deepcopy\n",
    "            novelty_metric_results[ood_type] = cal_metric(copy.deepcopy(test_LINe_scores), copy.deepcopy(novelty_LINe_scores[ood_type]), method=None) \n",
    "        novelty_avg_FPR = np.mean([novelty_metric_results[ood_type]['FPR'] for ood_type in novelty_metric_results])\n",
    "        novelty_avg_AUROC = np.mean([novelty_metric_results[ood_type]['AUROC'] for ood_type in novelty_metric_results])\n",
    "        # Save the result\n",
    "        LINe_registered_evaluation_results[test_p_n][test_p_w] = {}\n",
    "        LINe_registered_evaluation_results[test_p_n][test_p_w]['Detail'] = novelty_metric_results\n",
    "        LINe_registered_evaluation_results[test_p_n][test_p_w]['FPR'] = novelty_avg_FPR\n",
    "        LINe_registered_evaluation_results[test_p_n][test_p_w]['AUROC'] = novelty_avg_AUROC\n",
    "        # Determine if it is the best result\n",
    "        if LINe_registered_evaluation_results[test_p_n][test_p_w]['FPR'] < LINe_best_avg_FPR:\n",
    "            LINe_best_avg_FPR = LINe_registered_evaluation_results[test_p_n][test_p_w]['FPR']\n",
    "            LINe_best_p_n = test_p_n\n",
    "            LINe_best_p_w = test_p_w"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1b3ed30",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display the results\n",
    "LINe_FPR_results = []\n",
    "LINe_AUROC_results = []\n",
    "for test_p_n in LINe_registered_evaluation_results:\n",
    "    LINe_FPR_results.append([LINe_registered_evaluation_results[test_p_n][test_p_w]['FPR'] for test_p_w in LINe_registered_evaluation_results[test_p_n]])\n",
    "    LINe_AUROC_results.append([LINe_registered_evaluation_results[test_p_n][test_p_w]['AUROC'] for test_p_w in LINe_registered_evaluation_results[test_p_n]])\n",
    "LINe_FPR_df = pd.DataFrame(data=np.array(LINe_FPR_results),index=p_n_list, columns=p_w_list)\n",
    "LINe_AUROC_df = pd.DataFrame(data=np.array(LINe_AUROC_results),index=p_n_list, columns=p_w_list)\n",
    "print('FPR:')\n",
    "ip_display(LINe_FPR_df)\n",
    "print('AUROC:')\n",
    "ip_display(LINe_AUROC_df)\n",
    "print('Best result:')\n",
    "print('The average performance of LINe: FPR95:', LINe_registered_evaluation_results[LINe_best_p_n][LINe_best_p_w]['FPR'],\n",
    "      \"AUROC:\", LINe_registered_evaluation_results[LINe_best_p_n][LINe_best_p_w]['AUROC'])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
