{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['NOTEBOOK_MODE'] = '1'\n",
    "import cv2\n",
    "import math\n",
    "import dill\n",
    "import sys\n",
    "\n",
    "from efficientnet_pytorch import EfficientNet\n",
    "import torch as ch\n",
    "import torch.nn.functional as F\n",
    "from torchvision import transforms, utils\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "from sklearn.pipeline import make_pipeline\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.svm import SVC\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "from PIL import Image\n",
    "from scipy import stats\n",
    "from collections import defaultdict\n",
    "from tqdm import tqdm, tqdm_notebook\n",
    "import matplotlib.pyplot as plt\n",
    "from robustness import model_utils, datasets\n",
    "from robustness.tools.vis_tools import show_image_row, show_image_column\n",
    "from robustness.tools.label_maps import CLASS_DICT\n",
    "from user_constants import DATA_PATH_DICT\n",
    "from causal_testing_utils import *\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Constants\n",
    "DATA = 'ImageNet' # Choices: ['CIFAR', 'ImageNet', 'RestrictedImageNet']\n",
    "BATCH_SIZE = 100\n",
    "NUM_WORKERS = 8\n",
    "\n",
    "DATA_SHAPE = 32 if DATA == 'CIFAR' else 224 # Image size (fixed for dataset)\n",
    "REPRESENTATION_SIZE = 2048 # Size of representation vector (fixed for model)\n",
    "CLASSES = CLASS_DICT[DATA] # Class names for dataset\n",
    "NUM_CLASSES = len(CLASSES) - 1 \n",
    "GRAIN = 4 if DATA != 'CIFAR' else 1\n",
    "\n",
    "\n",
    "MODELS_PATH = \"path to your models\"\n",
    "\n",
    "# Load dataset\n",
    "dataset_function = getattr(datasets, DATA)\n",
    "dataset = dataset_function(DATA_PATH_DICT[DATA])\n",
    "train_loader, test_loader = dataset.make_loaders(workers=NUM_WORKERS, \n",
    "                                     batch_size=BATCH_SIZE, data_aug=False,\n",
    "                                     shuffle_train=False, shuffle_val=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_accuracy(logits, labels):\n",
    "    preds = np.argmax(logits, axis=1)\n",
    "    return compute_preds_accuracy(preds, labels)\n",
    "\n",
    "def compute_preds_accuracy(preds, labels):\n",
    "    num_correct = np.sum(preds==labels)\n",
    "    num_total = labels.size\n",
    "    acc = num_correct/num_total\n",
    "    return acc\n",
    "\n",
    "def compute_failure_labels(logits, labels):\n",
    "    preds = np.argmax(logits, axis=1)\n",
    "    success = (preds==labels)\n",
    "    failure = np.logical_not(success)\n",
    "    return failure\n",
    "\n",
    "def np_softmax(logits):\n",
    "    assert logits.ndim==2\n",
    "    max_logits = np.max(logits, axis=1, keepdims=True)\n",
    "    logits_n = logits - max_logits\n",
    "    \n",
    "    logits_n_exp = np.exp(logits_n)\n",
    "    logits_n_exp_sum = np.sum(logits_n_exp, axis=1, keepdims=True)\n",
    "    probs = logits_n_exp/logits_n_exp_sum\n",
    "    return probs\n",
    "\n",
    "def create_dict(unique, counts, dtype=int):\n",
    "    count_dict = defaultdict(dtype)\n",
    "    for u,c in zip(unique, counts): \n",
    "        count_dict[u] = count_dict[u] + c\n",
    "    return count_dict\n",
    "\n",
    "def compute_precision_recall(preds, labels):\n",
    "    num_true = np.sum(labels)\n",
    "    num_preds_true = np.sum(preds)\n",
    "    num_correct_preds_true = np.sum(np.logical_and(preds, labels))\n",
    "    \n",
    "    precision = num_correct_preds_true/num_preds_true\n",
    "    recall = num_correct_preds_true/num_true\n",
    "    return precision, recall\n",
    "\n",
    "def load_features(model_name, dataset):\n",
    "    assert dataset in ['train', 'test']\n",
    "    if model_name == 'ImageNetNat':\n",
    "        features = np.load('imagenet_features/nonrobust_' + dataset + '_features.npy')\n",
    "        logits = np.load('imagenet_features/nonrobust_' + dataset + '_logits.npy')\n",
    "        labels = np.load('imagenet_features/' + dataset + '_labels.npy')\n",
    "    elif model_name == 'robust_resnet50':\n",
    "        features = np.load('imagenet_features/robust_' + dataset + '_features.npy')\n",
    "        logits = np.load('imagenet_features/robust_' + dataset + '_logits.npy')\n",
    "        labels = np.load('imagenet_features/' + dataset + '_labels.npy')\n",
    "    else:\n",
    "        raise ValueError('Unidentified model name: ' + model_name)\n",
    "    return features, logits, labels\n",
    "\n",
    "def generate_dataset(model, data_loader):\n",
    "    features_all, logits_all, labels_all = [], [], []\n",
    "    total = 0\n",
    "    for _, (ims, labels) in enumerate(data_loader):\n",
    "        ims, labels = ims.cuda(), labels.cuda()\n",
    "        batch_size = ims.shape[0]\n",
    "        (logits, features), _ = model(ims, with_latent=True)\n",
    "        features = features.detach()\n",
    "        logits = logits.detach()\n",
    "        total += len(ims)\n",
    "        features_all.append(features.detach().cpu().numpy())\n",
    "        logits_all.append(logits.detach().cpu().numpy())\n",
    "        labels_all.append(labels.cpu().numpy())\n",
    "        print(total)\n",
    "        \n",
    "    features_all = np.concatenate(features_all, axis=0)\n",
    "    logits_all = np.concatenate(logits_all, axis=0)\n",
    "    labels_all = np.concatenate(labels_all, axis=0)\n",
    "    return features_all, logits_all, labels_all\n",
    "\n",
    "def masked_topk_hot(probs, k=2):\n",
    "    assert probs.ndim==2\n",
    "    assert k>=2\n",
    "    probs_topk_hot = np.zeros(probs.shape)\n",
    "    \n",
    "    argsort_indices = np.argsort(probs, axis=1)\n",
    "    sorted_probs = np.take_along_axis(probs, argsort_indices, axis=1) \n",
    "    \n",
    "    topk_indices = argsort_indices[:, -k:]\n",
    "    topk_probs = sorted_probs[:, -k:]\n",
    "    \n",
    "    np.put_along_axis(probs_topk_hot, topk_indices, topk_probs, axis=1)\n",
    "    probs_topk_hot = probs_topk_hot/np.sum(probs_topk_hot, axis=1, keepdims=True)    \n",
    "    return probs_topk_hot\n",
    "\n",
    "def reduce_probabilities(probs, n_probs):\n",
    "    probs_reduced = np.zeros((probs.shape[0], n_probs))\n",
    "    n_combine = int(probs.shape[1]/n_probs)\n",
    "    for i in range(n_probs-1):\n",
    "        probs_reduced[:, i] = np.sum(probs[:, i*n_combine: (i+1)*n_combine], axis=1)\n",
    "    probs_reduced[:, n_probs-1] = np.sum(probs[:, (n_probs-1)*n_combine: ], axis=1)\n",
    "    return probs_reduced\n",
    "\n",
    "def pca_reduce(train_data, test_data, n_components):\n",
    "    pca = PCA(n_components=n_components, whiten=False)\n",
    "    pca.fit(train_data)\n",
    "    train_data = pca.transform(train_data)\n",
    "    test_data = pca.transform(test_data)\n",
    "    return train_data, test_data\n",
    "\n",
    "def sample_data(train_data, train_labels, frac_samples):\n",
    "    train_data, _, train_labels, _ = train_test_split(train_data, \n",
    "                                            train_labels, \n",
    "                                            stratify=train_labels, \n",
    "                                            train_size=0.25)    \n",
    "    return train_data, train_labels\n",
    "\n",
    "def failure_statistics(logits, labels):\n",
    "    preds = np.argmax(logits, axis=1)\n",
    "    num_classes = logits.shape[1]\n",
    "    \n",
    "    pred_failures_arr = np.zeros(num_classes, dtype=np.long)\n",
    "    label_failures_arr = np.zeros(num_classes, dtype=np.long)\n",
    "    count_preds_arr = np.zeros(num_classes, dtype=np.long)\n",
    "    count_labels_arr = np.zeros(num_classes, dtype=np.long)\n",
    "    for i in range(num_classes):\n",
    "        pred_failures = np.sum((preds==i) & np.logical_not(preds==labels))\n",
    "        pred_failures_arr[i] = pred_failures\n",
    "        label_failures = np.sum((labels==i) & np.logical_not(preds==labels))\n",
    "        label_failures_arr[i] = label_failures\n",
    "        count_preds_arr[i] = np.sum(preds==i)\n",
    "        count_labels_arr[i] = np.sum(labels==i)\n",
    "        \n",
    "    dic = {'pred_failures': pred_failures_arr, \n",
    "           'label_failures': label_failures_arr,\n",
    "           'pred_counts': count_preds_arr,\n",
    "           'label_counts': count_labels_arr}\n",
    "    return dic\n",
    "\n",
    "def print_failure_stats(failure_dict, class_idx):\n",
    "    pred_failures_arr = failure_dict['pred_failures']\n",
    "    label_failures_arr = failure_dict['label_failures']\n",
    "    pred_counts_arr = failure_dict['pred_counts']\n",
    "    label_counts_arr = failure_dict['label_counts']\n",
    "\n",
    "    class_name = ', '.join(CLASSES[class_idx].split(',')[:2])\n",
    "    num_failures_pred = pred_failures_arr[class_idx]\n",
    "    num_failures_label = label_failures_arr[class_idx]\n",
    "    \n",
    "    num_preds = pred_counts_arr[class_idx]\n",
    "    num_labels = label_counts_arr[class_idx]\n",
    "\n",
    "    print('class_idx: {:d}, class_name: {:s}, num_failures (by pred): {:d}, num_failures (by pred): {:d}, num_preds: {:d}, num_labels: {:d}'.\n",
    "          format(class_idx, class_name, num_failures_pred, num_failures_label, num_preds, num_labels))\n",
    "    return class_name, num_failures_pred, num_failures_label, num_preds, num_labels\n",
    "    \n",
    "def print_failure_stats_label(failure_dict, class_idx):\n",
    "    label_failures_arr = failure_dict['label_failures']\n",
    "    pred_counts_arr = failure_dict['pred_counts']\n",
    "    label_counts_arr = failure_dict['label_counts']\n",
    "\n",
    "    class_name = CLASSES[class_idx]\n",
    "    num_failures = label_failures_arr[class_idx]\n",
    "    \n",
    "    num_preds = pred_counts_arr[class_idx]\n",
    "    num_labels = label_counts_arr[class_idx]\n",
    "    fraction = num_failures/num_labels\n",
    "\n",
    "    print('class_idx: {:d}, class_name: {:s}, num_failures: {:d}, num_preds: {:d}, num_labels: {:d}, fraction: {:.4f}'.\n",
    "          format(class_idx, class_name, num_failures, num_preds, num_labels, fraction))\n",
    "    return class_name, num_failures, num_preds, num_labels\n",
    "    \n",
    "def predicted_class_indices(logits, class_idx):\n",
    "    preds = np.argmax(logits, axis=1)\n",
    "    indices = np.nonzero(preds==class_idx)[0]\n",
    "    return indices\n",
    "\n",
    "def failure_data(indices, features, logits, labels):\n",
    "    features_indices = features[indices]\n",
    "    logits_indices = logits[indices]\n",
    "    preds_indices = np.argmax(logits_indices, axis=1)\n",
    "    labels_indices = labels[indices]\n",
    "    \n",
    "    failure_indices = np.logical_not(preds_indices==labels_indices)    \n",
    "    return features_indices, failure_indices\n",
    "\n",
    "def normalize(train_data, test_data):\n",
    "    scaler = StandardScaler()\n",
    "    scaler.fit(train_data)\n",
    "    train_data_n = scaler.transform(train_data)\n",
    "    test_data_n = scaler.transform(test_data)\n",
    "    return train_data_n, test_data_n\n",
    "\n",
    "def load_images(indices, data_loader):\n",
    "    img_list = []\n",
    "    label_list = []\n",
    "    for idx in indices:\n",
    "        img, label = data_loader.dataset.__getitem__(idx)\n",
    "        img_list.append(img)\n",
    "        label_list.append(label)\n",
    "    img_tensor = ch.stack(img_list, dim=0)\n",
    "    label_tensor = np.array(label_list)\n",
    "    return img_tensor, label_tensor\n",
    "\n",
    "def load_model(model_name):\n",
    "    if model_name == 'ImageNetNat.pt':\n",
    "        model_kwargs = {\n",
    "            'arch': 'resnet50',\n",
    "            'dataset': dataset,\n",
    "            'parallel': False\n",
    "        }\n",
    "    else:\n",
    "        model_kwargs = {\n",
    "            'arch': 'resnet50',\n",
    "            'dataset': dataset,\n",
    "            'resume_path': f'./models/' + model_name + '.pth',\n",
    "            'parallel': False\n",
    "        }\n",
    "    model, _ = model_utils.make_and_restore_model(**model_kwargs)\n",
    "    model.eval()\n",
    "    return model\n",
    "\n",
    "def load_model_fc(model_name):\n",
    "    if model_name == 'ImageNetNat.pt':\n",
    "        model_kwargs = {\n",
    "            'arch': 'resnet50',\n",
    "            'dataset': dataset,\n",
    "            'parallel': False\n",
    "        }\n",
    "    else:\n",
    "        model_kwargs = {\n",
    "            'arch': 'resnet50',\n",
    "            'dataset': dataset,\n",
    "            'resume_path': f'./models/' + model_name,\n",
    "            'parallel': False\n",
    "        }\n",
    "    model, _ = model_utils.make_and_restore_model(**model_kwargs)\n",
    "    model.eval()\n",
    "\n",
    "    W = (model.model.fc.weight).detach().cpu().numpy()\n",
    "    b = (model.model.fc.bias).detach().cpu().numpy()\n",
    "    return W, b"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Decision Tree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.tree import DecisionTreeClassifier, plot_tree, export_graphviz\n",
    "import graphviz \n",
    "from collections import defaultdict\n",
    "\n",
    "class CustomDecisionTreeClassifier(DecisionTreeClassifier):\n",
    "    def fit_tree(self, train_data, train_labels):\n",
    "        num_true = np.sum(train_labels)\n",
    "        num_false = np.sum(np.logical_not(train_labels))\n",
    "        if self.class_weight == \"balanced\":\n",
    "            self.float_class_weight = num_false/num_true\n",
    "        elif isinstance(self.class_weight, dict):\n",
    "            keys_list = list(self.class_weight.keys())\n",
    "            assert len(keys_list)==2\n",
    "            assert 0 in keys_list\n",
    "            assert 1 in keys_list\n",
    "            self.float_class_weight = self.class_weight[1]\n",
    "\n",
    "        self.fit(train_data, train_labels)\n",
    "        true_dict, false_dict = self.compute_TF_dict(train_data, train_labels)\n",
    "        self.train_true_dict = dict(true_dict)\n",
    "        self.train_false_dict = dict(false_dict)\n",
    "        \n",
    "        self._compute_parent()\n",
    "        \n",
    "        true_array = np.array(list(true_dict))\n",
    "        false_array = np.array(list(false_dict))\n",
    "        unique_leaf_ids = np.union1d(true_array, false_array)\n",
    "        self.leaf_ids = unique_leaf_ids\n",
    "        \n",
    "        true_leaves = []\n",
    "        \n",
    "        for leaf_id in unique_leaf_ids:\n",
    "            true_count = true_dict[leaf_id]\n",
    "            false_count = false_dict[leaf_id]\n",
    "            if true_count*self.float_class_weight > false_count:\n",
    "                true_leaves.append(leaf_id)\n",
    "        self.true_leaves = true_leaves\n",
    "        return self\n",
    "    \n",
    "    def _compute_parent(self):\n",
    "        n_nodes = self.tree_.node_count\n",
    "        children_left = self.tree_.children_left\n",
    "        children_right = self.tree_.children_right\n",
    "\n",
    "        self.parent = np.zeros(shape=n_nodes, dtype=np.int64)\n",
    "        stack = [0]  \n",
    "        while len(stack) > 0:\n",
    "            node_id = stack.pop()\n",
    "\n",
    "            child_left = children_left[node_id]\n",
    "            child_right = children_right[node_id]\n",
    "            if (child_left != child_right):\n",
    "                self.parent[child_left] = node_id\n",
    "                self.parent[child_right] = node_id\n",
    "                stack.append(child_left)\n",
    "                stack.append(child_right)\n",
    "    \n",
    "    def compute_leaf_data(self, data, leaf_id):\n",
    "        leaf_ids = self.apply(data)\n",
    "        return np.nonzero(leaf_ids==leaf_id)[0]\n",
    "        \n",
    "    def compute_TF_dict(self, data, labels):\n",
    "        leaf_ids = self.apply(data)\n",
    "        true_leaf_ids = leaf_ids[np.nonzero(labels)]\n",
    "        false_leaf_ids = leaf_ids[np.nonzero(np.logical_not(labels))]\n",
    "        \n",
    "        true_unique, _, true_unique_counts = np.unique(true_leaf_ids, return_index=True, return_counts=True)\n",
    "        true_dict = create_dict(true_unique, true_unique_counts)\n",
    "        false_unique, _, false_unique_counts = np.unique(false_leaf_ids, return_index=True, return_counts=True)\n",
    "        false_dict = create_dict(false_unique, false_unique_counts)\n",
    "        return true_dict, false_dict\n",
    "    \n",
    "    def compute_precision_recall(self, data, labels, compute_AP=True):\n",
    "        true_dict, false_dict = self.compute_TF_dict(data, labels)\n",
    "        total_true = np.sum(labels)\n",
    "        total_pred = 0\n",
    "        total = 0\n",
    "        for leaf_id in self.true_leaves:\n",
    "            true_count = true_dict[leaf_id]\n",
    "            false_count = false_dict[leaf_id]\n",
    "\n",
    "            total_pred += true_count\n",
    "            total += true_count + false_count\n",
    "            \n",
    "        precision = total_pred/total\n",
    "        recall = total_pred/total_true\n",
    "        \n",
    "        if compute_AP:\n",
    "            average_precision = self.compute_average_precision(data, labels)\n",
    "            return precision, recall, average_precision\n",
    "        else:\n",
    "            return precision, recall\n",
    "    \n",
    "    def compute_average_precision(self, data, labels):\n",
    "        num_true = np.sum(labels)\n",
    "        true_dict, false_dict = self.compute_TF_dict(data, labels)\n",
    "    \n",
    "        avg_precision = 0\n",
    "        for leaf_id in self.leaf_ids:\n",
    "            true_count = true_dict[leaf_id]\n",
    "            false_count = false_dict[leaf_id]\n",
    "            if true_count + false_count > 0:\n",
    "                curr_recall = true_count/num_true\n",
    "                curr_precision = true_count/(true_count + false_count)\n",
    "\n",
    "                avg_precision += curr_recall*curr_precision\n",
    "        return avg_precision\n",
    "\n",
    "    def compute_decision_path(self, leaf_id, important_features_indices):\n",
    "        assert leaf_id in self.leaf_ids\n",
    "\n",
    "        features_arr = self.tree_.feature\n",
    "        thresholds_arr = self.tree_.threshold\n",
    "        \n",
    "        children_left = self.tree_.children_left\n",
    "        children_right = self.tree_.children_right\n",
    "        path = []\n",
    "        curr_node = leaf_id\n",
    "        while curr_node > 0:\n",
    "            parent_node = self.parent[curr_node]\n",
    "            \n",
    "            is_left_child = (children_left[parent_node] == curr_node)\n",
    "            is_right_child = (children_right[parent_node] == curr_node)\n",
    "            assert (is_left_child ^ is_right_child)\n",
    "\n",
    "            if is_left_child:\n",
    "                direction = 'left'\n",
    "            else:\n",
    "                direction = 'right'\n",
    "            curr_node = parent_node\n",
    "            curr_feature = features_arr[curr_node]\n",
    "            curr_threshold = np.round(thresholds_arr[curr_node], 6)\n",
    "            curr_feature_original = important_features_indices[curr_feature]\n",
    "            path.insert(0, (curr_node, curr_feature_original, curr_threshold, direction))\n",
    "        return path\n",
    "    \n",
    "    def compute_average_precision_recall(self, data, labels):\n",
    "        total_failures = np.sum(labels)\n",
    "        \n",
    "        true_dict, false_dict = self.compute_TF_dict(data, labels)\n",
    "\n",
    "        n_nodes = self.tree_.node_count\n",
    "        children_left = self.tree_.children_left\n",
    "        children_right = self.tree_.children_right\n",
    "\n",
    "        precision_array = np.zeros(shape=n_nodes, dtype=float)\n",
    "        recall_array = np.zeros(shape=n_nodes, dtype=float)\n",
    "        \n",
    "        stack = [(0, True)]\n",
    "        while len(stack) > 0:\n",
    "            node_id, traverse = stack.pop()\n",
    "            child_left = children_left[node_id]\n",
    "            child_right = children_right[node_id]\n",
    "\n",
    "            if traverse:\n",
    "                if (child_left != child_right):\n",
    "                    stack.append((node_id, False))\n",
    "                    stack.append((child_left, True))\n",
    "                    stack.append((child_right, True))\n",
    "                else:\n",
    "                    num_true_in_node = true_dict[node_id]\n",
    "                    num_false_in_node = false_dict[node_id]\n",
    "                    num_total_in_node = num_true_in_node + num_false_in_node\n",
    "\n",
    "                    if num_total_in_node > 0:\n",
    "                        precision = (num_true_in_node/num_total_in_node)\n",
    "                    else:\n",
    "                        precision = 0.\n",
    "                    recall = (num_true_in_node/total_failures)\n",
    "                    recall_array[node_id] = recall\n",
    "                    precision_array[node_id] = precision\n",
    "            else:\n",
    "                child_left_p = precision_array[child_left]\n",
    "                child_right_p = precision_array[child_right]\n",
    "                \n",
    "                child_left_r = recall_array[child_left]\n",
    "                child_right_r = recall_array[child_right]\n",
    "                \n",
    "                child_p = child_left_p*child_left_r + child_right_p*child_right_r\n",
    "                child_r = child_left_r + child_right_r\n",
    "\n",
    "                if child_r > 0:\n",
    "                    precision_array[node_id] = child_p/child_r\n",
    "                else:\n",
    "                    precision_array[node_id] = 0.\n",
    "                recall_array[node_id] = child_r\n",
    "\n",
    "        return precision_array, recall_array"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def feature_attack_new(model, init_images, feature_indices, small_eps=3, \n",
    "                       large_eps=500, step_size=0.5, iterations=20):\n",
    "    init_images = init_images.cuda()\n",
    "    seed_images = ch.clone(init_images).cuda()\n",
    "    \n",
    "    batch_size = seed_images.shape[0]\n",
    "    for i in range(iterations+1):\n",
    "        seed_images.requires_grad_()\n",
    "\n",
    "        (_, features), _ = model(seed_images, with_latent=True)\n",
    "        features_select = features[ch.arange(batch_size), feature_indices]\n",
    "        if i==iterations:\n",
    "            seed_images = seed_images.detach()\n",
    "            break\n",
    "            \n",
    "        adv_loss = features_select.sum()\n",
    "        grads = ch.autograd.grad(adv_loss, [seed_images])[0]\n",
    "        seed_images = grad_step(seed_images, grads, step_size)\n",
    "    \n",
    "        diff_images = seed_images - init_images\n",
    "        diff_images_flat = diff_images.view(batch_size, -1)\n",
    "        diff_images_norm = ch.norm(diff_images_flat, dim=1, keepdim=True)\n",
    "        \n",
    "        diff_images_unit_flat = diff_images_flat/diff_images_norm\n",
    "        diff_images_flat_scaled = small_eps * diff_images_unit_flat\n",
    "        diff_images_new = ((diff_images_norm < small_eps) * (diff_images_flat)) + ((diff_images_norm >= small_eps) * (diff_images_flat_scaled))\n",
    "        diff_images_new = diff_images_new.view_as(diff_images)\n",
    "        seed_images = init_images + diff_images_new.view_as(diff_images_new)\n",
    "        \n",
    "        seed_images = ch.clamp(seed_images, min=0., max=1.)\n",
    "    \n",
    "    diff_images_new = diff_images_new.mean(dim=1)\n",
    "    diff_images_flat = diff_images_new.view(batch_size, -1) \n",
    "    diff_images_min, _ = ch.min(diff_images_flat, dim=1, keepdim=True)\n",
    "    diff_images_max, _ = ch.max(diff_images_flat, dim=1, keepdim=True)\n",
    "    gradcam_maps_flat = (diff_images_flat - diff_images_min)/(diff_images_max - diff_images_min)\n",
    "    gradcam_maps = gradcam_maps_flat.view_as(diff_images_new)\n",
    "\n",
    "    return gradcam_maps.detach().cpu().numpy(), features_select\n",
    "\n",
    "def compute_feature_maps(images, model, layer_name='layer4'):\n",
    "    images = images.cuda()\n",
    "    normalizer_module = model._modules['normalizer']\n",
    "    feature_module = model._modules['model']\n",
    "    x = normalizer_module(images)\n",
    "    for name, module in feature_module._modules.items():\n",
    "        x = module(x)\n",
    "        if name == layer_name:\n",
    "            break\n",
    "    return x\n",
    "\n",
    "# def feature_removal(model, seed_images, attack_features, k=5, \n",
    "#                     eps=500, step_size=0.5, iterations=1000):\n",
    "#     seed_images = seed_images.cuda()\n",
    "    \n",
    "#     batch_size = seed_images.shape[0]\n",
    "#     (_, init_features), _ = model(seed_images, with_latent=True)\n",
    "#     init_features = init_features.detach()\n",
    "#     diff_features = attack_features - init_features\n",
    "#     topk_values, topk_indices = ch.topk(diff_features.mean(dim=0), k)\n",
    "# #     print(topk_features)\n",
    "    \n",
    "#     select_init_features = init_features[:, topk_indices]\n",
    "# #     print(select_init_features.shape)\n",
    "    \n",
    "#     init_features[:, topk_indices] = 0.\n",
    "\n",
    "#     for i in range(iterations + 1):\n",
    "#         seed_images.requires_grad_()\n",
    "\n",
    "#         (_, features), _ = model(seed_images, with_latent=True)\n",
    "#         if i==iterations:\n",
    "#             seed_images = seed_images.detach()\n",
    "#             features_select = features[ch.arange(batch_size), topk_indices[0]]\n",
    "#             break\n",
    "\n",
    "# #         print(features.shape, init_features.shape)\n",
    "#         adv_loss = ((features - init_features)*(features - init_features)).sum(dim=1).mean(dim=0)\n",
    "#         grads = ch.autograd.grad(adv_loss, [seed_images])[0]\n",
    "#         seed_images = grad_step(seed_images, -grads, step_size)\n",
    "#         seed_images = ch.clamp(seed_images, min=0., max=1.)\n",
    "#     return seed_images, features_select\n",
    "\n",
    "def feature_removal(model, seed_images, gradcam_maps, feature_index):\n",
    "    seed_images = seed_images.detach().cpu().numpy()\n",
    "    gradcam_maps = gradcam_maps.detach().cpu().numpy()\n",
    "    \n",
    "    img_n_list = []\n",
    "    for (img, mask) in zip(seed_images, gradcam_maps):\n",
    "        \n",
    "        mask = np.uint8(255 * mask)\n",
    "        blur_mask = mask\n",
    "        thres, th_mask = cv2.threshold(blur_mask, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)\n",
    "        \n",
    "        thres = int(0.8 * thres)\n",
    "        (_, th_mask) = cv2.threshold(blur_mask, thres, 255, cv2.THRESH_BINARY)\n",
    "#         print(thres)\n",
    "        \n",
    "        th_mask = np.expand_dims(np.float32(th_mask)/255, axis=2)\n",
    "\n",
    "#         print(mask.min().item(), mask.mean().item(), mask.max().item())\n",
    "#         th_mask = np.float32(mask > 0.5)\n",
    "#         th_mask = np.expand_dims(np.float32(th_mask), axis=2)\n",
    "    \n",
    "        img_n = img * th_mask + 0.5 * (1 - th_mask)\n",
    "        img_n_list.append(img_n)\n",
    "    \n",
    "    img_ns = np.stack(img_n_list, axis=0)\n",
    "    img_ns = ch.from_numpy(img_ns).permute(0, 3, 1, 2).cuda()\n",
    "#     print(img_ns.shape)\n",
    "    \n",
    "    (_, features), _ = model(img_ns, with_latent=True)\n",
    "    features_select = features[:, feature_index]\n",
    "#     print(features_select.shape)\n",
    "#     print(features_select)\n",
    "    return img_ns, features_select\n",
    "\n",
    "def display_images(decision_path, precision_array, recall_array, image_indices, \n",
    "                   data_loader, model, features, num_images=6):\n",
    "    img_list = []\n",
    "    for node in decision_path:\n",
    "        node_id, feature_id, feature_threshold, direction = node\n",
    "        node_precision = precision_array[node_id]\n",
    "        node_recall = recall_array[node_id]\n",
    "\n",
    "        if direction == 'left':\n",
    "            print('\\n****************************** Feature[{:d}] < {:.6f} (left branching, lower feature evidence) **************************'.format(feature_id, feature_threshold))\n",
    "        else:\n",
    "            print('\\n****************************** Feature[{:d}] > {:.6f} (right branching, higher feature evidence) ************************'.format(feature_id, feature_threshold))\n",
    "            \n",
    "        print('*************************************** node_error_rate: {:.4f}, node_coverage: {:.4f} ************************************'.format(node_precision, node_recall))\n",
    "        \n",
    "        sorted_indices = np.argsort(features[:, feature_id])\n",
    "        \n",
    "        indices_high = sorted_indices[-num_images:]\n",
    "        image_indices_high = image_indices[indices_high]\n",
    "        images_highest, labels_highest = load_images(image_indices_high, data_loader)\n",
    "\n",
    "        images_captions = []\n",
    "        heatmaps_captions = []\n",
    "        for i,index in enumerate(indices_high):\n",
    "            label_index = labels_highest[i]\n",
    "            label_string = CLASSES[label_index].split(',')[0]\n",
    "\n",
    "            images_captions.append(label_string)\n",
    "            heatmaps_captions.append('most, {:.2f}'.format(features[index, feature_id]))\n",
    "            \n",
    "            \n",
    "        gradcam_maps = compute_gradcam(model, images_highest, feature_id, layer_name='layer4')                \n",
    "        images_heatmaps = compute_heatmaps(images_highest.permute(0, 2, 3, 1), gradcam_maps)\n",
    "        \n",
    "#         images_attack, features_all = feature_attack(model, images_highest, feature_id)\n",
    "#         features_attack = features_all[ch.arange(num_images), feature_id]\n",
    "        \n",
    "#         captions_attack = []\n",
    "#         for index in range(num_images):\n",
    "#             captions_attack.append('most, {:.2f}'.format(features_attack[index]))\n",
    "        \n",
    "        images_removed, features_removed = feature_removal(model, images_highest.permute(0, 2, 3, 1), \n",
    "                                                           gradcam_maps, feature_id)\n",
    "        \n",
    "        captions_removal = []\n",
    "        for index in range(num_images):\n",
    "            captions_removal.append('most, {:.2f}'.format(features_removed[index]))\n",
    "            \n",
    "        show_image_row([images_highest.cpu()], ['images'], tlist=[images_captions], fontsize=18)\n",
    "\n",
    "        show_image_row([images_removed.cpu()], ['feature removal'], \n",
    "                       tlist=[captions_removal], fontsize=18)\n",
    "        \n",
    "        show_image_row([images_heatmaps.cpu()], ['heatmaps'], \n",
    "                       tlist=[heatmaps_captions], fontsize=18)\n",
    "\n",
    "#         show_image_row([images_attack.cpu()], ['feature attack'], \n",
    "#                        tlist=[captions_attack], fontsize=18)\n",
    "\n",
    "        \n",
    "def display_failures(image_indices, leaf_failure_indices, data_loader, num_images=6, num_rows=1):\n",
    "    if len(leaf_failure_indices) > num_images*num_rows:\n",
    "        replace = False\n",
    "    else:\n",
    "        replace = True\n",
    "    leaf_select_failures = np.random.choice(leaf_failure_indices, num_images*num_rows, replace=replace)\n",
    "    \n",
    "    image_indices_failures = image_indices[leaf_select_failures]\n",
    "\n",
    "    print('\\n****************************************************** Failure samples ****************************************************')\n",
    "    start = 0\n",
    "    for row in range(num_rows):\n",
    "        image_indices_select = image_indices_failures[start: start+num_images]\n",
    "        images_failures, labels_failures = load_images(image_indices_select, data_loader)\n",
    "        full_images_captions = []\n",
    "        images_captions = []\n",
    "        for i in range(len(images_failures)):\n",
    "            label_index = labels_failures[i]\n",
    "            label_string = CLASSES[label_index].split(',')[0]\n",
    "            if len(label_string) > 12:\n",
    "                label_string = str(label_index)\n",
    "            images_captions.append(label_string)\n",
    "            full_images_captions.append(CLASSES[label_index])\n",
    "        print(full_images_captions)\n",
    "        show_image_row([images_failures.cpu()], ['failures'], tlist=[images_captions], fontsize=18)\n",
    "        start = start + num_images"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Mean feature selection"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using a non-robust model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.feature_selection import mutual_info_classif\n",
    "\n",
    "def mutual_info_select_features(train_features_class, train_failure_class, num_features=20):\n",
    "    mi = mutual_info_classif(train_features_class, train_failure_class, random_state = 0)\n",
    "    important_features_indices = np.argsort(mi)[-num_features:]\n",
    "    important_features_values = mi[important_features_indices]\n",
    "    return important_features_indices, important_features_values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_if_not_exists(folder_path):\n",
    "    if not(os.path.exists(folder_path)):\n",
    "        os.makedirs(folder_path)\n",
    "        \n",
    "def class_failure_data(robust_preds, nonrobust_preds, labels):\n",
    "    dic = {\n",
    "        'class_idx': [],\n",
    "        'class_name': [],\n",
    "        'num_labels': [],\n",
    "        'num_robust_preds': [],\n",
    "        'num_nonrobust_preds': [],\n",
    "        'num_robust_failures_pred': [],\n",
    "        'num_robust_failures_label': [],\n",
    "        'num_nonrobust_failures_pred': [],\n",
    "        'num_nonrobust_failures_label': []\n",
    "    }\n",
    "    for class_idx in range(1000):\n",
    "        class_name = CLASSES[class_idx]\n",
    "        \n",
    "        num_labels = np.sum(labels == class_idx)\n",
    "        num_robust_preds = np.sum(robust_preds == class_idx)\n",
    "        num_nonrobust_preds = np.sum(nonrobust_preds == class_idx)\n",
    "        \n",
    "        num_robust_failures_pred = np.sum(np.logical_not(labels == class_idx) & (robust_preds == class_idx))\n",
    "        num_robust_failures_label = np.sum((labels == class_idx) & np.logical_not(robust_preds == class_idx))\n",
    "        \n",
    "        num_nonrobust_failures_pred = np.sum(np.logical_not(labels == class_idx) & (nonrobust_preds == class_idx))\n",
    "        num_nonrobust_failures_label = np.sum((labels == class_idx) & np.logical_not(nonrobust_preds == class_idx))\n",
    "        \n",
    "        dic['class_idx'].append(class_idx)\n",
    "        dic['class_name'].append(class_name)\n",
    "        dic['num_labels'].append(num_labels)\n",
    "        dic['num_robust_preds'].append(num_robust_preds)\n",
    "        dic['num_nonrobust_preds'].append(num_nonrobust_preds)\n",
    "        dic['num_robust_failures_pred'].append(num_robust_failures_pred)\n",
    "        dic['num_robust_failures_label'].append(num_robust_failures_label)\n",
    "        dic['num_nonrobust_failures_pred'].append(num_nonrobust_failures_pred)\n",
    "        dic['num_nonrobust_failures_label'].append(num_nonrobust_failures_label)\n",
    "    return dic\n",
    "\n",
    "def feature_importance_data(feature_indices, feature_values):\n",
    "    dic = {\n",
    "        'feature_index': [],\n",
    "        'feature_importance': []\n",
    "    }\n",
    "    for feature_idx, feature_imp in zip(feature_indices, feature_values):        \n",
    "        dic['feature_index'].append(feature_idx)\n",
    "        dic['feature_importance'].append(feature_imp)\n",
    "    return dic\n",
    "\n",
    "def mean_select_features(robust_features, class_idx, \n",
    "                         robust_model, num_features=20):\n",
    "    W = (robust_model.model.fc.weight).detach().cpu().numpy()\n",
    "    W_class = W[class_idx:class_idx+1, :]\n",
    "    imp_values = np.mean(robust_features * W_class, axis=0)\n",
    "\n",
    "    important_features_indices = np.argsort(-imp_values)[:num_features]\n",
    "    important_features_values = imp_values[important_features_indices]\n",
    "    return important_features_indices, important_features_values\n",
    "\n",
    "def load_images_filenames(indices, data_loader):\n",
    "    images_list = []\n",
    "    filenames_list = []\n",
    "    labels_list = []\n",
    "    for idx in indices:\n",
    "        filename, _ = data_loader.dataset.samples[idx]\n",
    "        image, label = data_loader.dataset.__getitem__(idx)\n",
    "        images_list.append(image)\n",
    "        filenames_list.append(filename)\n",
    "        labels_list.append(label)\n",
    "    images_tensor = ch.stack(images_list, dim=0)\n",
    "    filenames_arr = np.array(filenames_list)\n",
    "    labels_arr = np.array(labels_list)\n",
    "    return images_tensor, filenames_arr, labels_arr\n",
    "\n",
    "def compute_activating_indices(feature_id, features, image_indices, num_images=10):\n",
    "    sorted_indices = np.argsort(-features[:, feature_id])\n",
    "    \n",
    "    indices_high = sorted_indices[:num_images]\n",
    "    image_indices_high = image_indices[indices_high]\n",
    "    features_high = features[indices_high, feature_id]\n",
    "    return image_indices_high, features_high\n",
    "\n",
    "def compute_gradcam(model, images, feature_index, layer_name='layer4'):\n",
    "    b_size = images.shape[0]\n",
    "    feature_maps = compute_feature_maps(images, model, layer_name=layer_name)\n",
    "    gradcam_maps = (feature_maps[:, feature_index, :, :]).detach()\n",
    "    gradcam_maps_flat = gradcam_maps.view(b_size, -1) \n",
    "    gradcam_maps_max, _ = ch.max(gradcam_maps_flat, dim=1, keepdim=True)\n",
    "    gradcam_maps_flat = gradcam_maps_flat/gradcam_maps_max\n",
    "    gradcam_maps = gradcam_maps_flat.view_as(gradcam_maps)\n",
    "\n",
    "    gradcam_maps_resized = []\n",
    "    for gradcam_map in gradcam_maps:\n",
    "        gradcam_map = gradcam_map.cpu().numpy()\n",
    "        gradcam_map = cv2.resize(gradcam_map, images.shape[2:])\n",
    "        gradcam_maps_resized.append(gradcam_map)\n",
    "    gradcam_maps = np.stack(gradcam_maps_resized, axis=0)\n",
    "    gradcam_maps = ch.from_numpy(1-gradcam_maps)\n",
    "    return gradcam_maps\n",
    "\n",
    "def compute_heatmaps(imgs, masks):\n",
    "    heatmaps = []\n",
    "    for (img, mask) in zip(imgs, masks):\n",
    "        heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)\n",
    "        heatmap = np.float32(heatmap) / 255\n",
    "        heatmap = heatmap + np.float32(img)\n",
    "        heatmap = heatmap / np.max(heatmap)\n",
    "        heatmaps.append(heatmap)\n",
    "    heatmaps = np.stack(heatmaps, axis=0)\n",
    "    heatmaps = ch.from_numpy(heatmaps).permute(0, 3, 1, 2)\n",
    "    return heatmaps\n",
    "\n",
    "def grad_step(adv_inputs, grad, step_size):\n",
    "    l = len(adv_inputs.shape) - 1\n",
    "    grad_norm = ch.norm(grad.view(grad.shape[0], -1), dim=1).view(-1, *([1]*l))\n",
    "    scaled_grad = grad / (grad_norm + 1e-10)\n",
    "    return adv_inputs + scaled_grad * step_size\n",
    "\n",
    "def feature_attack(model, seed_images, feature_indices, eps=500, \n",
    "                   step_size=1, iterations=1000):\n",
    "    seed_images = seed_images.cuda()\n",
    "    batch_size = seed_images.shape[0]\n",
    "    for i in range(iterations+1):\n",
    "        seed_images.requires_grad_()\n",
    "\n",
    "        (_, features), _ = model(seed_images, with_latent=True)\n",
    "        features_select = features[ch.arange(batch_size), feature_indices]\n",
    "        \n",
    "        if i==iterations:\n",
    "            seed_images = seed_images.detach()\n",
    "            features_select = features_select.detach().cpu().numpy()\n",
    "            break\n",
    "            \n",
    "        adv_loss = features_select.sum()\n",
    "        grads = ch.autograd.grad(adv_loss, [seed_images])[0]\n",
    "\n",
    "        seed_images = grad_step(seed_images.detach(), grads, step_size)\n",
    "        seed_images = ch.clamp(seed_images, min=0., max=1.)\n",
    "    return seed_images, features_select\n",
    "\n",
    "def save_image_data(feature_path, images, images_gradcams, \n",
    "                    images_heatmaps, images_attack, metadata):\n",
    "    len_ims = len(images)\n",
    "    assert (len_ims==len(images_gradcams)) and (len_ims==len(images_heatmaps)) and (len_ims==len(images_attack))\n",
    "    \n",
    "    metadata_df = pd.DataFrame.from_dict(metadata)\n",
    "    metadata_path = os.path.join(feature_path, 'metadata.csv')\n",
    "    metadata_df.to_csv(metadata_path)\n",
    "    \n",
    "    images_path = os.path.join(feature_path, 'images')\n",
    "    make_if_not_exists(images_path)\n",
    "\n",
    "    gradcams_path = os.path.join(feature_path, 'gradcams')\n",
    "    make_if_not_exists(gradcams_path)\n",
    "    \n",
    "    heatmaps_path = os.path.join(feature_path, 'heatmaps')\n",
    "    make_if_not_exists(heatmaps_path)\n",
    "    \n",
    "    images_attack_path = os.path.join(feature_path, 'feature_attack')\n",
    "    make_if_not_exists(images_attack_path)\n",
    "\n",
    "    for i in range(len_ims):\n",
    "        image = np.uint8(255*images[i].permute(1, 2, 0).cpu().numpy())\n",
    "        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)\n",
    "\n",
    "        gradcam = np.uint8(255*images_gradcams[i].cpu().numpy())\n",
    "        \n",
    "        heatmap = np.uint8(255*images_heatmaps[i].permute(1, 2, 0).cpu().numpy())\n",
    "        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_RGB2BGR)\n",
    "        \n",
    "        image_attack = np.uint8(255*images_attack[i].permute(1, 2, 0).cpu().numpy())\n",
    "        image_attack = cv2.cvtColor(image_attack, cv2.COLOR_RGB2BGR)\n",
    "        \n",
    "        image_name = os.path.join(images_path, str(i) + '.jpg')\n",
    "        cv2.imwrite(image_name, image)\n",
    "        \n",
    "        gradcam_name = os.path.join(gradcams_path, str(i) + '.jpg')\n",
    "        cv2.imwrite(gradcam_name, gradcam)\n",
    "        \n",
    "        heatmap_name = os.path.join(heatmaps_path, str(i) + '.jpg')\n",
    "        cv2.imwrite(heatmap_name, heatmap)\n",
    "        \n",
    "        image_attack_name = os.path.join(images_attack_path, str(i)  + '.jpg')\n",
    "        cv2.imwrite(image_attack_name, image_attack)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_features_histograms(features_all, features_subset, class_index, binwidth = 0.16):\n",
    "    f = plt.figure(figsize=(15, 3))\n",
    "    ax1 = f.add_subplot(131)\n",
    "    ax2 = f.add_subplot(132)\n",
    "    ax3 = f.add_subplot(133)\n",
    "    \n",
    "    sns.histplot(features_all, ax=ax1, binwidth=binwidth, color=\"skyblue\", \n",
    "                 label=\"histogram of all features\")\n",
    "    ax1.set_xlabel(\"Feature value\")\n",
    "    ax1.set_ylabel(\"Number of images\")\n",
    "    \n",
    "    sns.histplot(features_subset, ax=ax2, binwidth=binwidth, color=\"skyblue\", \n",
    "                 label=\"histogram with prediction: {:d}\".format(class_index))\n",
    "    ax2.set_xlabel(\"Feature value\")\n",
    "    ax2.set_ylabel(\"Number of images\")\n",
    "    \n",
    "    sns.histplot(features_subset, cumulative=True, ax=ax3, binwidth=binwidth, color=\"skyblue\", \n",
    "                 label=\"cumulative histogram with prediction: {:d}\".format(class_index))\n",
    "    ax3.set_xlabel(\"Feature value\")\n",
    "    ax3.set_ylabel(\"Number of images\")\n",
    "\n",
    "    \n",
    "    plt.show()\n",
    "    \n",
    "def visualize_feature_selected(indices_select, feature_index, image_indices, data_loader, \n",
    "                               model, features):\n",
    "    img_list = []\n",
    "    image_indices_select = image_indices[indices_select]\n",
    "    images_select, labels_select = load_images(image_indices_select, data_loader)\n",
    "        \n",
    "    images_captions = []\n",
    "    heatmaps_captions = []\n",
    "    for i,index in enumerate(indices_select):\n",
    "        label_index = labels_select[i]\n",
    "        label_string = CLASSES[label_index].split(',')[0]\n",
    "\n",
    "        images_captions.append(label_string)\n",
    "        heatmaps_captions.append('most, {:.2f}'.format(features[index, feature_index]))\n",
    "        \n",
    "    gradcam_maps = compute_gradcam(model, images_select, feature_index, layer_name='layer4')\n",
    "    images_heatmaps = compute_heatmaps(images_select.permute(0, 2, 3, 1), gradcam_maps)\n",
    "\n",
    "    images_attack, features_attack = feature_attack(model, images_select, feature_index, \n",
    "                                                    step_size=40, iterations=25)\n",
    "    attack_captions = []\n",
    "    for index, feature_val in enumerate(features_attack):\n",
    "        attack_captions.append('most, {:.2f}'.format(feature_val))\n",
    "        \n",
    "    show_image_row([images_select.cpu()], ['images'], \n",
    "                   tlist=[images_captions], fontsize=18)\n",
    "\n",
    "    show_image_row([images_heatmaps.cpu()], ['heatmaps'], \n",
    "                   tlist=[heatmaps_captions], fontsize=18)\n",
    "\n",
    "    show_image_row([images_attack.cpu()], ['feature attack'], \n",
    "                   tlist=[attack_captions], fontsize=18)\n",
    "    \n",
    "def visualize_feature_heatmap(image_indices_select, feature_index, data_loader, model, \n",
    "                              features, otsu_fraction=0.6, add_noise_std=0.25, \n",
    "                              replace_noise_std=0.01):\n",
    "    img_list = []\n",
    "    images_select, labels_select = load_images(image_indices_select, data_loader)\n",
    "        \n",
    "    heatmaps_captions = []\n",
    "    for i,index in enumerate(image_indices_select):\n",
    "        heatmaps_captions.append('most, {:.2f}'.format(features[index, feature_index]))\n",
    "\n",
    "    gradcam_maps = compute_gradcam(model, images_select, feature_index, layer_name='layer4')\n",
    "    images_heatmaps = compute_heatmaps(images_select.permute(0, 2, 3, 1), gradcam_maps)\n",
    "\n",
    "    images_add_noise = []\n",
    "    for image, gradcam_mask in zip(images_select, gradcam_maps):\n",
    "        image = np.transpose(image, (1, 2, 0))\n",
    "        image_add_noise = add_gaussian_noise(image, gradcam_mask, otsu_fraction=otsu_fraction, \n",
    "                                            noise_std=add_noise_std)\n",
    "        images_add_noise.append(ch.from_numpy(np.transpose(image_add_noise, (2, 0, 1))))\n",
    "    images_add_noise = ch.stack(images_add_noise, dim=0)\n",
    "\n",
    "\n",
    "    images_replace_noise = []\n",
    "    for image, gradcam_mask in zip(images_select, gradcam_maps):\n",
    "        image = np.transpose(image, (1, 2, 0))\n",
    "        image_replace_noise = replace_gaussian_noise(image, gradcam_mask, otsu_fraction=otsu_fraction, \n",
    "                                                     noise_std=replace_noise_std)\n",
    "        images_replace_noise.append(ch.from_numpy(np.transpose(image_replace_noise, (2, 0, 1))))\n",
    "    images_replace_noise = ch.stack(images_replace_noise, dim=0)\n",
    "\n",
    "\n",
    "    \n",
    "#     images_attack, features_attack = feature_attack(model, images_select, feature_index, \n",
    "#                                                     step_size=40, iterations=25)\n",
    "#     attack_captions = []\n",
    "#     for index, feature_val in enumerate(features_attack):\n",
    "#         attack_captions.append('most, {:.2f}'.format(feature_val))\n",
    "        \n",
    "#     show_image_row([images_select.cpu()], ['images'], \n",
    "#                    tlist=[images_captions], fontsize=18)\n",
    "\n",
    "    show_image_row([images_heatmaps.cpu()], ['heatmaps'], \n",
    "                   tlist=[heatmaps_captions], fontsize=18)\n",
    "\n",
    "    show_image_row([images_add_noise.cpu()], ['add'], \n",
    "                   tlist=[heatmaps_captions], fontsize=18)\n",
    "\n",
    "    show_image_row([images_replace_noise.cpu()], ['replace'], \n",
    "                   tlist=[heatmaps_captions], fontsize=18)\n",
    "\n",
    "\n",
    "#     show_image_row([images_attack.cpu()], ['feature attack'], \n",
    "#                    tlist=[attack_captions], fontsize=18)\n",
    "    \n",
    "def visualize_feature(feature_index, image_indices, data_loader, model, features, num_images=6, num_subset=50):\n",
    "    img_list = []\n",
    "    sorted_indices = np.argsort(-features[:, feature_index])    \n",
    "    \n",
    "    indices_high = sorted_indices[:num_images]\n",
    "    image_indices_high = image_indices[indices_high]\n",
    "    images_highest, labels_highest = load_images(image_indices_high, data_loader)\n",
    "        \n",
    "    images_captions = []\n",
    "    heatmaps_captions = []\n",
    "    for i,index in enumerate(indices_high):\n",
    "        label_index = labels_highest[i]\n",
    "        label_string = CLASSES[label_index].split(',')[0]\n",
    "\n",
    "        images_captions.append(label_string)\n",
    "        heatmaps_captions.append('most, {:.2f}'.format(features[index, feature_index]))\n",
    "        \n",
    "    gradcam_maps = compute_gradcam(model, images_highest, feature_index, layer_name='layer4')\n",
    "    images_heatmaps = compute_heatmaps(images_highest.permute(0, 2, 3, 1), gradcam_maps)\n",
    "\n",
    "    images_attack, features_attack = feature_attack(model, images_highest, feature_index, \n",
    "                                                    step_size=40, iterations=25)\n",
    "    attack_captions = []\n",
    "    for index in range(num_images):\n",
    "        attack_captions.append('most, {:.2f}'.format(features_attack[index]))\n",
    "    \n",
    "    show_image_row([images_highest.cpu()], ['images'], \n",
    "                   tlist=[images_captions], fontsize=18)\n",
    "\n",
    "    show_image_row([images_heatmaps.cpu()], ['heatmaps'], \n",
    "                   tlist=[heatmaps_captions], fontsize=18)\n",
    "\n",
    "    show_image_row([images_attack.cpu()], ['feature attack'], \n",
    "                   tlist=[attack_captions], fontsize=18)\n",
    "    \n",
    "def sandwich_indices(features, feature_index, percentile=5):\n",
    "    features_arr = features[:, feature_index]\n",
    "    sorted_indices = np.argsort(-features_arr)\n",
    "    \n",
    "    num_subset = int((percentile * len(features_arr))/100)\n",
    "    subset_indices = sorted_indices[:num_subset]\n",
    "    \n",
    "    return subset_indices\n",
    "\n",
    "def select_indices_class(train_logits, class_idx, num_select=None):\n",
    "    if num_select is None:\n",
    "        train_preds = np.argmax(train_logits, axis=1)\n",
    "        train_indices_class = np.nonzero(train_preds==class_idx)[0]\n",
    "    else:\n",
    "        train_probs = ch.softmax(ch.from_numpy(train_logits), dim=1).cpu().numpy()\n",
    "        train_probs_class = train_probs[:, class_idx]\n",
    "        train_indices_class = np.argsort(train_probs_class)\n",
    "        train_indices_class = train_indices_class[-num_select:]\n",
    "    return train_indices_class\n",
    "\n",
    "def compute_mask(cam, otsu_fraction):\n",
    "    cam = np.uint8(cam)\n",
    "    thresh, mask = cv2.threshold(cam, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)\n",
    "\n",
    "    thresh = int(otsu_fraction * thresh)\n",
    "    (_, mask) = cv2.threshold(cam, thresh, 255, cv2.THRESH_BINARY)\n",
    "    mask = 1. - (np.float32(mask)/255.)\n",
    "    return mask, thresh"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CausalFailureExplanation:\n",
    "    def __init__(self, robust_model, train_loader, train_features, train_logits, train_labels, \n",
    "                 batch_size=32, otsu_fraction=0.6, add_noise_mean=0., add_noise_std=0.25, \n",
    "                 replace_noise_mean=0.5, replace_noise_std=0.01,  images_per_feature=65):\n",
    "        self.robust_model = robust_model\n",
    "        self.train_loader = train_loader\n",
    "        \n",
    "        self.train_features = train_features\n",
    "        self.train_logits = train_logits\n",
    "        self.train_preds = np.argmax(train_logits, axis=1)\n",
    "        self.train_labels = train_labels\n",
    "        \n",
    "        self.batch_size = batch_size\n",
    "        self.otsu_fraction = otsu_fraction\n",
    "        \n",
    "        self.add_noise_mean = add_noise_mean\n",
    "        self.add_noise_std = add_noise_std\n",
    "        self.replace_noise_mean = replace_noise_mean\n",
    "        self.replace_noise_std = replace_noise_std\n",
    "        \n",
    "        self.images_per_feature = images_per_feature\n",
    "\n",
    "    def add_gaussian_noise(self, images, soft_masks, add_noise_std):\n",
    "        gaussian_noise = self.add_noise_mean + add_noise_std * ch.randn(*images.shape)\n",
    "        \n",
    "        images_n = images + (gaussian_noise * soft_masks)\n",
    "        images_n = ch.clamp(images_n, 0., 1.)\n",
    "        return images_n\n",
    "        \n",
    "    def corrupted_images(self, images, masks, corruption_type=\"add\", add_noise_std=None):\n",
    "        assert corruption_type in [\"none\", \"add\"]\n",
    "        \n",
    "        images = images.float().cpu()\n",
    "        masks = masks.float().cpu()\n",
    "\n",
    "        if corruption_type == \"add\":\n",
    "            images_corrupted = self.add_gaussian_noise(images, masks, add_noise_std)\n",
    "        elif corruption_type == \"none\":\n",
    "            images_corrupted = images\n",
    "        return images_corrupted\n",
    "    \n",
    "#     def control_group_images(self, images, images_exp, add_noise_std=None):\n",
    "#         b_size = images.shape[0]\n",
    "        \n",
    "#         images = images.float().cpu()\n",
    "#         masks = ch.ones_like(images)\n",
    "#         images_control = self.add_gaussian_noise(images, masks, add_noise_std)\n",
    "        \n",
    "#         diff_images_ctrl = (images_control - images).view(b_size, -1)\n",
    "#         l2_norm_ctrl = ch.norm(diff_images_ctrl, dim=1, keepdims=True)\n",
    "        \n",
    "#         diff_images_exp = (images_exp - images).view(b_size, -1)\n",
    "#         l2_norm_exp = ch.norm(diff_images_exp, dim=1, keepdims=True)\n",
    "        \n",
    "#         diff_images = diff_images * (l2_norm_exp/l2_norm_ctrl)\n",
    "#         images_control = images + diff_images.view_as(images)\n",
    "#         images_control = ch.clamp(images_control, 0., 1.)\n",
    "#         return images_control\n",
    "        \n",
    "    def compute_causal_acc(self, inspection_model, class_index, feature_indices, add_noise_std):\n",
    "        preds_all = []\n",
    "        preds_trmt_all = []\n",
    "        \n",
    "        causal_dataset = CustomDataSet('causal_imagenet', class_index, feature_indices)\n",
    "        total = len(causal_dataset)\n",
    "        \n",
    "        data_loader = DataLoader(causal_dataset, batch_size=self.batch_size, shuffle=False)\n",
    "        \n",
    "        for images_batch, soft_masks_batch in data_loader:\n",
    "            images_batch = images_batch.float().cuda()            \n",
    "            with ch.no_grad():\n",
    "                logits_batch = inspection_model(images_batch)\n",
    "            preds_batch = ch.argmax(logits_batch, dim=1).cpu().numpy()            \n",
    "            preds_all.append(preds_batch)\n",
    "\n",
    "            \n",
    "            images_trmt_batch = self.corrupted_images(images_batch, soft_masks_batch, \n",
    "                                                      corruption_type=\"add\", \n",
    "                                                      add_noise_std=add_noise_std)\n",
    "            images_trmt_batch = images_trmt_batch.float().cuda()\n",
    "            with ch.no_grad():\n",
    "                logits_trmt_batch = inspection_model(images_trmt_batch)\n",
    "            preds_trmt_batch = ch.argmax(logits_trmt_batch, dim=1).cpu().numpy()\n",
    "            preds_trmt_all.append(preds_trmt_batch)\n",
    "\n",
    "            \n",
    "#             images_ctrl_batch = self.control_group_images(images_batch, images_exp_batch, \n",
    "#                                                           add_noise_std=add_noise_std)\n",
    "#             images_ctrl_batch = images_ctrl_batch.float().cuda()\n",
    "#             with ch.no_grad():\n",
    "#                 logits_ctrl_batch = inspection_model(images_ctrl_batch)\n",
    "#             preds_ctrl_batch = ch.argmax(logits_ctrl_batch, dim=1).cpu().numpy()\n",
    "#             preds_ctrl_all.append(preds_ctrl_batch)\n",
    "\n",
    "            \n",
    "\n",
    "        preds_all = np.concatenate(preds_all, axis=0)\n",
    "        preds_trmt_all = np.concatenate(preds_trmt_all, axis=0)\n",
    "#         preds_ctrl_all = np.concatenate(preds_ctrl_all, axis=0)\n",
    "        \n",
    "        original_acc = np.sum(preds_all == class_index)/total\n",
    "        noise_trmt_acc = np.sum(preds_trmt_all == class_index)/total\n",
    "#         noise_ctrl_acc = np.sum(preds_ctrl_all == class_index)/total\n",
    "\n",
    "        return original_acc, noise_trmt_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CustomDataSet(Dataset):\n",
    "    def __init__(self, main_dir, class_index, feature_indices, otsu_fraction=0.6, \n",
    "                 images_per_feature=65, resize_size=224):\n",
    "        self.otsu_fraction = otsu_fraction\n",
    "        self.images_per_feature = images_per_feature\n",
    "        self.transform = transforms.Compose([\n",
    "            transforms.Resize((resize_size, resize_size)),\n",
    "            transforms.ToTensor()])\n",
    "\n",
    "        \n",
    "        class_path = os.path.join(main_dir, 'class_' + str(class_index))\n",
    "\n",
    "        image_indices_file = os.path.join(class_path, 'image_indices.csv')\n",
    "        image_indices_df = pd.read_csv(image_indices_file)\n",
    "        \n",
    "        feature_indices_dict = defaultdict(list)\n",
    "        image_paths = []\n",
    "        image_indices = []\n",
    "        for feature_index in feature_indices:\n",
    "            feature_path = os.path.join(class_path, 'feature_' + str(feature_index))\n",
    "            images_path = os.path.join(feature_path, 'images')            \n",
    "            \n",
    "            image_indices_curr = image_indices_df[str(feature_index)].to_numpy()\n",
    "            \n",
    "            for i in range(images_per_feature):                \n",
    "                image_index = image_indices_curr[i]\n",
    "                image_indices.append(image_index)\n",
    "                \n",
    "                image_path = os.path.join(images_path, str(i) + '.jpg')\n",
    "                image_paths.append(image_path)\n",
    "                \n",
    "                feature_indices_dict[image_index].append((feature_index, i))        \n",
    "        \n",
    "        image_indices = np.array(image_indices)\n",
    "        self.image_indices, unique_indices = np.unique(image_indices, return_index=True)        \n",
    "                \n",
    "        image_paths = np.array(image_paths)\n",
    "        self.image_paths = image_paths[unique_indices]\n",
    "        \n",
    "        self.feature_indices_dict = feature_indices_dict\n",
    "        \n",
    "        self.class_path = class_path\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.image_paths)\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        image_index = self.image_indices[index]\n",
    "        image_path = self.image_paths[index]\n",
    "\n",
    "        image = Image.open(image_path).convert(\"RGB\")\n",
    "        image_tensor = self.transform(image)\n",
    "        \n",
    "        feature_indices = self.feature_indices_dict[image_index]\n",
    "        \n",
    "        all_mask = np.zeros(image.size)\n",
    "        for j, tup in enumerate(feature_indices):            \n",
    "            feature_index, subdir_image_idx = tup\n",
    "            \n",
    "            feature_path = os.path.join(self.class_path, 'feature_' + str(feature_index))\n",
    "            cams_path = os.path.join(feature_path, 'cams')\n",
    "\n",
    "            cam_path = os.path.join(cams_path, str(subdir_image_idx) + '.jpg')\n",
    "            mask = cv2.imread(cam_path, cv2.IMREAD_GRAYSCALE)\n",
    "            mask = 1. - (mask/255.)\n",
    "            \n",
    "            all_mask = np.maximum(all_mask, mask)\n",
    "        \n",
    "        all_mask = np.uint8(all_mask * 255)\n",
    "        \n",
    "        all_mask = Image.fromarray(all_mask)\n",
    "        \n",
    "        mask_tensor = self.transform(all_mask)\n",
    "        return image_tensor, mask_tensor\n",
    "    \n",
    "    def visualize_random_subset(self, num_images=6):\n",
    "        images = []\n",
    "        masks1 = []\n",
    "        masks2 = []\n",
    "        all_masks = []\n",
    "        multi_indices = []\n",
    "        for index in range(len(self.image_indices)):\n",
    "            image_index = self.image_indices[index]\n",
    "            feature_indices = self.feature_indices_dict[image_index]\n",
    "            if len(feature_indices) == 2:                \n",
    "                multi_indices.append(index)\n",
    "                \n",
    "                image_path = self.image_paths[index]\n",
    "\n",
    "                image = Image.open(image_path).convert(\"RGB\")\n",
    "                image_tensor = self.transform(image)\n",
    "                images.append(image_tensor)\n",
    "                \n",
    "                \n",
    "                all_mask = np.zeros(image.size)\n",
    "                tup1, tup2 = feature_indices[0], feature_indices[1]\n",
    "                \n",
    "                feature_index1, subdir_image_idx1 = tup1\n",
    "                feature_index2, subdir_image_idx2 = tup2\n",
    "                \n",
    "                feature_path1 = os.path.join(self.class_path, 'feature_' + str(feature_index1))\n",
    "                cams_path1 = os.path.join(feature_path1, 'cams')\n",
    "                cam_path1 = os.path.join(cams_path1, str(subdir_image_idx1) + '.jpg')                \n",
    "                mask1 = cv2.imread(cam_path1, cv2.IMREAD_GRAYSCALE)\n",
    "                mask1 = 1. - (mask1/255.)\n",
    "                masks1.append(ch.Tensor(mask1).view(1, mask1.shape[0], mask1.shape[1]))\n",
    "                \n",
    "                feature_path2 = os.path.join(self.class_path, 'feature_' + str(feature_index2))\n",
    "                cams_path2 = os.path.join(feature_path2, 'cams')\n",
    "                cam_path2 = os.path.join(cams_path2, str(subdir_image_idx2) + '.jpg')                \n",
    "                mask2 = cv2.imread(cam_path2, cv2.IMREAD_GRAYSCALE)\n",
    "                mask2 = 1. - (mask2/255.)\n",
    "                masks2.append(ch.Tensor(mask2).view(1, mask2.shape[0], mask2.shape[1]))\n",
    "\n",
    "                all_mask = np.maximum(all_mask, mask1)\n",
    "                all_mask = np.maximum(all_mask, mask2)\n",
    "                \n",
    "                all_masks.append(ch.Tensor(all_mask).view(1, all_mask.shape[0], all_mask.shape[1]))\n",
    "\n",
    "        images = ch.stack(images, dim=0)\n",
    "        masks1 = ch.stack(masks1, dim=0)\n",
    "        masks2 = ch.stack(masks2, dim=0)\n",
    "        all_masks = ch.stack(all_masks, dim=0)\n",
    "        \n",
    "#         indices = np.array([2, 18, 20, 44, 92, 94, 103, 109, 120])\n",
    "#         indices = indices[:num_images]\n",
    "\n",
    "        images_select, all_masks_select = images[:6], all_masks[:6]\n",
    "        masks1_select, masks2_select = masks1[:6], masks2[:6]\n",
    "        \n",
    "        images_masked1 = images_select * masks1_select\n",
    "        images_masked2 = images_select * masks2_select\n",
    "        images_all_masked = images_select * all_masks_select\n",
    "        \n",
    "        show_image_row([images_select.cpu()], ['images'], \n",
    "                       tlist=[], fontsize=18)\n",
    "        \n",
    "        show_image_row([images_masked1.cpu()], ['masked1'], \n",
    "                       tlist=[], fontsize=18)\n",
    "        \n",
    "        show_image_row([images_masked2.cpu()], ['masked2'], \n",
    "                       tlist=[], fontsize=18)\n",
    "        \n",
    "        show_image_row([images_all_masked.cpu()], ['all_masked'], \n",
    "                       tlist=[], fontsize=18)\n",
    "        \n",
    "        noise = 0.25 * ch.randn(*images_select.shape)\n",
    "        \n",
    "        images_n1 = images_select + (noise * masks1_select)\n",
    "        images_n1 = ch.clamp(images_n1, 0., 1.)\n",
    "\n",
    "        images_n2 = images_select + (noise * masks2_select)\n",
    "        images_n2 = ch.clamp(images_n2, 0., 1.)\n",
    "\n",
    "        images_n_all = images_select + (noise * masks2_select)\n",
    "        images_n_all = ch.clamp(images_n_all, 0., 1.)\n",
    "        \n",
    "        show_image_row([images_select.cpu()], ['images'], \n",
    "                       tlist=[], fontsize=18)\n",
    "        \n",
    "        show_image_row([images_n1.cpu()], ['noise1'], \n",
    "                       tlist=[], fontsize=18)\n",
    "        \n",
    "        show_image_row([images_n2.cpu()], ['noise2'], \n",
    "                       tlist=[], fontsize=18)\n",
    "        \n",
    "        show_image_row([images_n_all.cpu()], ['noise_all'], \n",
    "                       tlist=[], fontsize=18)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision.models as models\n",
    "\n",
    "class CompleteModel(ch.nn.Module):\n",
    "    def __init__(self, model):\n",
    "        super(CompleteModel, self).__init__()\n",
    "        self.mean = ch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda()\n",
    "        self.std = ch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).cuda()\n",
    "        \n",
    "        self.model = model\n",
    "\n",
    "    def forward(self, inp):\n",
    "        normalized_inp = (inp - self.mean)/self.std        \n",
    "        logits = self.model(normalized_inp)\n",
    "        return logits\n",
    "\n",
    "def load_inspection_model(model_name):\n",
    "    models_dict = {\n",
    "        'resnet18': models.resnet18(pretrained=False, progress=False),\n",
    "        'resnet50': models.resnet50(pretrained=False, progress=False),\n",
    "        'vgg19_bn': models.vgg19_bn(pretrained=False, progress=False),\n",
    "        'inception_v3_google': models.inception_v3(pretrained=False, progress=False),\n",
    "        'googlenet': models.googlenet(pretrained=False, progress=False),\n",
    "        'shufflenetv2_x1': models.shufflenet_v2_x1_0(pretrained=False, progress=False),\n",
    "        'mobilenet_v2': models.mobilenet_v2(pretrained=False, progress=False),\n",
    "        'mobilenet_v3_large': models.mobilenet_v3_large(pretrained=False, progress=False),\n",
    "        'resnext50_32x4d': models.resnext50_32x4d(pretrained=False, progress=False),\n",
    "        'wide_resnet50_2': models.wide_resnet50_2(pretrained=False, progress=False),\n",
    "        'mnasnet1_0': models.mnasnet1_0(pretrained=False, progress=False),\n",
    "        'efficientnet-b0': EfficientNet.from_name('efficientnet-b0'),\n",
    "        'efficientnet-b4': EfficientNet.from_name('efficientnet-b4'),\n",
    "        'efficientnet-b7': EfficientNet.from_name('efficientnet-b7')\n",
    "    }\n",
    "    \n",
    "    if model_name == 'robust_resnet50':\n",
    "        model = load_model(model_name)\n",
    "        model = model.model\n",
    "    else:\n",
    "        model = models_dict[model_name]\n",
    "        \n",
    "        model_path = os.path.join(MODELS_PATH, \n",
    "                                  model_name + '.pth')\n",
    "        model_state_dict = ch.load(model_path)\n",
    "        model.load_state_dict(model_state_dict)\n",
    "\n",
    "    model.eval()\n",
    "    model = CompleteModel(model)\n",
    "    model = model.cuda()\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "k = 10\n",
    "num_features = 5\n",
    "\n",
    "robust_model_name = 'robust_resnet50'\n",
    "robust_model = load_model(robust_model_name)\n",
    "\n",
    "\n",
    "train_features, train_logits, train_labels = load_features(robust_model_name, 'train')    \n",
    "train_preds = np.argmax(train_logits, axis=1)\n",
    "\n",
    "num_features = 5\n",
    "# top_k_indices = [638]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "spurious_features_dict = get_spurious_features_dict(threshold=3, return_rank=True, validate_heatmap=False)\n",
    "all_features_dict = get_spurious_features_dict(threshold=0, return_rank=True, validate_heatmap=False)\n",
    "\n",
    "spurious_list = []\n",
    "for class_index in spurious_features_dict.keys():\n",
    "    full_class_name = CLASSES[class_index]\n",
    "    class_name = ', '.join(full_class_name.split(',')[:2])\n",
    "\n",
    "    feature_indices = spurious_features_dict[class_index]\n",
    "    for tup in feature_indices:\n",
    "        feature_index, feature_rank = tup\n",
    "        \n",
    "        key = str(class_index) + '_' + str(feature_index)\n",
    "        spurious_list.append(key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# add_noise_stds_list = [0.25, 0.5, 1.0, 1.25]\n",
    "add_noise_stds_list = [0.30, 0.35, 0.40, 0.45, 0.55, 0.60]\n",
    "model_name_list = ['resnet18', 'resnet50', 'vgg19_bn', 'inception_v3_google', \n",
    "                   'googlenet', 'shufflenetv2_x1', 'mobilenet_v2', 'mobilenet_v3_large', \n",
    "                   'resnext50_32x4d', 'wide_resnet50_2', 'mnasnet1_0', 'efficientnet-b0', \n",
    "                   'efficientnet-b4', 'efficientnet-b7', 'robust_resnet50']\n",
    "model_name_list = ['resnet50']\n",
    "\n",
    "\n",
    "num_images = 6\n",
    "batch_size = 32\n",
    "otsu_fraction = 0.6\n",
    "add_noise_mean = 0.\n",
    "add_noise_std = 0.25\n",
    "replace_noise_mean = 0.\n",
    "replace_noise_std = 0.01\n",
    "images_per_feature = 65\n",
    "num_decimal = 3\n",
    "num_runs = 10\n",
    "\n",
    "causal_failure_instance = CausalFailureExplanation(robust_model, train_loader, train_features, \n",
    "                                                   train_logits, train_labels, batch_size=batch_size, \n",
    "                                                   otsu_fraction=otsu_fraction, \n",
    "                                                   add_noise_mean=add_noise_mean, \n",
    "                                                   add_noise_std=add_noise_std, \n",
    "                                                   replace_noise_mean=replace_noise_mean, \n",
    "                                                   replace_noise_std=replace_noise_std, \n",
    "                                                   images_per_feature=images_per_feature)\n",
    "\n",
    "inspection_results_root = 'inspection_results_multiple_stds_all'\n",
    "os.makedirs(inspection_results_root, exist_ok=True)\n",
    "inspection_results_path = os.path.join(inspection_results_root, 'single_feature_per_class')\n",
    "os.makedirs(inspection_results_path, exist_ok=True)\n",
    "\n",
    "\n",
    "for model_name in model_name_list:\n",
    "    inspection_model = load_inspection_model(model_name)\n",
    "    print_with_stars(\" Model name: {:s} \".format(model_name))\n",
    "    model_results_path = os.path.join(inspection_results_path, model_name)\n",
    "    os.makedirs(model_results_path, exist_ok=True)\n",
    "\n",
    "    \n",
    "    for add_noise_std in add_noise_stds_list:\n",
    "        inspection_results_dict = defaultdict(list)\n",
    "        for class_index in spurious_features_dict.keys():\n",
    "            full_class_name = CLASSES[class_index]\n",
    "            class_name = ', '.join(full_class_name.split(',')[:2])\n",
    "\n",
    "            feature_indices = all_features_dict[class_index]\n",
    "            for tup in feature_indices:\n",
    "                feature_index, feature_rank = tup\n",
    "\n",
    "                key = str(class_index) + '_' + str(feature_index)\n",
    "                if key in spurious_list:\n",
    "                    feature_type = \"spurious\"\n",
    "                else:\n",
    "                    feature_type = \"causal\"\n",
    "\n",
    "                acc_tuple = causal_failure_instance.compute_causal_acc(inspection_model, \n",
    "                                                                       class_index, [feature_index], \n",
    "                                                                       add_noise_std=add_noise_std)\n",
    "                acc, causal_acc_add = acc_tuple\n",
    "                key = str(class_index) + '_' + str(feature_index)\n",
    "\n",
    "                \n",
    "                inspection_results_dict['class_index'].append(class_index)\n",
    "                inspection_results_dict['feature_index'].append(feature_index)\n",
    "                inspection_results_dict['feature_rank'].append(feature_rank)\n",
    "                inspection_results_dict['feature_type'].append(feature_type)\n",
    "                inspection_results_dict['class_name'].append(full_class_name)\n",
    "\n",
    "                inspection_results_dict['accuracy'].append(\n",
    "                    round(100 * acc, num_decimal))\n",
    "                inspection_results_dict['causal_accuracy_add'].append(\n",
    "                    round(100 * causal_acc_add, num_decimal))\n",
    "\n",
    "        inspection_results_df = pd.DataFrame.from_dict(inspection_results_dict)\n",
    "        curr_results_path = os.path.join(model_results_path, 'std_' + str(add_noise_std) + '.csv')\n",
    "        inspection_results_df.to_csv(curr_results_path, index=False)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
