{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "import os\n",
    "from IPython.display import display\n",
    "from functools import partial\n",
    "import math\n",
    "\n",
    "import matplotlib\n",
    "matplotlib.rcParams['pdf.fonttype'] = 42\n",
    "matplotlib.rcParams['ps.fonttype'] = 42\n",
    "matplotlib.rc('text', usetex=False)\n",
    "\n",
    "import scipy\n",
    "from mkdir_p import mkdir_p\n",
    "from PIL import Image\n",
    "import faiss\n",
    "import matplotlib.pyplot as plt\n",
    "import joblib\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from tqdm.notebook import tqdm\n",
    "from lime import lime_image\n",
    "from skimage.segmentation import mark_boundaries\n",
    "import shap\n",
    "\n",
    "\n",
    "from spurious_ml.models.torch_utils import archs, data_augs\n",
    "from spurious_ml.models.torch_utils.archs.wideresnet import NetworkBlock, BasicBlock\n",
    "from spurious_ml.variables import auto_var\n",
    "\n",
    "fontsize=15"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = \"/tmp2/XXX/archives/_models/mnist\"\n",
    "arch_name = \"CNN002\"\n",
    "data_aug = None\n",
    "\n",
    "trnX, trny, tstX, tsty = auto_var.get_var_with_argument(\"dataset\", \"mnist\")\n",
    "res = torch.load(model_path)\n",
    "if data_aug is None:\n",
    "    aug_fn = lambda x: x\n",
    "else:\n",
    "    aug_fn = getattr(data_augs, data_aug)()[1]\n",
    "model = getattr(archs, arch_name)(n_channels=None, n_classes=10)\n",
    "#model = WRN_features(n_classes=10)\n",
    "model.load_state_dict(res['model_state_dict'], strict=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = \"/tmp2/XXX/archives/_models/pgd-cifar10-0.031-aug01-ce-tor-WRN_40_10-lrem2mo.9-inf-0-ep0120.pt\"\n",
    "arch_name = \"WRN_40_10\"\n",
    "data_aug = \"aug01\"\n",
    "\n",
    "trnX, trny, tstX, tsty = auto_var.get_var_with_argument(\"dataset\", \"cifar10\")\n",
    "res = torch.load(model_path)\n",
    "if data_aug is None:\n",
    "    aug_fn = lambda x: x\n",
    "else:\n",
    "    aug_fn = getattr(data_augs, data_aug)()[1]\n",
    "model = getattr(archs, arch_name)(n_channels=None, n_classes=10)\n",
    "#model = WRN_features(n_classes=10)\n",
    "model.load_state_dict(res['model_state_dict'], strict=False)\n",
    "\n",
    "def get_features(X, model, aug_fn=None):\n",
    "    dset = torch.utils.data.TensorDataset(torch.from_numpy(X.transpose(0, 3, 1, 2)).float())\n",
    "    loader = torch.utils.data.DataLoader(dset, batch_size=128)\n",
    "\n",
    "    device = \"cuda\"\n",
    "    model.to(device).eval()\n",
    "    fetX = []\n",
    "    for (x, ) in tqdm(loader, desc=\"predicting\"):\n",
    "        x = x.to(device)\n",
    "        if aug_fn is not None:\n",
    "            x = aug_fn(x)\n",
    "        fetX.append(model.extract_features(x).cpu().detach().flatten(1).numpy())\n",
    "    fetX = np.concatenate(fetX, axis=0)\n",
    "    return fetX\n",
    "\n",
    "def get_model_preds(X, model, aug_fn=None):\n",
    "    dset = torch.utils.data.TensorDataset(torch.from_numpy(X.transpose(0, 3, 1, 2)).float())\n",
    "    loader = torch.utils.data.DataLoader(dset, batch_size=128)\n",
    "\n",
    "    device = \"cuda\"\n",
    "    model.to(device).eval()\n",
    "    fetX = []\n",
    "    for (x, ) in tqdm(loader, desc=\"predicting\"):\n",
    "        x = x.to(device)\n",
    "        if aug_fn is not None:\n",
    "            x = aug_fn(x)\n",
    "        fetX.append(model(x).argmax(1).cpu().detach().numpy())\n",
    "    fetX = np.concatenate(fetX, axis=0)\n",
    "    return fetX"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "features = get_features(trnX, model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.where(np.logical_and((features > 0).sum(0) < 5, (features > 0).sum(0) > 0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.where(features[:, 16] != 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda\"\n",
    "class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']\n",
    "tsti = 30821\n",
    "\n",
    "def pred_fn(x, fet_no):\n",
    "    return model.extract_features(x)[:, fet_no]\n",
    "\n",
    "model.eval()\n",
    "model.to(device)\n",
    "\n",
    "background = features[np.random.choice(features.shape[0], 200, replace=False)]\n",
    "background = torch.from_numpy(background).to(device)\n",
    "#background = trnX[np.random.choice(trnX.shape[0], 50, replace=False)]\n",
    "#background = torch.from_numpy(background.transpose(0, 3, 1, 2)).to(device)\n",
    "\n",
    "#test_examples = torch.from_numpy(trnX[30821:30822].transpose(0, 3, 1, 2)).float().to(device)\n",
    "test_examples = torch.from_numpy(features[tsti:tsti+2]).float().to(device)\n",
    "\n",
    "e = shap.GradientExplainer(model.fc, background)\n",
    "shap_values, indexes = e.shap_values(test_examples, ranked_outputs=2)\n",
    "\n",
    "# get the names for the classes\n",
    "index_names = np.vectorize(lambda x: class_names[x])(indexes.detach().cpu().numpy())\n",
    "\n",
    "# plot the explanations\n",
    "shap.image_plot(shap_values, trnX[tsti:tsti+2], index_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(shap_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda\"\n",
    "\n",
    "def pred_fn(x, fet_no):\n",
    "    return model.extract_features(x)[:, fet_no]\n",
    "\n",
    "model.eval()\n",
    "model.to(device)\n",
    "\n",
    "background = features[np.random.choice(features.shape[0], 200, replace=False)]\n",
    "background = torch.from_numpy(background).to(device)\n",
    "#background = trnX[np.random.choice(trnX.shape[0], 50, replace=False)]\n",
    "#background = torch.from_numpy(background.transpose(0, 3, 1, 2)).to(device)\n",
    "\n",
    "#test_examples = torch.from_numpy(trnX[30821:30822].transpose(0, 3, 1, 2)).float().to(device)\n",
    "test_examples = torch.from_numpy(features[30821:30822]).float().to(device)\n",
    "\n",
    "e = shap.GradientExplainer(model.fc, background)\n",
    "shap_values = e.shap_values(test_examples)\n",
    "shap_numpy = [np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in shap_values]\n",
    "test_numpy = test_examples.cpu().detach().numpy()\n",
    "# plot the feature attributions\n",
    "shap.image_plot(shap_numpy, -test_numpy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "shap_values[0].shape"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
