{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "virtual-headline",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n",
    "from torchvision import transforms\n",
    "from torchvision.datasets import MNIST\n",
    "from copy import deepcopy\n",
    "from utils import *\n",
    "from models import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "direct-arrow",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda', 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "adjusted-pointer",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load data\n",
    "class Flatten(object):\n",
    "    def __call__(self, pic):\n",
    "        return torch.flatten(pic)\n",
    "    \n",
    "mnist_dataset = MNIST('/tmp/mnist/', download=True, train=True,\n",
    "                      transform=transforms.Compose([transforms.ToTensor(), Flatten()]))\n",
    "images = mnist_dataset.data\n",
    "targets = mnist_dataset.targets\n",
    "np.random.seed(0)\n",
    "val_inds = np.sort(np.random.choice(len(images), size=10000, replace=False))\n",
    "train_inds = np.setdiff1d(np.arange(len(images)), val_inds)\n",
    "\n",
    "# Training dataset\n",
    "train_dataset = torch.utils.data.Subset(mnist_dataset, train_inds)\n",
    "\n",
    "# Validation dataset\n",
    "val_dataset = torch.utils.data.Subset(mnist_dataset, val_inds)\n",
    "\n",
    "# Test dataset\n",
    "test_dataset = MNIST('/tmp/mnist/', download=True, train=False,\n",
    "                     transform=transforms.Compose([transforms.ToTensor(), Flatten()]))\n",
    "\n",
    "# Set input/output dimensions\n",
    "d_in = 784\n",
    "d_out = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "brazilian-postcard",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Number of features to select\n",
    "max_features = 50"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "potential-anniversary",
   "metadata": {},
   "source": [
    "# Global FS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "separated-beatles",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.5627\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.5968\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.6331\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.6248\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.6484\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.6632\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.6592\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.6696\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.6889\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.6846\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.6811\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.7010\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.7036\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.7112\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.7114\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.7171\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.7217\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.7239\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.7289\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.7255\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = -0.7400\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = -0.7509\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = -0.7320\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = -0.7374\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = -0.7444\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = -0.7567\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = -0.7575\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = -0.7676\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = -0.7731\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = -0.7771\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = -0.7767\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = -0.7931\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = -0.7814\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = -0.8024\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = -0.8003\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = -0.8018\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = -0.7999\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = -0.8021\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = -0.8114\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = -0.8061\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = -0.8016\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = -0.8163\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = -0.8146\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = -0.8166\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = -0.8277\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = -0.8340\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = -0.8315\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = -0.8330\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = -0.8441\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = -0.8463\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = -0.8435\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = -0.8507\n",
      "\n",
      "--------Epoch 53--------\n",
      "Val loss = -0.8491\n",
      "\n",
      "--------Epoch 54--------\n",
      "Val loss = -0.8596\n",
      "\n",
      "--------Epoch 55--------\n",
      "Val loss = -0.8577\n",
      "\n",
      "--------Epoch 56--------\n",
      "Val loss = -0.8621\n",
      "\n",
      "--------Epoch 57--------\n",
      "Val loss = -0.8544\n",
      "\n",
      "--------Epoch 58--------\n",
      "Val loss = -0.8645\n",
      "\n",
      "--------Epoch 59--------\n",
      "Val loss = -0.8649\n",
      "\n",
      "--------Epoch 60--------\n",
      "Val loss = -0.8661\n",
      "\n",
      "--------Epoch 61--------\n",
      "Val loss = -0.8699\n",
      "\n",
      "--------Epoch 62--------\n",
      "Val loss = -0.8666\n",
      "\n",
      "--------Epoch 63--------\n",
      "Val loss = -0.8756\n",
      "\n",
      "--------Epoch 64--------\n",
      "Val loss = -0.8741\n",
      "\n",
      "--------Epoch 65--------\n",
      "Val loss = -0.8736\n",
      "\n",
      "--------Epoch 66--------\n",
      "Val loss = -0.8686\n",
      "\n",
      "--------Epoch 67--------\n",
      "Val loss = -0.8760\n",
      "\n",
      "--------Epoch 68--------\n",
      "Val loss = -0.8811\n",
      "\n",
      "--------Epoch 69--------\n",
      "Val loss = -0.8850\n",
      "\n",
      "--------Epoch 70--------\n",
      "Val loss = -0.8801\n",
      "\n",
      "--------Epoch 71--------\n",
      "Val loss = -0.8819\n",
      "\n",
      "--------Epoch 72--------\n",
      "Val loss = -0.8873\n",
      "\n",
      "--------Epoch 73--------\n",
      "Val loss = -0.8852\n",
      "\n",
      "--------Epoch 74--------\n",
      "Val loss = -0.8894\n",
      "\n",
      "--------Epoch 75--------\n",
      "Val loss = -0.8896\n",
      "\n",
      "--------Epoch 76--------\n",
      "Val loss = -0.8916\n",
      "\n",
      "--------Epoch 77--------\n",
      "Val loss = -0.8901\n",
      "\n",
      "--------Epoch 78--------\n",
      "Val loss = -0.8956\n",
      "\n",
      "--------Epoch 79--------\n",
      "Val loss = -0.9006\n",
      "\n",
      "--------Epoch 80--------\n",
      "Val loss = -0.8973\n",
      "\n",
      "--------Epoch 81--------\n",
      "Val loss = -0.9007\n",
      "\n",
      "--------Epoch 82--------\n",
      "Val loss = -0.9004\n",
      "\n",
      "--------Epoch 83--------\n",
      "Val loss = -0.8993\n",
      "\n",
      "--------Epoch 84--------\n",
      "Val loss = -0.9069\n",
      "\n",
      "--------Epoch 85--------\n",
      "Val loss = -0.9058\n",
      "\n",
      "--------Epoch 86--------\n",
      "Val loss = -0.8988\n",
      "\n",
      "--------Epoch 87--------\n",
      "Val loss = -0.9103\n",
      "\n",
      "--------Epoch 88--------\n",
      "Val loss = -0.9061\n",
      "\n",
      "--------Epoch 89--------\n",
      "Val loss = -0.9109\n",
      "\n",
      "--------Epoch 90--------\n",
      "Val loss = -0.9074\n",
      "\n",
      "--------Epoch 91--------\n",
      "Val loss = -0.9101\n",
      "\n",
      "--------Epoch 92--------\n",
      "Val loss = -0.9125\n",
      "\n",
      "--------Epoch 93--------\n",
      "Val loss = -0.9113\n",
      "\n",
      "--------Epoch 94--------\n",
      "Val loss = -0.9134\n",
      "\n",
      "--------Epoch 95--------\n",
      "Val loss = -0.9188\n",
      "\n",
      "--------Epoch 96--------\n",
      "Val loss = -0.9206\n",
      "\n",
      "--------Epoch 97--------\n",
      "Val loss = -0.9146\n",
      "\n",
      "--------Epoch 98--------\n",
      "Val loss = -0.9175\n",
      "\n",
      "--------Epoch 99--------\n",
      "Val loss = -0.9169\n",
      "\n",
      "--------Epoch 100--------\n",
      "Val loss = -0.9192\n",
      "\n",
      "--------Epoch 101--------\n",
      "Val loss = -0.9200\n",
      "\n",
      "--------Epoch 102--------\n",
      "Val loss = -0.9196\n",
      "\n",
      "--------Epoch 103--------\n",
      "Val loss = -0.9196\n",
      "\n",
      "--------Epoch 104--------\n",
      "Val loss = -0.9231\n",
      "\n",
      "--------Epoch 105--------\n",
      "Val loss = -0.9208\n",
      "\n",
      "--------Epoch 106--------\n",
      "Val loss = -0.9227\n",
      "\n",
      "--------Epoch 107--------\n",
      "Val loss = -0.9201\n",
      "\n",
      "--------Epoch 108--------\n",
      "Val loss = -0.9259\n",
      "\n",
      "--------Epoch 109--------\n",
      "Val loss = -0.9315\n",
      "\n",
      "--------Epoch 110--------\n",
      "Val loss = -0.9268\n",
      "\n",
      "--------Epoch 111--------\n",
      "Val loss = -0.9276\n",
      "\n",
      "--------Epoch 112--------\n",
      "Val loss = -0.9258\n",
      "\n",
      "--------Epoch 113--------\n",
      "Val loss = -0.9320\n",
      "\n",
      "--------Epoch 114--------\n",
      "Val loss = -0.9291\n",
      "\n",
      "--------Epoch 115--------\n",
      "Val loss = -0.9353\n",
      "\n",
      "--------Epoch 116--------\n",
      "Val loss = -0.9288\n",
      "\n",
      "--------Epoch 117--------\n",
      "Val loss = -0.9250\n",
      "\n",
      "--------Epoch 118--------\n",
      "Val loss = -0.9337\n",
      "\n",
      "--------Epoch 119--------\n",
      "Val loss = -0.9350\n",
      "\n",
      "--------Epoch 120--------\n",
      "Val loss = -0.9319\n",
      "\n",
      "--------Epoch 121--------\n",
      "Val loss = -0.9348\n",
      "\n",
      "--------Epoch 122--------\n",
      "Val loss = -0.9311\n",
      "\n",
      "--------Epoch 123--------\n",
      "Val loss = -0.9340\n",
      "\n",
      "--------Epoch 124--------\n",
      "Val loss = -0.9326\n",
      "\n",
      "--------Epoch 125--------\n",
      "Val loss = -0.9337\n",
      "\n",
      "--------Epoch 126--------\n",
      "Val loss = -0.9375\n",
      "\n",
      "--------Epoch 127--------\n",
      "Val loss = -0.9389\n",
      "\n",
      "--------Epoch 128--------\n",
      "Val loss = -0.9393\n",
      "\n",
      "--------Epoch 129--------\n",
      "Val loss = -0.9362\n",
      "\n",
      "--------Epoch 130--------\n",
      "Val loss = -0.9380\n",
      "\n",
      "--------Epoch 131--------\n",
      "Val loss = -0.9365\n",
      "\n",
      "--------Epoch 132--------\n",
      "Val loss = -0.9392\n",
      "\n",
      "--------Epoch 133--------\n",
      "Val loss = -0.9369\n",
      "\n",
      "--------Epoch 134--------\n",
      "Val loss = -0.9405\n",
      "\n",
      "--------Epoch 135--------\n",
      "Val loss = -0.9404\n",
      "\n",
      "--------Epoch 136--------\n",
      "Val loss = -0.9415\n",
      "\n",
      "--------Epoch 137--------\n",
      "Val loss = -0.9400\n",
      "\n",
      "--------Epoch 138--------\n",
      "Val loss = -0.9421\n",
      "\n",
      "--------Epoch 139--------\n",
      "Val loss = -0.9391\n",
      "\n",
      "--------Epoch 140--------\n",
      "Val loss = -0.9389\n",
      "\n",
      "--------Epoch 141--------\n",
      "Val loss = -0.9413\n",
      "\n",
      "--------Epoch 142--------\n",
      "Val loss = -0.9414\n",
      "\n",
      "--------Epoch 143--------\n",
      "Val loss = -0.9421\n",
      "\n",
      "--------Epoch 144--------\n",
      "Val loss = -0.9395\n",
      "\n",
      "--------Epoch 145--------\n",
      "Val loss = -0.9414\n",
      "\n",
      "--------Epoch 146--------\n",
      "Val loss = -0.9414\n",
      "\n",
      "--------Epoch 147--------\n",
      "Val loss = -0.9451\n",
      "\n",
      "--------Epoch 148--------\n",
      "Val loss = -0.9423\n",
      "\n",
      "--------Epoch 149--------\n",
      "Val loss = -0.9430\n",
      "\n",
      "--------Epoch 150--------\n",
      "Val loss = -0.9403\n",
      "\n",
      "--------Epoch 151--------\n",
      "Val loss = -0.9421\n",
      "\n",
      "--------Epoch 152--------\n",
      "Val loss = -0.9446\n",
      "\n",
      "--------Epoch 153--------\n",
      "Val loss = -0.9420\n",
      "\n",
      "--------Epoch 154--------\n",
      "Val loss = -0.9444\n",
      "\n",
      "--------Epoch 155--------\n",
      "Val loss = -0.9480\n",
      "\n",
      "--------Epoch 156--------\n",
      "Val loss = -0.9473\n",
      "\n",
      "--------Epoch 157--------\n",
      "Val loss = -0.9457\n",
      "\n",
      "--------Epoch 158--------\n",
      "Val loss = -0.9444\n",
      "\n",
      "--------Epoch 159--------\n",
      "Val loss = -0.9453\n",
      "\n",
      "--------Epoch 160--------\n",
      "Val loss = -0.9413\n",
      "\n",
      "--------Epoch 161--------\n",
      "Val loss = -0.9471\n",
      "\n",
      "--------Epoch 162--------\n",
      "Val loss = -0.9474\n",
      "\n",
      "--------Epoch 163--------\n",
      "Val loss = -0.9443\n",
      "\n",
      "--------Epoch 164--------\n",
      "Val loss = -0.9477\n",
      "\n",
      "--------Epoch 165--------\n",
      "Val loss = -0.9483\n",
      "\n",
      "--------Epoch 166--------\n",
      "Val loss = -0.9444\n",
      "\n",
      "--------Epoch 167--------\n",
      "Val loss = -0.9448\n",
      "\n",
      "--------Epoch 168--------\n",
      "Val loss = -0.9464\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up model\n",
    "global_model = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_out))\n",
    "selector = ConcreteMask(d_in, max_features, append=True)\n",
    "globalfs = GlobalSelector(global_model, selector).to(device)\n",
    "\n",
    "# Train\n",
    "globalfs.fit(train_dataset,\n",
    "             val_dataset,\n",
    "             mbsize=128,\n",
    "             lr=1e-3,\n",
    "             nepochs=250,\n",
    "             loss_fn=nn.CrossEntropyLoss(),\n",
    "             val_loss_fn=NegAccuracy(),\n",
    "             start_temp=1.0,\n",
    "             end_temp=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "exterior-attraction",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test acc = 95.07\n"
     ]
    }
   ],
   "source": [
    "test_acc = globalfs.evaluate(test_dataset, Accuracy(), 1024)\n",
    "print(f'Test acc = {100*test_acc:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "joined-canadian",
   "metadata": {},
   "outputs": [],
   "source": [
    "inds = selector.logits.argmax(dim=1).cpu().data.numpy()\n",
    "zeros = np.zeros(784)\n",
    "zeros[inds] = 1\n",
    "plt.figure()\n",
    "plt.imshow(zeros.reshape(28, 28))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "wound-melissa",
   "metadata": {},
   "source": [
    "# Greedy adaptive FS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "emerging-trading",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Set up model\n",
    "model = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_out))\n",
    "mask_layer = MaskLayer(append=True)\n",
    "pretrain = Pretrainer(model, mask_layer).to(device)\n",
    "\n",
    "# Pretrain\n",
    "pretrain.fit(train_dataset,\n",
    "             val_dataset,\n",
    "             mbsize=128,\n",
    "             lr=1e-3,\n",
    "             nepochs=100,\n",
    "             max_features=max_features,\n",
    "             loss_fn=nn.CrossEntropyLoss(),\n",
    "             val_loss_fn=NegAccuracy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "specified-principal",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Set up selector\n",
    "selector = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_in))\n",
    "selector_layer = ConcreteSelector()\n",
    "gafs = GreedyAdaptiveFS(selector, deepcopy(model), mask_layer, selector_layer).to(device)\n",
    "\n",
    "# Train\n",
    "gafs.fit(train_dataset,\n",
    "         val_dataset,\n",
    "         mbsize=128,\n",
    "         lr=2e-4,\n",
    "         nepochs=250,\n",
    "         max_features=max_features,\n",
    "         loss_fn=nn.CrossEntropyLoss(),\n",
    "         val_loss_fn=NegAccuracy(),\n",
    "         start_temp=10.0,\n",
    "         end_temp=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "speaking-clone",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test acc = 96.75\n"
     ]
    }
   ],
   "source": [
    "test_acc = gafs.evaluate(test_dataset, max_features, Accuracy(), 1024)\n",
    "print(f'Test acc = {100*test_acc:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "crazy-closure",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
