{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6d6ec556-d331-4892-9914-f2916c3ab0a7",
   "metadata": {},
   "source": [
    "## Table of Contents\n",
    "\n",
    "* [Setup & Imports](#imports)\n",
    "* [Reproducibility study](#4.1)\n",
    "    * [Qualitative evaluation of counterfactual samples](#qual-eval-cf-samples) (HQC)\n",
    "    * [Evaluating invariant classifiers](#eval-inv-clf) (ODR)\n",
    "        * [Experiments on MNISTs](#repr-mnist)\n",
    "        * [Experiments on ImageNet-mini](#repr-in-mini)\n",
    "    * [Loss Ablation](#loss-ablation) (IBR)\n",
    "* [Additional experiments and analyses](#add-expts)\n",
    "    * [Explainability analysis on MNISTs](#expl-mnist)\n",
    "    * [Explainability analysis on ImageNet-mini](#expl-in-mini)\n",
    "    * [Robustness to OOD generalization](#ood)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "baf5dbf2-dab0-4ec9-952a-f7d34d874d14",
   "metadata": {},
   "source": [
    "## Setup & Imports  <a class=\"anchor\" id=\"imports\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fd6f6a5-13ef-49b9-a03b-b7d4489babf0",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0160665-1027-451c-97be-c601dc31b8e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b667e041-62b9-4a6d-bc32-954db15c2675",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23b0f663-e30c-4fc8-98d6-9edb61de7c37",
   "metadata": {},
   "source": [
    "### Download data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34317b0b",
   "metadata": {},
   "source": [
    "The ImageNet-mini dataset needs to be downloaded from Kaggle. Please export your Kaggle credentials using the following command. The key is the Kaggle API key and can be found in your account settings.\n",
    "```sh\n",
    "export KAGGLE_USERNAME=<your_username>\n",
    "export KAGGLE_KEY=<your_key>\n",
    "```\n",
    "\n",
    "Or alternatively, you can download your API key `kaggle.json` file and put it here `~/.kaggle/kaggle.json`.\n",
    "\n",
    "> Note: Downloading all datasets takes about 20 mins and needs about 7GB of free space. In case the download fails at some point, you can re-run the cell, it will not download datasets already downloaded\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13ba36f2-5b69-416f-8226-f8d3adf1ae7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "!python ../setup/download_datasets.py"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14a5ad90",
   "metadata": {},
   "source": [
    "This should download datasets for both `mnists` and `imagenet` tasks.\n",
    "\n",
    "For MNISTs, the folder structure is as follows:\n",
    "```sh\n",
    "mnists/data\n",
    "├── colored_mnist\n",
    "└── textures\n",
    "    ├── background\n",
    "    └── object\n",
    "\n",
    "4 directories\n",
    "```\n",
    "\n",
    "For ImageNet, the folder structure is as follows:\n",
    "```sh\n",
    "imagenet/data\n",
    "├── cue_conflict\n",
    "├── in-a\n",
    "├── in-mini\n",
    "├── in-sketch\n",
    "├── in-stylized\n",
    "└── in9\n",
    "\n",
    "6 directories\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7949a384-7cb9-457a-86ac-ee9a8fc062f0",
   "metadata": {},
   "source": [
    "### Download model weights\n",
    "\n",
    "> Note: This takes < 5 mins to run and needs about 6GBs of free space. In case the download fails at some point, you can re-run the cell, it will not download model weights already downloaded."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "276fec4f-d415-4dfe-946b-dd35f167ca20",
   "metadata": {},
   "outputs": [],
   "source": [
    "!python ../setup/download_weights.py"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31af44ca",
   "metadata": {},
   "source": [
    "This will download the weights for all tasks.\n",
    "\n",
    "```bash\n",
    "imagenet/weights/\n",
    "├── biggan256.pth\n",
    "├── cgn.pth\n",
    "├── u2net.pth\n",
    "├── :\n",
    "└── resnet50_from_scratch_model_best.pth.tar\n",
    "\n",
    "4 files\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a462e32f-0a78-4a40-be17-423131327447",
   "metadata": {},
   "source": [
    "## Section 4.1. Reproducibility study <a class=\"anchor\" id=\"4.1\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "04761f4c",
   "metadata": {},
   "source": [
    "### Qualitative evaluation of counterfactual samples <a class=\"anchor\" id=\"qual-eval-cf-samples\"></a>\n",
    "\n",
    "**Addressed Claim: High-quality counterfactuals (HQC)**\n",
    "\n",
    "\n",
    "* Relevant section in the paper: Section 4.1.\n",
    "* Relevant figures in the paper: Figure 2 and 3 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85e05b66",
   "metadata": {},
   "outputs": [],
   "source": [
    "from counterfactual_mnist import main as mnist_main\n",
    "from counterfactual_imagenet import main as imagenet_main"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a9af16c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "mnist_main(dataset_size=100, no_cfs=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ef1eb74-62ad-4465-b88c-f542b1bd0be7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "imagenet_main()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b167ec5-d212-466f-a91f-cdc19fba5a7f",
   "metadata": {},
   "source": [
    "### Evaluating invariant classifiers <a class=\"anchor\" id=\"eval-inv-clf\"></a>\n",
    "\n",
    "**Addressed Claim: Out-of-Distribution Robustness (ODR)**"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6dcbad18-5cdf-4a1c-b996-53195e306ffb",
   "metadata": {},
   "source": [
    "#### Experiments on MNISTs <a class=\"anchor\" id=\"repr-mnist\"></a>\n",
    "\n",
    "* Relevant tables in paper: Table 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a40779b-6ec0-4857-a4fe-2e725f38f4d2",
   "metadata": {},
   "source": [
    "**Step 1**: Prepation - Run experiments on GPUs before visualizing results\n",
    "\n",
    "> **Note**: It might take a long time to run this on a CPU machine (~40 mins). Instead, we would recommend\n",
    "> running it on a GPU machine using the following instructions. That will run and generate all results\n",
    "> and then you could run the cell below which will display all the results. We give instructions to run on a cluster managed by `Slurm`. In case you can get terminal access to GPU machines, you can run the individual scripts (look at the job files to check the commands)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86417a73-9d8e-4eb0-adc1-83eaf5655767",
   "metadata": {},
   "source": [
    "**Running using cached models wherever necessary**\n",
    "\n",
    "Follow the instruction to run it on a GPU on a cluster. \n",
    "\n",
    "1. We have created a job script `run_mnist.job`.\n",
    "    Please change it appropriately, if needed, to run on a GPU.\n",
    "\n",
    "2. Run the job script using \n",
    "\n",
    "    ```sh\n",
    "    cd /path/to/repo/experiments/\n",
    "    sbatch run_mnist.job\n",
    "    ```\n",
    "    You can check logs in `slurm_output_*.out` files.\n",
    "\n",
    "3. After these steps are done, you can run the following cells that will display the result.\n",
    "\n",
    "> This step takes about 20-25 mins to run on a GPU with 2 CPUs. Please let this step finish before running the cells.\n",
    "\n",
    "\n",
    "**Re-running without using any cached models.**\n",
    "\n",
    "The previous steps will generate CF data and load classification results from cache. If you want to run all these steps from scratch, you need to change `ignore_cache=True` in the following line in `mnist_pipeline.py` and follow the same steps as above.\n",
    "```python\n",
    "df = run_experiments(seed=0, ignore_cache=False)\n",
    "```\n",
    "You do not need to change the cell below."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3b9780c0-6b81-44f6-a753-f1314fefa84f",
   "metadata": {},
   "source": [
    "**Step 2**: Replication of Table 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcb0dd2a-e0ac-4607-8e60-ac7c62a2cadd",
   "metadata": {},
   "outputs": [],
   "source": [
    "from mnist_pipeline import run_experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d549c5c3-37ed-4af3-ae17-ce89feac6857",
   "metadata": {},
   "outputs": [],
   "source": [
    "# here, it is needed to pass `ignore_cache=False` since the results must be generated\n",
    "# from the previous step, thus you only need to run this as it is\n",
    "df = run_experiments(seed=0, show=False, ignore_cache=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f911c62-ddfd-4277-87d5-2ab60ef198a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.index = [\"Original\", \"GAN\", \"CGN\", \"Original + GAN\", \"Original + CGN\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28a6cc2d-f3bb-4aaf-8de7-ce232d7a90c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# show the results    \n",
    "df.astype(float).round(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "83c270f3-5aa7-4efb-aa2a-b2771708e906",
   "metadata": {},
   "source": [
    "#### Experiments on ImageNet-mini <a class=\"anchor\" id=\"repr-in-mini\"></a>\n",
    "\n",
    "* Relevant tables in the paper: Table 3 and 4"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ccc35c9-4144-4835-af01-3773ef693a09",
   "metadata": {},
   "source": [
    "**Step 1**: Prepation - Run experiments on GPUs before visualizing results"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "99f10232",
   "metadata": {},
   "source": [
    "**Running using cached models wherever necessary**\n",
    "\n",
    "Follow the instruction to run it on a GPU on a cluster. \n",
    "\n",
    "1. We have created a job script `run_imagenet.job`.\n",
    "    Please change it appropriately, if needed, to run on a GPU.\n",
    "\n",
    "2. Run the job script using \n",
    "\n",
    "    ```sh\n",
    "    cd /path/to/repo/experiments/\n",
    "    sbatch run_imagenet.job\n",
    "    ```\n",
    "    You can check logs in `slurm_output_*.out` files.\n",
    "\n",
    "3. After these steps are done, you can run the following cells that will display the result.\n",
    "\n",
    "> This step takes about 5 mins to run. Please let this step finish before running the cells.\n",
    "\n",
    "\n",
    "**Re-running without using any cached models.**\n",
    "\n",
    "The previous steps will generate CF data and load classification results from cache. If you want to run all these steps from scratch, you need to change `ignore_cache=True` and `generate_cf_data=True` in the following line in `experiments/imagenet_pipeline.py` and follow the same steps as above.\n",
    "\n",
    "```python\n",
    "metrics_clf, df_ood = run_experiments(seed=0, generate_cf_data=False, disp_epoch=34, ignore_cache=False)\n",
    "```\n",
    "\n",
    "> Note: This section involves generating counterfactual samples and training classifiers on IN-mini.\n",
    "> Generating CF samples can take about 3.5 hours and training the classifier about 2 hours on a GPU."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "597d4b6d-a938-41f8-a6c3-bff51c985f7b",
   "metadata": {},
   "source": [
    "**Step 2**: Results for Table 3 and 4 from the paper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "edb0eb63-63d0-4be5-8b5c-9f551331f2c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from imagenet_pipeline import run_experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0776f3a4-12ad-4724-8b94-3b19fe79ec73",
   "metadata": {},
   "outputs": [],
   "source": [
    "# temporarily showing results for 0th epoch\n",
    "metrics_clf, df_ood = run_experiments(seed=0, disp_epoch=34)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "173b59f3-a15d-4b28-b8a2-97a43fb468e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# construct Table 3 of the paper\n",
    "\n",
    "heads = [\"shape\", \"texture\", \"bg\"]\n",
    "table_3 = pd.DataFrame(\n",
    "    None,\n",
    "    columns=[\"Shape bias\", \"Top 1\", \"Top 5\"],\n",
    "    index=[f\"IN-mini + CGN/{h}\" for h in heads],\n",
    ")\n",
    "for i, h in enumerate(heads):\n",
    "    table_3.at[f\"IN-mini + CGN/{h}\", \"Shape bias\"] = metrics_clf[f\"shape_biases/{i}_m_{h}_bias\"]\n",
    "    table_3.at[f\"IN-mini + CGN/{h}\", \"Top 1\"] = metrics_clf[f\"acc1/1_real\"]\n",
    "    table_3.at[f\"IN-mini + CGN/{h}\", \"Top 5\"] = metrics_clf[f\"acc5/1_real\"]\n",
    "\n",
    "table_3[\"Shape bias\"] *= 100.0\n",
    "table_3 = table_3.astype(float).round(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83de2efd-e746-4fb2-9d47-ed8c7557829c",
   "metadata": {},
   "outputs": [],
   "source": [
    "table_3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e2d19e4-1a76-4a06-bfe3-a5f46e88a4da",
   "metadata": {},
   "outputs": [],
   "source": [
    "# construct Table 4 of the paper\n",
    "table_4 = pd.DataFrame(\n",
    "    None,\n",
    "    columns=[\"IN-9\", \"Mixed-same\", \"Mixed-rand\", \"BG-gap\"],\n",
    "    index=[\"IN-mini + CGN\"],\n",
    ")\n",
    "\n",
    "col_to_key = {\n",
    "    \"IN-9\": \"in_9_acc1_original/shape_texture\",\n",
    "    \"Mixed-same\": \"in_9_acc1_mixed_same/shape_texture\",\n",
    "    \"Mixed-rand\": \"in_9_acc1_mixed_rand/shape_texture\",\n",
    "    \"BG-gap\": \"in_9_gaps/bg_gap\",\n",
    "}\n",
    "\n",
    "for c in table_4.columns:\n",
    "    assert col_to_key[c] in metrics_clf\n",
    "    key = col_to_key[c]\n",
    "    table_4.at[\"IN-mini + CGN\", c] = metrics_clf[key]\n",
    "\n",
    "table_4 = table_4.astype(float).round(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9486cd7f-c6a2-4d0a-b9ab-4aec75d10633",
   "metadata": {},
   "outputs": [],
   "source": [
    "table_4"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43c1d63c",
   "metadata": {},
   "source": [
    "### Loss ablation <a class=\"anchor\" id=\"loss-ablation\"></a>\n",
    "\n",
    "**Addressed Claim: Inductive Bias Requirements (IBR)**\n",
    "\n",
    "* Relevant section in the paper: `Evaluating loss ablation` within Section 4.1.\n",
    "* Relevant table in paper: Table 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c08480f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ablation_study import run_experiments\n",
    "\n",
    "for loss_name, inception, avg_mask, sd_mask in run_experiments(ignore_cache=False):\n",
    "    print(f\"for {loss_name} inception_score = {inception[0]} and mu_mask = {avg_mask}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "510c7c8a-11de-4aa2-8779-6e63d484dbfe",
   "metadata": {},
   "source": [
    "### Additional experiments and analyses <a class=\"anchor\" id=\"add-expts\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f130a9b4-0d1b-4983-a235-59d5bb804764",
   "metadata": {},
   "source": [
    "#### Explainability analysis on MNISTs <a class=\"anchor\" id=\"expl-mnist\"></a>\n",
    "\n",
    "* Relevant sections in the paper: 4.2.2\n",
    "* Relevant figures in the paper: Figure 5 and 6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fd348db-e58d-4653-96e6-e325e8c7c885",
   "metadata": {},
   "outputs": [],
   "source": [
    "from mnist_analysis import run_analyses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23e4920d-234c-49bf-83ed-8fcaf42d3f20",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "run_analyses(\n",
    "    datasets=[\"colored_MNIST\", \"double_colored_MNIST\", \"wildlife_MNIST\"],\n",
    "    debug=False,\n",
    "    show=True,\n",
    "    ignore_cache=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d8665a2-5361-4d44-94dd-45205ef4e9e9",
   "metadata": {},
   "source": [
    "#### Explainability analysis on ImageNet-mini <a class=\"anchor\" id=\"expl-in-mini\"></a>\n",
    "\n",
    "* Relevant figures in the paper: Figure 7"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a0e5ff5-bad7-4c5d-9008-7ca455c949cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from gradio_demo import init_gradio_module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "edcded5f-232c-4e85-9252-d3d84cdf8bad",
   "metadata": {},
   "outputs": [],
   "source": [
    "cgn_gradio = init_gradio_module(launch=True, share=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15260f61-863a-4f72-ba95-c23ee628dfa6",
   "metadata": {},
   "source": [
    "#### Robustness to out-of-distribution generalization <a class=\"anchor\" id=\"ood\"></a>\n",
    "\n",
    "* Relevant tables in the paper: Table 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4773c92e",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_ood = df_ood.astype(float).round(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac064bbe-a579-43b2-b8f9-822df7057fbb",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_ood"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "804a5f2d-14db-46f1-aeb9-988fd74267a9",
   "metadata": {},
   "source": [
    "### Appendix"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ea3d400-02d8-4926-809e-622598ce72a3",
   "metadata": {},
   "source": [
    "#### Ablation on number of counterfactuals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc56f87d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_mnist_ablation_on_cfs(\n",
    "        file,\n",
    "        datasets=[\"colored_MNIST\", \"double_colored_MNIST\", \"wildlife_MNIST\"],\n",
    "    ):\n",
    "    CF_ratios = [1, 5, 10, 20]\n",
    "    dataset_sizes = [10000, 100000, 1000000]\n",
    "\n",
    "    with open(file, \"r\") as f:\n",
    "        results = json.load(f)\n",
    "\n",
    "    fig, axs = plt.subplots(1, 3, figsize=(16,5))\n",
    "    plt.setp(axs, xticks=[0, 1, 2], xticklabels=[r'$10^4$', r'$10^5$', r'$10^6$'])\n",
    "\n",
    "    for i, dataset in enumerate(datasets):\n",
    "        for CF_ratio in CF_ratios:\n",
    "            # Skip the CF_ratio of 20 for the colored MNIST dataset, as there are only 10 possible colors\n",
    "            # per shape.\n",
    "            if CF_ratio == 20 and dataset == \"colored_MNIST\":\n",
    "                continue\n",
    "\n",
    "            line = []\n",
    "            for size in dataset_sizes:\n",
    "                line.append(results[f\"{dataset}_counterfactual_{size}_{CF_ratio}\"])\n",
    "            axs[i].plot(np.arange(3), line, label=f'CF ratio = {CF_ratio}', marker='o')\n",
    "            axs[i].set_xlabel(\"Num Counterfactual Datapoints\")\n",
    "            axs[i].set_ylabel(\"Test Accuracy (%)\")\n",
    "            axs[i].grid(True)\n",
    "            axs[i].legend()\n",
    "        axs[i].set_title(datasets[i])\n",
    "\n",
    "    plt.savefig('../media/figures/figure7_reproduced.pdf', bbox_inches='tight')\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba42a968-958e-4246-8a4b-aa031a5e0a0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_path = \"../experiments/results/cache/mnist_ablation_on_cfs.json\"\n",
    "plot_mnist_ablation_on_cfs(results_path)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
