{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "310f7920-ef89-49d3-98e4-53e1e07efb56",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "from functools import lru_cache\n",
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "from os.path import join\n",
    "from IPython.display import display\n",
    "from functools import partial\n",
    "\n",
    "import matplotlib\n",
    "matplotlib.rcParams['pdf.fonttype'] = 42\n",
    "matplotlib.rcParams['ps.fonttype'] = 42\n",
    "matplotlib.rc('text', usetex=False)\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from scipy.stats import ttest_ind, sem\n",
    "from scipy.fft import fft2\n",
    "import scipy\n",
    "import joblib\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from tqdm.notebook import tqdm\n",
    "from skimage.segmentation import mark_boundaries\n",
    "from scipy.stats import pearsonr\n",
    "import pandas as pd\n",
    "pd.options.display.float_format = '{:,.3f}'.format\n",
    "pd.set_option('display.max_rows', 128)\n",
    "\n",
    "from spurious_ml.models.torch_utils import archs\n",
    "from spurious_ml.variables import auto_var\n",
    "from params import *\n",
    "from utils import params_to_dataframe\n",
    "\n",
    "fontsize=18"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "346f130e-60ea-4b5a-9c58-de079c427268",
   "metadata": {},
   "outputs": [],
   "source": [
    "def cnn_pred_fn(X, model, device=\"cuda\"):\n",
    "    if len(X.shape) == 4:\n",
    "        X = X.transpose(0, 3, 1, 2)\n",
    "        \n",
    "    dset = torch.utils.data.TensorDataset(torch.from_numpy(X).float())\n",
    "    loader = torch.utils.data.DataLoader(dset, batch_size=128)\n",
    "    \n",
    "    model.to(device).eval()\n",
    "    fetX = []\n",
    "    for (x, ) in tqdm(loader, desc=\"[pred_fn]\"):\n",
    "        fetX.append(model(x.to(device)).cpu().detach().numpy())\n",
    "    fetX = np.concatenate(fetX, axis=0)\n",
    "    return fetX\n",
    "\n",
    "def mlp_pred_fn(X, model, device=\"cuda\"):\n",
    "    dset = torch.utils.data.TensorDataset(torch.from_numpy(X.reshape(len(X), -1)).float())\n",
    "    loader = torch.utils.data.DataLoader(dset, batch_size=256)\n",
    "    \n",
    "    model.to(device).eval()\n",
    "    fetX = []\n",
    "    for (x, ) in tqdm(loader, desc=\"[pred_fn]\"):\n",
    "        fetX.append(model(x.to(device)).cpu().detach().numpy())\n",
    "    fetX = np.concatenate(fetX, axis=0)\n",
    "    return fetX"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2306836-7db3-4c52-927a-5579f55b4dc3",
   "metadata": {},
   "outputs": [],
   "source": [
    "params = {\n",
    "    \"mnist\": [\n",
    "        [\"CNN002\", \"CNN\", \"../models/train_classifier/128-mnist-70-0.01-ce-tor-CNN002-0.0-adam-0-0.0.pt\"],\n",
    "        [\"MLP\", \"small MLP\", \"../models/train_classifier/128-mnist-70-0.01-ce-tor-MLP-0.0-adam-0-0.0.pt\"],\n",
    "        [\"LargeMLP\", \"MLP\", \"../models/train_classifier/128-mnist-70-0.01-ce-tor-LargeMLP-0.0-adam-0-0.0.pt\"],\n",
    "        [\"LargeMLPv2\", \"large MLP\", \"../models/train_classifier/128-mnist-70-0.01-ce-tor-LargeMLPv2-0.0-adam-0-0.0.pt\"],\n",
    "    ],\n",
    "    \"fashion\": [\n",
    "        [\"CNN002\", \"CNN\", \"../models/train_classifier/128-fashion-70-0.01-ce-tor-CNN002-0.0-adam-0-0.0.pt\"],\n",
    "        [\"MLP\", \"small MLP\", \"../models/train_classifier/128-fashion-70-0.01-ce-tor-MLP-0.0-adam-0-0.0.pt\"],\n",
    "        [\"LargeMLP\", \"MLP\", \"../models/train_classifier/128-fashion-70-0.01-ce-tor-LargeMLP-0.0-adam-0-0.0.pt\"],\n",
    "        [\"LargeMLPv2\", \"large MLP\", \"../models/train_classifier/128-fashion-70-0.01-ce-tor-LargeMLPv2-0.0-adam-0-0.0.pt\"],\n",
    "    ],\n",
    "    \"cifar10\": [\n",
    "        [\"altResNet20Norm02\", \"ResNet20\", \"../models/train_classifier/128-cifar10-70-0.1-aug01-ce-tor-altResNet20Norm02-0.9-sgd-0-0.0001.pt\"],\n",
    "        [\"altResNet32Norm02\", \"ResNet32\", \"../models/train_classifier/128-cifar10-70-0.1-aug01-ce-tor-altResNet32Norm02-0.9-sgd-0-0.0001.pt\"],\n",
    "        [\"altResNet110Norm02\", \"ResNet110\", \"../models/train_classifier/128-cifar10-70-0.1-aug01-ce-tor-altResNet110Norm02-0.9-sgd-0-0.0001.pt\"],\n",
    "        [\"Vgg16Norm02\", \"Vgg16\", \"../models/train_classifier/128-cifar10-70-0.01-aug01-ce-tor-Vgg16Norm02-0.9-sgd-0-0.0001.pt\"],\n",
    "    ],\n",
    "}\n",
    "\n",
    "def occlude_X(X, occlude, random_state):\n",
    "    if occlude == 0:\n",
    "        return np.copy(X)\n",
    "    if occlude == 1:\n",
    "        return np.zeros_like(X)\n",
    "    retX = []\n",
    "    tX = np.copy(X.reshape(X.shape[0], X.shape[1] * X.shape[2], X.shape[3]))\n",
    "    inds = np.where(tX != 0)\n",
    "    n_blocked = int(X.shape[1] * X.shape[2] * occlude)\n",
    "    blocked_pixels = random_state.choice(np.arange(2), size=len(inds[1]), p=[1-occlude, occlude], replace=True)\n",
    "    for i in tqdm(range(len(X))):\n",
    "        pixel_no = inds[1][inds[0] == i][blocked_pixels[inds[0] == i] == 1]\n",
    "        tX[i][pixel_no] = 0\n",
    "        retX.append(tX[i].reshape(1, X.shape[1], X.shape[2], X.shape[3]))\n",
    "    return np.concatenate(retX, axis=0)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72a77988-4611-4a99-9b73-9cbfd8978915",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "n_samples = 500\n",
    "\n",
    "results = {}\n",
    "for ds_name, ds_params in params.items():\n",
    "    trnX, trny, tstX, tsty, spurious_ind = auto_var.get_var_with_argument(\"dataset\", ds_name)\n",
    "    n_classes = len(np.unique(trny))\n",
    "    n_channels = trnX.shape[-1]\n",
    "\n",
    "    ind = np.random.RandomState(0).choice(np.arange(len(trny)), size=n_samples, replace=False)\n",
    "    occuluded_Xs = []\n",
    "    for p in range(51):\n",
    "        occuluded_Xs.append(occlude_X(trnX[ind], p/50, np.random.RandomState(0)))\n",
    "        \n",
    "    for arch_name, arch_shown_name, model_path in ds_params:\n",
    "        model = getattr(archs, arch_name)(n_features=(np.prod(trnX.shape[1:]), ), n_channels=n_channels, n_classes=n_classes)\n",
    "        model.load_state_dict(torch.load(model_path)['model_state_dict'])\n",
    "    \n",
    "        preds = []\n",
    "        for p in range(51):\n",
    "            occ_trnX = occuluded_Xs[p]\n",
    "            if \"MLP\" in arch_name:\n",
    "                pred = mlp_pred_fn(occ_trnX, model, )[None, :, :]\n",
    "            else:\n",
    "                pred = cnn_pred_fn(occ_trnX, model, )[None, :, :]\n",
    "            preds.append(pred)\n",
    "        preds = scipy.special.softmax(np.concatenate(preds), axis=2)\n",
    "        \n",
    "        ttt = []\n",
    "        for i in range(n_samples):\n",
    "            ttt.append(preds[:, i, trny[ind[i]]])\n",
    "        results[(ds_name, arch_name)] = (np.mean(ttt), )\n",
    "        plt.plot(np.arange(51)/50, np.mean(ttt, axis=0), label=arch_shown_name)\n",
    "    plt.ylabel(\"Probability\", fontsize=fontsize)\n",
    "    plt.xlabel(\"Portion of pixels removed\", fontsize=fontsize)\n",
    "    plt.xticks(xaxis[::10], labels=xaxis[::10], fontsize=fontsize)\n",
    "    plt.yticks(fontsize=fontsize)\n",
    "    plt.legend(fontsize=fontsize)\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f\"./figs/pixel_removal/{ds_name}.png\")\n",
    "    plt.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b0a7c5c-d424-4478-ae8a-d4a516959f47",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = {}\n",
    "for ds_name, ds_params in params.items():\n",
    "    trnX, trny, tstX, tsty, spurious_ind = auto_var.get_var_with_argument(\"dataset\", ds_name)\n",
    "    n_classes = len(np.unique(trny))\n",
    "    n_channels = trnX.shape[-1]\n",
    "\n",
    "    for arch_name, arch_shown_name, model_path in ds_params:\n",
    "        if arch_shown_name in data:\n",
    "            continue\n",
    "        model = getattr(archs, arch_name)(n_features=(np.prod(trnX.shape[1:]), ), n_channels=n_channels, n_classes=n_classes)\n",
    "        \n",
    "        model_parameters = filter(lambda p: p.requires_grad, model.parameters())\n",
    "        data[arch_shown_name] = sum([np.prod(p.size()) for p in model_parameters])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "363ee4ea-8157-4853-b266-f8210c65e59b",
   "metadata": {},
   "outputs": [],
   "source": [
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67b9c7c7-6f9e-4feb-ae40-1aa97877b5de",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(occ_trnX[0])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
