{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b3ab180d-64b7-451e-9b26-043b80729365",
   "metadata": {},
   "source": [
    "## Intro\n",
    "\n",
    "This notebook is provided to recreate a portion of the results on the Grassy MNIST dataset in \"Feature Selection in the Contrastive Analysis Setting\". To reproduce the results seen in the paper, the script `run_grassy_mnist_experiment.py` should be run as follows (after running `get_grassy_mnist.ipynb`):\n",
    "\n",
    "`python run_grassy_mnist_experiment.py <model> <num_features> <max_epochs>`, where \\<model\\> is one of `ConcreteAutoencoder`, `CFS_Joint`, `CFS_Pretrained`, or `CFS_Gates`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16031fce-e88c-44f7-bffc-9fe9c8be8c75",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.ensemble import ExtraTreesClassifier"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a834ff32-e0cb-4cb4-92ad-2fc19e5a95da",
   "metadata": {},
   "source": [
    "### Load data and models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7043adf2-6649-4731-918e-5c00c72cf70c",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name = 'Grassy_MNIST'\n",
    "k = 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28d35adb-9f05-420c-a9c8-81e73e1fae4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_concrete_layer_gates(layer, title):\n",
    "    M = layer.sample(n_samples=256)\n",
    "    values = torch.mean(M, dim=0)\n",
    "    values = torch.sum(values, dim=0)\n",
    "\n",
    "    _, idxs = torch.topk(values, k=20)\n",
    "    idxs = idxs[:k]\n",
    "\n",
    "    blank_image = np.zeros(784)\n",
    "    blank_image[idxs] = 1\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=(5, 5))\n",
    "\n",
    "    ax.imshow(blank_image.reshape(28, 28))\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])\n",
    "    fontsize=20\n",
    "    fig.suptitle(title, fontsize=fontsize)\n",
    "    fig.subplots_adjust(top=0.9)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33d59bf6-9ac2-4ecd-b9d5-d0596412c04c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "concrete_model = torch.load(\"results/{}/ConcreteAutoencoder/{}/checkpoint.chkpt\".format(dataset_name, k))\n",
    "layer = concrete_model.input_layer\n",
    "plot_concrete_layer_gates(layer, \"Concrete Autoencoder\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3f6eed5-54d1-41c3-81a6-3abc3a96b018",
   "metadata": {},
   "outputs": [],
   "source": [
    "contrastive_concrete_model = torch.load(\"results/{}/CFS_Pretrained/{}/checkpoint.chkpt\".format(dataset_name, k))\n",
    "layer = contrastive_concrete_model.input_layer\n",
    "plot_concrete_layer_gates(layer, \"CFS (Pretrained)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa4068d9-9a4a-4bc5-b1bc-327034c0a92c",
   "metadata": {},
   "outputs": [],
   "source": [
    "contrastive_concrete_model = torch.load(\"results/{}/CFS_Gates/{}/checkpoint.chkpt\".format(dataset_name, k))\n",
    "layer = contrastive_concrete_model.salient_input_layer\n",
    "plot_concrete_layer_gates(layer, \"CFS (Gates)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e54302a5-10cb-4d1f-8a60-11a96eeb0a0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "contrastive_concrete_model = torch.load(\"results/{}/CFS_Joint/{}/checkpoint.chkpt\".format(dataset_name, k))\n",
    "layer = contrastive_concrete_model.salient_input_layer\n",
    "plot_concrete_layer_gates(layer, \"CFS (Joint)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa071b4c-0129-4633-9c01-d5b2272b173b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import set_seeds\n",
    "\n",
    "def evaluate_features(method, k):\n",
    "    background = np.load(\"data/Grassy_MNIST/background.npy\")\n",
    "    target = np.load(\"data/Grassy_MNIST/target.npy\")\n",
    "    target_labels = np.load(\"data/Grassy_MNIST/target_labels.npy\") + 1\n",
    "    \n",
    "    \n",
    "    data_ = target\n",
    "    labels_ = target_labels\n",
    "\n",
    "    data_train, data_test, labels_train, labels_test = train_test_split(\n",
    "        data_,\n",
    "        labels_,\n",
    "        test_size=0.2,\n",
    "        random_state=42\n",
    "    )\n",
    "        \n",
    "        \n",
    "    if method == 'ConcreteAutoencoder':\n",
    "        concrete_model = torch.load(\"results/Grassy_MNIST/ConcreteAutoencoder/{}/checkpoint.chkpt\".format(k))\n",
    "        layer = concrete_model.input_layer\n",
    "        indices = layer.get_inds()\n",
    "        \n",
    "    elif method == 'CFS_Pretrained':\n",
    "        concrete_model = torch.load(\"results/Grassy_MNIST/CFS_Pretrained/{}/checkpoint.chkpt\".format(k))\n",
    "        layer = concrete_model.input_layer\n",
    "        indices = layer.get_inds()\n",
    "        \n",
    "    elif method == 'CFS_Gates':\n",
    "        concrete_model = torch.load(\"results/Grassy_MNIST/CFS_Gates/{}/checkpoint.chkpt\".format(k))\n",
    "        layer = concrete_model.salient_input_layer\n",
    "        indices = layer.get_inds()\n",
    "        \n",
    "    elif method == 'CFS_Joint':\n",
    "        concrete_model = torch.load(\"results/Grassy_MNIST/CFS_Joint/{}/checkpoint.chkpt\".format(k))\n",
    "        layer = concrete_model.salient_input_layer\n",
    "        indices = layer.get_inds()\n",
    "        \n",
    "    set_seeds()\n",
    "    data_train, data_test = data_train[:, indices], data_test[:, indices]\n",
    "    clf = ExtraTreesClassifier(n_estimators=100)\n",
    "    clf.fit(data_train, labels_train)\n",
    "        \n",
    "    return clf.score(data_test, labels_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fa76e8a-20a3-464d-8be0-3412cf620fdd",
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluate_features('ConcreteAutoencoder', 20)\n",
    "evaluate_features('CFS_Joint', 20)\n",
    "evaluate_features('CFS_Pretrained', 20)\n",
    "evaluate_features('CFS_Gates', 20)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
