{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "average-genealogy",
   "metadata": {},
   "source": [
    "# Implementing REBAR\n",
    "\n",
    "- Rather than using the Concrete distribution, or REINFORCE, this version uses REBAR gradients.\n",
    "- It doesn't work as well as the Concrete distribution on its own."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "pharmaceutical-professor",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n",
    "from torchvision import transforms\n",
    "from torchvision.datasets import MNIST\n",
    "from copy import deepcopy\n",
    "from utils import *\n",
    "from models import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "unknown-danger",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda', 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ambient-stretch",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load data\n",
    "class Flatten(object):\n",
    "    def __call__(self, pic):\n",
    "        return torch.flatten(pic)\n",
    "    \n",
    "mnist_dataset = MNIST('/tmp/mnist/', download=True, train=True,\n",
    "                      transform=transforms.Compose([transforms.ToTensor(), Flatten()]))\n",
    "images = mnist_dataset.data\n",
    "targets = mnist_dataset.targets\n",
    "np.random.seed(0)\n",
    "val_inds = np.sort(np.random.choice(len(images), size=10000, replace=False))\n",
    "train_inds = np.setdiff1d(np.arange(len(images)), val_inds)\n",
    "\n",
    "# Training dataset\n",
    "train_dataset = torch.utils.data.Subset(mnist_dataset, train_inds)\n",
    "\n",
    "# Validation dataset\n",
    "val_dataset = torch.utils.data.Subset(mnist_dataset, val_inds)\n",
    "\n",
    "# Test dataset\n",
    "test_dataset = MNIST('/tmp/mnist/', download=True, train=False,\n",
    "                     transform=transforms.Compose([transforms.ToTensor(), Flatten()]))\n",
    "\n",
    "# Set input/output dimensions\n",
    "d_in = 784\n",
    "d_out = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "direct-indian",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Number of features to select\n",
    "max_features = 10"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "binary-sauce",
   "metadata": {},
   "source": [
    "# Global FS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "specific-collector",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.2784\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.3027\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.3163\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.3135\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.3233\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.3311\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.3360\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.3385\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.3384\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.3483\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.3527\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.3572\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.3664\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.3725\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.3680\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.3907\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.3887\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.3892\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.3890\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.3892\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = -0.3958\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = -0.4029\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = -0.4028\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = -0.4077\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = -0.4156\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = -0.4155\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = -0.4122\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = -0.4210\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = -0.4167\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = -0.4335\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = -0.4385\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = -0.4425\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = -0.4521\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = -0.4586\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = -0.4537\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = -0.4521\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = -0.4721\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = -0.4660\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = -0.4641\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = -0.4746\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = -0.4764\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = -0.4829\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = -0.4895\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = -0.4829\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = -0.4968\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = -0.4955\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = -0.4955\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = -0.5045\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = -0.4973\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = -0.5012\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = -0.5029\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = -0.5164\n",
      "\n",
      "--------Epoch 53--------\n",
      "Val loss = -0.5193\n",
      "\n",
      "--------Epoch 54--------\n",
      "Val loss = -0.5162\n",
      "\n",
      "--------Epoch 55--------\n",
      "Val loss = -0.5263\n",
      "\n",
      "--------Epoch 56--------\n",
      "Val loss = -0.5233\n",
      "\n",
      "--------Epoch 57--------\n",
      "Val loss = -0.5297\n",
      "\n",
      "--------Epoch 58--------\n",
      "Val loss = -0.5343\n",
      "\n",
      "--------Epoch 59--------\n",
      "Val loss = -0.5290\n",
      "\n",
      "--------Epoch 60--------\n",
      "Val loss = -0.5391\n",
      "\n",
      "--------Epoch 61--------\n",
      "Val loss = -0.5441\n",
      "\n",
      "--------Epoch 62--------\n",
      "Val loss = -0.5421\n",
      "\n",
      "--------Epoch 63--------\n",
      "Val loss = -0.5548\n",
      "\n",
      "--------Epoch 64--------\n",
      "Val loss = -0.5471\n",
      "\n",
      "--------Epoch 65--------\n",
      "Val loss = -0.5542\n",
      "\n",
      "--------Epoch 66--------\n",
      "Val loss = -0.5498\n",
      "\n",
      "--------Epoch 67--------\n",
      "Val loss = -0.5571\n",
      "\n",
      "--------Epoch 68--------\n",
      "Val loss = -0.5630\n",
      "\n",
      "--------Epoch 69--------\n",
      "Val loss = -0.5636\n",
      "\n",
      "--------Epoch 70--------\n",
      "Val loss = -0.5668\n",
      "\n",
      "--------Epoch 71--------\n",
      "Val loss = -0.5831\n",
      "\n",
      "--------Epoch 72--------\n",
      "Val loss = -0.5773\n",
      "\n",
      "--------Epoch 73--------\n",
      "Val loss = -0.5805\n",
      "\n",
      "--------Epoch 74--------\n",
      "Val loss = -0.5752\n",
      "\n",
      "--------Epoch 75--------\n",
      "Val loss = -0.5857\n",
      "\n",
      "--------Epoch 76--------\n",
      "Val loss = -0.5873\n",
      "\n",
      "--------Epoch 77--------\n",
      "Val loss = -0.5860\n",
      "\n",
      "--------Epoch 78--------\n",
      "Val loss = -0.5881\n",
      "\n",
      "--------Epoch 79--------\n",
      "Val loss = -0.5890\n",
      "\n",
      "--------Epoch 80--------\n",
      "Val loss = -0.5976\n",
      "\n",
      "--------Epoch 81--------\n",
      "Val loss = -0.6014\n",
      "\n",
      "--------Epoch 82--------\n",
      "Val loss = -0.5970\n",
      "\n",
      "--------Epoch 83--------\n",
      "Val loss = -0.6078\n",
      "\n",
      "--------Epoch 84--------\n",
      "Val loss = -0.5973\n",
      "\n",
      "--------Epoch 85--------\n",
      "Val loss = -0.6001\n",
      "\n",
      "--------Epoch 86--------\n",
      "Val loss = -0.6065\n",
      "\n",
      "--------Epoch 87--------\n",
      "Val loss = -0.6104\n",
      "\n",
      "--------Epoch 88--------\n",
      "Val loss = -0.6076\n",
      "\n",
      "--------Epoch 89--------\n",
      "Val loss = -0.6174\n",
      "\n",
      "--------Epoch 90--------\n",
      "Val loss = -0.6211\n",
      "\n",
      "--------Epoch 91--------\n",
      "Val loss = -0.6237\n",
      "\n",
      "--------Epoch 92--------\n",
      "Val loss = -0.6179\n",
      "\n",
      "--------Epoch 93--------\n",
      "Val loss = -0.6230\n",
      "\n",
      "--------Epoch 94--------\n",
      "Val loss = -0.6272\n",
      "\n",
      "--------Epoch 95--------\n",
      "Val loss = -0.6279\n",
      "\n",
      "--------Epoch 96--------\n",
      "Val loss = -0.6319\n",
      "\n",
      "--------Epoch 97--------\n",
      "Val loss = -0.6308\n",
      "\n",
      "--------Epoch 98--------\n",
      "Val loss = -0.6394\n",
      "\n",
      "--------Epoch 99--------\n",
      "Val loss = -0.6314\n",
      "\n",
      "--------Epoch 100--------\n",
      "Val loss = -0.6386\n",
      "\n",
      "--------Epoch 101--------\n",
      "Val loss = -0.6344\n",
      "\n",
      "--------Epoch 102--------\n",
      "Val loss = -0.6460\n",
      "\n",
      "--------Epoch 103--------\n",
      "Val loss = -0.6413\n",
      "\n",
      "--------Epoch 104--------\n",
      "Val loss = -0.6449\n",
      "\n",
      "--------Epoch 105--------\n",
      "Val loss = -0.6454\n",
      "\n",
      "--------Epoch 106--------\n",
      "Val loss = -0.6518\n",
      "\n",
      "--------Epoch 107--------\n",
      "Val loss = -0.6510\n",
      "\n",
      "--------Epoch 108--------\n",
      "Val loss = -0.6643\n",
      "\n",
      "--------Epoch 109--------\n",
      "Val loss = -0.6637\n",
      "\n",
      "--------Epoch 110--------\n",
      "Val loss = -0.6637\n",
      "\n",
      "--------Epoch 111--------\n",
      "Val loss = -0.6703\n",
      "\n",
      "--------Epoch 112--------\n",
      "Val loss = -0.6676\n",
      "\n",
      "--------Epoch 113--------\n",
      "Val loss = -0.6645\n",
      "\n",
      "--------Epoch 114--------\n",
      "Val loss = -0.6679\n",
      "\n",
      "--------Epoch 115--------\n",
      "Val loss = -0.6744\n",
      "\n",
      "--------Epoch 116--------\n",
      "Val loss = -0.6808\n",
      "\n",
      "--------Epoch 117--------\n",
      "Val loss = -0.6833\n",
      "\n",
      "--------Epoch 118--------\n",
      "Val loss = -0.6864\n",
      "\n",
      "--------Epoch 119--------\n",
      "Val loss = -0.6874\n",
      "\n",
      "--------Epoch 120--------\n",
      "Val loss = -0.6904\n",
      "\n",
      "--------Epoch 121--------\n",
      "Val loss = -0.7023\n",
      "\n",
      "--------Epoch 122--------\n",
      "Val loss = -0.6994\n",
      "\n",
      "--------Epoch 123--------\n",
      "Val loss = -0.7032\n",
      "\n",
      "--------Epoch 124--------\n",
      "Val loss = -0.7023\n",
      "\n",
      "--------Epoch 125--------\n",
      "Val loss = -0.7022\n",
      "\n",
      "--------Epoch 126--------\n",
      "Val loss = -0.7119\n",
      "\n",
      "--------Epoch 127--------\n",
      "Val loss = -0.7087\n",
      "\n",
      "--------Epoch 128--------\n",
      "Val loss = -0.7163\n",
      "\n",
      "--------Epoch 129--------\n",
      "Val loss = -0.7172\n",
      "\n",
      "--------Epoch 130--------\n",
      "Val loss = -0.7254\n",
      "\n",
      "--------Epoch 131--------\n",
      "Val loss = -0.7179\n",
      "\n",
      "--------Epoch 132--------\n",
      "Val loss = -0.7208\n",
      "\n",
      "--------Epoch 133--------\n",
      "Val loss = -0.7241\n",
      "\n",
      "--------Epoch 134--------\n",
      "Val loss = -0.7209\n",
      "\n",
      "--------Epoch 135--------\n",
      "Val loss = -0.7370\n",
      "\n",
      "--------Epoch 136--------\n",
      "Val loss = -0.7313\n",
      "\n",
      "--------Epoch 137--------\n",
      "Val loss = -0.7375\n",
      "\n",
      "--------Epoch 138--------\n",
      "Val loss = -0.7392\n",
      "\n",
      "--------Epoch 139--------\n",
      "Val loss = -0.7361\n",
      "\n",
      "--------Epoch 140--------\n",
      "Val loss = -0.7336\n",
      "\n",
      "--------Epoch 141--------\n",
      "Val loss = -0.7424\n",
      "\n",
      "--------Epoch 142--------\n",
      "Val loss = -0.7446\n",
      "\n",
      "--------Epoch 143--------\n",
      "Val loss = -0.7476\n",
      "\n",
      "--------Epoch 144--------\n",
      "Val loss = -0.7443\n",
      "\n",
      "--------Epoch 145--------\n",
      "Val loss = -0.7492\n",
      "\n",
      "--------Epoch 146--------\n",
      "Val loss = -0.7484\n",
      "\n",
      "--------Epoch 147--------\n",
      "Val loss = -0.7477\n",
      "\n",
      "--------Epoch 148--------\n",
      "Val loss = -0.7531\n",
      "\n",
      "--------Epoch 149--------\n",
      "Val loss = -0.7550\n",
      "\n",
      "--------Epoch 150--------\n",
      "Val loss = -0.7459\n",
      "\n",
      "--------Epoch 151--------\n",
      "Val loss = -0.7557\n",
      "\n",
      "--------Epoch 152--------\n",
      "Val loss = -0.7596\n",
      "\n",
      "--------Epoch 153--------\n",
      "Val loss = -0.7547\n",
      "\n",
      "--------Epoch 154--------\n",
      "Val loss = -0.7558\n",
      "\n",
      "--------Epoch 155--------\n",
      "Val loss = -0.7533\n",
      "\n",
      "--------Epoch 156--------\n",
      "Val loss = -0.7594\n",
      "\n",
      "--------Epoch 157--------\n",
      "Val loss = -0.7555\n",
      "\n",
      "--------Epoch 158--------\n",
      "Val loss = -0.7601\n",
      "\n",
      "--------Epoch 159--------\n",
      "Val loss = -0.7571\n",
      "\n",
      "--------Epoch 160--------\n",
      "Val loss = -0.7569\n",
      "\n",
      "--------Epoch 161--------\n",
      "Val loss = -0.7645\n",
      "\n",
      "--------Epoch 162--------\n",
      "Val loss = -0.7643\n",
      "\n",
      "--------Epoch 163--------\n",
      "Val loss = -0.7651\n",
      "\n",
      "--------Epoch 164--------\n",
      "Val loss = -0.7617\n",
      "\n",
      "--------Epoch 165--------\n",
      "Val loss = -0.7621\n",
      "\n",
      "--------Epoch 166--------\n",
      "Val loss = -0.7577\n",
      "\n",
      "--------Epoch 167--------\n",
      "Val loss = -0.7614\n",
      "\n",
      "--------Epoch 168--------\n",
      "Val loss = -0.7680\n",
      "\n",
      "--------Epoch 169--------\n",
      "Val loss = -0.7664\n",
      "\n",
      "--------Epoch 170--------\n",
      "Val loss = -0.7599\n",
      "\n",
      "--------Epoch 171--------\n",
      "Val loss = -0.7698\n",
      "\n",
      "--------Epoch 172--------\n",
      "Val loss = -0.7665\n",
      "\n",
      "--------Epoch 173--------\n",
      "Val loss = -0.7662\n",
      "\n",
      "--------Epoch 174--------\n",
      "Val loss = -0.7671\n",
      "\n",
      "--------Epoch 175--------\n",
      "Val loss = -0.7698\n",
      "\n",
      "--------Epoch 176--------\n",
      "Val loss = -0.7683\n",
      "\n",
      "--------Epoch 177--------\n",
      "Val loss = -0.7670\n",
      "\n",
      "--------Epoch 178--------\n",
      "Val loss = -0.7708\n",
      "\n",
      "--------Epoch 179--------\n",
      "Val loss = -0.7715\n",
      "\n",
      "--------Epoch 180--------\n",
      "Val loss = -0.7667\n",
      "\n",
      "--------Epoch 181--------\n",
      "Val loss = -0.7700\n",
      "\n",
      "--------Epoch 182--------\n",
      "Val loss = -0.7649\n",
      "\n",
      "--------Epoch 183--------\n",
      "Val loss = -0.7678\n",
      "\n",
      "--------Epoch 184--------\n",
      "Val loss = -0.7726\n",
      "\n",
      "--------Epoch 185--------\n",
      "Val loss = -0.7715\n",
      "\n",
      "--------Epoch 186--------\n",
      "Val loss = -0.7687\n",
      "\n",
      "--------Epoch 187--------\n",
      "Val loss = -0.7702\n",
      "\n",
      "--------Epoch 188--------\n",
      "Val loss = -0.7714\n",
      "\n",
      "--------Epoch 189--------\n",
      "Val loss = -0.7716\n",
      "\n",
      "--------Epoch 190--------\n",
      "Val loss = -0.7752\n",
      "\n",
      "--------Epoch 191--------\n",
      "Val loss = -0.7674\n",
      "\n",
      "--------Epoch 192--------\n",
      "Val loss = -0.7666\n",
      "\n",
      "--------Epoch 193--------\n",
      "Val loss = -0.7711\n",
      "\n",
      "--------Epoch 194--------\n",
      "Val loss = -0.7697\n",
      "\n",
      "--------Epoch 195--------\n",
      "Val loss = -0.7713\n",
      "\n",
      "--------Epoch 196--------\n",
      "Val loss = -0.7722\n",
      "\n",
      "--------Epoch 197--------\n",
      "Val loss = -0.7679\n",
      "\n",
      "--------Epoch 198--------\n",
      "Val loss = -0.7727\n",
      "\n",
      "--------Epoch 199--------\n",
      "Val loss = -0.7715\n",
      "\n",
      "--------Epoch 200--------\n",
      "Val loss = -0.7702\n",
      "\n",
      "--------Epoch 201--------\n",
      "Val loss = -0.7697\n",
      "\n",
      "--------Epoch 202--------\n",
      "Val loss = -0.7641\n",
      "\n",
      "--------Epoch 203--------\n",
      "Val loss = -0.7673\n",
      "\n",
      "--------Epoch 204--------\n",
      "Val loss = -0.7722\n",
      "\n",
      "--------Epoch 205--------\n",
      "Val loss = -0.7680\n",
      "\n",
      "--------Epoch 206--------\n",
      "Val loss = -0.7697\n",
      "\n",
      "--------Epoch 207--------\n",
      "Val loss = -0.7655\n",
      "\n",
      "--------Epoch 208--------\n",
      "Val loss = -0.7690\n",
      "\n",
      "--------Epoch 209--------\n",
      "Val loss = -0.7661\n",
      "\n",
      "--------Epoch 210--------\n",
      "Val loss = -0.7697\n",
      "\n",
      "--------Epoch 211--------\n",
      "Val loss = -0.7680\n",
      "\n",
      "--------Epoch 212--------\n",
      "Val loss = -0.7696\n",
      "\n",
      "--------Epoch 213--------\n",
      "Val loss = -0.7702\n",
      "\n",
      "--------Epoch 214--------\n",
      "Val loss = -0.7702\n",
      "\n",
      "--------Epoch 215--------\n",
      "Val loss = -0.7649\n",
      "\n",
      "--------Epoch 216--------\n",
      "Val loss = -0.7694\n",
      "\n",
      "--------Epoch 217--------\n",
      "Val loss = -0.7644\n",
      "\n",
      "--------Epoch 218--------\n",
      "Val loss = -0.7637\n",
      "\n",
      "--------Epoch 219--------\n",
      "Val loss = -0.7652\n",
      "\n",
      "--------Epoch 220--------\n",
      "Val loss = -0.7619\n",
      "\n",
      "--------Epoch 221--------\n",
      "Val loss = -0.7651\n",
      "\n",
      "--------Epoch 222--------\n",
      "Val loss = -0.7662\n",
      "\n",
      "--------Epoch 223--------\n",
      "Val loss = -0.7618\n",
      "\n",
      "--------Epoch 224--------\n",
      "Val loss = -0.7597\n",
      "\n",
      "--------Epoch 225--------\n",
      "Val loss = -0.7637\n",
      "\n",
      "--------Epoch 226--------\n",
      "Val loss = -0.7593\n",
      "\n",
      "--------Epoch 227--------\n",
      "Val loss = -0.7636\n",
      "\n",
      "--------Epoch 228--------\n",
      "Val loss = -0.7595\n",
      "\n",
      "--------Epoch 229--------\n",
      "Val loss = -0.7605\n",
      "\n",
      "--------Epoch 230--------\n",
      "Val loss = -0.7624\n",
      "\n",
      "--------Epoch 231--------\n",
      "Val loss = -0.7619\n",
      "\n",
      "--------Epoch 232--------\n",
      "Val loss = -0.7577\n",
      "\n",
      "--------Epoch 233--------\n",
      "Val loss = -0.7576\n",
      "\n",
      "--------Epoch 234--------\n",
      "Val loss = -0.7611\n",
      "\n",
      "--------Epoch 235--------\n",
      "Val loss = -0.7623\n",
      "\n",
      "--------Epoch 236--------\n",
      "Val loss = -0.7595\n",
      "\n",
      "--------Epoch 237--------\n",
      "Val loss = -0.7576\n",
      "\n",
      "--------Epoch 238--------\n",
      "Val loss = -0.7573\n",
      "\n",
      "--------Epoch 239--------\n",
      "Val loss = -0.7533\n",
      "\n",
      "--------Epoch 240--------\n",
      "Val loss = -0.7593\n",
      "\n",
      "--------Epoch 241--------\n",
      "Val loss = -0.7569\n",
      "\n",
      "--------Epoch 242--------\n",
      "Val loss = -0.7515\n",
      "\n",
      "--------Epoch 243--------\n",
      "Val loss = -0.7569\n",
      "\n",
      "--------Epoch 244--------\n",
      "Val loss = -0.7563\n",
      "\n",
      "--------Epoch 245--------\n",
      "Val loss = -0.7564\n",
      "\n",
      "--------Epoch 246--------\n",
      "Val loss = -0.7576\n",
      "\n",
      "--------Epoch 247--------\n",
      "Val loss = -0.7576\n",
      "\n",
      "--------Epoch 248--------\n",
      "Val loss = -0.7483\n",
      "\n",
      "--------Epoch 249--------\n",
      "Val loss = -0.7524\n",
      "\n",
      "--------Epoch 250--------\n",
      "Val loss = -0.7518\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up model\n",
    "global_model = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_out))\n",
    "selector = ConcreteMask(d_in, max_features, append=True)\n",
    "globalfs = GlobalSelector(global_model, selector).to(device)\n",
    "\n",
    "# Train\n",
    "globalfs.fit(train_dataset,\n",
    "             val_dataset,\n",
    "             mbsize=128,\n",
    "             lr=1e-3,\n",
    "             nepochs=250,\n",
    "             loss_fn=nn.CrossEntropyLoss(),\n",
    "             val_loss_fn=NegAccuracy(),\n",
    "             start_temp=1.0,\n",
    "             end_temp=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "figured-pregnancy",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test acc = 77.86\n"
     ]
    }
   ],
   "source": [
    "test_acc = globalfs.evaluate(test_dataset, Accuracy(), 1024)\n",
    "print(f'Test acc = {100*test_acc:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "recreational-vault",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAK3klEQVR4nO3dT6hc93mH8edbV1ZASUCqa6M6bpMGL2oKVcpFKbgUF9PU8cbOIiVeBBUCyiKGBLKoSRfx0pQmoYsSUGoRtaQOgcS1FqaJEAGTjfG1UW25amvXqIkiITVoYadQWXbeLu5xuZHvP8+c+SO/zwcuM3Nm7p2XQY/O3Dkz95eqQtK7368segBJ82HsUhPGLjVh7FITxi418avzvLMbs7vew5553qXUyv/yP7xeV7LRdVPFnuQe4G+AG4C/q6pHtrr9e9jDR3P3NHcpaQtP18lNr5v4aXySG4C/BT4O3AE8kOSOSX+epNma5nf2g8DLVfVKVb0OfBu4b5yxJI1tmthvBX6y7vK5YdsvSXI4yWqS1atcmeLuJE1jmtg3ehHgbe+9raojVbVSVSu72D3F3UmaxjSxnwNuW3f5A8D56caRNCvTxP4McHuSDyW5EfgUcHycsSSNbeJDb1X1RpIHge+zdujtaFW9ONpkkkY11XH2qnoSeHKkWSTNkG+XlZowdqkJY5eaMHapCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmjB2qYmplmxOchZ4DXgTeKOqVsYYStL4pop98MdV9bMRfo6kGfJpvNTEtLEX8IMkzyY5vNENkhxOsppk9SpXprw7SZOa9mn8nVV1PsnNwIkk/1ZVT62/QVUdAY4AvD/7asr7kzShqfbsVXV+OL0EPA4cHGMoSeObOPYke5K8763zwMeA02MNJmlc0zyNvwV4PMlbP+cfq+qfR5lK7xrfP39q0+v+9DcOzG0OTRF7Vb0C/N6Is0iaIQ+9SU0Yu9SEsUtNGLvUhLFLTYzxQRhpUx5eWx7u2aUmjF1qwtilJoxdasLYpSaMXWrC2KUmjF1qwtilJoxdasLYpSaMXWrC2KUmjF1qwtilJvw8u7a01Z+CBj+vfj1xzy41YexSE8YuNWHsUhPGLjVh7FITxi414XF2bWna4+izXLLZ9wC8M9vu2ZMcTXIpyel12/YlOZHkpeF072zHlDStnTyN/yZwzzXbHgJOVtXtwMnhsqQltm3sVfUUcPmazfcBx4bzx4D7xx1L0tgmfYHulqq6ADCc3rzZDZMcTrKaZPUqVya8O0nTmvmr8VV1pKpWqmplF7tnfXeSNjFp7BeT7AcYTi+NN5KkWZg09uPAoeH8IeCJccaRNCupqq1vkDwG3AXcBFwEvgz8E/Ad4DeBHwOfrKprX8R7m/dnX300d083cUMeT9ZOPV0nebUuZ6Prtn1TTVU9sMlVVitdR3y7rNSEsUtNGLvUhLFLTRi71IQfcb0OeGhNY3DPLjVh7FITxi41YexSE8YuNWHsUhPGLjVh7FITxi41YexSE8YuNWHsUhPGLjVh7FITxi414efZtaVZ/hlr/0T2fLlnl5owdqkJY5eaMHapCWOXmjB2qQljl5rwOLu2NMtj3R5Hn69t9+xJjia5lOT0um0PJ/lpklPD172zHVPStHbyNP6bwD0bbP9aVR0Yvp4cdyxJY9s29qp6Crg8h1kkzdA0L9A9mOT54Wn+3s1ulORwktUkq1e5MsXdSZrGpLF/HfgwcAC4AHxlsxtW1ZGqWqmqlV3snvDuJE1rotir6mJVvVlVvwC+ARwcdyxJY5so9iT71138BHB6s9tKWg7bHmdP8hhwF3BTknPAl4G7khwACjgLfHZ2I0oaw7axV9UDG2x+dAazSJoh3y4rNWHsUhPGLjVh7FITxi414Udcm/PPOffhnl1qwtilJoxdasLYpSaMXWrC2KUmjF1qwuPszXkcvQ/37FITxi41YexSE8YuNWHsUhPGLjVh7FITHme/DviZc43BPbvUhLFLTRi71ISxS00Yu9SEsUtNGLvUhMfZrwMeR9cYtt2zJ7ktyQ+TnEnyYpLPD9v3JTmR5KXhdO/sx5U0qZ08jX8D+GJV/Q7wB8DnktwBPAScrKrbgZPDZUlLatvYq+pCVT03nH8NOAPcCtwHHBtudgy4f0YzShrBO3qBLskHgY8ATwO3VNUFWPsPAbh5k+85nGQ1yepVrkw5rqRJ7Tj2JO8Fvgt8oape3en3VdWRqlqpqpVd7J5kRkkj2FHsSXaxFvq3qup7w+aLSfYP1+8HLs1mRElj2PbQW5IAjwJnquqr6646DhwCHhlOn5jJhO8CfkRVy2Anx9nvBD4NvJDk1LDtS6xF/p0knwF+DHxyJhNKGsW2sVfVj4BscvXd444jaVZ8u6zUhLFLTRi71ISxS00Yu9SEH3GdA4+jaxm4Z5eaMHapCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmjB2qQljl5rYNvYktyX5YZIzSV5M8vlh+8NJfprk1PB17+zHlTSpnSwS8Qbwxap6Lsn7gGeTnBiu+1pV/fXsxpM0lp2sz34BuDCcfy3JGeDWWQ8maVzv6Hf2JB8EPgI8PWx6MMnzSY4m2bvJ9xxOsppk9SpXpptW0sR2HHuS9wLfBb5QVa8CXwc+DBxgbc//lY2+r6qOVNVKVa3sYvf0E0uayI5iT7KLtdC/VVXfA6iqi1X1ZlX9AvgGcHB2Y0qa1k5ejQ/wKHCmqr66bvv+dTf7BHB6/PEkjWUnr8bfCXwaeCHJqWHbl4AHkhwACjgLfHYG80kayU5ejf8RkA2uenL8cSTNiu+gk5owdqkJY5eaMHapCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmjB2qQljl5owdqmJVNX87iz5b+C/1m26CfjZ3AZ4Z5Z1tmWdC5xtUmPO9ltV9esbXTHX2N9258lqVa0sbIAtLOtsyzoXONuk5jWbT+OlJoxdamLRsR9Z8P1vZVlnW9a5wNkmNZfZFvo7u6T5WfSeXdKcGLvUxEJiT3JPkn9P8nKShxYxw2aSnE3ywrAM9eqCZzma5FKS0+u27UtyIslLw+mGa+wtaLalWMZ7i2XGF/rYLXr587n/zp7kBuA/gD8BzgHPAA9U1b/OdZBNJDkLrFTVwt+AkeSPgJ8Df19Vvzts+yvgclU9MvxHubeq/mJJZnsY+Pmil/EeVivav36ZceB+4M9Z4GO3xVx/xhwet0Xs2Q8CL1fVK1X1OvBt4L4FzLH0quop4PI1m+8Djg3nj7H2j2XuNpltKVTVhap6bjj/GvDWMuMLfey2mGsuFhH7rcBP1l0+x3Kt917AD5I8m+TwoofZwC1VdQHW/vEANy94nmttu4z3PF2zzPjSPHaTLH8+rUXEvtFSUst0/O/Oqvp94OPA54anq9qZHS3jPS8bLDO+FCZd/nxai4j9HHDbussfAM4vYI4NVdX54fQS8DjLtxT1xbdW0B1OLy14nv+3TMt4b7TMOEvw2C1y+fNFxP4McHuSDyW5EfgUcHwBc7xNkj3DCyck2QN8jOVbivo4cGg4fwh4YoGz/JJlWcZ7s2XGWfBjt/Dlz6tq7l/Avay9Iv+fwF8uYoZN5vpt4F+GrxcXPRvwGGtP666y9ozoM8CvASeBl4bTfUs02z8ALwDPsxbW/gXN9oes/Wr4PHBq+Lp30Y/dFnPN5XHz7bJSE76DTmrC2KUmjF1qwtilJoxdasLYpSaMXWri/wCKFG0rCtAxvgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "inds = selector.logits.argmax(dim=1).cpu().data.numpy()\n",
    "zeros = np.zeros(784)\n",
    "zeros[inds] = 1\n",
    "plt.figure()\n",
    "plt.imshow(zeros.reshape(28, 28))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "coastal-poison",
   "metadata": {},
   "source": [
    "# Pretrain model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "controlled-handy",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.2442\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.2721\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.2917\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.2911\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.3001\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.3019\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.3061\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.3010\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.3051\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.3075\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.3042\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.3102\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.3138\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.3112\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.3148\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.3101\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.3156\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.3137\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.3180\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.3116\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = -0.3113\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = -0.3154\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = -0.3247\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = -0.3214\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = -0.3166\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = -0.3118\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = -0.3148\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = -0.3262\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = -0.3181\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = -0.3205\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = -0.3142\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = -0.3210\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = -0.3175\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = -0.3241\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = -0.3208\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = -0.3282\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = -0.3205\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = -0.3193\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = -0.3246\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = -0.3249\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = -0.3267\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = -0.3312\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = -0.3232\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = -0.3244\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = -0.3259\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = -0.3285\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = -0.3198\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = -0.3323\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = -0.3267\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = -0.3226\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = -0.3273\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = -0.3327\n",
      "\n",
      "--------Epoch 53--------\n",
      "Val loss = -0.3300\n",
      "\n",
      "--------Epoch 54--------\n",
      "Val loss = -0.3302\n",
      "\n",
      "--------Epoch 55--------\n",
      "Val loss = -0.3315\n",
      "\n",
      "--------Epoch 56--------\n",
      "Val loss = -0.3311\n",
      "\n",
      "--------Epoch 57--------\n",
      "Val loss = -0.3266\n",
      "\n",
      "--------Epoch 58--------\n",
      "Val loss = -0.3313\n",
      "\n",
      "--------Epoch 59--------\n",
      "Val loss = -0.3263\n",
      "\n",
      "--------Epoch 60--------\n",
      "Val loss = -0.3366\n",
      "\n",
      "--------Epoch 61--------\n",
      "Val loss = -0.3265\n",
      "\n",
      "--------Epoch 62--------\n",
      "Val loss = -0.3332\n",
      "\n",
      "--------Epoch 63--------\n",
      "Val loss = -0.3327\n",
      "\n",
      "--------Epoch 64--------\n",
      "Val loss = -0.3324\n",
      "\n",
      "--------Epoch 65--------\n",
      "Val loss = -0.3330\n",
      "\n",
      "--------Epoch 66--------\n",
      "Val loss = -0.3351\n",
      "\n",
      "--------Epoch 67--------\n",
      "Val loss = -0.3260\n",
      "\n",
      "--------Epoch 68--------\n",
      "Val loss = -0.3355\n",
      "\n",
      "--------Epoch 69--------\n",
      "Val loss = -0.3386\n",
      "\n",
      "--------Epoch 70--------\n",
      "Val loss = -0.3362\n",
      "\n",
      "--------Epoch 71--------\n",
      "Val loss = -0.3350\n",
      "\n",
      "--------Epoch 72--------\n",
      "Val loss = -0.3412\n",
      "\n",
      "--------Epoch 73--------\n",
      "Val loss = -0.3364\n",
      "\n",
      "--------Epoch 74--------\n",
      "Val loss = -0.3368\n",
      "\n",
      "--------Epoch 75--------\n",
      "Val loss = -0.3432\n",
      "\n",
      "--------Epoch 76--------\n",
      "Val loss = -0.3431\n",
      "\n",
      "--------Epoch 77--------\n",
      "Val loss = -0.3450\n",
      "\n",
      "--------Epoch 78--------\n",
      "Val loss = -0.3408\n",
      "\n",
      "--------Epoch 79--------\n",
      "Val loss = -0.3403\n",
      "\n",
      "--------Epoch 80--------\n",
      "Val loss = -0.3306\n",
      "\n",
      "--------Epoch 81--------\n",
      "Val loss = -0.3329\n",
      "\n",
      "--------Epoch 82--------\n",
      "Val loss = -0.3358\n",
      "\n",
      "--------Epoch 83--------\n",
      "Val loss = -0.3305\n",
      "\n",
      "--------Epoch 84--------\n",
      "Val loss = -0.3312\n",
      "\n",
      "--------Epoch 85--------\n",
      "Val loss = -0.3388\n",
      "\n",
      "--------Epoch 86--------\n",
      "Val loss = -0.3408\n",
      "\n",
      "--------Epoch 87--------\n",
      "Val loss = -0.3351\n",
      "\n",
      "--------Epoch 88--------\n",
      "Val loss = -0.3396\n",
      "\n",
      "--------Epoch 89--------\n",
      "Val loss = -0.3404\n",
      "\n",
      "--------Epoch 90--------\n",
      "Val loss = -0.3307\n",
      "\n",
      "--------Epoch 91--------\n",
      "Val loss = -0.3463\n",
      "\n",
      "--------Epoch 92--------\n",
      "Val loss = -0.3455\n",
      "\n",
      "--------Epoch 93--------\n",
      "Val loss = -0.3474\n",
      "\n",
      "--------Epoch 94--------\n",
      "Val loss = -0.3406\n",
      "\n",
      "--------Epoch 95--------\n",
      "Val loss = -0.3363\n",
      "\n",
      "--------Epoch 96--------\n",
      "Val loss = -0.3398\n",
      "\n",
      "--------Epoch 97--------\n",
      "Val loss = -0.3459\n",
      "\n",
      "--------Epoch 98--------\n",
      "Val loss = -0.3401\n",
      "\n",
      "--------Epoch 99--------\n",
      "Val loss = -0.3336\n",
      "\n",
      "--------Epoch 100--------\n",
      "Val loss = -0.3420\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up model\n",
    "model = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_out))\n",
    "mask_layer = MaskLayer(append=True)\n",
    "pretrain = Pretrainer(model, mask_layer).to(device)\n",
    "\n",
    "# Pretrain\n",
    "pretrain.fit(train_dataset,\n",
    "             val_dataset,\n",
    "             mbsize=128,\n",
    "             lr=1e-3,\n",
    "             nepochs=100,\n",
    "             max_features=max_features,\n",
    "             loss_fn=nn.CrossEntropyLoss(),\n",
    "             val_loss_fn=NegAccuracy())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "insured-encounter",
   "metadata": {},
   "source": [
    "# REBAR training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "armed-still",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.distributions import Categorical\n",
    "from torch.distributions.utils import clamp_probs\n",
    "from torch.utils.data import DataLoader\n",
    "from utils import restore_parameters\n",
    "from copy import deepcopy\n",
    "\n",
    "\n",
    "def make_onehot(x):\n",
    "    '''Make an approximately one-hot vector one-hot.'''\n",
    "    argmax = torch.argmax(x, dim=1)\n",
    "    onehot = torch.zeros(x.shape, dtype=x.dtype, device=x.device)\n",
    "    onehot[torch.arange(len(x)), argmax] = 1\n",
    "    return onehot\n",
    "\n",
    "\n",
    "def logsoftmax(x, dim):\n",
    "    m = torch.max(x, dim=dim, keepdim=True).values\n",
    "    x = x - m\n",
    "    return x - torch.log(torch.sum(torch.exp(x), dim=dim, keepdim=True))\n",
    "\n",
    "\n",
    "class DiscreteSampler(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        \n",
    "    def forward(self, logits):\n",
    "        dist = Categorical(logits=logits)\n",
    "        inds = dist.sample()\n",
    "        onehot = torch.zeros(logits.shape, dtype=logits.dtype, device=logits.device)\n",
    "        onehot[torch.arange(len(logits)), inds] = 1\n",
    "        return inds, onehot\n",
    "\n",
    "\n",
    "class GreedyAdaptiveFS(nn.Module):\n",
    "    '''\n",
    "    Greedy adaptive feature selection.\n",
    "    '''\n",
    "\n",
    "    def __init__(self, selector, model, mask_layer):\n",
    "        super().__init__()\n",
    "        self.selector = selector\n",
    "        self.model = model\n",
    "        self.mask_layer = mask_layer\n",
    "        self.sampler = DiscreteSampler()\n",
    "    \n",
    "    def fit(self,\n",
    "            train,\n",
    "            val,\n",
    "            mbsize,\n",
    "            lr,\n",
    "            nepochs,\n",
    "            max_features,\n",
    "            loss_fn,\n",
    "            temp=0.1,\n",
    "            eta=1.0,\n",
    "            val_loss_fn=None,\n",
    "            train_model=True,\n",
    "            train_selector=True,\n",
    "            argmax=True,\n",
    "            no_repeats=True,\n",
    "            validation_mode='final',\n",
    "            verbose=True):\n",
    "        '''\n",
    "        Train models.\n",
    "        '''\n",
    "        # Set up data loaders.\n",
    "        train_loader = DataLoader(\n",
    "            train, batch_size=mbsize, shuffle=True, pin_memory=True,\n",
    "            drop_last=True, num_workers=4)\n",
    "        val_loader = DataLoader(\n",
    "            val, batch_size=mbsize, shuffle=False, pin_memory=True,\n",
    "            drop_last=False, num_workers=4)\n",
    "        \n",
    "        # More setup.\n",
    "        selector = self.selector\n",
    "        model = self.model\n",
    "        mask_layer = self.mask_layer\n",
    "        sampler = self.sampler\n",
    "        device = next(model.parameters()).device\n",
    "        assert validation_mode in ('final', 'mean')\n",
    "        assert train_model or train_selector\n",
    "        if train_model:\n",
    "            model_opt = optim.Adam(model.parameters(), lr=lr)\n",
    "        if train_selector:\n",
    "            selector_opt = optim.Adam(selector.parameters(), lr=lr)\n",
    "        if val_loss_fn is None:\n",
    "            val_loss_fn = deepcopy(loss_fn)\n",
    "        \n",
    "        # Ensure correct loss reduction.\n",
    "        if loss_fn.reduction == 'mean':\n",
    "            loss_fn.reduction = 'none'\n",
    "        if val_loss_fn == 'none':\n",
    "            val_loss_fn = 'mean'\n",
    "\n",
    "        # For tracking best model.\n",
    "        best_model = None\n",
    "        best_selector = None\n",
    "        best_loss = float('inf')\n",
    "        \n",
    "        for epoch in range(nepochs):\n",
    "            for x, y in train_loader:\n",
    "                # Move to device.\n",
    "                x = x.to(device)\n",
    "                y = y.to(device)\n",
    "                \n",
    "                # Setup.\n",
    "                m_hard = torch.zeros(x.shape, dtype=x.dtype, device=device)\n",
    "                total_loss = 0\n",
    "                \n",
    "                for i in range(max_features):\n",
    "                    # Evaluate selector model.\n",
    "                    x_masked = mask_layer(x, m_hard)\n",
    "                    logits = selector(x_masked)\n",
    "                    \n",
    "                    # Generate z.\n",
    "                    u = clamp_probs(torch.rand(logits.shape, dtype=logits.dtype, device=logits.device))\n",
    "                    g = - torch.log(- torch.log(u))\n",
    "                    z = logits + g\n",
    "                    \n",
    "                    # Generate b = H(z).\n",
    "                    k = torch.argmax(z, dim=1)\n",
    "                    b = torch.zeros(z.shape, dtype=z.dtype, device=z.device)\n",
    "                    b[torch.arange(len(z)), k] = 1\n",
    "                    \n",
    "                    # Generate tilde_z.\n",
    "                    v = clamp_probs(torch.rand(logits.shape, dtype=logits.dtype, device=logits.device))\n",
    "                    p = clamp_probs(torch.softmax(logits, dim=1))\n",
    "                    vk = torch.unsqueeze(v[torch.arange(len(b)), k], 1)\n",
    "                    tilde_z = - torch.log(- torch.log(vk) - (1 - b) * torch.log(v) / p)\n",
    "                    \n",
    "                    # Evaluate with z.\n",
    "                    sigma_z = torch.softmax(z / temp, dim=1)\n",
    "                    pred_z = model(mask_layer(x, torch.max(m_hard, sigma_z)))\n",
    "                    loss_z = loss_fn(pred_z, y)\n",
    "                    \n",
    "                    # Evaluate with b = H(z).\n",
    "                    pred_b = model(mask_layer(x, torch.max(m_hard, b)))\n",
    "                    loss_b = loss_fn(pred_b, y)\n",
    "                    \n",
    "                    # Evaluate with tilde_z.\n",
    "                    sigma_tilde_z = torch.softmax(tilde_z / temp, dim=1)\n",
    "                    pred_tilde_z = model(mask_layer(x, torch.max(m_hard, sigma_tilde_z)))\n",
    "                    loss_tilde_z = loss_fn(pred_tilde_z, y)\n",
    "                    \n",
    "                    # Prepare rebar loss.\n",
    "                    logprobs = logsoftmax(logits, 1)\n",
    "                    logprobs_b = logprobs[torch.arange(len(b)), k]\n",
    "                    rebar_loss = (\n",
    "                        eta * loss_z\n",
    "                        - eta * loss_tilde_z\n",
    "                        + (loss_b - eta * loss_tilde_z).detach() * logprobs_b)\n",
    "                    \n",
    "                    # Calculate loss.\n",
    "                    model_loss = loss_b\n",
    "                    loss = torch.mean(model_loss) + torch.mean(rebar_loss)\n",
    "                    total_loss = total_loss + loss\n",
    "                    \n",
    "                    # For next selection.\n",
    "                    m_hard = torch.max(m_hard, b)\n",
    "                    \n",
    "                # Take gradient step.\n",
    "                total_loss = total_loss / max_features\n",
    "                total_loss.backward()\n",
    "                if train_model:\n",
    "                    model_opt.step()\n",
    "                if train_selector:\n",
    "                    selector_opt.step()\n",
    "                model.zero_grad()\n",
    "                selector.zero_grad()\n",
    "                \n",
    "            # Calculate validation loss.\n",
    "            with torch.no_grad():\n",
    "                # For mean loss.\n",
    "                val_loss = 0\n",
    "                n = 0\n",
    "\n",
    "                for x, y in val_loader:\n",
    "                    # Move to device.\n",
    "                    x = x.to(device)\n",
    "                    y = y.to(device)\n",
    "\n",
    "                    # Setup.\n",
    "                    m_hard = torch.zeros(x.shape, dtype=x.dtype, device=device)\n",
    "                    total_loss = 0\n",
    "\n",
    "                    for i in range(max_features):\n",
    "                        # Evaluate selector model.\n",
    "                        x_masked = mask_layer(x, m_hard)\n",
    "                        logits = selector(x_masked)\n",
    "\n",
    "                        # Get selections.\n",
    "                        if no_repeats:\n",
    "                            logits = logits - 1e6 * m_hard\n",
    "                        if argmax:\n",
    "                            hard = make_onehot(logits)\n",
    "                        else:\n",
    "                            _, hard = sampler(logits)\n",
    "\n",
    "                        # Evaluate model.\n",
    "                        m_hard = torch.max(m_hard, hard)\n",
    "                        x_masked = mask_layer(x, m_hard)\n",
    "                        pred = model(x_masked)\n",
    "\n",
    "                        # Calculate loss.\n",
    "                        loss = val_loss_fn(pred, y).item()\n",
    "                        total_loss = total_loss + loss\n",
    "                        \n",
    "                    # Update mean loss.\n",
    "                    if validation_mode == 'final':\n",
    "                        batch_loss = loss\n",
    "                    elif validation_mode == 'mean':\n",
    "                        batch_loss = total_loss / max_features\n",
    "                    val_loss = (\n",
    "                        (batch_loss * len(x) + val_loss * n) / (len(x) + n))\n",
    "                    n += len(x)\n",
    "            \n",
    "            # Print progress.\n",
    "            print(f'{\"-\"*8}Epoch {epoch+1}{\"-\"*8}')\n",
    "            print(f'Val loss = {val_loss:.4f}\\n')\n",
    "\n",
    "            # See if best model.\n",
    "            if val_loss < best_loss:\n",
    "                best_loss = val_loss\n",
    "                best_model = deepcopy(model)\n",
    "                best_selector = deepcopy(selector)\n",
    "\n",
    "        # Restore best parameters.\n",
    "        if best_model:\n",
    "            restore_parameters(model, best_model)\n",
    "        if best_selector:\n",
    "            restore_parameters(selector, best_selector)\n",
    "\n",
    "    def forward(self, x, max_features, argmax=True, no_repeats=True):\n",
    "        '''\n",
    "        Make predictions using selected features.\n",
    "\n",
    "        Args:\n",
    "          x: input data (torch.Tensor).\n",
    "          max_features: max features to observe.\n",
    "          argmax: whether to select the next feature using the max probability.\n",
    "          no_repeats: whether to ensure that no repeated selections occur.\n",
    "        '''\n",
    "        # Setup.\n",
    "        selector = self.selector\n",
    "        model = self.model\n",
    "        mask_layer = self.mask_layer\n",
    "        sampler = self.sampler\n",
    "        device = next(model.parameters()).device\n",
    "        m_hard = torch.zeros(x.shape, dtype=x.dtype, device=device)\n",
    "\n",
    "        for i in range(max_features):\n",
    "            # Evaluate selector model.\n",
    "            x_masked = mask_layer(x, m_hard)\n",
    "            logits = selector(x_masked)\n",
    "\n",
    "            # Update selections.\n",
    "            if no_repeats:\n",
    "                logits = logits - 1e6 * m_hard\n",
    "            if argmax:\n",
    "                hard = make_onehot(logits)\n",
    "            else:\n",
    "                _, hard = sampler(logits)\n",
    "            m_hard = torch.max(m_hard, hard)\n",
    "\n",
    "        # Make predictions.\n",
    "        x_masked = mask_layer(x, m_hard)\n",
    "        pred = model(x_masked)\n",
    "        return pred\n",
    "\n",
    "    def evaluate(self, dataset, max_features, loss_fn, batch_size, argmax=True,\n",
    "                 no_repeats=True):\n",
    "        '''\n",
    "        Evaluate mean performance across a dataset.\n",
    "        '''\n",
    "        # Setup.\n",
    "        device = next(self.model.parameters()).device\n",
    "        loader = DataLoader(\n",
    "            dataset, batch_size=batch_size, shuffle=False, pin_memory=True,\n",
    "            drop_last=False, num_workers=4)\n",
    "\n",
    "        # For calculating mean loss.\n",
    "        mean_loss = 0\n",
    "        n = 0\n",
    "\n",
    "        with torch.no_grad():\n",
    "            for x, y in loader:\n",
    "                # Move to GPU.\n",
    "                x = x.to(device)\n",
    "                y = y.to(device)\n",
    "\n",
    "                # Calculate loss.\n",
    "                pred = self.forward(x, max_features, argmax, no_repeats)\n",
    "                loss = loss_fn(pred, y).item()\n",
    "\n",
    "                # Update average.\n",
    "                mean_loss = (mean_loss * n + loss * len(x)) / (n + len(x))\n",
    "                n += len(x)\n",
    "\n",
    "        return mean_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "proud-rainbow",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.6205\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.6519\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.6881\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.6914\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.6871\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.6951\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.6955\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.6981\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.6974\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.6997\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.6977\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.6952\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.7087\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.6863\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.6878\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.6929\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.6849\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.6935\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.6880\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.7037\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = -0.6837\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = -0.6995\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = -0.6902\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = -0.7000\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = -0.7003\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = -0.6769\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = -0.6796\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = -0.6878\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = -0.6915\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = -0.6851\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = -0.7024\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = -0.6851\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = -0.7027\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = -0.7012\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = -0.7048\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = -0.7093\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = -0.7062\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = -0.6946\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = -0.6963\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = -0.6981\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = -0.7059\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = -0.6906\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = -0.6998\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = -0.7014\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = -0.7015\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = -0.6972\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = -0.6971\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = -0.6909\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = -0.6941\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = -0.6942\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = -0.6940\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = -0.7046\n",
      "\n",
      "--------Epoch 53--------\n",
      "Val loss = -0.7045\n",
      "\n",
      "--------Epoch 54--------\n",
      "Val loss = -0.7036\n",
      "\n",
      "--------Epoch 55--------\n",
      "Val loss = -0.7049\n",
      "\n",
      "--------Epoch 56--------\n",
      "Val loss = -0.6979\n",
      "\n",
      "--------Epoch 57--------\n",
      "Val loss = -0.6957\n",
      "\n",
      "--------Epoch 58--------\n",
      "Val loss = -0.7108\n",
      "\n",
      "--------Epoch 59--------\n",
      "Val loss = -0.7045\n",
      "\n",
      "--------Epoch 60--------\n",
      "Val loss = -0.7140\n",
      "\n",
      "--------Epoch 61--------\n",
      "Val loss = -0.7022\n",
      "\n",
      "--------Epoch 62--------\n",
      "Val loss = -0.7056\n",
      "\n",
      "--------Epoch 63--------\n",
      "Val loss = -0.7107\n",
      "\n",
      "--------Epoch 64--------\n",
      "Val loss = -0.7104\n",
      "\n",
      "--------Epoch 65--------\n",
      "Val loss = -0.7064\n",
      "\n",
      "--------Epoch 66--------\n",
      "Val loss = -0.7114\n",
      "\n",
      "--------Epoch 67--------\n",
      "Val loss = -0.7048\n",
      "\n",
      "--------Epoch 68--------\n",
      "Val loss = -0.7114\n",
      "\n",
      "--------Epoch 69--------\n",
      "Val loss = -0.7093\n",
      "\n",
      "--------Epoch 70--------\n",
      "Val loss = -0.7074\n",
      "\n",
      "--------Epoch 71--------\n",
      "Val loss = -0.7102\n",
      "\n",
      "--------Epoch 72--------\n",
      "Val loss = -0.7082\n",
      "\n",
      "--------Epoch 73--------\n",
      "Val loss = -0.6927\n",
      "\n",
      "--------Epoch 74--------\n",
      "Val loss = -0.7076\n",
      "\n",
      "--------Epoch 75--------\n",
      "Val loss = -0.7015\n",
      "\n",
      "--------Epoch 76--------\n",
      "Val loss = -0.7078\n",
      "\n",
      "--------Epoch 77--------\n",
      "Val loss = -0.7121\n",
      "\n",
      "--------Epoch 78--------\n",
      "Val loss = -0.7124\n",
      "\n",
      "--------Epoch 79--------\n",
      "Val loss = -0.6915\n",
      "\n",
      "--------Epoch 80--------\n",
      "Val loss = -0.7109\n",
      "\n",
      "--------Epoch 81--------\n",
      "Val loss = -0.7177\n",
      "\n",
      "--------Epoch 82--------\n",
      "Val loss = -0.7128\n",
      "\n",
      "--------Epoch 83--------\n",
      "Val loss = -0.7190\n",
      "\n",
      "--------Epoch 84--------\n",
      "Val loss = -0.7156\n",
      "\n",
      "--------Epoch 85--------\n",
      "Val loss = -0.7163\n",
      "\n",
      "--------Epoch 86--------\n",
      "Val loss = -0.7238\n",
      "\n",
      "--------Epoch 87--------\n",
      "Val loss = -0.7238\n",
      "\n",
      "--------Epoch 88--------\n",
      "Val loss = -0.7210\n",
      "\n",
      "--------Epoch 89--------\n",
      "Val loss = -0.7257\n",
      "\n",
      "--------Epoch 90--------\n",
      "Val loss = -0.7275\n",
      "\n",
      "--------Epoch 91--------\n",
      "Val loss = -0.7290\n",
      "\n",
      "--------Epoch 92--------\n",
      "Val loss = -0.7329\n",
      "\n",
      "--------Epoch 93--------\n",
      "Val loss = -0.7321\n",
      "\n",
      "--------Epoch 94--------\n",
      "Val loss = -0.7339\n",
      "\n",
      "--------Epoch 95--------\n",
      "Val loss = -0.7368\n",
      "\n",
      "--------Epoch 96--------\n",
      "Val loss = -0.7355\n",
      "\n",
      "--------Epoch 97--------\n",
      "Val loss = -0.7379\n",
      "\n",
      "--------Epoch 98--------\n",
      "Val loss = -0.7356\n",
      "\n",
      "--------Epoch 99--------\n",
      "Val loss = -0.7383\n",
      "\n",
      "--------Epoch 100--------\n",
      "Val loss = -0.7406\n",
      "\n",
      "--------Epoch 101--------\n",
      "Val loss = -0.7426\n",
      "\n",
      "--------Epoch 102--------\n",
      "Val loss = -0.7400\n",
      "\n",
      "--------Epoch 103--------\n",
      "Val loss = -0.7430\n",
      "\n",
      "--------Epoch 104--------\n",
      "Val loss = -0.7419\n",
      "\n",
      "--------Epoch 105--------\n",
      "Val loss = -0.7434\n",
      "\n",
      "--------Epoch 106--------\n",
      "Val loss = -0.7452\n",
      "\n",
      "--------Epoch 107--------\n",
      "Val loss = -0.7454\n",
      "\n",
      "--------Epoch 108--------\n",
      "Val loss = -0.7479\n",
      "\n",
      "--------Epoch 109--------\n",
      "Val loss = -0.7460\n",
      "\n",
      "--------Epoch 110--------\n",
      "Val loss = -0.7463\n",
      "\n",
      "--------Epoch 111--------\n",
      "Val loss = -0.7443\n",
      "\n",
      "--------Epoch 112--------\n",
      "Val loss = -0.7431\n",
      "\n",
      "--------Epoch 113--------\n",
      "Val loss = -0.7478\n",
      "\n",
      "--------Epoch 114--------\n",
      "Val loss = -0.7470\n",
      "\n",
      "--------Epoch 115--------\n",
      "Val loss = -0.7461\n",
      "\n",
      "--------Epoch 116--------\n",
      "Val loss = -0.7445\n",
      "\n",
      "--------Epoch 117--------\n",
      "Val loss = -0.7481\n",
      "\n",
      "--------Epoch 118--------\n",
      "Val loss = -0.7462\n",
      "\n",
      "--------Epoch 119--------\n",
      "Val loss = -0.7483\n",
      "\n",
      "--------Epoch 120--------\n",
      "Val loss = -0.7480\n",
      "\n",
      "--------Epoch 121--------\n",
      "Val loss = -0.7484\n",
      "\n",
      "--------Epoch 122--------\n",
      "Val loss = -0.7486\n",
      "\n",
      "--------Epoch 123--------\n",
      "Val loss = -0.7479\n",
      "\n",
      "--------Epoch 124--------\n",
      "Val loss = -0.7457\n",
      "\n",
      "--------Epoch 125--------\n",
      "Val loss = -0.7467\n",
      "\n",
      "--------Epoch 126--------\n",
      "Val loss = -0.7489\n",
      "\n",
      "--------Epoch 127--------\n",
      "Val loss = -0.7498\n",
      "\n",
      "--------Epoch 128--------\n",
      "Val loss = -0.7487\n",
      "\n",
      "--------Epoch 129--------\n",
      "Val loss = -0.7499\n",
      "\n",
      "--------Epoch 130--------\n",
      "Val loss = -0.7517\n",
      "\n",
      "--------Epoch 131--------\n",
      "Val loss = -0.7502\n",
      "\n",
      "--------Epoch 132--------\n",
      "Val loss = -0.7476\n",
      "\n",
      "--------Epoch 133--------\n",
      "Val loss = -0.7504\n",
      "\n",
      "--------Epoch 134--------\n",
      "Val loss = -0.7501\n",
      "\n",
      "--------Epoch 135--------\n",
      "Val loss = -0.7465\n",
      "\n",
      "--------Epoch 136--------\n",
      "Val loss = -0.7485\n",
      "\n",
      "--------Epoch 137--------\n",
      "Val loss = -0.7512\n",
      "\n",
      "--------Epoch 138--------\n",
      "Val loss = -0.7491\n",
      "\n",
      "--------Epoch 139--------\n",
      "Val loss = -0.7495\n",
      "\n",
      "--------Epoch 140--------\n",
      "Val loss = -0.7490\n",
      "\n",
      "--------Epoch 141--------\n",
      "Val loss = -0.7491\n",
      "\n",
      "--------Epoch 142--------\n",
      "Val loss = -0.7520\n",
      "\n",
      "--------Epoch 143--------\n",
      "Val loss = -0.7514\n",
      "\n",
      "--------Epoch 144--------\n",
      "Val loss = -0.7513\n",
      "\n",
      "--------Epoch 145--------\n",
      "Val loss = -0.7514\n",
      "\n",
      "--------Epoch 146--------\n",
      "Val loss = -0.7500\n",
      "\n",
      "--------Epoch 147--------\n",
      "Val loss = -0.7513\n",
      "\n",
      "--------Epoch 148--------\n",
      "Val loss = -0.7505\n",
      "\n",
      "--------Epoch 149--------\n",
      "Val loss = -0.7517\n",
      "\n",
      "--------Epoch 150--------\n",
      "Val loss = -0.7524\n",
      "\n",
      "--------Epoch 151--------\n",
      "Val loss = -0.7521\n",
      "\n",
      "--------Epoch 152--------\n",
      "Val loss = -0.7506\n",
      "\n",
      "--------Epoch 153--------\n",
      "Val loss = -0.7519\n",
      "\n",
      "--------Epoch 154--------\n",
      "Val loss = -0.7496\n",
      "\n",
      "--------Epoch 155--------\n",
      "Val loss = -0.7519\n",
      "\n",
      "--------Epoch 156--------\n",
      "Val loss = -0.7517\n",
      "\n",
      "--------Epoch 157--------\n",
      "Val loss = -0.7518\n",
      "\n",
      "--------Epoch 158--------\n",
      "Val loss = -0.7499\n",
      "\n",
      "--------Epoch 159--------\n",
      "Val loss = -0.7503\n",
      "\n",
      "--------Epoch 160--------\n",
      "Val loss = -0.7516\n",
      "\n",
      "--------Epoch 161--------\n",
      "Val loss = -0.7530\n",
      "\n",
      "--------Epoch 162--------\n",
      "Val loss = -0.7512\n",
      "\n",
      "--------Epoch 163--------\n",
      "Val loss = -0.7530\n",
      "\n",
      "--------Epoch 164--------\n",
      "Val loss = -0.7521\n",
      "\n",
      "--------Epoch 165--------\n",
      "Val loss = -0.7523\n",
      "\n",
      "--------Epoch 166--------\n",
      "Val loss = -0.7522\n",
      "\n",
      "--------Epoch 167--------\n",
      "Val loss = -0.7516\n",
      "\n",
      "--------Epoch 168--------\n",
      "Val loss = -0.7522\n",
      "\n",
      "--------Epoch 169--------\n",
      "Val loss = -0.7515\n",
      "\n",
      "--------Epoch 170--------\n",
      "Val loss = -0.7543\n",
      "\n",
      "--------Epoch 171--------\n",
      "Val loss = -0.7535\n",
      "\n",
      "--------Epoch 172--------\n",
      "Val loss = -0.7529\n",
      "\n",
      "--------Epoch 173--------\n",
      "Val loss = -0.7519\n",
      "\n",
      "--------Epoch 174--------\n",
      "Val loss = -0.7532\n",
      "\n",
      "--------Epoch 175--------\n",
      "Val loss = -0.7535\n",
      "\n",
      "--------Epoch 176--------\n",
      "Val loss = -0.7503\n",
      "\n",
      "--------Epoch 177--------\n",
      "Val loss = -0.7549\n",
      "\n",
      "--------Epoch 178--------\n",
      "Val loss = -0.7544\n",
      "\n",
      "--------Epoch 179--------\n",
      "Val loss = -0.7552\n",
      "\n",
      "--------Epoch 180--------\n",
      "Val loss = -0.7536\n",
      "\n",
      "--------Epoch 181--------\n",
      "Val loss = -0.7542\n",
      "\n",
      "--------Epoch 182--------\n",
      "Val loss = -0.7534\n",
      "\n",
      "--------Epoch 183--------\n",
      "Val loss = -0.7532\n",
      "\n",
      "--------Epoch 184--------\n",
      "Val loss = -0.7534\n",
      "\n",
      "--------Epoch 185--------\n",
      "Val loss = -0.7527\n",
      "\n",
      "--------Epoch 186--------\n",
      "Val loss = -0.7550\n",
      "\n",
      "--------Epoch 187--------\n",
      "Val loss = -0.7542\n",
      "\n",
      "--------Epoch 188--------\n",
      "Val loss = -0.7551\n",
      "\n",
      "--------Epoch 189--------\n",
      "Val loss = -0.7541\n",
      "\n",
      "--------Epoch 190--------\n",
      "Val loss = -0.7533\n",
      "\n",
      "--------Epoch 191--------\n",
      "Val loss = -0.7548\n",
      "\n",
      "--------Epoch 192--------\n",
      "Val loss = -0.7541\n",
      "\n",
      "--------Epoch 193--------\n",
      "Val loss = -0.7538\n",
      "\n",
      "--------Epoch 194--------\n",
      "Val loss = -0.7520\n",
      "\n",
      "--------Epoch 195--------\n",
      "Val loss = -0.7543\n",
      "\n",
      "--------Epoch 196--------\n",
      "Val loss = -0.7545\n",
      "\n",
      "--------Epoch 197--------\n",
      "Val loss = -0.7534\n",
      "\n",
      "--------Epoch 198--------\n",
      "Val loss = -0.7530\n",
      "\n",
      "--------Epoch 199--------\n",
      "Val loss = -0.7556\n",
      "\n",
      "--------Epoch 200--------\n",
      "Val loss = -0.7544\n",
      "\n",
      "--------Epoch 201--------\n",
      "Val loss = -0.7541\n",
      "\n",
      "--------Epoch 202--------\n",
      "Val loss = -0.7525\n",
      "\n",
      "--------Epoch 203--------\n",
      "Val loss = -0.7532\n",
      "\n",
      "--------Epoch 204--------\n",
      "Val loss = -0.7541\n",
      "\n",
      "--------Epoch 205--------\n",
      "Val loss = -0.7540\n",
      "\n",
      "--------Epoch 206--------\n",
      "Val loss = -0.7558\n",
      "\n",
      "--------Epoch 207--------\n",
      "Val loss = -0.7557\n",
      "\n",
      "--------Epoch 208--------\n",
      "Val loss = -0.7548\n",
      "\n",
      "--------Epoch 209--------\n",
      "Val loss = -0.7545\n",
      "\n",
      "--------Epoch 210--------\n",
      "Val loss = -0.7555\n",
      "\n",
      "--------Epoch 211--------\n",
      "Val loss = -0.7548\n",
      "\n",
      "--------Epoch 212--------\n",
      "Val loss = -0.7560\n",
      "\n",
      "--------Epoch 213--------\n",
      "Val loss = -0.7546\n",
      "\n",
      "--------Epoch 214--------\n",
      "Val loss = -0.7530\n",
      "\n",
      "--------Epoch 215--------\n",
      "Val loss = -0.7556\n",
      "\n",
      "--------Epoch 216--------\n",
      "Val loss = -0.7570\n",
      "\n",
      "--------Epoch 217--------\n",
      "Val loss = -0.7552\n",
      "\n",
      "--------Epoch 218--------\n",
      "Val loss = -0.7553\n",
      "\n",
      "--------Epoch 219--------\n",
      "Val loss = -0.7563\n",
      "\n",
      "--------Epoch 220--------\n",
      "Val loss = -0.7546\n",
      "\n",
      "--------Epoch 221--------\n",
      "Val loss = -0.7565\n",
      "\n",
      "--------Epoch 222--------\n",
      "Val loss = -0.7561\n",
      "\n",
      "--------Epoch 223--------\n",
      "Val loss = -0.7607\n",
      "\n",
      "--------Epoch 224--------\n",
      "Val loss = -0.7643\n",
      "\n",
      "--------Epoch 225--------\n",
      "Val loss = -0.7654\n",
      "\n",
      "--------Epoch 226--------\n",
      "Val loss = -0.7689\n",
      "\n",
      "--------Epoch 227--------\n",
      "Val loss = -0.7715\n",
      "\n",
      "--------Epoch 228--------\n",
      "Val loss = -0.7790\n",
      "\n",
      "--------Epoch 229--------\n",
      "Val loss = -0.7780\n",
      "\n",
      "--------Epoch 230--------\n",
      "Val loss = -0.7788\n",
      "\n",
      "--------Epoch 231--------\n",
      "Val loss = -0.7843\n",
      "\n",
      "--------Epoch 232--------\n",
      "Val loss = -0.7808\n",
      "\n",
      "--------Epoch 233--------\n",
      "Val loss = -0.7800\n",
      "\n",
      "--------Epoch 234--------\n",
      "Val loss = -0.7842\n",
      "\n",
      "--------Epoch 235--------\n",
      "Val loss = -0.7811\n",
      "\n",
      "--------Epoch 236--------\n",
      "Val loss = -0.7781\n",
      "\n",
      "--------Epoch 237--------\n",
      "Val loss = -0.7807\n",
      "\n",
      "--------Epoch 238--------\n",
      "Val loss = -0.7814\n",
      "\n",
      "--------Epoch 239--------\n",
      "Val loss = -0.7816\n",
      "\n",
      "--------Epoch 240--------\n",
      "Val loss = -0.7823\n",
      "\n",
      "--------Epoch 241--------\n",
      "Val loss = -0.7822\n",
      "\n",
      "--------Epoch 242--------\n",
      "Val loss = -0.7847\n",
      "\n",
      "--------Epoch 243--------\n",
      "Val loss = -0.7843\n",
      "\n",
      "--------Epoch 244--------\n",
      "Val loss = -0.7878\n",
      "\n",
      "--------Epoch 245--------\n",
      "Val loss = -0.7838\n",
      "\n",
      "--------Epoch 246--------\n",
      "Val loss = -0.7859\n",
      "\n",
      "--------Epoch 247--------\n",
      "Val loss = -0.7831\n",
      "\n",
      "--------Epoch 248--------\n",
      "Val loss = -0.7849\n",
      "\n",
      "--------Epoch 249--------\n",
      "Val loss = -0.7863\n",
      "\n",
      "--------Epoch 250--------\n",
      "Val loss = -0.7839\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up selector\n",
    "selector = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_in))\n",
    "gafs = GreedyAdaptiveFS(selector, deepcopy(model), mask_layer).to(device)\n",
    "\n",
    "# Train\n",
    "gafs.fit(train_dataset,\n",
    "         val_dataset,\n",
    "         mbsize=128,\n",
    "         lr=2e-4,\n",
    "         nepochs=250,\n",
    "         eta=1.0,\n",
    "         max_features=max_features,\n",
    "         loss_fn=nn.CrossEntropyLoss(),\n",
    "         val_loss_fn=NegAccuracy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "driven-september",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test acc = 78.74\n"
     ]
    }
   ],
   "source": [
    "test_acc = gafs.evaluate(test_dataset, max_features, Accuracy(), 1024)\n",
    "print(f'Test acc = {100*test_acc:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "charitable-moscow",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.6031\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.6410\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.6488\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.6599\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.6679\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.6733\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.6707\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.6730\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.6739\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.6790\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.6812\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.6760\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.6803\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.6727\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.6768\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.6832\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.6882\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.7008\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.7049\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.7048\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = -0.7052\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = -0.7019\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = -0.7032\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = -0.7036\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = -0.7068\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = -0.7062\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = -0.7041\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = -0.6983\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = -0.7037\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = -0.7068\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = -0.7124\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = -0.7096\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = -0.7036\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = -0.7066\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = -0.7068\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = -0.7109\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = -0.7110\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = -0.7091\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = -0.7162\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = -0.7105\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = -0.7149\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = -0.7201\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = -0.7218\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = -0.7157\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = -0.7379\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = -0.7390\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = -0.7405\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = -0.7430\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = -0.7435\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = -0.7404\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = -0.7450\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = -0.7457\n",
      "\n",
      "--------Epoch 53--------\n",
      "Val loss = -0.7460\n",
      "\n",
      "--------Epoch 54--------\n",
      "Val loss = -0.7545\n",
      "\n",
      "--------Epoch 55--------\n",
      "Val loss = -0.7566\n",
      "\n",
      "--------Epoch 56--------\n",
      "Val loss = -0.7594\n",
      "\n",
      "--------Epoch 57--------\n",
      "Val loss = -0.7704\n",
      "\n",
      "--------Epoch 58--------\n",
      "Val loss = -0.7717\n",
      "\n",
      "--------Epoch 59--------\n",
      "Val loss = -0.7731\n",
      "\n",
      "--------Epoch 60--------\n",
      "Val loss = -0.7700\n",
      "\n",
      "--------Epoch 61--------\n",
      "Val loss = -0.7717\n",
      "\n",
      "--------Epoch 62--------\n",
      "Val loss = -0.7770\n",
      "\n",
      "--------Epoch 63--------\n",
      "Val loss = -0.7759\n",
      "\n",
      "--------Epoch 64--------\n",
      "Val loss = -0.7777\n",
      "\n",
      "--------Epoch 65--------\n",
      "Val loss = -0.7762\n",
      "\n",
      "--------Epoch 66--------\n",
      "Val loss = -0.7798\n",
      "\n",
      "--------Epoch 67--------\n",
      "Val loss = -0.7794\n",
      "\n",
      "--------Epoch 68--------\n",
      "Val loss = -0.7796\n",
      "\n",
      "--------Epoch 69--------\n",
      "Val loss = -0.7791\n",
      "\n",
      "--------Epoch 70--------\n",
      "Val loss = -0.7795\n",
      "\n",
      "--------Epoch 71--------\n",
      "Val loss = -0.7762\n",
      "\n",
      "--------Epoch 72--------\n",
      "Val loss = -0.7812\n",
      "\n",
      "--------Epoch 73--------\n",
      "Val loss = -0.7815\n",
      "\n",
      "--------Epoch 74--------\n",
      "Val loss = -0.7815\n",
      "\n",
      "--------Epoch 75--------\n",
      "Val loss = -0.7848\n",
      "\n",
      "--------Epoch 76--------\n",
      "Val loss = -0.7841\n",
      "\n",
      "--------Epoch 77--------\n",
      "Val loss = -0.7823\n",
      "\n",
      "--------Epoch 78--------\n",
      "Val loss = -0.7825\n",
      "\n",
      "--------Epoch 79--------\n",
      "Val loss = -0.7834\n",
      "\n",
      "--------Epoch 80--------\n",
      "Val loss = -0.7858\n",
      "\n",
      "--------Epoch 81--------\n",
      "Val loss = -0.7863\n",
      "\n",
      "--------Epoch 82--------\n",
      "Val loss = -0.7865\n",
      "\n",
      "--------Epoch 83--------\n",
      "Val loss = -0.7869\n",
      "\n",
      "--------Epoch 84--------\n",
      "Val loss = -0.7849\n",
      "\n",
      "--------Epoch 85--------\n",
      "Val loss = -0.7878\n",
      "\n",
      "--------Epoch 86--------\n",
      "Val loss = -0.7858\n",
      "\n",
      "--------Epoch 87--------\n",
      "Val loss = -0.7894\n",
      "\n",
      "--------Epoch 88--------\n",
      "Val loss = -0.7885\n",
      "\n",
      "--------Epoch 89--------\n",
      "Val loss = -0.7854\n",
      "\n",
      "--------Epoch 90--------\n",
      "Val loss = -0.7880\n",
      "\n",
      "--------Epoch 91--------\n",
      "Val loss = -0.7864\n",
      "\n",
      "--------Epoch 92--------\n",
      "Val loss = -0.7882\n",
      "\n",
      "--------Epoch 93--------\n",
      "Val loss = -0.7918\n",
      "\n",
      "--------Epoch 94--------\n",
      "Val loss = -0.7874\n",
      "\n",
      "--------Epoch 95--------\n",
      "Val loss = -0.7883\n",
      "\n",
      "--------Epoch 96--------\n",
      "Val loss = -0.7876\n",
      "\n",
      "--------Epoch 97--------\n",
      "Val loss = -0.7882\n",
      "\n",
      "--------Epoch 98--------\n",
      "Val loss = -0.7883\n",
      "\n",
      "--------Epoch 99--------\n",
      "Val loss = -0.7902\n",
      "\n",
      "--------Epoch 100--------\n",
      "Val loss = -0.7905\n",
      "\n",
      "--------Epoch 101--------\n",
      "Val loss = -0.7877\n",
      "\n",
      "--------Epoch 102--------\n",
      "Val loss = -0.7886\n",
      "\n",
      "--------Epoch 103--------\n",
      "Val loss = -0.7866\n",
      "\n",
      "--------Epoch 104--------\n",
      "Val loss = -0.7917\n",
      "\n",
      "--------Epoch 105--------\n",
      "Val loss = -0.7919\n",
      "\n",
      "--------Epoch 106--------\n",
      "Val loss = -0.7909\n",
      "\n",
      "--------Epoch 107--------\n",
      "Val loss = -0.7909\n",
      "\n",
      "--------Epoch 108--------\n",
      "Val loss = -0.7918\n",
      "\n",
      "--------Epoch 109--------\n",
      "Val loss = -0.7900\n",
      "\n",
      "--------Epoch 110--------\n",
      "Val loss = -0.7906\n",
      "\n",
      "--------Epoch 111--------\n",
      "Val loss = -0.7932\n",
      "\n",
      "--------Epoch 112--------\n",
      "Val loss = -0.7922\n",
      "\n",
      "--------Epoch 113--------\n",
      "Val loss = -0.7895\n",
      "\n",
      "--------Epoch 114--------\n",
      "Val loss = -0.7895\n",
      "\n",
      "--------Epoch 115--------\n",
      "Val loss = -0.7917\n",
      "\n",
      "--------Epoch 116--------\n",
      "Val loss = -0.7936\n",
      "\n",
      "--------Epoch 117--------\n",
      "Val loss = -0.7901\n",
      "\n",
      "--------Epoch 118--------\n",
      "Val loss = -0.7921\n",
      "\n",
      "--------Epoch 119--------\n",
      "Val loss = -0.7940\n",
      "\n",
      "--------Epoch 120--------\n",
      "Val loss = -0.7993\n",
      "\n",
      "--------Epoch 121--------\n",
      "Val loss = -0.7927\n",
      "\n",
      "--------Epoch 122--------\n",
      "Val loss = -0.7933\n",
      "\n",
      "--------Epoch 123--------\n",
      "Val loss = -0.7953\n",
      "\n",
      "--------Epoch 124--------\n",
      "Val loss = -0.7957\n",
      "\n",
      "--------Epoch 125--------\n",
      "Val loss = -0.7951\n",
      "\n",
      "--------Epoch 126--------\n",
      "Val loss = -0.7957\n",
      "\n",
      "--------Epoch 127--------\n",
      "Val loss = -0.7962\n",
      "\n",
      "--------Epoch 128--------\n",
      "Val loss = -0.7962\n",
      "\n",
      "--------Epoch 129--------\n",
      "Val loss = -0.7974\n",
      "\n",
      "--------Epoch 130--------\n",
      "Val loss = -0.7996\n",
      "\n",
      "--------Epoch 131--------\n",
      "Val loss = -0.8001\n",
      "\n",
      "--------Epoch 132--------\n",
      "Val loss = -0.7977\n",
      "\n",
      "--------Epoch 133--------\n",
      "Val loss = -0.8013\n",
      "\n",
      "--------Epoch 134--------\n",
      "Val loss = -0.7986\n",
      "\n",
      "--------Epoch 135--------\n",
      "Val loss = -0.7986\n",
      "\n",
      "--------Epoch 136--------\n",
      "Val loss = -0.7982\n",
      "\n",
      "--------Epoch 137--------\n",
      "Val loss = -0.7966\n",
      "\n",
      "--------Epoch 138--------\n",
      "Val loss = -0.8045\n",
      "\n",
      "--------Epoch 139--------\n",
      "Val loss = -0.7994\n",
      "\n",
      "--------Epoch 140--------\n",
      "Val loss = -0.7965\n",
      "\n",
      "--------Epoch 141--------\n",
      "Val loss = -0.7993\n",
      "\n",
      "--------Epoch 142--------\n",
      "Val loss = -0.8005\n",
      "\n",
      "--------Epoch 143--------\n",
      "Val loss = -0.8014\n",
      "\n",
      "--------Epoch 144--------\n",
      "Val loss = -0.8011\n",
      "\n",
      "--------Epoch 145--------\n",
      "Val loss = -0.7963\n",
      "\n",
      "--------Epoch 146--------\n",
      "Val loss = -0.7977\n",
      "\n",
      "--------Epoch 147--------\n",
      "Val loss = -0.7981\n",
      "\n",
      "--------Epoch 148--------\n",
      "Val loss = -0.7995\n",
      "\n",
      "--------Epoch 149--------\n",
      "Val loss = -0.7992\n",
      "\n",
      "--------Epoch 150--------\n",
      "Val loss = -0.7997\n",
      "\n",
      "--------Epoch 151--------\n",
      "Val loss = -0.8028\n",
      "\n",
      "--------Epoch 152--------\n",
      "Val loss = -0.8000\n",
      "\n",
      "--------Epoch 153--------\n",
      "Val loss = -0.7999\n",
      "\n",
      "--------Epoch 154--------\n",
      "Val loss = -0.8021\n",
      "\n",
      "--------Epoch 155--------\n",
      "Val loss = -0.7982\n",
      "\n",
      "--------Epoch 156--------\n",
      "Val loss = -0.8029\n",
      "\n",
      "--------Epoch 157--------\n",
      "Val loss = -0.8018\n",
      "\n",
      "--------Epoch 158--------\n",
      "Val loss = -0.8012\n",
      "\n",
      "--------Epoch 159--------\n",
      "Val loss = -0.8006\n",
      "\n",
      "--------Epoch 160--------\n",
      "Val loss = -0.8020\n",
      "\n",
      "--------Epoch 161--------\n",
      "Val loss = -0.8027\n",
      "\n",
      "--------Epoch 162--------\n",
      "Val loss = -0.8020\n",
      "\n",
      "--------Epoch 163--------\n",
      "Val loss = -0.8028\n",
      "\n",
      "--------Epoch 164--------\n",
      "Val loss = -0.7997\n",
      "\n",
      "--------Epoch 165--------\n",
      "Val loss = -0.8039\n",
      "\n",
      "--------Epoch 166--------\n",
      "Val loss = -0.8017\n",
      "\n",
      "--------Epoch 167--------\n",
      "Val loss = -0.8038\n",
      "\n",
      "--------Epoch 168--------\n",
      "Val loss = -0.8029\n",
      "\n",
      "--------Epoch 169--------\n",
      "Val loss = -0.8030\n",
      "\n",
      "--------Epoch 170--------\n",
      "Val loss = -0.8072\n",
      "\n",
      "--------Epoch 171--------\n",
      "Val loss = -0.8045\n",
      "\n",
      "--------Epoch 172--------\n",
      "Val loss = -0.8050\n",
      "\n",
      "--------Epoch 173--------\n",
      "Val loss = -0.8038\n",
      "\n",
      "--------Epoch 174--------\n",
      "Val loss = -0.8023\n",
      "\n",
      "--------Epoch 175--------\n",
      "Val loss = -0.8034\n",
      "\n",
      "--------Epoch 176--------\n",
      "Val loss = -0.8019\n",
      "\n",
      "--------Epoch 177--------\n",
      "Val loss = -0.8042\n",
      "\n",
      "--------Epoch 178--------\n",
      "Val loss = -0.8049\n",
      "\n",
      "--------Epoch 179--------\n",
      "Val loss = -0.8054\n",
      "\n",
      "--------Epoch 180--------\n",
      "Val loss = -0.8029\n",
      "\n",
      "--------Epoch 181--------\n",
      "Val loss = -0.8072\n",
      "\n",
      "--------Epoch 182--------\n",
      "Val loss = -0.8064\n",
      "\n",
      "--------Epoch 183--------\n",
      "Val loss = -0.8067\n",
      "\n",
      "--------Epoch 184--------\n",
      "Val loss = -0.8028\n",
      "\n",
      "--------Epoch 185--------\n",
      "Val loss = -0.8082\n",
      "\n",
      "--------Epoch 186--------\n",
      "Val loss = -0.8038\n",
      "\n",
      "--------Epoch 187--------\n",
      "Val loss = -0.8043\n",
      "\n",
      "--------Epoch 188--------\n",
      "Val loss = -0.8067\n",
      "\n",
      "--------Epoch 189--------\n",
      "Val loss = -0.8089\n",
      "\n",
      "--------Epoch 190--------\n",
      "Val loss = -0.8074\n",
      "\n",
      "--------Epoch 191--------\n",
      "Val loss = -0.8066\n",
      "\n",
      "--------Epoch 192--------\n",
      "Val loss = -0.8072\n",
      "\n",
      "--------Epoch 193--------\n",
      "Val loss = -0.8079\n",
      "\n",
      "--------Epoch 194--------\n",
      "Val loss = -0.8103\n",
      "\n",
      "--------Epoch 195--------\n",
      "Val loss = -0.8083\n",
      "\n",
      "--------Epoch 196--------\n",
      "Val loss = -0.8055\n",
      "\n",
      "--------Epoch 197--------\n",
      "Val loss = -0.8051\n",
      "\n",
      "--------Epoch 198--------\n",
      "Val loss = -0.8084\n",
      "\n",
      "--------Epoch 199--------\n",
      "Val loss = -0.8092\n",
      "\n",
      "--------Epoch 200--------\n",
      "Val loss = -0.8124\n",
      "\n",
      "--------Epoch 201--------\n",
      "Val loss = -0.8081\n",
      "\n",
      "--------Epoch 202--------\n",
      "Val loss = -0.8075\n",
      "\n",
      "--------Epoch 203--------\n",
      "Val loss = -0.8108\n",
      "\n",
      "--------Epoch 204--------\n",
      "Val loss = -0.8112\n",
      "\n",
      "--------Epoch 205--------\n",
      "Val loss = -0.8110\n",
      "\n",
      "--------Epoch 206--------\n",
      "Val loss = -0.8107\n",
      "\n",
      "--------Epoch 207--------\n",
      "Val loss = -0.8111\n",
      "\n",
      "--------Epoch 208--------\n",
      "Val loss = -0.8102\n",
      "\n",
      "--------Epoch 209--------\n",
      "Val loss = -0.8084\n",
      "\n",
      "--------Epoch 210--------\n",
      "Val loss = -0.8106\n",
      "\n",
      "--------Epoch 211--------\n",
      "Val loss = -0.8152\n",
      "\n",
      "--------Epoch 212--------\n",
      "Val loss = -0.8126\n",
      "\n",
      "--------Epoch 213--------\n",
      "Val loss = -0.8106\n",
      "\n",
      "--------Epoch 214--------\n",
      "Val loss = -0.8127\n",
      "\n",
      "--------Epoch 215--------\n",
      "Val loss = -0.8125\n",
      "\n",
      "--------Epoch 216--------\n",
      "Val loss = -0.8145\n",
      "\n",
      "--------Epoch 217--------\n",
      "Val loss = -0.8147\n",
      "\n",
      "--------Epoch 218--------\n",
      "Val loss = -0.8110\n",
      "\n",
      "--------Epoch 219--------\n",
      "Val loss = -0.8157\n",
      "\n",
      "--------Epoch 220--------\n",
      "Val loss = -0.8138\n",
      "\n",
      "--------Epoch 221--------\n",
      "Val loss = -0.8145\n",
      "\n",
      "--------Epoch 222--------\n",
      "Val loss = -0.8112\n",
      "\n",
      "--------Epoch 223--------\n",
      "Val loss = -0.8152\n",
      "\n",
      "--------Epoch 224--------\n",
      "Val loss = -0.8137\n",
      "\n",
      "--------Epoch 225--------\n",
      "Val loss = -0.8137\n",
      "\n",
      "--------Epoch 226--------\n",
      "Val loss = -0.8138\n",
      "\n",
      "--------Epoch 227--------\n",
      "Val loss = -0.8128\n",
      "\n",
      "--------Epoch 228--------\n",
      "Val loss = -0.8155\n",
      "\n",
      "--------Epoch 229--------\n",
      "Val loss = -0.8141\n",
      "\n",
      "--------Epoch 230--------\n",
      "Val loss = -0.8133\n",
      "\n",
      "--------Epoch 231--------\n",
      "Val loss = -0.8143\n",
      "\n",
      "--------Epoch 232--------\n",
      "Val loss = -0.8152\n",
      "\n",
      "--------Epoch 233--------\n",
      "Val loss = -0.8110\n",
      "\n",
      "--------Epoch 234--------\n",
      "Val loss = -0.8131\n",
      "\n",
      "--------Epoch 235--------\n",
      "Val loss = -0.8139\n",
      "\n",
      "--------Epoch 236--------\n",
      "Val loss = -0.8153\n",
      "\n",
      "--------Epoch 237--------\n",
      "Val loss = -0.8120\n",
      "\n",
      "--------Epoch 238--------\n",
      "Val loss = -0.8153\n",
      "\n",
      "--------Epoch 239--------\n",
      "Val loss = -0.8127\n",
      "\n",
      "--------Epoch 240--------\n",
      "Val loss = -0.8124\n",
      "\n",
      "--------Epoch 241--------\n",
      "Val loss = -0.8146\n",
      "\n",
      "--------Epoch 242--------\n",
      "Val loss = -0.8116\n",
      "\n",
      "--------Epoch 243--------\n",
      "Val loss = -0.8111\n",
      "\n",
      "--------Epoch 244--------\n",
      "Val loss = -0.8089\n",
      "\n",
      "--------Epoch 245--------\n",
      "Val loss = -0.8128\n",
      "\n",
      "--------Epoch 246--------\n",
      "Val loss = -0.8118\n",
      "\n",
      "--------Epoch 247--------\n",
      "Val loss = -0.8130\n",
      "\n",
      "--------Epoch 248--------\n",
      "Val loss = -0.8134\n",
      "\n",
      "--------Epoch 249--------\n",
      "Val loss = -0.8149\n",
      "\n",
      "--------Epoch 250--------\n",
      "Val loss = -0.8128\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up selector\n",
    "selector = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_in))\n",
    "gafs = GreedyAdaptiveFS(selector, deepcopy(model), mask_layer).to(device)\n",
    "\n",
    "# Tie weights\n",
    "selector[0].weight = nn.Parameter(gafs.model[0].weight)\n",
    "selector[2].weight = nn.Parameter(gafs.model[2].weight)\n",
    "\n",
    "# Train\n",
    "gafs.fit(train_dataset,\n",
    "         val_dataset,\n",
    "         mbsize=128,\n",
    "         lr=2e-4,\n",
    "         nepochs=250,\n",
    "         eta=1.0,\n",
    "         max_features=max_features,\n",
    "         loss_fn=nn.CrossEntropyLoss(),\n",
    "         val_loss_fn=NegAccuracy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "increasing-mouse",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test acc = 82.93\n"
     ]
    }
   ],
   "source": [
    "test_acc = gafs.evaluate(test_dataset, max_features, Accuracy(), 1024)\n",
    "print(f'Test acc = {100*test_acc:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "proprietary-swaziland",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
