{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7e88ded8",
   "metadata": {},
   "source": [
    "### *Module Loading*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0edbfb6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import faiss\n",
    "import pickle\n",
    "from subprocess import PIPE, run\n",
    "from IPython.display import display as ip_display"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ccbc82f",
   "metadata": {},
   "source": [
    "### *External Module Loading*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcf4ad4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "external_modules_path = '..\\\\nn_likelihood_modules'\n",
    "sys.path.append(external_modules_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe63d714",
   "metadata": {},
   "outputs": [],
   "source": [
    "from common_imports import *\n",
    "from common_use_functions import *\n",
    "from constant import *\n",
    "from experim_neural_network import *\n",
    "from experim_preparation import *\n",
    "from generate_activation_level import *\n",
    "from pytorch_model_predict import *\n",
    "from ResNet_OOD import *\n",
    "from imagenet_data_prep import *\n",
    "from sensitivity_analysis import *\n",
    "from deep_KNN import *\n",
    "from novelty_data_prep import *\n",
    "from activation_level_processing import *\n",
    "from CoNNGuide_sensitivity_indices import *\n",
    "from LINe_OOD_score import *\n",
    "from neuron_importance_analysis import *\n",
    "from multscore_utils import *\n",
    "from CoNNGuide_OOD_datasets import *\n",
    "from OOD_score_utils import *\n",
    "from DICE_OOD_score import *\n",
    "from knn_search_GPU import *\n",
    "from RegNet import *\n",
    "from sota_ood_scores import *\n",
    "from pytorch_training_preparation import *"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b8dc22ef",
   "metadata": {},
   "source": [
    "### *GPU verification*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "930c998b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the GPU\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "nb_gpu = torch.cuda.device_count()\n",
    "if nb_gpu > 0:\n",
    "    print(torch.cuda.get_device_name(0))\n",
    "else:\n",
    "    print(\"CPU\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0701537f",
   "metadata": {},
   "source": [
    "### *Working directory*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89c6568b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Current path\n",
    "current_path = os.path.abspath(os.getcwd())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f5c8167",
   "metadata": {},
   "source": [
    "### *Load configurations and data*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cd2037f",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "All the parameters in this part should be configured\n",
    "\"\"\"\n",
    "# Experience path\n",
    "experim_path = current_path\n",
    "\n",
    "# File extensions\n",
    "json_ext = '.json'\n",
    "np_ext = '.npy'\n",
    "pkl_ext = '.pkl'\n",
    "csv_ext = '.csv'\n",
    "\n",
    "# Network model prefix\n",
    "model_name_prefix = 'imagenet'\n",
    "\n",
    "# Image max pixel value\n",
    "image_max_pix_val = 255\n",
    "\n",
    "# Tested sets name\n",
    "train_set_name = 'train'\n",
    "test_set_name = 'test'\n",
    "valid_set_name = 'valid'\n",
    "input_extension = 'X'\n",
    "label_extension = 'Y'\n",
    "\n",
    "# Save paths\n",
    "model_save_path = path_join(experim_path, 'experim_models_imagenet')\n",
    "\n",
    "\"\"\"\n",
    "The following parameters should be configured according to your experiments\n",
    "\"\"\"\n",
    "\n",
    "# Trained model name \n",
    "trained_net_name = 'imagenet_regnet'\n",
    "\n",
    "# Network related params\n",
    "net_model_name = 'regnet'\n",
    "\n",
    "# Dataset general informations\n",
    "data_set_infos = {\n",
    "    'nb_classes' : 1000\n",
    "}\n",
    "\n",
    "# The maximum number of considered k-nearst neighbors\n",
    "k_max = 1000\n",
    "\n",
    "# The path to the original imagenet data\n",
    "data_path = 'Your path here'\n",
    "\n",
    "# The path to the preprocessed feature vectors\n",
    "precompute_data_path = path_join(experim_path, 'precomputed_data_'+trained_net_name)\n",
    "\n",
    "# The number of features\n",
    "nb_vars = 3024\n",
    "\n",
    "# The batch size of the knn search\n",
    "knn_batch_size = 50\n",
    "\n",
    "# The k value for the deep-KNN method\n",
    "k_imagenet = 1000\n",
    "\n",
    "\"\"\"\n",
    "\"\"\"\n",
    "# The output folder\n",
    "output_path = path_join(experim_path, 'output_SOTA_Imagenet_RegNet')\n",
    "\n",
    "# Build the class list\n",
    "class_list = list(range(data_set_infos['nb_classes']))\n",
    "\n",
    "# Batch size for the dataloader creation\n",
    "torch_batch_size = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab83bc84",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the folders\n",
    "create_directory(output_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a1ae1fb",
   "metadata": {},
   "source": [
    "### *Experiment preparation*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55636cc8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the dataset\n",
    "imagenet_train_dataset, imagenet_test_dataset = get_imagenet_dataset_without_transform_for_regnet(data_path, normalize=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b9931900",
   "metadata": {},
   "source": [
    "### *Load the trained Network*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "132f3376",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the network\n",
    "trained_net = load_model_by_net_name(model_save_path, trained_net_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "959a79b0",
   "metadata": {},
   "source": [
    "### *Move the model to GPU*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1a9c01d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Move to gpu\n",
    "trained_net.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d849492d",
   "metadata": {},
   "source": [
    "### *Evaluate the activation levels for the original training and test sets*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "493c5931",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the training set activation levels\n",
    "train_actLevels = None\n",
    "train_actLevels_file_path = path_join(precompute_data_path, 'train'+pkl_ext)\n",
    "if file_or_folder_existence(train_actLevels_file_path):\n",
    "    # Load the data     \n",
    "    with open(train_actLevels_file_path, 'rb') as f:\n",
    "        train_actLevels = pickle.load(f)\n",
    "    # Delete the not used probability data\n",
    "    train_actLevels.pop('prob', None)\n",
    "else:\n",
    "    # Dataloader building\n",
    "    train_loader = create_loader_from_torch_dataset(imagenet_train_dataset, batch_size=torch_batch_size, shuffle=False, num_workers=4)\n",
    "    # Get the training set activation levels \n",
    "    train_actLevels = obtain_activation_levels(trained_net,\n",
    "                                               train_loader, 'train', with_predict_class=True, loss_type='cross_entropy')\n",
    "    # Evaluate the probabilities\n",
    "    train_actLevels = probability_evaluation(train_actLevels)\n",
    "    # Save the activation levels\n",
    "    with open(train_actLevels_file_path, 'wb') as f:\n",
    "        pickle.dump(train_actLevels, f, protocol=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ed01114",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the training set activation levels\n",
    "test_actLevels = None\n",
    "test_actLevels_file_path = path_join(precompute_data_path, 'test'+pkl_ext)\n",
    "if file_or_folder_existence(test_actLevels_file_path):\n",
    "    # Load the data\n",
    "    with open(test_actLevels_file_path, 'rb') as f:\n",
    "        test_actLevels = pickle.load(f)\n",
    "    # Delete the not used probability data\n",
    "    test_actLevels.pop('prob', None)\n",
    "else:\n",
    "    # Dataloader building\n",
    "    test_loader = create_loader_from_torch_dataset(imagenet_test_dataset, batch_size=torch_batch_size, shuffle=False, num_workers=4)\n",
    "    # Get the test set activation levels \n",
    "    test_actLevels = obtain_activation_levels(trained_net,\n",
    "                                               test_loader, 'test', with_predict_class=True, loss_type='cross_entropy')\n",
    "    # Evaluate the probabilities\n",
    "    test_actLevels = probability_evaluation(test_actLevels)\n",
    "    # Save the activation levels\n",
    "    with open(test_actLevels_file_path, 'wb') as f:\n",
    "        pickle.dump(test_actLevels, f, protocol=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abbab086",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Reduce the memory use\n",
    "del imagenet_train_dataset, imagenet_test_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b13de364",
   "metadata": {},
   "source": [
    "### *Evaluate the activation levels for the novelty sets*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0de364c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The considered novelty types\n",
    "ood_types = ['dtd', 'inat', 'places', 'sun', 'openimage']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f924888a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the novelty datasets\n",
    "novelty_datasets = {}\n",
    "for ood_type in ood_types:\n",
    "    novelty_datasets[ood_type] = get_imagenet_ood_datasets_for_regnet(set_name=ood_type, normalize=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aaed65f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Iterate over the ood sets to generate the activation levels\n",
    "novelty_actLevels = {}\n",
    "for ood_type in ood_types:\n",
    "    current_novelty_actLevels_file_path = path_join(precompute_data_path, ood_type+pkl_ext)\n",
    "    if file_or_folder_existence(current_novelty_actLevels_file_path):\n",
    "        # Load the data\n",
    "        with open(current_novelty_actLevels_file_path, 'rb') as f:\n",
    "            novelty_actLevels[ood_type] = pickle.load(f)\n",
    "        # Delete the not used probability data\n",
    "        novelty_actLevels[ood_type].pop('prob', None)\n",
    "    else:\n",
    "        # Dataloader building\n",
    "        current_novelty_loader = create_loader_from_torch_dataset(novelty_datasets[ood_type], batch_size=torch_batch_size, shuffle=False, num_workers=4)\n",
    "        # Get the test set activation levels \n",
    "        current_novelty_actLevels = obtain_activation_levels(trained_net, current_novelty_loader, ood_type,\n",
    "                                                               with_predict_class=True, loss_type='cross_entropy')\n",
    "        # Evaluate the probabilities\n",
    "        current_novelty_actLevels = probability_evaluation(current_novelty_actLevels)\n",
    "        # Save the activation levels\n",
    "        with open(current_novelty_actLevels_file_path, 'wb') as f:\n",
    "            pickle.dump(current_novelty_actLevels, f, protocol=4)\n",
    "        # Assign the activation levels\n",
    "        novelty_actLevels[ood_type] = current_novelty_actLevels"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe065212",
   "metadata": {},
   "source": [
    "### *Identify the activation threshold*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83c1523c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the last hidden layer Id\n",
    "last_hidden_layerId = list(train_actLevels['actLevel'].keys())[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71faeb67",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Determine the activation threshold for the ReAct (May use the correctly predicted examples for more precise evaluation)\n",
    "act_threshold = np.percentile(train_actLevels['actLevel'][last_hidden_layerId], 100)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "69177f60",
   "metadata": {},
   "source": [
    "### *Get the correctly predicted examples*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a848e3bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the correctly predicted training set activation levels\n",
    "correct_train_actLevels = build_correct_actLevels(train_actLevels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef4bf642",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Free the memory for the original training set activation levels\n",
    "del train_actLevels\n",
    "_ = gc.collect()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c7fbbf0",
   "metadata": {},
   "source": [
    "### *Sensitivity Index Evaluation Preparation*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e4e38d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build the correspondance dictionary for the examples of each class\n",
    "class_index_dict = class_index_dict_build(correct_train_actLevels['class'].reshape(-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "891bb225",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the last layer parameters\n",
    "model_params = get_model_parameters(trained_net, to_numpy=True)\n",
    "final_linear_params = model_params['fc']\n",
    "nb_vars = final_linear_params['weight'].shape[1]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4eed07b7",
   "metadata": {},
   "source": [
    "### *Evaluate the OOD scores using only logits and their performance*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95984e80",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the test set logits\n",
    "test_origin_logits = get_logits_without_masking_torch_ver(test_actLevels['actLevel'][last_hidden_layerId],\n",
    "                                                   final_linear_params, display=False)\n",
    "# Get the logits for the novelty ood sets\n",
    "novelty_origin_logits = {}\n",
    "for ood_type in novelty_actLevels:\n",
    "    novelty_origin_logits[ood_type]  = get_logits_without_masking_torch_ver(novelty_actLevels[ood_type]['actLevel'][last_hidden_layerId],\n",
    "                                                                     final_linear_params, display=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7334ff0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The scores using logits\n",
    "logit_methods = ['msp', 'maxlogit', 'kl', 'energy']\n",
    "test_logit_scores = {}\n",
    "novelty_logit_scores = {}\n",
    "logit_ood_performances = {}\n",
    "for method in logit_methods:\n",
    "    test_logit_scores[method] = get_score(test_origin_logits, method)\n",
    "    novelty_logit_scores[method] = {}\n",
    "    for ood_type in novelty_origin_logits:\n",
    "        novelty_logit_scores[method][ood_type] = get_score(novelty_origin_logits[ood_type], method)\n",
    "    logit_ood_performances[method] = ood_performance_evaluation_imagenet(test_logit_scores[method], novelty_logit_scores[method], method, display=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b6721b2",
   "metadata": {},
   "source": [
    "### *The possible pruning rates*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f29378f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The possible values of p_n and p_w\n",
    "p_n_list = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90]\n",
    "p_w_list = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5023f8ef",
   "metadata": {},
   "source": [
    "### *Evaluate the DICE scores without ReAct and its performance*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3cb71cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the contribution matrix for the DICE scores\n",
    "DICE_weight_contrib = DICE_weight_contribution_matrix_torch_ver(final_linear_params, correct_train_actLevels, last_hidden_layerId)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6e3c3df",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the DICE performance\n",
    "DICE_ood_performance = {}\n",
    "DICE_best_avg_FPR = 1\n",
    "DICE_best_p_w = None\n",
    "for p_w in tqdm(p_w_list, desc='Processed weight pruning rate'):\n",
    "    # Evaluate the weight pruning mask     \n",
    "    DICE_weight_pruning_mask = pruning_mask(DICE_weight_contrib, p=p_w)\n",
    "    # Get the test set logits\n",
    "    test_DICE_logits = get_logits_torch_ver(test_actLevels['actLevel'][last_hidden_layerId],\n",
    "                                            final_linear_params, np.ones(nb_vars, dtype=int), DICE_weight_pruning_mask,\n",
    "                                            display=False)\n",
    "    # Get the logits for the novelty ood sets\n",
    "    novelty_DICE_logits = {}\n",
    "    for ood_type in novelty_actLevels:\n",
    "        novelty_DICE_logits[ood_type]  = get_logits_torch_ver(novelty_actLevels[ood_type]['actLevel'][last_hidden_layerId],\n",
    "                                                              final_linear_params, np.ones(nb_vars, dtype=int), DICE_weight_pruning_mask,\n",
    "                                                              display=False)\n",
    "    # Compute the scores     \n",
    "    test_DICE_scores = DICE_score_evaluation_logit_ver(test_DICE_logits)\n",
    "    novelty_DICE_scores = {}\n",
    "    for ood_type in novelty_actLevels:\n",
    "        novelty_DICE_scores[ood_type] = DICE_score_evaluation_logit_ver(novelty_DICE_logits[ood_type])\n",
    "    # Evaluate the performance     \n",
    "    DICE_ood_performance[p_w] = ood_performance_evaluation_imagenet(test_DICE_scores, novelty_DICE_scores, 'DICE', display=False)\n",
    "    if DICE_ood_performance[p_w]['FPR_curated'] < DICE_best_avg_FPR:\n",
    "        DICE_best_avg_FPR = DICE_ood_performance[p_w]['FPR_curated']\n",
    "        DICE_best_p_w = p_w"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb7517a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Free the memory\n",
    "del test_DICE_logits, novelty_DICE_logits\n",
    "torch.cuda.empty_cache()\n",
    "_= gc.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8ff14b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display the results\n",
    "DICE_FPR_curated_results = [DICE_ood_performance[test_p_w]['FPR_curated'] for test_p_w in DICE_ood_performance]\n",
    "DICE_AUROC_curated_results = [DICE_ood_performance[test_p_w]['AUROC_curated'] for test_p_w in DICE_ood_performance]\n",
    "DICE_FPR_curated_df = pd.DataFrame(data=np.array(DICE_FPR_curated_results).reshape(1,-1), columns=p_w_list)\n",
    "DICE_AUROC_curated_df = pd.DataFrame(data=np.array(DICE_AUROC_curated_results).reshape(1,-1), columns=p_w_list)\n",
    "print('FPR_curated:')\n",
    "ip_display(DICE_FPR_curated_df)\n",
    "print('AUROC_curated:')\n",
    "ip_display(DICE_AUROC_curated_df)\n",
    "print('Best result:')\n",
    "print('The average performance: FPR95 (curated):', DICE_ood_performance[DICE_best_p_w]['FPR_curated'],\n",
    "      \"AUROC (curated):\", DICE_ood_performance[DICE_best_p_w]['AUROC_curated'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "787b0007",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display the results\n",
    "DICE_FPR_results = [DICE_ood_performance[test_p_w]['FPR'] for test_p_w in DICE_ood_performance]\n",
    "DICE_AUROC_results = [DICE_ood_performance[test_p_w]['AUROC'] for test_p_w in DICE_ood_performance]\n",
    "DICE_FPR_df = pd.DataFrame(data=np.array(DICE_FPR_results).reshape(1,-1), columns=p_w_list)\n",
    "DICE_AUROC_df = pd.DataFrame(data=np.array(DICE_AUROC_results).reshape(1,-1), columns=p_w_list)\n",
    "print('FPR:')\n",
    "ip_display(DICE_FPR_df)\n",
    "print('AUROC:')\n",
    "ip_display(DICE_AUROC_df)\n",
    "print('Best result:')\n",
    "print('The average performance of DICE: FPR95:', DICE_ood_performance[DICE_best_p_w]['FPR'],\n",
    "      \"AUROC:\", DICE_ood_performance[DICE_best_p_w]['AUROC'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "73896a63-f91b-4499-9ce0-4d46ea28559f",
   "metadata": {},
   "source": [
    "### *Deep-KNN*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39ed4bed-2154-45c0-a3ef-d0b827c110dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the feature vectors\n",
    "KNN_correct_train_zs = normalize_feature_vecs_knn(correct_train_actLevels, last_hidden_layerId)\n",
    "KNN_test_zs = normalize_feature_vecs_knn(test_actLevels, last_hidden_layerId)\n",
    "KNN_novelty_zs = {}\n",
    "for ood_type in novelty_actLevels:\n",
    "    KNN_novelty_zs[ood_type] = normalize_feature_vecs_knn(novelty_actLevels[ood_type], last_hidden_layerId)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "780f288b-eeba-4540-a5b2-6b41c7f1525c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the test set k-nearst neighbor scores\n",
    "KNN_test_dists, _ = knn_L2_dist_GPU(KNN_correct_train_zs, KNN_test_zs, batch_size=50, k=k_imagenet, display=True, half_precision=True)\n",
    "KNN_test_S = -KNN_test_dists[:,-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5c39d8b-c628-45fb-b7ad-3b33d0bcfa6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the train set k-nearst neighbor scores\n",
    "KNN_novelty_S = {}\n",
    "for ood_type in KNN_novelty_zs:\n",
    "    current_novelty_dists, _ = knn_L2_dist_GPU(KNN_correct_train_zs, KNN_novelty_zs[ood_type], batch_size=50,\n",
    "                                            k=k_imagenet, display=True, half_precision=True)\n",
    "    KNN_novelty_S[ood_type] = -current_novelty_dists[:,-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64eb7508-d2d4-4707-8310-39569dcd5f4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the performance\n",
    "KNN_ood_performance = ood_performance_evaluation_imagenet(KNN_test_S, KNN_novelty_S, 'KNN', display=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f210716",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Free the memory\n",
    "del KNN_correct_train_zs, KNN_test_zs, KNN_novelty_zs\n",
    "torch.cuda.empty_cache()\n",
    "_= gc.collect()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aeb4ffaa",
   "metadata": {},
   "source": [
    "### *Prepare the normalized feature vectors for NNguide*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67cf6e52",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute the confidence scores\n",
    "nnguide_correct_train_logits = get_logits_without_masking_torch_ver(correct_train_actLevels['actLevel'][last_hidden_layerId],\n",
    "                                                   final_linear_params, display=False)\n",
    "confs_train = log_sum_exponential_score(nnguide_correct_train_logits, sum_axis=1).reshape(-1,1)\n",
    "del nnguide_correct_train_logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfeb4b01",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build the training set normalized feature vectors (only the correctly predicted examples)\n",
    "correct_train_zs = normalize_feature_vecs_knn(correct_train_actLevels, last_hidden_layerId)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94059ce9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Estimate the k value\n",
    "nnguide_train_sims,_ = knn_sim_IP_GPU(correct_train_zs, correct_train_zs, batch_size=knn_batch_size,\n",
    "                                        k=k_max, display=True, half_precision=True)\n",
    "nnguide_knn_k = estimate_dense_k(nnguide_train_sims, verify_steps=3, variation_threshold=0.1, min_k=5, smooth_sigma=0)\n",
    "del nnguide_train_sims\n",
    "_ = gc.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "166764a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Adjust the vectors\n",
    "correct_train_zs = confs_train * correct_train_zs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "386a96d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build the test set normalized feature vectors\n",
    "test_zs = normalize_feature_vecs_knn(test_actLevels, last_hidden_layerId)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bd7312f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the normalized feature vectors of the novelty ood set\n",
    "novelty_zs = {}\n",
    "for ood_type in novelty_actLevels:\n",
    "    novelty_zs[ood_type] = normalize_feature_vecs_knn(novelty_actLevels[ood_type], last_hidden_layerId)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce0e7819",
   "metadata": {},
   "source": [
    "### *Clip the activations*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ddf27b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Execute the activation clip\n",
    "correct_train_actLevels = clip_activations(correct_train_actLevels, last_hidden_layerId, act_threshold)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "891d55a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Execute the activation clip the probabilities\n",
    "test_actLevels = clip_activations(test_actLevels, last_hidden_layerId, act_threshold)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8b8e169",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Clip the activations\n",
    "for ood_type in novelty_actLevels:\n",
    "    novelty_actLevels[ood_type] = clip_activations(novelty_actLevels[ood_type], last_hidden_layerId, act_threshold)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43a5f97c",
   "metadata": {},
   "source": [
    "### *Prepare the feature vectors with all neurons*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64f7570c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the correctly predicted training set feature vectors\n",
    "correct_train_vecs = correct_train_actLevels['actLevel'][last_hidden_layerId]\n",
    "correct_train_preds = correct_train_actLevels['predict_class'].reshape(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44f3092a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build the test set feature vectors\n",
    "test_vecs = test_actLevels['actLevel'][last_hidden_layerId]\n",
    "test_preds = test_actLevels['predict_class'].reshape(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0dcd8e12",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the feature vectors of the novelty ood set\n",
    "novelty_vecs = {}\n",
    "novelty_preds = {}\n",
    "for ood_type in novelty_actLevels:\n",
    "    novelty_vecs[ood_type] = novelty_actLevels[ood_type]['actLevel'][last_hidden_layerId]\n",
    "    novelty_preds[ood_type] = novelty_actLevels[ood_type]['predict_class'].reshape(-1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "106d4c17",
   "metadata": {},
   "source": [
    "### *Evaluate various OOD scores and their performance*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cfae2177",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The logits after the pruning\n",
    "# Get the test set logits\n",
    "test_logits = get_logits_without_masking_torch_ver(test_vecs, final_linear_params, display=False)\n",
    "# Get the logits for the novelty ood sets\n",
    "novelty_logits = {}\n",
    "for ood_type in novelty_vecs:\n",
    "    novelty_logits[ood_type]  = get_logits_without_masking_torch_ver(novelty_vecs[ood_type], final_linear_params, display=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db2e390a",
   "metadata": {},
   "source": [
    "#### *NNguide*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb11bb70",
   "metadata": {},
   "outputs": [],
   "source": [
    "def nnguide_score(scaled_train_zs, zs, logits, knn_k, knn_batch_size):\n",
    "    confs =  log_sum_exponential_score(logits, sum_axis=1)\n",
    "    sims,_ = knn_sim_IP_GPU(scaled_train_zs, zs, batch_size=knn_batch_size,\n",
    "                                                   k=knn_k, display=True, half_precision=True)\n",
    "    guidances = np.mean(sims[:, :knn_k], axis=1)\n",
    "    scores = guidances*confs\n",
    "    del sims\n",
    "    _ = gc.collect()\n",
    "    \n",
    "    return scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b68ef34d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The mahalanobis scores\n",
    "test_nnguide_scores = nnguide_score(correct_train_zs, test_zs, test_logits, nnguide_knn_k, knn_batch_size)\n",
    "novelty_nnguide_scores = {}\n",
    "for ood_type in novelty_logits:\n",
    "    novelty_nnguide_scores[ood_type] = nnguide_score(correct_train_zs, novelty_zs[ood_type], novelty_logits[ood_type],\n",
    "                                                     nnguide_knn_k, knn_batch_size)\n",
    "nnguide_ood_performance = ood_performance_evaluation_imagenet(test_nnguide_scores, novelty_nnguide_scores, 'nnguide', display=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84c3b9d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Free the memory\n",
    "del correct_train_zs, test_zs, novelty_zs\n",
    "_ = gc.collect()\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c19dba6d",
   "metadata": {},
   "source": [
    "#### *LINe*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c6d0e45-c6d8-48f7-b682-bdb58a7bad47",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the shapley matrix\n",
    "shapley_contrib_matrix = None\n",
    "shapley_contrib_matrix_file_path = path_join(output_path, 'shapley_contrib_matrix'+pkl_ext)\n",
    "if file_or_folder_existence(shapley_contrib_matrix_file_path):\n",
    "    # Load the data\n",
    "    with open(shapley_contrib_matrix_file_path, 'rb') as f:\n",
    "        shapley_contrib_matrix = pickle.load(f)\n",
    "else:\n",
    "    shapley_contrib_matrix = shapley_score_evaluation_batch_GPU(correct_train_actLevels, final_linear_params, last_hidden_layerId, \n",
    "                                                                predict_class_ver=False,\n",
    "                                                                batch_size=1000, block_size=32)\n",
    "    # Save the data\n",
    "    with open(shapley_contrib_matrix_file_path, 'wb') as f:\n",
    "        pickle.dump(shapley_contrib_matrix, f, protocol=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc3f2f28-a178-463f-b741-ffe8f8faba67",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the weight contribution and pruning masks for different classes\n",
    "LINe_best_avg_FPR = 1\n",
    "LINe_best_p_n = None\n",
    "LINe_best_p_w = None\n",
    "LINe_registered_evaluation_results = {}\n",
    "test_LINe_scores = {}\n",
    "novelty_LINe_scores = {}\n",
    "test_LINe_scores_file_path = path_join(output_path, 'test_LINe_scores'+pkl_ext)\n",
    "novelty_LINe_scores_file_path = path_join(output_path, 'novelty_LINe_scores'+pkl_ext)\n",
    "if file_or_folder_existence(test_LINe_scores_file_path) and file_or_folder_existence(novelty_LINe_scores_file_path):\n",
    "    # Load the data\n",
    "    with open(test_LINe_scores_file_path, 'rb') as f:\n",
    "        test_LINe_scores = pickle.load(f)\n",
    "    with open(novelty_LINe_scores_file_path, 'rb') as f:\n",
    "        novelty_LINe_scores = pickle.load(f)\n",
    "    for test_p_n in p_n_list:\n",
    "        LINe_registered_evaluation_results[test_p_n] = {}\n",
    "        for test_p_w in tqdm(p_w_list, desc='Processed weight pruning rate'):\n",
    "            # Evaluate according to the metric in the DICE paper\n",
    "            novelty_metric_results = {}\n",
    "            for ood_type in novelty_LINe_scores[test_p_n][test_p_w]:\n",
    "                # To prevent the order change with deepcopy\n",
    "                novelty_metric_results[ood_type] = cal_metric(copy.deepcopy(test_LINe_scores[test_p_n][test_p_w]),\n",
    "                                                              copy.deepcopy(novelty_LINe_scores[test_p_n][test_p_w][ood_type]), method=None) \n",
    "            novelty_avg_FPR_curated = np.mean([novelty_metric_results[ood_type]['FPR'] for ood_type in novelty_metric_results \n",
    "                                            if ood_type != 'openimage'])\n",
    "            novelty_avg_AUROC_curated = np.mean([novelty_metric_results[ood_type]['AUROC'] for ood_type in novelty_metric_results \n",
    "                                                if ood_type != 'openimage'])\n",
    "            novelty_avg_FPR = np.mean([novelty_metric_results[ood_type]['FPR'] for ood_type in novelty_metric_results])\n",
    "            novelty_avg_AUROC = np.mean([novelty_metric_results[ood_type]['AUROC'] for ood_type in novelty_metric_results])\n",
    "            # Save the result\n",
    "            LINe_registered_evaluation_results[test_p_n][test_p_w] = {}\n",
    "            LINe_registered_evaluation_results[test_p_n][test_p_w]['Detail'] = novelty_metric_results\n",
    "            LINe_registered_evaluation_results[test_p_n][test_p_w]['FPR'] = novelty_avg_FPR\n",
    "            LINe_registered_evaluation_results[test_p_n][test_p_w]['AUROC'] = novelty_avg_AUROC\n",
    "            LINe_registered_evaluation_results[test_p_n][test_p_w]['FPR_curated'] = novelty_avg_FPR_curated\n",
    "            LINe_registered_evaluation_results[test_p_n][test_p_w]['AUROC_curated'] = novelty_avg_AUROC_curated\n",
    "            # Determine if it is the best result\n",
    "            if LINe_registered_evaluation_results[test_p_n][test_p_w]['FPR_curated'] < LINe_best_avg_FPR:\n",
    "                LINe_best_avg_FPR = LINe_registered_evaluation_results[test_p_n][test_p_w]['FPR_curated']\n",
    "                LINe_best_p_n = test_p_n\n",
    "                LINe_best_p_w = test_p_w\n",
    "else:\n",
    "    for test_p_n in p_n_list:\n",
    "        test_LINe_scores[test_p_n] = {}\n",
    "        novelty_LINe_scores[test_p_n] = {}\n",
    "        LINe_registered_evaluation_results[test_p_n] = {}\n",
    "        for test_p_w in tqdm(p_w_list, desc='Processed weight pruning rate'):\n",
    "            # Evaluate the pruning masks\n",
    "            LINe_weight_contributions = {}\n",
    "            LINe_feature_pruning_masks = {}\n",
    "            LINe_weight_pruning_masks = {}\n",
    "            for classId in class_list:\n",
    "                LINe_weight_contributions[classId] = weight_contribution(final_linear_params['weight'], shapley_contrib_matrix[classId])\n",
    "            for classId in class_list:\n",
    "                LINe_feature_pruning_masks[classId] = pruning_mask(shapley_contrib_matrix[classId], test_p_n)\n",
    "                LINe_weight_pruning_masks[classId] = pruning_mask(LINe_weight_contributions[classId], test_p_w)\n",
    "            # Get the test set logits\n",
    "            test_LINe_logits = get_logits_by_preds_torch_ver(test_vecs, test_preds,\n",
    "                                                        final_linear_params,\n",
    "                                                        LINe_feature_pruning_masks,\n",
    "                                                        LINe_weight_pruning_masks,\n",
    "                                                        display=False)\n",
    "            # Get the logits for the novelty ood sets\n",
    "            novelty_LINe_logits = {}\n",
    "            for ood_type in novelty_vecs:\n",
    "                novelty_LINe_logits[ood_type]  = get_logits_by_preds_torch_ver(novelty_vecs[ood_type], \n",
    "                                                                           novelty_preds[ood_type],\n",
    "                                                                           final_linear_params,\n",
    "                                                                           LINe_feature_pruning_masks,\n",
    "                                                                           LINe_weight_pruning_masks,\n",
    "                                                                           display=False)\n",
    "            # Free the memory for the masks\n",
    "            del LINe_weight_contributions, LINe_feature_pruning_masks, LINe_weight_pruning_masks\n",
    "            _ = gc.collect()\n",
    "            ## Get the scores\n",
    "            # Test set         \n",
    "            test_LINe_scores[test_p_n][test_p_w] = LINe_score_evaluation_logit_ver(test_LINe_logits)\n",
    "            # Novelty sets\n",
    "            novelty_LINe_scores[test_p_n][test_p_w] = {}\n",
    "            for ood_type in novelty_LINe_logits:\n",
    "                # Compute the scores\n",
    "                novelty_LINe_scores[test_p_n][test_p_w][ood_type] = LINe_score_evaluation_logit_ver(novelty_LINe_logits[ood_type])\n",
    "            # Evaluate according to the metric in the DICE paper\n",
    "            novelty_metric_results = {}\n",
    "            for ood_type in novelty_LINe_scores[test_p_n][test_p_w]:\n",
    "                # To prevent the order change with deepcopy\n",
    "                novelty_metric_results[ood_type] = cal_metric(copy.deepcopy(test_LINe_scores[test_p_n][test_p_w]),\n",
    "                                                              copy.deepcopy(novelty_LINe_scores[test_p_n][test_p_w][ood_type]), method=None) \n",
    "            novelty_avg_FPR_curated = np.mean([novelty_metric_results[ood_type]['FPR'] for ood_type in novelty_metric_results \n",
    "                                            if ood_type != 'openimage'])\n",
    "            novelty_avg_AUROC_curated = np.mean([novelty_metric_results[ood_type]['AUROC'] for ood_type in novelty_metric_results \n",
    "                                                if ood_type != 'openimage'])\n",
    "            novelty_avg_FPR = np.mean([novelty_metric_results[ood_type]['FPR'] for ood_type in novelty_metric_results])\n",
    "            novelty_avg_AUROC = np.mean([novelty_metric_results[ood_type]['AUROC'] for ood_type in novelty_metric_results])\n",
    "            # Save the result\n",
    "            LINe_registered_evaluation_results[test_p_n][test_p_w] = {}\n",
    "            LINe_registered_evaluation_results[test_p_n][test_p_w]['Detail'] = novelty_metric_results\n",
    "            LINe_registered_evaluation_results[test_p_n][test_p_w]['FPR'] = novelty_avg_FPR\n",
    "            LINe_registered_evaluation_results[test_p_n][test_p_w]['AUROC'] = novelty_avg_AUROC\n",
    "            LINe_registered_evaluation_results[test_p_n][test_p_w]['FPR_curated'] = novelty_avg_FPR_curated\n",
    "            LINe_registered_evaluation_results[test_p_n][test_p_w]['AUROC_curated'] = novelty_avg_AUROC_curated\n",
    "            # Determine if it is the best result\n",
    "            if LINe_registered_evaluation_results[test_p_n][test_p_w]['FPR_curated'] < LINe_best_avg_FPR:\n",
    "                LINe_best_avg_FPR = LINe_registered_evaluation_results[test_p_n][test_p_w]['FPR_curated']\n",
    "                LINe_best_p_n = test_p_n\n",
    "                LINe_best_p_w = test_p_w\n",
    "    # Save the data\n",
    "    with open(test_LINe_scores_file_path, 'wb') as f:\n",
    "        pickle.dump(test_LINe_scores, f, protocol=4)\n",
    "    with open(novelty_LINe_scores_file_path, 'wb') as f:\n",
    "        pickle.dump(novelty_LINe_scores, f, protocol=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1b3ed30",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display the results\n",
    "FPR_curated_results = []\n",
    "AUROC_curated_results = []\n",
    "for test_p_n in LINe_registered_evaluation_results:\n",
    "    FPR_curated_results.append([LINe_registered_evaluation_results[test_p_n][test_p_w]['FPR_curated'] for test_p_w in LINe_registered_evaluation_results[test_p_n]])\n",
    "    AUROC_curated_results.append([LINe_registered_evaluation_results[test_p_n][test_p_w]['AUROC_curated'] for test_p_w in LINe_registered_evaluation_results[test_p_n]])\n",
    "FPR_curated_df = pd.DataFrame(data=np.array(FPR_curated_results),index=p_n_list, columns=p_w_list)\n",
    "AUROC_curated_df = pd.DataFrame(data=np.array(AUROC_curated_results),index=p_n_list, columns=p_w_list)\n",
    "print('FPR_curated:')\n",
    "ip_display(FPR_curated_df)\n",
    "print('AUROC_curated:')\n",
    "ip_display(AUROC_curated_df)\n",
    "print('Best result:')\n",
    "print('The average performance: FPR95 (curated):', LINe_registered_evaluation_results[LINe_best_p_n][LINe_best_p_w]['FPR_curated'],\n",
    "      \"AUROC (curated):\", LINe_registered_evaluation_results[LINe_best_p_n][LINe_best_p_w]['AUROC_curated'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fee94cd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display the results\n",
    "FPR_results = []\n",
    "AUROC_results = []\n",
    "for test_p_n in LINe_registered_evaluation_results:\n",
    "    FPR_results.append([LINe_registered_evaluation_results[test_p_n][test_p_w]['FPR'] for test_p_w in LINe_registered_evaluation_results[test_p_n]])\n",
    "    AUROC_results.append([LINe_registered_evaluation_results[test_p_n][test_p_w]['AUROC'] for test_p_w in LINe_registered_evaluation_results[test_p_n]])\n",
    "FPR_df = pd.DataFrame(data=np.array(FPR_results),index=p_n_list, columns=p_w_list)\n",
    "AUROC_df = pd.DataFrame(data=np.array(AUROC_results),index=p_n_list, columns=p_w_list)\n",
    "print('FPR:')\n",
    "ip_display(FPR_df)\n",
    "print('AUROC:')\n",
    "ip_display(AUROC_df)\n",
    "print('Best result:')\n",
    "print('The average performance: FPR95:', LINe_registered_evaluation_results[LINe_best_p_n][LINe_best_p_w]['FPR'],\n",
    "      \"AUROC:\", LINe_registered_evaluation_results[LINe_best_p_n][LINe_best_p_w]['AUROC'])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
