{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "consecutive-uncle",
   "metadata": {},
   "source": [
    "# Implementing REINFORCE\n",
    "\n",
    "- Rather than optimizing with the Concrete distribution, this version uses the REINFORCE gradient estimator.\n",
    "- It works quite a bit worse. But weight tying (see gafs v4) improves the results significantly, although they're still worse than just using the Concrete distribution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "affected-three",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n",
    "from torchvision import transforms\n",
    "from torchvision.datasets import MNIST\n",
    "from copy import deepcopy\n",
    "from utils import *\n",
    "from models import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "crazy-ancient",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda', 6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "increasing-ordering",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load data\n",
    "class Flatten(object):\n",
    "    def __call__(self, pic):\n",
    "        return torch.flatten(pic)\n",
    "    \n",
    "mnist_dataset = MNIST('/tmp/mnist/', download=True, train=True,\n",
    "                      transform=transforms.Compose([transforms.ToTensor(), Flatten()]))\n",
    "images = mnist_dataset.data\n",
    "targets = mnist_dataset.targets\n",
    "np.random.seed(0)\n",
    "val_inds = np.sort(np.random.choice(len(images), size=10000, replace=False))\n",
    "train_inds = np.setdiff1d(np.arange(len(images)), val_inds)\n",
    "\n",
    "# Training dataset\n",
    "train_dataset = torch.utils.data.Subset(mnist_dataset, train_inds)\n",
    "\n",
    "# Validation dataset\n",
    "val_dataset = torch.utils.data.Subset(mnist_dataset, val_inds)\n",
    "\n",
    "# Test dataset\n",
    "test_dataset = MNIST('/tmp/mnist/', download=True, train=False,\n",
    "                     transform=transforms.Compose([transforms.ToTensor(), Flatten()]))\n",
    "\n",
    "# Set input/output dimensions\n",
    "d_in = 784\n",
    "d_out = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "dense-commander",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Number of features to select\n",
    "max_features = 10"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "theoretical-imagination",
   "metadata": {},
   "source": [
    "# Global FS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "informed-cookie",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.2784\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.3027\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.3163\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.3135\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.3233\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.3311\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.3360\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.3385\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.3384\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.3483\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.3527\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.3572\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.3664\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.3725\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.3680\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.3907\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.3887\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.3892\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.3890\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.3892\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = -0.3958\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = -0.4029\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = -0.4028\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = -0.4077\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = -0.4156\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = -0.4155\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = -0.4122\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = -0.4210\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = -0.4167\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = -0.4335\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = -0.4385\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = -0.4425\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = -0.4521\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = -0.4586\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = -0.4537\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = -0.4521\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = -0.4721\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = -0.4660\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = -0.4641\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = -0.4746\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = -0.4764\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = -0.4829\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = -0.4895\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = -0.4829\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = -0.4968\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = -0.4955\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = -0.4955\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = -0.5045\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = -0.4973\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = -0.5012\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = -0.5029\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = -0.5164\n",
      "\n",
      "--------Epoch 53--------\n",
      "Val loss = -0.5193\n",
      "\n",
      "--------Epoch 54--------\n",
      "Val loss = -0.5162\n",
      "\n",
      "--------Epoch 55--------\n",
      "Val loss = -0.5263\n",
      "\n",
      "--------Epoch 56--------\n",
      "Val loss = -0.5233\n",
      "\n",
      "--------Epoch 57--------\n",
      "Val loss = -0.5297\n",
      "\n",
      "--------Epoch 58--------\n",
      "Val loss = -0.5343\n",
      "\n",
      "--------Epoch 59--------\n",
      "Val loss = -0.5290\n",
      "\n",
      "--------Epoch 60--------\n",
      "Val loss = -0.5391\n",
      "\n",
      "--------Epoch 61--------\n",
      "Val loss = -0.5441\n",
      "\n",
      "--------Epoch 62--------\n",
      "Val loss = -0.5421\n",
      "\n",
      "--------Epoch 63--------\n",
      "Val loss = -0.5548\n",
      "\n",
      "--------Epoch 64--------\n",
      "Val loss = -0.5471\n",
      "\n",
      "--------Epoch 65--------\n",
      "Val loss = -0.5542\n",
      "\n",
      "--------Epoch 66--------\n",
      "Val loss = -0.5498\n",
      "\n",
      "--------Epoch 67--------\n",
      "Val loss = -0.5571\n",
      "\n",
      "--------Epoch 68--------\n",
      "Val loss = -0.5630\n",
      "\n",
      "--------Epoch 69--------\n",
      "Val loss = -0.5636\n",
      "\n",
      "--------Epoch 70--------\n",
      "Val loss = -0.5668\n",
      "\n",
      "--------Epoch 71--------\n",
      "Val loss = -0.5831\n",
      "\n",
      "--------Epoch 72--------\n",
      "Val loss = -0.5773\n",
      "\n",
      "--------Epoch 73--------\n",
      "Val loss = -0.5805\n",
      "\n",
      "--------Epoch 74--------\n",
      "Val loss = -0.5752\n",
      "\n",
      "--------Epoch 75--------\n",
      "Val loss = -0.5857\n",
      "\n",
      "--------Epoch 76--------\n",
      "Val loss = -0.5873\n",
      "\n",
      "--------Epoch 77--------\n",
      "Val loss = -0.5860\n",
      "\n",
      "--------Epoch 78--------\n",
      "Val loss = -0.5881\n",
      "\n",
      "--------Epoch 79--------\n",
      "Val loss = -0.5890\n",
      "\n",
      "--------Epoch 80--------\n",
      "Val loss = -0.5976\n",
      "\n",
      "--------Epoch 81--------\n",
      "Val loss = -0.6014\n",
      "\n",
      "--------Epoch 82--------\n",
      "Val loss = -0.5970\n",
      "\n",
      "--------Epoch 83--------\n",
      "Val loss = -0.6078\n",
      "\n",
      "--------Epoch 84--------\n",
      "Val loss = -0.5973\n",
      "\n",
      "--------Epoch 85--------\n",
      "Val loss = -0.6001\n",
      "\n",
      "--------Epoch 86--------\n",
      "Val loss = -0.6065\n",
      "\n",
      "--------Epoch 87--------\n",
      "Val loss = -0.6104\n",
      "\n",
      "--------Epoch 88--------\n",
      "Val loss = -0.6076\n",
      "\n",
      "--------Epoch 89--------\n",
      "Val loss = -0.6174\n",
      "\n",
      "--------Epoch 90--------\n",
      "Val loss = -0.6211\n",
      "\n",
      "--------Epoch 91--------\n",
      "Val loss = -0.6237\n",
      "\n",
      "--------Epoch 92--------\n",
      "Val loss = -0.6179\n",
      "\n",
      "--------Epoch 93--------\n",
      "Val loss = -0.6230\n",
      "\n",
      "--------Epoch 94--------\n",
      "Val loss = -0.6272\n",
      "\n",
      "--------Epoch 95--------\n",
      "Val loss = -0.6279\n",
      "\n",
      "--------Epoch 96--------\n",
      "Val loss = -0.6319\n",
      "\n",
      "--------Epoch 97--------\n",
      "Val loss = -0.6308\n",
      "\n",
      "--------Epoch 98--------\n",
      "Val loss = -0.6394\n",
      "\n",
      "--------Epoch 99--------\n",
      "Val loss = -0.6314\n",
      "\n",
      "--------Epoch 100--------\n",
      "Val loss = -0.6386\n",
      "\n",
      "--------Epoch 101--------\n",
      "Val loss = -0.6344\n",
      "\n",
      "--------Epoch 102--------\n",
      "Val loss = -0.6460\n",
      "\n",
      "--------Epoch 103--------\n",
      "Val loss = -0.6413\n",
      "\n",
      "--------Epoch 104--------\n",
      "Val loss = -0.6449\n",
      "\n",
      "--------Epoch 105--------\n",
      "Val loss = -0.6454\n",
      "\n",
      "--------Epoch 106--------\n",
      "Val loss = -0.6518\n",
      "\n",
      "--------Epoch 107--------\n",
      "Val loss = -0.6510\n",
      "\n",
      "--------Epoch 108--------\n",
      "Val loss = -0.6643\n",
      "\n",
      "--------Epoch 109--------\n",
      "Val loss = -0.6637\n",
      "\n",
      "--------Epoch 110--------\n",
      "Val loss = -0.6637\n",
      "\n",
      "--------Epoch 111--------\n",
      "Val loss = -0.6703\n",
      "\n",
      "--------Epoch 112--------\n",
      "Val loss = -0.6676\n",
      "\n",
      "--------Epoch 113--------\n",
      "Val loss = -0.6645\n",
      "\n",
      "--------Epoch 114--------\n",
      "Val loss = -0.6679\n",
      "\n",
      "--------Epoch 115--------\n",
      "Val loss = -0.6744\n",
      "\n",
      "--------Epoch 116--------\n",
      "Val loss = -0.6808\n",
      "\n",
      "--------Epoch 117--------\n",
      "Val loss = -0.6833\n",
      "\n",
      "--------Epoch 118--------\n",
      "Val loss = -0.6864\n",
      "\n",
      "--------Epoch 119--------\n",
      "Val loss = -0.6874\n",
      "\n",
      "--------Epoch 120--------\n",
      "Val loss = -0.6904\n",
      "\n",
      "--------Epoch 121--------\n",
      "Val loss = -0.7023\n",
      "\n",
      "--------Epoch 122--------\n",
      "Val loss = -0.6994\n",
      "\n",
      "--------Epoch 123--------\n",
      "Val loss = -0.7032\n",
      "\n",
      "--------Epoch 124--------\n",
      "Val loss = -0.7023\n",
      "\n",
      "--------Epoch 125--------\n",
      "Val loss = -0.7022\n",
      "\n",
      "--------Epoch 126--------\n",
      "Val loss = -0.7119\n",
      "\n",
      "--------Epoch 127--------\n",
      "Val loss = -0.7087\n",
      "\n",
      "--------Epoch 128--------\n",
      "Val loss = -0.7163\n",
      "\n",
      "--------Epoch 129--------\n",
      "Val loss = -0.7172\n",
      "\n",
      "--------Epoch 130--------\n",
      "Val loss = -0.7254\n",
      "\n",
      "--------Epoch 131--------\n",
      "Val loss = -0.7179\n",
      "\n",
      "--------Epoch 132--------\n",
      "Val loss = -0.7208\n",
      "\n",
      "--------Epoch 133--------\n",
      "Val loss = -0.7241\n",
      "\n",
      "--------Epoch 134--------\n",
      "Val loss = -0.7209\n",
      "\n",
      "--------Epoch 135--------\n",
      "Val loss = -0.7370\n",
      "\n",
      "--------Epoch 136--------\n",
      "Val loss = -0.7313\n",
      "\n",
      "--------Epoch 137--------\n",
      "Val loss = -0.7375\n",
      "\n",
      "--------Epoch 138--------\n",
      "Val loss = -0.7392\n",
      "\n",
      "--------Epoch 139--------\n",
      "Val loss = -0.7361\n",
      "\n",
      "--------Epoch 140--------\n",
      "Val loss = -0.7336\n",
      "\n",
      "--------Epoch 141--------\n",
      "Val loss = -0.7424\n",
      "\n",
      "--------Epoch 142--------\n",
      "Val loss = -0.7446\n",
      "\n",
      "--------Epoch 143--------\n",
      "Val loss = -0.7476\n",
      "\n",
      "--------Epoch 144--------\n",
      "Val loss = -0.7443\n",
      "\n",
      "--------Epoch 145--------\n",
      "Val loss = -0.7492\n",
      "\n",
      "--------Epoch 146--------\n",
      "Val loss = -0.7484\n",
      "\n",
      "--------Epoch 147--------\n",
      "Val loss = -0.7477\n",
      "\n",
      "--------Epoch 148--------\n",
      "Val loss = -0.7531\n",
      "\n",
      "--------Epoch 149--------\n",
      "Val loss = -0.7550\n",
      "\n",
      "--------Epoch 150--------\n",
      "Val loss = -0.7459\n",
      "\n",
      "--------Epoch 151--------\n",
      "Val loss = -0.7557\n",
      "\n",
      "--------Epoch 152--------\n",
      "Val loss = -0.7596\n",
      "\n",
      "--------Epoch 153--------\n",
      "Val loss = -0.7547\n",
      "\n",
      "--------Epoch 154--------\n",
      "Val loss = -0.7558\n",
      "\n",
      "--------Epoch 155--------\n",
      "Val loss = -0.7533\n",
      "\n",
      "--------Epoch 156--------\n",
      "Val loss = -0.7594\n",
      "\n",
      "--------Epoch 157--------\n",
      "Val loss = -0.7555\n",
      "\n",
      "--------Epoch 158--------\n",
      "Val loss = -0.7601\n",
      "\n",
      "--------Epoch 159--------\n",
      "Val loss = -0.7571\n",
      "\n",
      "--------Epoch 160--------\n",
      "Val loss = -0.7569\n",
      "\n",
      "--------Epoch 161--------\n",
      "Val loss = -0.7645\n",
      "\n",
      "--------Epoch 162--------\n",
      "Val loss = -0.7643\n",
      "\n",
      "--------Epoch 163--------\n",
      "Val loss = -0.7651\n",
      "\n",
      "--------Epoch 164--------\n",
      "Val loss = -0.7617\n",
      "\n",
      "--------Epoch 165--------\n",
      "Val loss = -0.7621\n",
      "\n",
      "--------Epoch 166--------\n",
      "Val loss = -0.7577\n",
      "\n",
      "--------Epoch 167--------\n",
      "Val loss = -0.7614\n",
      "\n",
      "--------Epoch 168--------\n",
      "Val loss = -0.7680\n",
      "\n",
      "--------Epoch 169--------\n",
      "Val loss = -0.7664\n",
      "\n",
      "--------Epoch 170--------\n",
      "Val loss = -0.7599\n",
      "\n",
      "--------Epoch 171--------\n",
      "Val loss = -0.7698\n",
      "\n",
      "--------Epoch 172--------\n",
      "Val loss = -0.7665\n",
      "\n",
      "--------Epoch 173--------\n",
      "Val loss = -0.7662\n",
      "\n",
      "--------Epoch 174--------\n",
      "Val loss = -0.7671\n",
      "\n",
      "--------Epoch 175--------\n",
      "Val loss = -0.7698\n",
      "\n",
      "--------Epoch 176--------\n",
      "Val loss = -0.7683\n",
      "\n",
      "--------Epoch 177--------\n",
      "Val loss = -0.7670\n",
      "\n",
      "--------Epoch 178--------\n",
      "Val loss = -0.7708\n",
      "\n",
      "--------Epoch 179--------\n",
      "Val loss = -0.7715\n",
      "\n",
      "--------Epoch 180--------\n",
      "Val loss = -0.7667\n",
      "\n",
      "--------Epoch 181--------\n",
      "Val loss = -0.7700\n",
      "\n",
      "--------Epoch 182--------\n",
      "Val loss = -0.7649\n",
      "\n",
      "--------Epoch 183--------\n",
      "Val loss = -0.7678\n",
      "\n",
      "--------Epoch 184--------\n",
      "Val loss = -0.7726\n",
      "\n",
      "--------Epoch 185--------\n",
      "Val loss = -0.7715\n",
      "\n",
      "--------Epoch 186--------\n",
      "Val loss = -0.7687\n",
      "\n",
      "--------Epoch 187--------\n",
      "Val loss = -0.7702\n",
      "\n",
      "--------Epoch 188--------\n",
      "Val loss = -0.7714\n",
      "\n",
      "--------Epoch 189--------\n",
      "Val loss = -0.7716\n",
      "\n",
      "--------Epoch 190--------\n",
      "Val loss = -0.7752\n",
      "\n",
      "--------Epoch 191--------\n",
      "Val loss = -0.7674\n",
      "\n",
      "--------Epoch 192--------\n",
      "Val loss = -0.7666\n",
      "\n",
      "--------Epoch 193--------\n",
      "Val loss = -0.7711\n",
      "\n",
      "--------Epoch 194--------\n",
      "Val loss = -0.7697\n",
      "\n",
      "--------Epoch 195--------\n",
      "Val loss = -0.7713\n",
      "\n",
      "--------Epoch 196--------\n",
      "Val loss = -0.7722\n",
      "\n",
      "--------Epoch 197--------\n",
      "Val loss = -0.7679\n",
      "\n",
      "--------Epoch 198--------\n",
      "Val loss = -0.7727\n",
      "\n",
      "--------Epoch 199--------\n",
      "Val loss = -0.7715\n",
      "\n",
      "--------Epoch 200--------\n",
      "Val loss = -0.7702\n",
      "\n",
      "--------Epoch 201--------\n",
      "Val loss = -0.7697\n",
      "\n",
      "--------Epoch 202--------\n",
      "Val loss = -0.7641\n",
      "\n",
      "--------Epoch 203--------\n",
      "Val loss = -0.7673\n",
      "\n",
      "--------Epoch 204--------\n",
      "Val loss = -0.7722\n",
      "\n",
      "--------Epoch 205--------\n",
      "Val loss = -0.7680\n",
      "\n",
      "--------Epoch 206--------\n",
      "Val loss = -0.7697\n",
      "\n",
      "--------Epoch 207--------\n",
      "Val loss = -0.7655\n",
      "\n",
      "--------Epoch 208--------\n",
      "Val loss = -0.7690\n",
      "\n",
      "--------Epoch 209--------\n",
      "Val loss = -0.7661\n",
      "\n",
      "--------Epoch 210--------\n",
      "Val loss = -0.7697\n",
      "\n",
      "--------Epoch 211--------\n",
      "Val loss = -0.7680\n",
      "\n",
      "--------Epoch 212--------\n",
      "Val loss = -0.7696\n",
      "\n",
      "--------Epoch 213--------\n",
      "Val loss = -0.7702\n",
      "\n",
      "--------Epoch 214--------\n",
      "Val loss = -0.7702\n",
      "\n",
      "--------Epoch 215--------\n",
      "Val loss = -0.7649\n",
      "\n",
      "--------Epoch 216--------\n",
      "Val loss = -0.7694\n",
      "\n",
      "--------Epoch 217--------\n",
      "Val loss = -0.7644\n",
      "\n",
      "--------Epoch 218--------\n",
      "Val loss = -0.7637\n",
      "\n",
      "--------Epoch 219--------\n",
      "Val loss = -0.7652\n",
      "\n",
      "--------Epoch 220--------\n",
      "Val loss = -0.7619\n",
      "\n",
      "--------Epoch 221--------\n",
      "Val loss = -0.7651\n",
      "\n",
      "--------Epoch 222--------\n",
      "Val loss = -0.7662\n",
      "\n",
      "--------Epoch 223--------\n",
      "Val loss = -0.7618\n",
      "\n",
      "--------Epoch 224--------\n",
      "Val loss = -0.7597\n",
      "\n",
      "--------Epoch 225--------\n",
      "Val loss = -0.7637\n",
      "\n",
      "--------Epoch 226--------\n",
      "Val loss = -0.7593\n",
      "\n",
      "--------Epoch 227--------\n",
      "Val loss = -0.7636\n",
      "\n",
      "--------Epoch 228--------\n",
      "Val loss = -0.7595\n",
      "\n",
      "--------Epoch 229--------\n",
      "Val loss = -0.7605\n",
      "\n",
      "--------Epoch 230--------\n",
      "Val loss = -0.7624\n",
      "\n",
      "--------Epoch 231--------\n",
      "Val loss = -0.7619\n",
      "\n",
      "--------Epoch 232--------\n",
      "Val loss = -0.7577\n",
      "\n",
      "--------Epoch 233--------\n",
      "Val loss = -0.7576\n",
      "\n",
      "--------Epoch 234--------\n",
      "Val loss = -0.7611\n",
      "\n",
      "--------Epoch 235--------\n",
      "Val loss = -0.7623\n",
      "\n",
      "--------Epoch 236--------\n",
      "Val loss = -0.7595\n",
      "\n",
      "--------Epoch 237--------\n",
      "Val loss = -0.7576\n",
      "\n",
      "--------Epoch 238--------\n",
      "Val loss = -0.7573\n",
      "\n",
      "--------Epoch 239--------\n",
      "Val loss = -0.7533\n",
      "\n",
      "--------Epoch 240--------\n",
      "Val loss = -0.7593\n",
      "\n",
      "--------Epoch 241--------\n",
      "Val loss = -0.7569\n",
      "\n",
      "--------Epoch 242--------\n",
      "Val loss = -0.7515\n",
      "\n",
      "--------Epoch 243--------\n",
      "Val loss = -0.7569\n",
      "\n",
      "--------Epoch 244--------\n",
      "Val loss = -0.7563\n",
      "\n",
      "--------Epoch 245--------\n",
      "Val loss = -0.7564\n",
      "\n",
      "--------Epoch 246--------\n",
      "Val loss = -0.7576\n",
      "\n",
      "--------Epoch 247--------\n",
      "Val loss = -0.7576\n",
      "\n",
      "--------Epoch 248--------\n",
      "Val loss = -0.7483\n",
      "\n",
      "--------Epoch 249--------\n",
      "Val loss = -0.7524\n",
      "\n",
      "--------Epoch 250--------\n",
      "Val loss = -0.7518\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up model\n",
    "global_model = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_out))\n",
    "selector = ConcreteMask(d_in, max_features, append=True)\n",
    "globalfs = GlobalSelector(global_model, selector).to(device)\n",
    "\n",
    "# Train\n",
    "globalfs.fit(train_dataset,\n",
    "             val_dataset,\n",
    "             mbsize=128,\n",
    "             lr=1e-3,\n",
    "             nepochs=250,\n",
    "             loss_fn=nn.CrossEntropyLoss(),\n",
    "             val_loss_fn=NegAccuracy(),\n",
    "             start_temp=1.0,\n",
    "             end_temp=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "acting-bleeding",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test acc = 77.86\n"
     ]
    }
   ],
   "source": [
    "test_acc = globalfs.evaluate(test_dataset, Accuracy(), 1024)\n",
    "print(f'Test acc = {100*test_acc:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "danish-lodging",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAK3klEQVR4nO3dT6hc93mH8edbV1ZASUCqa6M6bpMGL2oKVcpFKbgUF9PU8cbOIiVeBBUCyiKGBLKoSRfx0pQmoYsSUGoRtaQOgcS1FqaJEAGTjfG1UW25amvXqIkiITVoYadQWXbeLu5xuZHvP8+c+SO/zwcuM3Nm7p2XQY/O3Dkz95eqQtK7368segBJ82HsUhPGLjVh7FITxi418avzvLMbs7vew5553qXUyv/yP7xeV7LRdVPFnuQe4G+AG4C/q6pHtrr9e9jDR3P3NHcpaQtP18lNr5v4aXySG4C/BT4O3AE8kOSOSX+epNma5nf2g8DLVfVKVb0OfBu4b5yxJI1tmthvBX6y7vK5YdsvSXI4yWqS1atcmeLuJE1jmtg3ehHgbe+9raojVbVSVSu72D3F3UmaxjSxnwNuW3f5A8D56caRNCvTxP4McHuSDyW5EfgUcHycsSSNbeJDb1X1RpIHge+zdujtaFW9ONpkkkY11XH2qnoSeHKkWSTNkG+XlZowdqkJY5eaMHapCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmjB2qYmplmxOchZ4DXgTeKOqVsYYStL4pop98MdV9bMRfo6kGfJpvNTEtLEX8IMkzyY5vNENkhxOsppk9SpXprw7SZOa9mn8nVV1PsnNwIkk/1ZVT62/QVUdAY4AvD/7asr7kzShqfbsVXV+OL0EPA4cHGMoSeObOPYke5K8763zwMeA02MNJmlc0zyNvwV4PMlbP+cfq+qfR5lK7xrfP39q0+v+9DcOzG0OTRF7Vb0C/N6Is0iaIQ+9SU0Yu9SEsUtNGLvUhLFLTYzxQRhpUx5eWx7u2aUmjF1qwtilJoxdasLYpSaMXWrC2KUmjF1qwtilJoxdasLYpSaMXWrC2KUmjF1qwtilJvw8u7a01Z+CBj+vfj1xzy41YexSE8YuNWHsUhPGLjVh7FITxi414XF2bWna4+izXLLZ9wC8M9vu2ZMcTXIpyel12/YlOZHkpeF072zHlDStnTyN/yZwzzXbHgJOVtXtwMnhsqQltm3sVfUUcPmazfcBx4bzx4D7xx1L0tgmfYHulqq6ADCc3rzZDZMcTrKaZPUqVya8O0nTmvmr8VV1pKpWqmplF7tnfXeSNjFp7BeT7AcYTi+NN5KkWZg09uPAoeH8IeCJccaRNCupqq1vkDwG3AXcBFwEvgz8E/Ad4DeBHwOfrKprX8R7m/dnX300d083cUMeT9ZOPV0nebUuZ6Prtn1TTVU9sMlVVitdR3y7rNSEsUtNGLvUhLFLTRi71IQfcb0OeGhNY3DPLjVh7FITxi41YexSE8YuNWHsUhPGLjVh7FITxi41YexSE8YuNWHsUhPGLjVh7FITxi414efZtaVZ/hlr/0T2fLlnl5owdqkJY5eaMHapCWOXmjB2qQljl5rwOLu2NMtj3R5Hn69t9+xJjia5lOT0um0PJ/lpklPD172zHVPStHbyNP6bwD0bbP9aVR0Yvp4cdyxJY9s29qp6Crg8h1kkzdA0L9A9mOT54Wn+3s1ulORwktUkq1e5MsXdSZrGpLF/HfgwcAC4AHxlsxtW1ZGqWqmqlV3snvDuJE1rotir6mJVvVlVvwC+ARwcdyxJY5so9iT71138BHB6s9tKWg7bHmdP8hhwF3BTknPAl4G7khwACjgLfHZ2I0oaw7axV9UDG2x+dAazSJoh3y4rNWHsUhPGLjVh7FITxi414Udcm/PPOffhnl1qwtilJoxdasLYpSaMXWrC2KUmjF1qwuPszXkcvQ/37FITxi41YexSE8YuNWHsUhPGLjVh7FITHme/DviZc43BPbvUhLFLTRi71ISxS00Yu9SEsUtNGLvUhMfZrwMeR9cYtt2zJ7ktyQ+TnEnyYpLPD9v3JTmR5KXhdO/sx5U0qZ08jX8D+GJV/Q7wB8DnktwBPAScrKrbgZPDZUlLatvYq+pCVT03nH8NOAPcCtwHHBtudgy4f0YzShrBO3qBLskHgY8ATwO3VNUFWPsPAbh5k+85nGQ1yepVrkw5rqRJ7Tj2JO8Fvgt8oape3en3VdWRqlqpqpVd7J5kRkkj2FHsSXaxFvq3qup7w+aLSfYP1+8HLs1mRElj2PbQW5IAjwJnquqr6646DhwCHhlOn5jJhO8CfkRVy2Anx9nvBD4NvJDk1LDtS6xF/p0knwF+DHxyJhNKGsW2sVfVj4BscvXd444jaVZ8u6zUhLFLTRi71ISxS00Yu9SEH3GdA4+jaxm4Z5eaMHapCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmjB2qQljl5rYNvYktyX5YZIzSV5M8vlh+8NJfprk1PB17+zHlTSpnSwS8Qbwxap6Lsn7gGeTnBiu+1pV/fXsxpM0lp2sz34BuDCcfy3JGeDWWQ8maVzv6Hf2JB8EPgI8PWx6MMnzSY4m2bvJ9xxOsppk9SpXpptW0sR2HHuS9wLfBb5QVa8CXwc+DBxgbc//lY2+r6qOVNVKVa3sYvf0E0uayI5iT7KLtdC/VVXfA6iqi1X1ZlX9AvgGcHB2Y0qa1k5ejQ/wKHCmqr66bvv+dTf7BHB6/PEkjWUnr8bfCXwaeCHJqWHbl4AHkhwACjgLfHYG80kayU5ejf8RkA2uenL8cSTNiu+gk5owdqkJY5eaMHapCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmjB2qQljl5owdqmJVNX87iz5b+C/1m26CfjZ3AZ4Z5Z1tmWdC5xtUmPO9ltV9esbXTHX2N9258lqVa0sbIAtLOtsyzoXONuk5jWbT+OlJoxdamLRsR9Z8P1vZVlnW9a5wNkmNZfZFvo7u6T5WfSeXdKcGLvUxEJiT3JPkn9P8nKShxYxw2aSnE3ywrAM9eqCZzma5FKS0+u27UtyIslLw+mGa+wtaLalWMZ7i2XGF/rYLXr587n/zp7kBuA/gD8BzgHPAA9U1b/OdZBNJDkLrFTVwt+AkeSPgJ8Df19Vvzts+yvgclU9MvxHubeq/mJJZnsY+Pmil/EeVivav36ZceB+4M9Z4GO3xVx/xhwet0Xs2Q8CL1fVK1X1OvBt4L4FzLH0quop4PI1m+8Djg3nj7H2j2XuNpltKVTVhap6bjj/GvDWMuMLfey2mGsuFhH7rcBP1l0+x3Kt917AD5I8m+TwoofZwC1VdQHW/vEANy94nmttu4z3PF2zzPjSPHaTLH8+rUXEvtFSUst0/O/Oqvp94OPA54anq9qZHS3jPS8bLDO+FCZd/nxai4j9HHDbussfAM4vYI4NVdX54fQS8DjLtxT1xbdW0B1OLy14nv+3TMt4b7TMOEvw2C1y+fNFxP4McHuSDyW5EfgUcHwBc7xNkj3DCyck2QN8jOVbivo4cGg4fwh4YoGz/JJlWcZ7s2XGWfBjt/Dlz6tq7l/Avay9Iv+fwF8uYoZN5vpt4F+GrxcXPRvwGGtP666y9ozoM8CvASeBl4bTfUs02z8ALwDPsxbW/gXN9oes/Wr4PHBq+Lp30Y/dFnPN5XHz7bJSE76DTmrC2KUmjF1qwtilJoxdasLYpSaMXWri/wCKFG0rCtAxvgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "inds = selector.logits.argmax(dim=1).cpu().data.numpy()\n",
    "zeros = np.zeros(784)\n",
    "zeros[inds] = 1\n",
    "plt.figure()\n",
    "plt.imshow(zeros.reshape(28, 28))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "characteristic-drive",
   "metadata": {},
   "source": [
    "# Pretrain model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "black-breeding",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.2443\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.2638\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.2896\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.2932\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.2976\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.2992\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.2917\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.3107\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.3035\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.3000\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.3024\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.3009\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.3092\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.3108\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.3161\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.3097\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.3183\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.3084\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.3117\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.3147\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = -0.3257\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = -0.3161\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = -0.3084\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = -0.3152\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = -0.3179\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = -0.3221\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = -0.3196\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = -0.3166\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = -0.3155\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = -0.3269\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = -0.3190\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = -0.3212\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = -0.3230\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = -0.3221\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = -0.3201\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = -0.3215\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = -0.3295\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = -0.3303\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = -0.3295\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = -0.3326\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = -0.3283\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = -0.3141\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = -0.3294\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = -0.3270\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = -0.3276\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = -0.3229\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = -0.3288\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = -0.3282\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = -0.3230\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = -0.3301\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = -0.3291\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = -0.3298\n",
      "\n",
      "--------Epoch 53--------\n",
      "Val loss = -0.3308\n",
      "\n",
      "--------Epoch 54--------\n",
      "Val loss = -0.3298\n",
      "\n",
      "--------Epoch 55--------\n",
      "Val loss = -0.3220\n",
      "\n",
      "--------Epoch 56--------\n",
      "Val loss = -0.3312\n",
      "\n",
      "--------Epoch 57--------\n",
      "Val loss = -0.3368\n",
      "\n",
      "--------Epoch 58--------\n",
      "Val loss = -0.3317\n",
      "\n",
      "--------Epoch 59--------\n",
      "Val loss = -0.3343\n",
      "\n",
      "--------Epoch 60--------\n",
      "Val loss = -0.3385\n",
      "\n",
      "--------Epoch 61--------\n",
      "Val loss = -0.3314\n",
      "\n",
      "--------Epoch 62--------\n",
      "Val loss = -0.3333\n",
      "\n",
      "--------Epoch 63--------\n",
      "Val loss = -0.3317\n",
      "\n",
      "--------Epoch 64--------\n",
      "Val loss = -0.3361\n",
      "\n",
      "--------Epoch 65--------\n",
      "Val loss = -0.3265\n",
      "\n",
      "--------Epoch 66--------\n",
      "Val loss = -0.3381\n",
      "\n",
      "--------Epoch 67--------\n",
      "Val loss = -0.3282\n",
      "\n",
      "--------Epoch 68--------\n",
      "Val loss = -0.3337\n",
      "\n",
      "--------Epoch 69--------\n",
      "Val loss = -0.3435\n",
      "\n",
      "--------Epoch 70--------\n",
      "Val loss = -0.3264\n",
      "\n",
      "--------Epoch 71--------\n",
      "Val loss = -0.3379\n",
      "\n",
      "--------Epoch 72--------\n",
      "Val loss = -0.3409\n",
      "\n",
      "--------Epoch 73--------\n",
      "Val loss = -0.3366\n",
      "\n",
      "--------Epoch 74--------\n",
      "Val loss = -0.3264\n",
      "\n",
      "--------Epoch 75--------\n",
      "Val loss = -0.3398\n",
      "\n",
      "--------Epoch 76--------\n",
      "Val loss = -0.3366\n",
      "\n",
      "--------Epoch 77--------\n",
      "Val loss = -0.3378\n",
      "\n",
      "--------Epoch 78--------\n",
      "Val loss = -0.3321\n",
      "\n",
      "--------Epoch 79--------\n",
      "Val loss = -0.3332\n",
      "\n",
      "--------Epoch 80--------\n",
      "Val loss = -0.3377\n",
      "\n",
      "--------Epoch 81--------\n",
      "Val loss = -0.3373\n",
      "\n",
      "--------Epoch 82--------\n",
      "Val loss = -0.3381\n",
      "\n",
      "--------Epoch 83--------\n",
      "Val loss = -0.3405\n",
      "\n",
      "--------Epoch 84--------\n",
      "Val loss = -0.3355\n",
      "\n",
      "--------Epoch 85--------\n",
      "Val loss = -0.3273\n",
      "\n",
      "--------Epoch 86--------\n",
      "Val loss = -0.3373\n",
      "\n",
      "--------Epoch 87--------\n",
      "Val loss = -0.3271\n",
      "\n",
      "--------Epoch 88--------\n",
      "Val loss = -0.3319\n",
      "\n",
      "--------Epoch 89--------\n",
      "Val loss = -0.3408\n",
      "\n",
      "--------Epoch 90--------\n",
      "Val loss = -0.3422\n",
      "\n",
      "--------Epoch 91--------\n",
      "Val loss = -0.3363\n",
      "\n",
      "--------Epoch 92--------\n",
      "Val loss = -0.3379\n",
      "\n",
      "--------Epoch 93--------\n",
      "Val loss = -0.3403\n",
      "\n",
      "--------Epoch 94--------\n",
      "Val loss = -0.3455\n",
      "\n",
      "--------Epoch 95--------\n",
      "Val loss = -0.3331\n",
      "\n",
      "--------Epoch 96--------\n",
      "Val loss = -0.3380\n",
      "\n",
      "--------Epoch 97--------\n",
      "Val loss = -0.3491\n",
      "\n",
      "--------Epoch 98--------\n",
      "Val loss = -0.3372\n",
      "\n",
      "--------Epoch 99--------\n",
      "Val loss = -0.3397\n",
      "\n",
      "--------Epoch 100--------\n",
      "Val loss = -0.3395\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up model\n",
    "model = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_out))\n",
    "mask_layer = MaskLayer(append=True)\n",
    "pretrain = Pretrainer(model, mask_layer).to(device)\n",
    "\n",
    "# Pretrain\n",
    "pretrain.fit(train_dataset,\n",
    "             val_dataset,\n",
    "             mbsize=128,\n",
    "             lr=1e-3,\n",
    "             nepochs=100,\n",
    "             max_features=max_features,\n",
    "             loss_fn=nn.CrossEntropyLoss(),\n",
    "             val_loss_fn=NegAccuracy())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "polish-statement",
   "metadata": {},
   "source": [
    "# REINFORCE training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "boxed-prague",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.distributions import Categorical\n",
    "from torch.utils.data import DataLoader\n",
    "from utils import restore_parameters\n",
    "from copy import deepcopy\n",
    "\n",
    "\n",
    "def make_onehot(x):\n",
    "    '''Make an approximately one-hot vector one-hot.'''\n",
    "    argmax = torch.argmax(x, dim=1)\n",
    "    onehot = torch.zeros(x.shape, dtype=x.dtype, device=x.device)\n",
    "    onehot[torch.arange(len(x)), argmax] = 1\n",
    "    return onehot\n",
    "\n",
    "\n",
    "def logsoftmax(x, dim):\n",
    "    m = torch.max(x, dim=dim, keepdim=True).values\n",
    "    x = x - m\n",
    "    return x - torch.log(torch.sum(torch.exp(x), dim=dim, keepdim=True))\n",
    "\n",
    "\n",
    "class DiscreteSampler(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        \n",
    "    def forward(self, logits):\n",
    "        dist = Categorical(logits=logits)\n",
    "        inds = dist.sample()\n",
    "        onehot = torch.zeros(logits.shape, dtype=logits.dtype, device=logits.device)\n",
    "        onehot[torch.arange(len(logits)), inds] = 1\n",
    "        return inds, onehot\n",
    "\n",
    "\n",
    "class GreedyAdaptiveFS(nn.Module):\n",
    "    '''\n",
    "    Greedy adaptive feature selection.\n",
    "    '''\n",
    "\n",
    "    def __init__(self, selector, model, mask_layer):\n",
    "        super().__init__()\n",
    "        self.selector = selector\n",
    "        self.model = model\n",
    "        self.mask_layer = mask_layer\n",
    "        self.sampler = DiscreteSampler()\n",
    "    \n",
    "    def fit(self,\n",
    "            train,\n",
    "            val,\n",
    "            mbsize,\n",
    "            lr,\n",
    "            nepochs,\n",
    "            max_features,\n",
    "            loss_fn,\n",
    "            val_loss_fn=None,\n",
    "            train_model=True,\n",
    "            train_selector=True,\n",
    "            argmax=True,\n",
    "            no_repeats=True,\n",
    "            validation_mode='final',\n",
    "            verbose=True):\n",
    "        '''\n",
    "        Train models.\n",
    "        '''\n",
    "        # Set up data loaders.\n",
    "        train_loader = DataLoader(\n",
    "            train, batch_size=mbsize, shuffle=True, pin_memory=True,\n",
    "            drop_last=True, num_workers=4)\n",
    "        val_loader = DataLoader(\n",
    "            val, batch_size=mbsize, shuffle=False, pin_memory=True,\n",
    "            drop_last=False, num_workers=4)\n",
    "        \n",
    "        # More setup.\n",
    "        selector = self.selector\n",
    "        model = self.model\n",
    "        mask_layer = self.mask_layer\n",
    "        sampler = self.sampler\n",
    "        device = next(model.parameters()).device\n",
    "        mse_loss_fn = nn.MSELoss()\n",
    "        assert validation_mode in ('final', 'mean')\n",
    "        assert train_model or train_selector\n",
    "        if train_model:\n",
    "            model_opt = optim.Adam(model.parameters(), lr=lr)\n",
    "        if train_selector:\n",
    "            selector_opt = optim.Adam(selector.parameters(), lr=lr)\n",
    "        if val_loss_fn is None:\n",
    "            val_loss_fn = deepcopy(loss_fn)\n",
    "            \n",
    "        # Fix loss function reduction.\n",
    "        if loss_fn.reduction == 'mean':\n",
    "            loss_fn.reduction = 'none'\n",
    "        else:\n",
    "            val_loss_fn.reduction = 'mean'\n",
    "\n",
    "        # For tracking best model.\n",
    "        best_model = None\n",
    "        best_selector = None\n",
    "        best_loss = float('inf')\n",
    "        \n",
    "        for epoch in range(nepochs):\n",
    "            for x, y in train_loader:\n",
    "                # Move to device.\n",
    "                x = x.to(device)\n",
    "                y = y.to(device)\n",
    "                \n",
    "                # Setup.\n",
    "                m_hard = torch.zeros(x.shape, dtype=x.dtype, device=device)\n",
    "                total_loss = 0\n",
    "                \n",
    "                for i in range(max_features):\n",
    "                    # Evaluate selector model.\n",
    "                    x_masked = mask_layer(x, m_hard)\n",
    "                    logits = selector(x_masked)\n",
    "                    \n",
    "                    # Take actions.\n",
    "                    inds, hard = sampler(logits)\n",
    "                    m_hard = torch.max(m_hard, hard)\n",
    "                    \n",
    "                    # Evaluate model.\n",
    "                    x_masked = mask_layer(x, m_hard)\n",
    "                    pred = model(x_masked)\n",
    "                    \n",
    "                    # Calculate loss.\n",
    "                    model_loss = loss_fn(pred, y)\n",
    "                    selector_loss = model_loss.detach() * logsoftmax(logits, dim=1)[torch.arange(len(logits)), inds]\n",
    "                    loss = torch.mean(model_loss) + torch.mean(selector_loss)\n",
    "                    total_loss = total_loss + loss\n",
    "                    \n",
    "                # Take gradient step.\n",
    "                total_loss = total_loss / max_features\n",
    "                total_loss.backward()\n",
    "                if train_model:\n",
    "                    model_opt.step()\n",
    "                if train_selector:\n",
    "                    selector_opt.step()\n",
    "                model.zero_grad()\n",
    "                selector.zero_grad()\n",
    "                \n",
    "            # Calculate validation loss.\n",
    "            with torch.no_grad():\n",
    "                # For mean loss.\n",
    "                val_loss = 0\n",
    "                n = 0\n",
    "\n",
    "                for x, y in val_loader:\n",
    "                    # Move to device.\n",
    "                    x = x.to(device)\n",
    "                    y = y.to(device)\n",
    "\n",
    "                    # Setup.\n",
    "                    m_hard = torch.zeros(x.shape, dtype=x.dtype, device=device)\n",
    "                    total_loss = 0\n",
    "\n",
    "                    for i in range(max_features):\n",
    "                        # Evaluate selector model.\n",
    "                        x_masked = mask_layer(x, m_hard)\n",
    "                        logits = selector(x_masked)\n",
    "\n",
    "                        # Get selections.\n",
    "                        if no_repeats:\n",
    "                            logits = logits - 1e6 * m_hard\n",
    "                        if argmax:\n",
    "                            hard = make_onehot(logits)\n",
    "                        else:\n",
    "                            _, hard = sampler(logits)\n",
    "\n",
    "                        # Evaluate model.\n",
    "                        m_hard = torch.max(m_hard, hard)\n",
    "                        x_masked = mask_layer(x, m_hard)\n",
    "                        pred = model(x_masked)\n",
    "\n",
    "                        # Calculate loss.\n",
    "                        loss = val_loss_fn(pred, y).item()\n",
    "                        total_loss = total_loss + loss\n",
    "                        \n",
    "                    # Update mean loss.\n",
    "                    if validation_mode == 'final':\n",
    "                        batch_loss = loss\n",
    "                    elif validation_mode == 'mean':\n",
    "                        batch_loss = total_loss / max_features\n",
    "                    val_loss = (\n",
    "                        (batch_loss * len(x) + val_loss * n) / (len(x) + n))\n",
    "                    n += len(x)\n",
    "            \n",
    "            # Print progress.\n",
    "            print(f'{\"-\"*8}Epoch {epoch+1}{\"-\"*8}')\n",
    "            print(f'Val loss = {val_loss:.4f}\\n')\n",
    "\n",
    "            # See if best model.\n",
    "            if val_loss < best_loss:\n",
    "                best_loss = val_loss\n",
    "                best_model = deepcopy(model)\n",
    "                best_selector = deepcopy(selector)\n",
    "\n",
    "        # Restore best parameters.\n",
    "        if best_model:\n",
    "            restore_parameters(model, best_model)\n",
    "        if best_selector:\n",
    "            restore_parameters(selector, best_selector)\n",
    "\n",
    "    def forward(self, x, max_features, argmax=True, no_repeats=True):\n",
    "        '''\n",
    "        Make predictions using selected features.\n",
    "\n",
    "        Args:\n",
    "          x: input data (torch.Tensor).\n",
    "          max_features: max features to observe.\n",
    "          argmax: whether to select the next feature using the max probability.\n",
    "          no_repeats: whether to ensure that no repeated selections occur.\n",
    "        '''\n",
    "        # Setup.\n",
    "        selector = self.selector\n",
    "        model = self.model\n",
    "        mask_layer = self.mask_layer\n",
    "        sampler = self.sampler\n",
    "        device = next(model.parameters()).device\n",
    "        m_hard = torch.zeros(x.shape, dtype=x.dtype, device=device)\n",
    "\n",
    "        for i in range(max_features):\n",
    "            # Evaluate selector model.\n",
    "            x_masked = mask_layer(x, m_hard)\n",
    "            logits = selector(x_masked)\n",
    "\n",
    "            # Update selections.\n",
    "            if no_repeats:\n",
    "                logits = logits - 1e6 * m_hard\n",
    "            if argmax:\n",
    "                hard = make_onehot(logits)\n",
    "            else:\n",
    "                _, hard = sampler(logits)\n",
    "            m_hard = torch.max(m_hard, hard)\n",
    "\n",
    "        # Make predictions.\n",
    "        x_masked = mask_layer(x, m_hard)\n",
    "        pred = model(x_masked)\n",
    "        return pred\n",
    "\n",
    "    def evaluate(self, dataset, max_features, loss_fn, batch_size, argmax=True,\n",
    "                 no_repeats=True):\n",
    "        '''\n",
    "        Evaluate mean performance across a dataset.\n",
    "        '''\n",
    "        # Setup.\n",
    "        device = next(self.model.parameters()).device\n",
    "        loader = DataLoader(\n",
    "            dataset, batch_size=batch_size, shuffle=False, pin_memory=True,\n",
    "            drop_last=False, num_workers=4)\n",
    "\n",
    "        # For calculating mean loss.\n",
    "        mean_loss = 0\n",
    "        n = 0\n",
    "\n",
    "        with torch.no_grad():\n",
    "            for x, y in loader:\n",
    "                # Move to GPU.\n",
    "                x = x.to(device)\n",
    "                y = y.to(device)\n",
    "\n",
    "                # Calculate loss.\n",
    "                pred = self.forward(x, max_features, argmax, no_repeats)\n",
    "                loss = loss_fn(pred, y).item()\n",
    "\n",
    "                # Update average.\n",
    "                mean_loss = (mean_loss * n + loss * len(x)) / (n + len(x))\n",
    "                n += len(x)\n",
    "\n",
    "        return mean_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "figured-concert",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.5638\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.5759\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.6213\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.6349\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.6439\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.6689\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.6757\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.6625\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.6672\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.6822\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.6780\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.6799\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.6842\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.6846\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.6873\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.6899\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.6808\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.6916\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.6865\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.6954\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = -0.6978\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = -0.7037\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = -0.7060\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = -0.6847\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = -0.6999\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = -0.6984\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = -0.6967\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = -0.6984\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = -0.7007\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = -0.7045\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = -0.6940\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = -0.7058\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = -0.6908\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = -0.6930\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = -0.6957\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = -0.6935\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = -0.6967\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = -0.6983\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = -0.6937\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = -0.6927\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = -0.7011\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = -0.7025\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = -0.6954\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = -0.6918\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = -0.6937\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = -0.6934\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = -0.6966\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = -0.6906\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = -0.6894\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = -0.6908\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = -0.6969\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = -0.6952\n",
      "\n",
      "--------Epoch 53--------\n",
      "Val loss = -0.6920\n",
      "\n",
      "--------Epoch 54--------\n",
      "Val loss = -0.7069\n",
      "\n",
      "--------Epoch 55--------\n",
      "Val loss = -0.7240\n",
      "\n",
      "--------Epoch 56--------\n",
      "Val loss = -0.6974\n",
      "\n",
      "--------Epoch 57--------\n",
      "Val loss = -0.6957\n",
      "\n",
      "--------Epoch 58--------\n",
      "Val loss = -0.6999\n",
      "\n",
      "--------Epoch 59--------\n",
      "Val loss = -0.6992\n",
      "\n",
      "--------Epoch 60--------\n",
      "Val loss = -0.7005\n",
      "\n",
      "--------Epoch 61--------\n",
      "Val loss = -0.7110\n",
      "\n",
      "--------Epoch 62--------\n",
      "Val loss = -0.6944\n",
      "\n",
      "--------Epoch 63--------\n",
      "Val loss = -0.7091\n",
      "\n",
      "--------Epoch 64--------\n",
      "Val loss = -0.7067\n",
      "\n",
      "--------Epoch 65--------\n",
      "Val loss = -0.7043\n",
      "\n",
      "--------Epoch 66--------\n",
      "Val loss = -0.7014\n",
      "\n",
      "--------Epoch 67--------\n",
      "Val loss = -0.7063\n",
      "\n",
      "--------Epoch 68--------\n",
      "Val loss = -0.7067\n",
      "\n",
      "--------Epoch 69--------\n",
      "Val loss = -0.7039\n",
      "\n",
      "--------Epoch 70--------\n",
      "Val loss = -0.7008\n",
      "\n",
      "--------Epoch 71--------\n",
      "Val loss = -0.7049\n",
      "\n",
      "--------Epoch 72--------\n",
      "Val loss = -0.7110\n",
      "\n",
      "--------Epoch 73--------\n",
      "Val loss = -0.6981\n",
      "\n",
      "--------Epoch 74--------\n",
      "Val loss = -0.7087\n",
      "\n",
      "--------Epoch 75--------\n",
      "Val loss = -0.7050\n",
      "\n",
      "--------Epoch 76--------\n",
      "Val loss = -0.7092\n",
      "\n",
      "--------Epoch 77--------\n",
      "Val loss = -0.6998\n",
      "\n",
      "--------Epoch 78--------\n",
      "Val loss = -0.7011\n",
      "\n",
      "--------Epoch 79--------\n",
      "Val loss = -0.7016\n",
      "\n",
      "--------Epoch 80--------\n",
      "Val loss = -0.7079\n",
      "\n",
      "--------Epoch 81--------\n",
      "Val loss = -0.7000\n",
      "\n",
      "--------Epoch 82--------\n",
      "Val loss = -0.7266\n",
      "\n",
      "--------Epoch 83--------\n",
      "Val loss = -0.7036\n",
      "\n",
      "--------Epoch 84--------\n",
      "Val loss = -0.6980\n",
      "\n",
      "--------Epoch 85--------\n",
      "Val loss = -0.7155\n",
      "\n",
      "--------Epoch 86--------\n",
      "Val loss = -0.7078\n",
      "\n",
      "--------Epoch 87--------\n",
      "Val loss = -0.7196\n",
      "\n",
      "--------Epoch 88--------\n",
      "Val loss = -0.7107\n",
      "\n",
      "--------Epoch 89--------\n",
      "Val loss = -0.7343\n",
      "\n",
      "--------Epoch 90--------\n",
      "Val loss = -0.7103\n",
      "\n",
      "--------Epoch 91--------\n",
      "Val loss = -0.7074\n",
      "\n",
      "--------Epoch 92--------\n",
      "Val loss = -0.7028\n",
      "\n",
      "--------Epoch 93--------\n",
      "Val loss = -0.6958\n",
      "\n",
      "--------Epoch 94--------\n",
      "Val loss = -0.6989\n",
      "\n",
      "--------Epoch 95--------\n",
      "Val loss = -0.6989\n",
      "\n",
      "--------Epoch 96--------\n",
      "Val loss = -0.7008\n",
      "\n",
      "--------Epoch 97--------\n",
      "Val loss = -0.7049\n",
      "\n",
      "--------Epoch 98--------\n",
      "Val loss = -0.6997\n",
      "\n",
      "--------Epoch 99--------\n",
      "Val loss = -0.7034\n",
      "\n",
      "--------Epoch 100--------\n",
      "Val loss = -0.7056\n",
      "\n",
      "--------Epoch 101--------\n",
      "Val loss = -0.7111\n",
      "\n",
      "--------Epoch 102--------\n",
      "Val loss = -0.7069\n",
      "\n",
      "--------Epoch 103--------\n",
      "Val loss = -0.6990\n",
      "\n",
      "--------Epoch 104--------\n",
      "Val loss = -0.7031\n",
      "\n",
      "--------Epoch 105--------\n",
      "Val loss = -0.7135\n",
      "\n",
      "--------Epoch 106--------\n",
      "Val loss = -0.7065\n",
      "\n",
      "--------Epoch 107--------\n",
      "Val loss = -0.7125\n",
      "\n",
      "--------Epoch 108--------\n",
      "Val loss = -0.7132\n",
      "\n",
      "--------Epoch 109--------\n",
      "Val loss = -0.7113\n",
      "\n",
      "--------Epoch 110--------\n",
      "Val loss = -0.7131\n",
      "\n",
      "--------Epoch 111--------\n",
      "Val loss = -0.7132\n",
      "\n",
      "--------Epoch 112--------\n",
      "Val loss = -0.7135\n",
      "\n",
      "--------Epoch 113--------\n",
      "Val loss = -0.7062\n",
      "\n",
      "--------Epoch 114--------\n",
      "Val loss = -0.7062\n",
      "\n",
      "--------Epoch 115--------\n",
      "Val loss = -0.7053\n",
      "\n",
      "--------Epoch 116--------\n",
      "Val loss = -0.7025\n",
      "\n",
      "--------Epoch 117--------\n",
      "Val loss = -0.7045\n",
      "\n",
      "--------Epoch 118--------\n",
      "Val loss = -0.7069\n",
      "\n",
      "--------Epoch 119--------\n",
      "Val loss = -0.7071\n",
      "\n",
      "--------Epoch 120--------\n",
      "Val loss = -0.7015\n",
      "\n",
      "--------Epoch 121--------\n",
      "Val loss = -0.7084\n",
      "\n",
      "--------Epoch 122--------\n",
      "Val loss = -0.7074\n",
      "\n",
      "--------Epoch 123--------\n",
      "Val loss = -0.7086\n",
      "\n",
      "--------Epoch 124--------\n",
      "Val loss = -0.7061\n",
      "\n",
      "--------Epoch 125--------\n",
      "Val loss = -0.7018\n",
      "\n",
      "--------Epoch 126--------\n",
      "Val loss = -0.6996\n",
      "\n",
      "--------Epoch 127--------\n",
      "Val loss = -0.7050\n",
      "\n",
      "--------Epoch 128--------\n",
      "Val loss = -0.7060\n",
      "\n",
      "--------Epoch 129--------\n",
      "Val loss = -0.7048\n",
      "\n",
      "--------Epoch 130--------\n",
      "Val loss = -0.7048\n",
      "\n",
      "--------Epoch 131--------\n",
      "Val loss = -0.7136\n",
      "\n",
      "--------Epoch 132--------\n",
      "Val loss = -0.7053\n",
      "\n",
      "--------Epoch 133--------\n",
      "Val loss = -0.7077\n",
      "\n",
      "--------Epoch 134--------\n",
      "Val loss = -0.7086\n",
      "\n",
      "--------Epoch 135--------\n",
      "Val loss = -0.7045\n",
      "\n",
      "--------Epoch 136--------\n",
      "Val loss = -0.7087\n",
      "\n",
      "--------Epoch 137--------\n",
      "Val loss = -0.6906\n",
      "\n",
      "--------Epoch 138--------\n",
      "Val loss = -0.7034\n",
      "\n",
      "--------Epoch 139--------\n",
      "Val loss = -0.7042\n",
      "\n",
      "--------Epoch 140--------\n",
      "Val loss = -0.7068\n",
      "\n",
      "--------Epoch 141--------\n",
      "Val loss = -0.7173\n",
      "\n",
      "--------Epoch 142--------\n",
      "Val loss = -0.7153\n",
      "\n",
      "--------Epoch 143--------\n",
      "Val loss = -0.7104\n",
      "\n",
      "--------Epoch 144--------\n",
      "Val loss = -0.7163\n",
      "\n",
      "--------Epoch 145--------\n",
      "Val loss = -0.7091\n",
      "\n",
      "--------Epoch 146--------\n",
      "Val loss = -0.7152\n",
      "\n",
      "--------Epoch 147--------\n",
      "Val loss = -0.7139\n",
      "\n",
      "--------Epoch 148--------\n",
      "Val loss = -0.7163\n",
      "\n",
      "--------Epoch 149--------\n",
      "Val loss = -0.7045\n",
      "\n",
      "--------Epoch 150--------\n",
      "Val loss = -0.7101\n",
      "\n",
      "--------Epoch 151--------\n",
      "Val loss = -0.7094\n",
      "\n",
      "--------Epoch 152--------\n",
      "Val loss = -0.7083\n",
      "\n",
      "--------Epoch 153--------\n",
      "Val loss = -0.6980\n",
      "\n",
      "--------Epoch 154--------\n",
      "Val loss = -0.7037\n",
      "\n",
      "--------Epoch 155--------\n",
      "Val loss = -0.7090\n",
      "\n",
      "--------Epoch 156--------\n",
      "Val loss = -0.7080\n",
      "\n",
      "--------Epoch 157--------\n",
      "Val loss = -0.7107\n",
      "\n",
      "--------Epoch 158--------\n",
      "Val loss = -0.7099\n",
      "\n",
      "--------Epoch 159--------\n",
      "Val loss = -0.7118\n",
      "\n",
      "--------Epoch 160--------\n",
      "Val loss = -0.7089\n",
      "\n",
      "--------Epoch 161--------\n",
      "Val loss = -0.7072\n",
      "\n",
      "--------Epoch 162--------\n",
      "Val loss = -0.7104\n",
      "\n",
      "--------Epoch 163--------\n",
      "Val loss = -0.7062\n",
      "\n",
      "--------Epoch 164--------\n",
      "Val loss = -0.7064\n",
      "\n",
      "--------Epoch 165--------\n",
      "Val loss = -0.7096\n",
      "\n",
      "--------Epoch 166--------\n",
      "Val loss = -0.7029\n",
      "\n",
      "--------Epoch 167--------\n",
      "Val loss = -0.7052\n",
      "\n",
      "--------Epoch 168--------\n",
      "Val loss = -0.7056\n",
      "\n",
      "--------Epoch 169--------\n",
      "Val loss = -0.7093\n",
      "\n",
      "--------Epoch 170--------\n",
      "Val loss = -0.7093\n",
      "\n",
      "--------Epoch 171--------\n",
      "Val loss = -0.7146\n",
      "\n",
      "--------Epoch 172--------\n",
      "Val loss = -0.7098\n",
      "\n",
      "--------Epoch 173--------\n",
      "Val loss = -0.7122\n",
      "\n",
      "--------Epoch 174--------\n",
      "Val loss = -0.7106\n",
      "\n",
      "--------Epoch 175--------\n",
      "Val loss = -0.7062\n",
      "\n",
      "--------Epoch 176--------\n",
      "Val loss = -0.7141\n",
      "\n",
      "--------Epoch 177--------\n",
      "Val loss = -0.7171\n",
      "\n",
      "--------Epoch 178--------\n",
      "Val loss = -0.7180\n",
      "\n",
      "--------Epoch 179--------\n",
      "Val loss = -0.7118\n",
      "\n",
      "--------Epoch 180--------\n",
      "Val loss = -0.7126\n",
      "\n",
      "--------Epoch 181--------\n",
      "Val loss = -0.7089\n",
      "\n",
      "--------Epoch 182--------\n",
      "Val loss = -0.7124\n",
      "\n",
      "--------Epoch 183--------\n",
      "Val loss = -0.7059\n",
      "\n",
      "--------Epoch 184--------\n",
      "Val loss = -0.7097\n",
      "\n",
      "--------Epoch 185--------\n",
      "Val loss = -0.7107\n",
      "\n",
      "--------Epoch 186--------\n",
      "Val loss = -0.7086\n",
      "\n",
      "--------Epoch 187--------\n",
      "Val loss = -0.7088\n",
      "\n",
      "--------Epoch 188--------\n",
      "Val loss = -0.7063\n",
      "\n",
      "--------Epoch 189--------\n",
      "Val loss = -0.7068\n",
      "\n",
      "--------Epoch 190--------\n",
      "Val loss = -0.7040\n",
      "\n",
      "--------Epoch 191--------\n",
      "Val loss = -0.7040\n",
      "\n",
      "--------Epoch 192--------\n",
      "Val loss = -0.7064\n",
      "\n",
      "--------Epoch 193--------\n",
      "Val loss = -0.7120\n",
      "\n",
      "--------Epoch 194--------\n",
      "Val loss = -0.7124\n",
      "\n",
      "--------Epoch 195--------\n",
      "Val loss = -0.7064\n",
      "\n",
      "--------Epoch 196--------\n",
      "Val loss = -0.7147\n",
      "\n",
      "--------Epoch 197--------\n",
      "Val loss = -0.7108\n",
      "\n",
      "--------Epoch 198--------\n",
      "Val loss = -0.7165\n",
      "\n",
      "--------Epoch 199--------\n",
      "Val loss = -0.7117\n",
      "\n",
      "--------Epoch 200--------\n",
      "Val loss = -0.7101\n",
      "\n",
      "--------Epoch 201--------\n",
      "Val loss = -0.7101\n",
      "\n",
      "--------Epoch 202--------\n",
      "Val loss = -0.7164\n",
      "\n",
      "--------Epoch 203--------\n",
      "Val loss = -0.7105\n",
      "\n",
      "--------Epoch 204--------\n",
      "Val loss = -0.7159\n",
      "\n",
      "--------Epoch 205--------\n",
      "Val loss = -0.7104\n",
      "\n",
      "--------Epoch 206--------\n",
      "Val loss = -0.7001\n",
      "\n",
      "--------Epoch 207--------\n",
      "Val loss = -0.7131\n",
      "\n",
      "--------Epoch 208--------\n",
      "Val loss = -0.7023\n",
      "\n",
      "--------Epoch 209--------\n",
      "Val loss = -0.7149\n",
      "\n",
      "--------Epoch 210--------\n",
      "Val loss = -0.7088\n",
      "\n",
      "--------Epoch 211--------\n",
      "Val loss = -0.7028\n",
      "\n",
      "--------Epoch 212--------\n",
      "Val loss = -0.7094\n",
      "\n",
      "--------Epoch 213--------\n",
      "Val loss = -0.7080\n",
      "\n",
      "--------Epoch 214--------\n",
      "Val loss = -0.7088\n",
      "\n",
      "--------Epoch 215--------\n",
      "Val loss = -0.7116\n",
      "\n",
      "--------Epoch 216--------\n",
      "Val loss = -0.6956\n",
      "\n",
      "--------Epoch 217--------\n",
      "Val loss = -0.7150\n",
      "\n",
      "--------Epoch 218--------\n",
      "Val loss = -0.7123\n",
      "\n",
      "--------Epoch 219--------\n",
      "Val loss = -0.7107\n",
      "\n",
      "--------Epoch 220--------\n",
      "Val loss = -0.7189\n",
      "\n",
      "--------Epoch 221--------\n",
      "Val loss = -0.7183\n",
      "\n",
      "--------Epoch 222--------\n",
      "Val loss = -0.7116\n",
      "\n",
      "--------Epoch 223--------\n",
      "Val loss = -0.7036\n",
      "\n",
      "--------Epoch 224--------\n",
      "Val loss = -0.7204\n",
      "\n",
      "--------Epoch 225--------\n",
      "Val loss = -0.7155\n",
      "\n",
      "--------Epoch 226--------\n",
      "Val loss = -0.7195\n",
      "\n",
      "--------Epoch 227--------\n",
      "Val loss = -0.7161\n",
      "\n",
      "--------Epoch 228--------\n",
      "Val loss = -0.7144\n",
      "\n",
      "--------Epoch 229--------\n",
      "Val loss = -0.7070\n",
      "\n",
      "--------Epoch 230--------\n",
      "Val loss = -0.7067\n",
      "\n",
      "--------Epoch 231--------\n",
      "Val loss = -0.7102\n",
      "\n",
      "--------Epoch 232--------\n",
      "Val loss = -0.7123\n",
      "\n",
      "--------Epoch 233--------\n",
      "Val loss = -0.7027\n",
      "\n",
      "--------Epoch 234--------\n",
      "Val loss = -0.6946\n",
      "\n",
      "--------Epoch 235--------\n",
      "Val loss = -0.6979\n",
      "\n",
      "--------Epoch 236--------\n",
      "Val loss = -0.7066\n",
      "\n",
      "--------Epoch 237--------\n",
      "Val loss = -0.7119\n",
      "\n",
      "--------Epoch 238--------\n",
      "Val loss = -0.7120\n",
      "\n",
      "--------Epoch 239--------\n",
      "Val loss = -0.6988\n",
      "\n",
      "--------Epoch 240--------\n",
      "Val loss = -0.7032\n",
      "\n",
      "--------Epoch 241--------\n",
      "Val loss = -0.7074\n",
      "\n",
      "--------Epoch 242--------\n",
      "Val loss = -0.7060\n",
      "\n",
      "--------Epoch 243--------\n",
      "Val loss = -0.7043\n",
      "\n",
      "--------Epoch 244--------\n",
      "Val loss = -0.7123\n",
      "\n",
      "--------Epoch 245--------\n",
      "Val loss = -0.7056\n",
      "\n",
      "--------Epoch 246--------\n",
      "Val loss = -0.7071\n",
      "\n",
      "--------Epoch 247--------\n",
      "Val loss = -0.7068\n",
      "\n",
      "--------Epoch 248--------\n",
      "Val loss = -0.7033\n",
      "\n",
      "--------Epoch 249--------\n",
      "Val loss = -0.7047\n",
      "\n",
      "--------Epoch 250--------\n",
      "Val loss = -0.7059\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up selector\n",
    "selector = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_in))\n",
    "gafs = GreedyAdaptiveFS(selector, deepcopy(model), mask_layer).to(device)\n",
    "\n",
    "# Train\n",
    "gafs.fit(train_dataset,\n",
    "         val_dataset,\n",
    "         mbsize=128,\n",
    "         lr=2e-4,\n",
    "         nepochs=250,\n",
    "         max_features=max_features,\n",
    "         loss_fn=nn.CrossEntropyLoss(),\n",
    "         val_loss_fn=NegAccuracy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "commercial-element",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test acc = 73.95\n"
     ]
    }
   ],
   "source": [
    "test_acc = gafs.evaluate(test_dataset, max_features, Accuracy(), 1024)\n",
    "print(f'Test acc = {100*test_acc:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "unsigned-calcium",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.4957\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.5348\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.5551\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.5868\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.6267\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.6370\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.6461\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.6545\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.6572\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.6624\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.6602\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.6674\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.6794\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.6823\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.6826\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.6923\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.6949\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.6950\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.6997\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.6994\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = -0.7009\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = -0.6979\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = -0.7066\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = -0.7167\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = -0.7253\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = -0.7222\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = -0.7223\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = -0.7288\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = -0.7323\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = -0.7324\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = -0.7297\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = -0.7357\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = -0.7382\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = -0.7351\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = -0.7418\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = -0.7403\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = -0.7390\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = -0.7449\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = -0.7489\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = -0.7492\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = -0.7527\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = -0.7524\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = -0.7514\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = -0.7634\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = -0.7637\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = -0.7657\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = -0.7644\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = -0.7677\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = -0.7714\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = -0.7700\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = -0.7723\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = -0.7705\n",
      "\n",
      "--------Epoch 53--------\n",
      "Val loss = -0.7751\n",
      "\n",
      "--------Epoch 54--------\n",
      "Val loss = -0.7705\n",
      "\n",
      "--------Epoch 55--------\n",
      "Val loss = -0.7720\n",
      "\n",
      "--------Epoch 56--------\n",
      "Val loss = -0.7759\n",
      "\n",
      "--------Epoch 57--------\n",
      "Val loss = -0.7758\n",
      "\n",
      "--------Epoch 58--------\n",
      "Val loss = -0.7757\n",
      "\n",
      "--------Epoch 59--------\n",
      "Val loss = -0.7749\n",
      "\n",
      "--------Epoch 60--------\n",
      "Val loss = -0.7756\n",
      "\n",
      "--------Epoch 61--------\n",
      "Val loss = -0.7774\n",
      "\n",
      "--------Epoch 62--------\n",
      "Val loss = -0.7732\n",
      "\n",
      "--------Epoch 63--------\n",
      "Val loss = -0.7706\n",
      "\n",
      "--------Epoch 64--------\n",
      "Val loss = -0.7788\n",
      "\n",
      "--------Epoch 65--------\n",
      "Val loss = -0.7791\n",
      "\n",
      "--------Epoch 66--------\n",
      "Val loss = -0.7768\n",
      "\n",
      "--------Epoch 67--------\n",
      "Val loss = -0.7788\n",
      "\n",
      "--------Epoch 68--------\n",
      "Val loss = -0.7749\n",
      "\n",
      "--------Epoch 69--------\n",
      "Val loss = -0.7824\n",
      "\n",
      "--------Epoch 70--------\n",
      "Val loss = -0.7813\n",
      "\n",
      "--------Epoch 71--------\n",
      "Val loss = -0.7834\n",
      "\n",
      "--------Epoch 72--------\n",
      "Val loss = -0.7791\n",
      "\n",
      "--------Epoch 73--------\n",
      "Val loss = -0.7787\n",
      "\n",
      "--------Epoch 74--------\n",
      "Val loss = -0.7843\n",
      "\n",
      "--------Epoch 75--------\n",
      "Val loss = -0.7849\n",
      "\n",
      "--------Epoch 76--------\n",
      "Val loss = -0.7845\n",
      "\n",
      "--------Epoch 77--------\n",
      "Val loss = -0.7823\n",
      "\n",
      "--------Epoch 78--------\n",
      "Val loss = -0.7840\n",
      "\n",
      "--------Epoch 79--------\n",
      "Val loss = -0.7890\n",
      "\n",
      "--------Epoch 80--------\n",
      "Val loss = -0.7850\n",
      "\n",
      "--------Epoch 81--------\n",
      "Val loss = -0.7842\n",
      "\n",
      "--------Epoch 82--------\n",
      "Val loss = -0.7902\n",
      "\n",
      "--------Epoch 83--------\n",
      "Val loss = -0.7891\n",
      "\n",
      "--------Epoch 84--------\n",
      "Val loss = -0.7843\n",
      "\n",
      "--------Epoch 85--------\n",
      "Val loss = -0.7891\n",
      "\n",
      "--------Epoch 86--------\n",
      "Val loss = -0.7835\n",
      "\n",
      "--------Epoch 87--------\n",
      "Val loss = -0.7887\n",
      "\n",
      "--------Epoch 88--------\n",
      "Val loss = -0.7895\n",
      "\n",
      "--------Epoch 89--------\n",
      "Val loss = -0.7898\n",
      "\n",
      "--------Epoch 90--------\n",
      "Val loss = -0.7919\n",
      "\n",
      "--------Epoch 91--------\n",
      "Val loss = -0.7894\n",
      "\n",
      "--------Epoch 92--------\n",
      "Val loss = -0.7903\n",
      "\n",
      "--------Epoch 93--------\n",
      "Val loss = -0.7930\n",
      "\n",
      "--------Epoch 94--------\n",
      "Val loss = -0.7872\n",
      "\n",
      "--------Epoch 95--------\n",
      "Val loss = -0.7910\n",
      "\n",
      "--------Epoch 96--------\n",
      "Val loss = -0.7901\n",
      "\n",
      "--------Epoch 97--------\n",
      "Val loss = -0.7939\n",
      "\n",
      "--------Epoch 98--------\n",
      "Val loss = -0.7901\n",
      "\n",
      "--------Epoch 99--------\n",
      "Val loss = -0.7907\n",
      "\n",
      "--------Epoch 100--------\n",
      "Val loss = -0.7923\n",
      "\n",
      "--------Epoch 101--------\n",
      "Val loss = -0.7930\n",
      "\n",
      "--------Epoch 102--------\n",
      "Val loss = -0.7955\n",
      "\n",
      "--------Epoch 103--------\n",
      "Val loss = -0.7883\n",
      "\n",
      "--------Epoch 104--------\n",
      "Val loss = -0.7930\n",
      "\n",
      "--------Epoch 105--------\n",
      "Val loss = -0.7969\n",
      "\n",
      "--------Epoch 106--------\n",
      "Val loss = -0.7907\n",
      "\n",
      "--------Epoch 107--------\n",
      "Val loss = -0.7950\n",
      "\n",
      "--------Epoch 108--------\n",
      "Val loss = -0.7949\n",
      "\n",
      "--------Epoch 109--------\n",
      "Val loss = -0.7935\n",
      "\n",
      "--------Epoch 110--------\n",
      "Val loss = -0.7948\n",
      "\n",
      "--------Epoch 111--------\n",
      "Val loss = -0.7935\n",
      "\n",
      "--------Epoch 112--------\n",
      "Val loss = -0.7973\n",
      "\n",
      "--------Epoch 113--------\n",
      "Val loss = -0.7965\n",
      "\n",
      "--------Epoch 114--------\n",
      "Val loss = -0.8006\n",
      "\n",
      "--------Epoch 115--------\n",
      "Val loss = -0.7996\n",
      "\n",
      "--------Epoch 116--------\n",
      "Val loss = -0.7963\n",
      "\n",
      "--------Epoch 117--------\n",
      "Val loss = -0.7996\n",
      "\n",
      "--------Epoch 118--------\n",
      "Val loss = -0.7982\n",
      "\n",
      "--------Epoch 119--------\n",
      "Val loss = -0.7989\n",
      "\n",
      "--------Epoch 120--------\n",
      "Val loss = -0.8035\n",
      "\n",
      "--------Epoch 121--------\n",
      "Val loss = -0.7986\n",
      "\n",
      "--------Epoch 122--------\n",
      "Val loss = -0.8011\n",
      "\n",
      "--------Epoch 123--------\n",
      "Val loss = -0.8015\n",
      "\n",
      "--------Epoch 124--------\n",
      "Val loss = -0.7973\n",
      "\n",
      "--------Epoch 125--------\n",
      "Val loss = -0.8008\n",
      "\n",
      "--------Epoch 126--------\n",
      "Val loss = -0.8003\n",
      "\n",
      "--------Epoch 127--------\n",
      "Val loss = -0.8004\n",
      "\n",
      "--------Epoch 128--------\n",
      "Val loss = -0.7994\n",
      "\n",
      "--------Epoch 129--------\n",
      "Val loss = -0.7993\n",
      "\n",
      "--------Epoch 130--------\n",
      "Val loss = -0.8041\n",
      "\n",
      "--------Epoch 131--------\n",
      "Val loss = -0.8037\n",
      "\n",
      "--------Epoch 132--------\n",
      "Val loss = -0.8003\n",
      "\n",
      "--------Epoch 133--------\n",
      "Val loss = -0.8021\n",
      "\n",
      "--------Epoch 134--------\n",
      "Val loss = -0.8023\n",
      "\n",
      "--------Epoch 135--------\n",
      "Val loss = -0.8017\n",
      "\n",
      "--------Epoch 136--------\n",
      "Val loss = -0.8054\n",
      "\n",
      "--------Epoch 137--------\n",
      "Val loss = -0.8068\n",
      "\n",
      "--------Epoch 138--------\n",
      "Val loss = -0.8045\n",
      "\n",
      "--------Epoch 139--------\n",
      "Val loss = -0.8032\n",
      "\n",
      "--------Epoch 140--------\n",
      "Val loss = -0.8041\n",
      "\n",
      "--------Epoch 141--------\n",
      "Val loss = -0.8016\n",
      "\n",
      "--------Epoch 142--------\n",
      "Val loss = -0.8039\n",
      "\n",
      "--------Epoch 143--------\n",
      "Val loss = -0.8032\n",
      "\n",
      "--------Epoch 144--------\n",
      "Val loss = -0.8021\n",
      "\n",
      "--------Epoch 145--------\n",
      "Val loss = -0.8073\n",
      "\n",
      "--------Epoch 146--------\n",
      "Val loss = -0.8039\n",
      "\n",
      "--------Epoch 147--------\n",
      "Val loss = -0.8030\n",
      "\n",
      "--------Epoch 148--------\n",
      "Val loss = -0.8051\n",
      "\n",
      "--------Epoch 149--------\n",
      "Val loss = -0.8062\n",
      "\n",
      "--------Epoch 150--------\n",
      "Val loss = -0.8062\n",
      "\n",
      "--------Epoch 151--------\n",
      "Val loss = -0.8048\n",
      "\n",
      "--------Epoch 152--------\n",
      "Val loss = -0.8068\n",
      "\n",
      "--------Epoch 153--------\n",
      "Val loss = -0.8080\n",
      "\n",
      "--------Epoch 154--------\n",
      "Val loss = -0.8063\n",
      "\n",
      "--------Epoch 155--------\n",
      "Val loss = -0.8062\n",
      "\n",
      "--------Epoch 156--------\n",
      "Val loss = -0.8060\n",
      "\n",
      "--------Epoch 157--------\n",
      "Val loss = -0.8043\n",
      "\n",
      "--------Epoch 158--------\n",
      "Val loss = -0.8027\n",
      "\n",
      "--------Epoch 159--------\n",
      "Val loss = -0.8041\n",
      "\n",
      "--------Epoch 160--------\n",
      "Val loss = -0.8069\n",
      "\n",
      "--------Epoch 161--------\n",
      "Val loss = -0.8044\n",
      "\n",
      "--------Epoch 162--------\n",
      "Val loss = -0.8046\n",
      "\n",
      "--------Epoch 163--------\n",
      "Val loss = -0.8075\n",
      "\n",
      "--------Epoch 164--------\n",
      "Val loss = -0.8070\n",
      "\n",
      "--------Epoch 165--------\n",
      "Val loss = -0.8050\n",
      "\n",
      "--------Epoch 166--------\n",
      "Val loss = -0.8045\n",
      "\n",
      "--------Epoch 167--------\n",
      "Val loss = -0.8081\n",
      "\n",
      "--------Epoch 168--------\n",
      "Val loss = -0.8055\n",
      "\n",
      "--------Epoch 169--------\n",
      "Val loss = -0.8067\n",
      "\n",
      "--------Epoch 170--------\n",
      "Val loss = -0.8057\n",
      "\n",
      "--------Epoch 171--------\n",
      "Val loss = -0.8063\n",
      "\n",
      "--------Epoch 172--------\n",
      "Val loss = -0.8031\n",
      "\n",
      "--------Epoch 173--------\n",
      "Val loss = -0.8034\n",
      "\n",
      "--------Epoch 174--------\n",
      "Val loss = -0.8051\n",
      "\n",
      "--------Epoch 175--------\n",
      "Val loss = -0.8085\n",
      "\n",
      "--------Epoch 176--------\n",
      "Val loss = -0.8075\n",
      "\n",
      "--------Epoch 177--------\n",
      "Val loss = -0.8056\n",
      "\n",
      "--------Epoch 178--------\n",
      "Val loss = -0.8052\n",
      "\n",
      "--------Epoch 179--------\n",
      "Val loss = -0.8078\n",
      "\n",
      "--------Epoch 180--------\n",
      "Val loss = -0.8077\n",
      "\n",
      "--------Epoch 181--------\n",
      "Val loss = -0.8067\n",
      "\n",
      "--------Epoch 182--------\n",
      "Val loss = -0.8035\n",
      "\n",
      "--------Epoch 183--------\n",
      "Val loss = -0.8084\n",
      "\n",
      "--------Epoch 184--------\n",
      "Val loss = -0.8112\n",
      "\n",
      "--------Epoch 185--------\n",
      "Val loss = -0.8083\n",
      "\n",
      "--------Epoch 186--------\n",
      "Val loss = -0.8110\n",
      "\n",
      "--------Epoch 187--------\n",
      "Val loss = -0.8087\n",
      "\n",
      "--------Epoch 188--------\n",
      "Val loss = -0.8090\n",
      "\n",
      "--------Epoch 189--------\n",
      "Val loss = -0.8099\n",
      "\n",
      "--------Epoch 190--------\n",
      "Val loss = -0.8092\n",
      "\n",
      "--------Epoch 191--------\n",
      "Val loss = -0.8071\n",
      "\n",
      "--------Epoch 192--------\n",
      "Val loss = -0.8071\n",
      "\n",
      "--------Epoch 193--------\n",
      "Val loss = -0.8068\n",
      "\n",
      "--------Epoch 194--------\n",
      "Val loss = -0.8062\n",
      "\n",
      "--------Epoch 195--------\n",
      "Val loss = -0.8103\n",
      "\n",
      "--------Epoch 196--------\n",
      "Val loss = -0.8065\n",
      "\n",
      "--------Epoch 197--------\n",
      "Val loss = -0.8069\n",
      "\n",
      "--------Epoch 198--------\n",
      "Val loss = -0.8072\n",
      "\n",
      "--------Epoch 199--------\n",
      "Val loss = -0.8067\n",
      "\n",
      "--------Epoch 200--------\n",
      "Val loss = -0.8081\n",
      "\n",
      "--------Epoch 201--------\n",
      "Val loss = -0.8122\n",
      "\n",
      "--------Epoch 202--------\n",
      "Val loss = -0.8061\n",
      "\n",
      "--------Epoch 203--------\n",
      "Val loss = -0.8091\n",
      "\n",
      "--------Epoch 204--------\n",
      "Val loss = -0.8083\n",
      "\n",
      "--------Epoch 205--------\n",
      "Val loss = -0.8106\n",
      "\n",
      "--------Epoch 206--------\n",
      "Val loss = -0.8051\n",
      "\n",
      "--------Epoch 207--------\n",
      "Val loss = -0.8074\n",
      "\n",
      "--------Epoch 208--------\n",
      "Val loss = -0.8095\n",
      "\n",
      "--------Epoch 209--------\n",
      "Val loss = -0.8081\n",
      "\n",
      "--------Epoch 210--------\n",
      "Val loss = -0.8097\n",
      "\n",
      "--------Epoch 211--------\n",
      "Val loss = -0.8082\n",
      "\n",
      "--------Epoch 212--------\n",
      "Val loss = -0.8085\n",
      "\n",
      "--------Epoch 213--------\n",
      "Val loss = -0.8085\n",
      "\n",
      "--------Epoch 214--------\n",
      "Val loss = -0.8082\n",
      "\n",
      "--------Epoch 215--------\n",
      "Val loss = -0.8074\n",
      "\n",
      "--------Epoch 216--------\n",
      "Val loss = -0.8069\n",
      "\n",
      "--------Epoch 217--------\n",
      "Val loss = -0.8049\n",
      "\n",
      "--------Epoch 218--------\n",
      "Val loss = -0.8069\n",
      "\n",
      "--------Epoch 219--------\n",
      "Val loss = -0.8079\n",
      "\n",
      "--------Epoch 220--------\n",
      "Val loss = -0.8090\n",
      "\n",
      "--------Epoch 221--------\n",
      "Val loss = -0.8072\n",
      "\n",
      "--------Epoch 222--------\n",
      "Val loss = -0.8084\n",
      "\n",
      "--------Epoch 223--------\n",
      "Val loss = -0.8060\n",
      "\n",
      "--------Epoch 224--------\n",
      "Val loss = -0.8075\n",
      "\n",
      "--------Epoch 225--------\n",
      "Val loss = -0.8023\n",
      "\n",
      "--------Epoch 226--------\n",
      "Val loss = -0.8087\n",
      "\n",
      "--------Epoch 227--------\n",
      "Val loss = -0.8094\n",
      "\n",
      "--------Epoch 228--------\n",
      "Val loss = -0.8061\n",
      "\n",
      "--------Epoch 229--------\n",
      "Val loss = -0.8033\n",
      "\n",
      "--------Epoch 230--------\n",
      "Val loss = -0.8092\n",
      "\n",
      "--------Epoch 231--------\n",
      "Val loss = -0.8081\n",
      "\n",
      "--------Epoch 232--------\n",
      "Val loss = -0.8097\n",
      "\n",
      "--------Epoch 233--------\n",
      "Val loss = -0.8060\n",
      "\n",
      "--------Epoch 234--------\n",
      "Val loss = -0.8102\n",
      "\n",
      "--------Epoch 235--------\n",
      "Val loss = -0.8077\n",
      "\n",
      "--------Epoch 236--------\n",
      "Val loss = -0.8068\n",
      "\n",
      "--------Epoch 237--------\n",
      "Val loss = -0.8094\n",
      "\n",
      "--------Epoch 238--------\n",
      "Val loss = -0.8072\n",
      "\n",
      "--------Epoch 239--------\n",
      "Val loss = -0.8084\n",
      "\n",
      "--------Epoch 240--------\n",
      "Val loss = -0.8129\n",
      "\n",
      "--------Epoch 241--------\n",
      "Val loss = -0.8069\n",
      "\n",
      "--------Epoch 242--------\n",
      "Val loss = -0.8086\n",
      "\n",
      "--------Epoch 243--------\n",
      "Val loss = -0.8068\n",
      "\n",
      "--------Epoch 244--------\n",
      "Val loss = -0.8066\n",
      "\n",
      "--------Epoch 245--------\n",
      "Val loss = -0.8101\n",
      "\n",
      "--------Epoch 246--------\n",
      "Val loss = -0.8098\n",
      "\n",
      "--------Epoch 247--------\n",
      "Val loss = -0.8089\n",
      "\n",
      "--------Epoch 248--------\n",
      "Val loss = -0.8099\n",
      "\n",
      "--------Epoch 249--------\n",
      "Val loss = -0.8060\n",
      "\n",
      "--------Epoch 250--------\n",
      "Val loss = -0.8091\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up selector\n",
    "selector = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_in))\n",
    "gafs = GreedyAdaptiveFS(selector, deepcopy(model), mask_layer).to(device)\n",
    "\n",
    "# Tie weights\n",
    "selector[0].weight = nn.Parameter(gafs.model[0].weight)\n",
    "selector[2].weight = nn.Parameter(gafs.model[2].weight)\n",
    "\n",
    "# Train\n",
    "gafs.fit(train_dataset,\n",
    "         val_dataset,\n",
    "         mbsize=128,\n",
    "         lr=2e-4,\n",
    "         nepochs=250,\n",
    "         max_features=max_features,\n",
    "         loss_fn=nn.CrossEntropyLoss(),\n",
    "         val_loss_fn=NegAccuracy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "solved-richmond",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test acc = 82.24\n"
     ]
    }
   ],
   "source": [
    "test_acc = gafs.evaluate(test_dataset, max_features, Accuracy(), 1024)\n",
    "print(f'Test acc = {100*test_acc:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "coated-motor",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
