{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "frequent-bookmark",
   "metadata": {},
   "source": [
    "# Implementing $\\epsilon$-greedy\n",
    "\n",
    "- Inspired by a shaky understanding of contextual bandits, this version performs selection using a different mechanism altogether. The previous versions learn a selector model that outputs a probability distribution over the candidate features. In contrast, this one attempts to predict the loss associated with selecting each feature, and updates those predictions during training by observing the loss for the selected features: a feature chosen uniformly at random is selected with probability $\\epsilon$, and the feature with the best predicted loss is chosen with probability $1 - \\epsilon$.\n",
    "- This version kind of works, actually."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "confident-waste",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n",
    "from torchvision import transforms\n",
    "from torchvision.datasets import MNIST\n",
    "from copy import deepcopy\n",
    "from utils import *\n",
    "from models import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "clinical-wholesale",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda', 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "loose-royal",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load data\n",
    "class Flatten(object):\n",
    "    def __call__(self, pic):\n",
    "        return torch.flatten(pic)\n",
    "    \n",
    "mnist_dataset = MNIST('/tmp/mnist/', download=True, train=True,\n",
    "                      transform=transforms.Compose([transforms.ToTensor(), Flatten()]))\n",
    "images = mnist_dataset.data\n",
    "targets = mnist_dataset.targets\n",
    "np.random.seed(0)\n",
    "val_inds = np.sort(np.random.choice(len(images), size=10000, replace=False))\n",
    "train_inds = np.setdiff1d(np.arange(len(images)), val_inds)\n",
    "\n",
    "# Training dataset\n",
    "train_dataset = torch.utils.data.Subset(mnist_dataset, train_inds)\n",
    "\n",
    "# Validation dataset\n",
    "val_dataset = torch.utils.data.Subset(mnist_dataset, val_inds)\n",
    "\n",
    "# Test dataset\n",
    "test_dataset = MNIST('/tmp/mnist/', download=True, train=False,\n",
    "                     transform=transforms.Compose([transforms.ToTensor(), Flatten()]))\n",
    "\n",
    "# Set input/output dimensions\n",
    "d_in = 784\n",
    "d_out = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "intelligent-federal",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Number of features to select\n",
    "max_features = 10"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "adult-distinction",
   "metadata": {},
   "source": [
    "# Global FS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "close-intent",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.2784\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.3027\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.3163\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.3135\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.3233\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.3311\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.3360\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.3385\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.3384\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.3483\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.3527\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.3572\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.3664\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.3725\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.3680\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.3907\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.3887\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.3892\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.3890\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.3892\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = -0.3958\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = -0.4029\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = -0.4028\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = -0.4077\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = -0.4156\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = -0.4155\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = -0.4122\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = -0.4210\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = -0.4167\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = -0.4335\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = -0.4385\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = -0.4425\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = -0.4521\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = -0.4586\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = -0.4537\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = -0.4521\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = -0.4721\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = -0.4660\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = -0.4641\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = -0.4746\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = -0.4764\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = -0.4829\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = -0.4895\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = -0.4829\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = -0.4968\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = -0.4955\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = -0.4955\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = -0.5045\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = -0.4973\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = -0.5012\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = -0.5029\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = -0.5164\n",
      "\n",
      "--------Epoch 53--------\n",
      "Val loss = -0.5193\n",
      "\n",
      "--------Epoch 54--------\n",
      "Val loss = -0.5162\n",
      "\n",
      "--------Epoch 55--------\n",
      "Val loss = -0.5263\n",
      "\n",
      "--------Epoch 56--------\n",
      "Val loss = -0.5233\n",
      "\n",
      "--------Epoch 57--------\n",
      "Val loss = -0.5297\n",
      "\n",
      "--------Epoch 58--------\n",
      "Val loss = -0.5343\n",
      "\n",
      "--------Epoch 59--------\n",
      "Val loss = -0.5290\n",
      "\n",
      "--------Epoch 60--------\n",
      "Val loss = -0.5391\n",
      "\n",
      "--------Epoch 61--------\n",
      "Val loss = -0.5441\n",
      "\n",
      "--------Epoch 62--------\n",
      "Val loss = -0.5421\n",
      "\n",
      "--------Epoch 63--------\n",
      "Val loss = -0.5548\n",
      "\n",
      "--------Epoch 64--------\n",
      "Val loss = -0.5471\n",
      "\n",
      "--------Epoch 65--------\n",
      "Val loss = -0.5542\n",
      "\n",
      "--------Epoch 66--------\n",
      "Val loss = -0.5498\n",
      "\n",
      "--------Epoch 67--------\n",
      "Val loss = -0.5571\n",
      "\n",
      "--------Epoch 68--------\n",
      "Val loss = -0.5630\n",
      "\n",
      "--------Epoch 69--------\n",
      "Val loss = -0.5636\n",
      "\n",
      "--------Epoch 70--------\n",
      "Val loss = -0.5668\n",
      "\n",
      "--------Epoch 71--------\n",
      "Val loss = -0.5831\n",
      "\n",
      "--------Epoch 72--------\n",
      "Val loss = -0.5773\n",
      "\n",
      "--------Epoch 73--------\n",
      "Val loss = -0.5805\n",
      "\n",
      "--------Epoch 74--------\n",
      "Val loss = -0.5752\n",
      "\n",
      "--------Epoch 75--------\n",
      "Val loss = -0.5857\n",
      "\n",
      "--------Epoch 76--------\n",
      "Val loss = -0.5873\n",
      "\n",
      "--------Epoch 77--------\n",
      "Val loss = -0.5860\n",
      "\n",
      "--------Epoch 78--------\n",
      "Val loss = -0.5881\n",
      "\n",
      "--------Epoch 79--------\n",
      "Val loss = -0.5890\n",
      "\n",
      "--------Epoch 80--------\n",
      "Val loss = -0.5976\n",
      "\n",
      "--------Epoch 81--------\n",
      "Val loss = -0.6014\n",
      "\n",
      "--------Epoch 82--------\n",
      "Val loss = -0.5970\n",
      "\n",
      "--------Epoch 83--------\n",
      "Val loss = -0.6078\n",
      "\n",
      "--------Epoch 84--------\n",
      "Val loss = -0.5973\n",
      "\n",
      "--------Epoch 85--------\n",
      "Val loss = -0.6001\n",
      "\n",
      "--------Epoch 86--------\n",
      "Val loss = -0.6065\n",
      "\n",
      "--------Epoch 87--------\n",
      "Val loss = -0.6104\n",
      "\n",
      "--------Epoch 88--------\n",
      "Val loss = -0.6076\n",
      "\n",
      "--------Epoch 89--------\n",
      "Val loss = -0.6174\n",
      "\n",
      "--------Epoch 90--------\n",
      "Val loss = -0.6211\n",
      "\n",
      "--------Epoch 91--------\n",
      "Val loss = -0.6237\n",
      "\n",
      "--------Epoch 92--------\n",
      "Val loss = -0.6179\n",
      "\n",
      "--------Epoch 93--------\n",
      "Val loss = -0.6230\n",
      "\n",
      "--------Epoch 94--------\n",
      "Val loss = -0.6272\n",
      "\n",
      "--------Epoch 95--------\n",
      "Val loss = -0.6279\n",
      "\n",
      "--------Epoch 96--------\n",
      "Val loss = -0.6319\n",
      "\n",
      "--------Epoch 97--------\n",
      "Val loss = -0.6308\n",
      "\n",
      "--------Epoch 98--------\n",
      "Val loss = -0.6394\n",
      "\n",
      "--------Epoch 99--------\n",
      "Val loss = -0.6314\n",
      "\n",
      "--------Epoch 100--------\n",
      "Val loss = -0.6386\n",
      "\n",
      "--------Epoch 101--------\n",
      "Val loss = -0.6344\n",
      "\n",
      "--------Epoch 102--------\n",
      "Val loss = -0.6460\n",
      "\n",
      "--------Epoch 103--------\n",
      "Val loss = -0.6413\n",
      "\n",
      "--------Epoch 104--------\n",
      "Val loss = -0.6449\n",
      "\n",
      "--------Epoch 105--------\n",
      "Val loss = -0.6454\n",
      "\n",
      "--------Epoch 106--------\n",
      "Val loss = -0.6518\n",
      "\n",
      "--------Epoch 107--------\n",
      "Val loss = -0.6510\n",
      "\n",
      "--------Epoch 108--------\n",
      "Val loss = -0.6643\n",
      "\n",
      "--------Epoch 109--------\n",
      "Val loss = -0.6637\n",
      "\n",
      "--------Epoch 110--------\n",
      "Val loss = -0.6637\n",
      "\n",
      "--------Epoch 111--------\n",
      "Val loss = -0.6703\n",
      "\n",
      "--------Epoch 112--------\n",
      "Val loss = -0.6676\n",
      "\n",
      "--------Epoch 113--------\n",
      "Val loss = -0.6645\n",
      "\n",
      "--------Epoch 114--------\n",
      "Val loss = -0.6679\n",
      "\n",
      "--------Epoch 115--------\n",
      "Val loss = -0.6744\n",
      "\n",
      "--------Epoch 116--------\n",
      "Val loss = -0.6808\n",
      "\n",
      "--------Epoch 117--------\n",
      "Val loss = -0.6833\n",
      "\n",
      "--------Epoch 118--------\n",
      "Val loss = -0.6864\n",
      "\n",
      "--------Epoch 119--------\n",
      "Val loss = -0.6874\n",
      "\n",
      "--------Epoch 120--------\n",
      "Val loss = -0.6904\n",
      "\n",
      "--------Epoch 121--------\n",
      "Val loss = -0.7023\n",
      "\n",
      "--------Epoch 122--------\n",
      "Val loss = -0.6994\n",
      "\n",
      "--------Epoch 123--------\n",
      "Val loss = -0.7032\n",
      "\n",
      "--------Epoch 124--------\n",
      "Val loss = -0.7023\n",
      "\n",
      "--------Epoch 125--------\n",
      "Val loss = -0.7022\n",
      "\n",
      "--------Epoch 126--------\n",
      "Val loss = -0.7119\n",
      "\n",
      "--------Epoch 127--------\n",
      "Val loss = -0.7087\n",
      "\n",
      "--------Epoch 128--------\n",
      "Val loss = -0.7163\n",
      "\n",
      "--------Epoch 129--------\n",
      "Val loss = -0.7172\n",
      "\n",
      "--------Epoch 130--------\n",
      "Val loss = -0.7254\n",
      "\n",
      "--------Epoch 131--------\n",
      "Val loss = -0.7179\n",
      "\n",
      "--------Epoch 132--------\n",
      "Val loss = -0.7208\n",
      "\n",
      "--------Epoch 133--------\n",
      "Val loss = -0.7241\n",
      "\n",
      "--------Epoch 134--------\n",
      "Val loss = -0.7209\n",
      "\n",
      "--------Epoch 135--------\n",
      "Val loss = -0.7370\n",
      "\n",
      "--------Epoch 136--------\n",
      "Val loss = -0.7313\n",
      "\n",
      "--------Epoch 137--------\n",
      "Val loss = -0.7375\n",
      "\n",
      "--------Epoch 138--------\n",
      "Val loss = -0.7392\n",
      "\n",
      "--------Epoch 139--------\n",
      "Val loss = -0.7361\n",
      "\n",
      "--------Epoch 140--------\n",
      "Val loss = -0.7336\n",
      "\n",
      "--------Epoch 141--------\n",
      "Val loss = -0.7424\n",
      "\n",
      "--------Epoch 142--------\n",
      "Val loss = -0.7446\n",
      "\n",
      "--------Epoch 143--------\n",
      "Val loss = -0.7476\n",
      "\n",
      "--------Epoch 144--------\n",
      "Val loss = -0.7443\n",
      "\n",
      "--------Epoch 145--------\n",
      "Val loss = -0.7492\n",
      "\n",
      "--------Epoch 146--------\n",
      "Val loss = -0.7484\n",
      "\n",
      "--------Epoch 147--------\n",
      "Val loss = -0.7477\n",
      "\n",
      "--------Epoch 148--------\n",
      "Val loss = -0.7531\n",
      "\n",
      "--------Epoch 149--------\n",
      "Val loss = -0.7550\n",
      "\n",
      "--------Epoch 150--------\n",
      "Val loss = -0.7459\n",
      "\n",
      "--------Epoch 151--------\n",
      "Val loss = -0.7557\n",
      "\n",
      "--------Epoch 152--------\n",
      "Val loss = -0.7596\n",
      "\n",
      "--------Epoch 153--------\n",
      "Val loss = -0.7547\n",
      "\n",
      "--------Epoch 154--------\n",
      "Val loss = -0.7558\n",
      "\n",
      "--------Epoch 155--------\n",
      "Val loss = -0.7533\n",
      "\n",
      "--------Epoch 156--------\n",
      "Val loss = -0.7594\n",
      "\n",
      "--------Epoch 157--------\n",
      "Val loss = -0.7555\n",
      "\n",
      "--------Epoch 158--------\n",
      "Val loss = -0.7601\n",
      "\n",
      "--------Epoch 159--------\n",
      "Val loss = -0.7571\n",
      "\n",
      "--------Epoch 160--------\n",
      "Val loss = -0.7569\n",
      "\n",
      "--------Epoch 161--------\n",
      "Val loss = -0.7645\n",
      "\n",
      "--------Epoch 162--------\n",
      "Val loss = -0.7643\n",
      "\n",
      "--------Epoch 163--------\n",
      "Val loss = -0.7651\n",
      "\n",
      "--------Epoch 164--------\n",
      "Val loss = -0.7617\n",
      "\n",
      "--------Epoch 165--------\n",
      "Val loss = -0.7621\n",
      "\n",
      "--------Epoch 166--------\n",
      "Val loss = -0.7577\n",
      "\n",
      "--------Epoch 167--------\n",
      "Val loss = -0.7614\n",
      "\n",
      "--------Epoch 168--------\n",
      "Val loss = -0.7680\n",
      "\n",
      "--------Epoch 169--------\n",
      "Val loss = -0.7664\n",
      "\n",
      "--------Epoch 170--------\n",
      "Val loss = -0.7599\n",
      "\n",
      "--------Epoch 171--------\n",
      "Val loss = -0.7698\n",
      "\n",
      "--------Epoch 172--------\n",
      "Val loss = -0.7665\n",
      "\n",
      "--------Epoch 173--------\n",
      "Val loss = -0.7662\n",
      "\n",
      "--------Epoch 174--------\n",
      "Val loss = -0.7671\n",
      "\n",
      "--------Epoch 175--------\n",
      "Val loss = -0.7698\n",
      "\n",
      "--------Epoch 176--------\n",
      "Val loss = -0.7683\n",
      "\n",
      "--------Epoch 177--------\n",
      "Val loss = -0.7670\n",
      "\n",
      "--------Epoch 178--------\n",
      "Val loss = -0.7708\n",
      "\n",
      "--------Epoch 179--------\n",
      "Val loss = -0.7715\n",
      "\n",
      "--------Epoch 180--------\n",
      "Val loss = -0.7667\n",
      "\n",
      "--------Epoch 181--------\n",
      "Val loss = -0.7700\n",
      "\n",
      "--------Epoch 182--------\n",
      "Val loss = -0.7649\n",
      "\n",
      "--------Epoch 183--------\n",
      "Val loss = -0.7678\n",
      "\n",
      "--------Epoch 184--------\n",
      "Val loss = -0.7726\n",
      "\n",
      "--------Epoch 185--------\n",
      "Val loss = -0.7715\n",
      "\n",
      "--------Epoch 186--------\n",
      "Val loss = -0.7687\n",
      "\n",
      "--------Epoch 187--------\n",
      "Val loss = -0.7702\n",
      "\n",
      "--------Epoch 188--------\n",
      "Val loss = -0.7714\n",
      "\n",
      "--------Epoch 189--------\n",
      "Val loss = -0.7716\n",
      "\n",
      "--------Epoch 190--------\n",
      "Val loss = -0.7752\n",
      "\n",
      "--------Epoch 191--------\n",
      "Val loss = -0.7674\n",
      "\n",
      "--------Epoch 192--------\n",
      "Val loss = -0.7666\n",
      "\n",
      "--------Epoch 193--------\n",
      "Val loss = -0.7711\n",
      "\n",
      "--------Epoch 194--------\n",
      "Val loss = -0.7697\n",
      "\n",
      "--------Epoch 195--------\n",
      "Val loss = -0.7713\n",
      "\n",
      "--------Epoch 196--------\n",
      "Val loss = -0.7722\n",
      "\n",
      "--------Epoch 197--------\n",
      "Val loss = -0.7679\n",
      "\n",
      "--------Epoch 198--------\n",
      "Val loss = -0.7727\n",
      "\n",
      "--------Epoch 199--------\n",
      "Val loss = -0.7715\n",
      "\n",
      "--------Epoch 200--------\n",
      "Val loss = -0.7702\n",
      "\n",
      "--------Epoch 201--------\n",
      "Val loss = -0.7697\n",
      "\n",
      "--------Epoch 202--------\n",
      "Val loss = -0.7641\n",
      "\n",
      "--------Epoch 203--------\n",
      "Val loss = -0.7673\n",
      "\n",
      "--------Epoch 204--------\n",
      "Val loss = -0.7722\n",
      "\n",
      "--------Epoch 205--------\n",
      "Val loss = -0.7680\n",
      "\n",
      "--------Epoch 206--------\n",
      "Val loss = -0.7697\n",
      "\n",
      "--------Epoch 207--------\n",
      "Val loss = -0.7655\n",
      "\n",
      "--------Epoch 208--------\n",
      "Val loss = -0.7690\n",
      "\n",
      "--------Epoch 209--------\n",
      "Val loss = -0.7661\n",
      "\n",
      "--------Epoch 210--------\n",
      "Val loss = -0.7697\n",
      "\n",
      "--------Epoch 211--------\n",
      "Val loss = -0.7680\n",
      "\n",
      "--------Epoch 212--------\n",
      "Val loss = -0.7696\n",
      "\n",
      "--------Epoch 213--------\n",
      "Val loss = -0.7702\n",
      "\n",
      "--------Epoch 214--------\n",
      "Val loss = -0.7702\n",
      "\n",
      "--------Epoch 215--------\n",
      "Val loss = -0.7649\n",
      "\n",
      "--------Epoch 216--------\n",
      "Val loss = -0.7694\n",
      "\n",
      "--------Epoch 217--------\n",
      "Val loss = -0.7644\n",
      "\n",
      "--------Epoch 218--------\n",
      "Val loss = -0.7637\n",
      "\n",
      "--------Epoch 219--------\n",
      "Val loss = -0.7652\n",
      "\n",
      "--------Epoch 220--------\n",
      "Val loss = -0.7619\n",
      "\n",
      "--------Epoch 221--------\n",
      "Val loss = -0.7651\n",
      "\n",
      "--------Epoch 222--------\n",
      "Val loss = -0.7662\n",
      "\n",
      "--------Epoch 223--------\n",
      "Val loss = -0.7618\n",
      "\n",
      "--------Epoch 224--------\n",
      "Val loss = -0.7597\n",
      "\n",
      "--------Epoch 225--------\n",
      "Val loss = -0.7637\n",
      "\n",
      "--------Epoch 226--------\n",
      "Val loss = -0.7593\n",
      "\n",
      "--------Epoch 227--------\n",
      "Val loss = -0.7636\n",
      "\n",
      "--------Epoch 228--------\n",
      "Val loss = -0.7595\n",
      "\n",
      "--------Epoch 229--------\n",
      "Val loss = -0.7605\n",
      "\n",
      "--------Epoch 230--------\n",
      "Val loss = -0.7624\n",
      "\n",
      "--------Epoch 231--------\n",
      "Val loss = -0.7619\n",
      "\n",
      "--------Epoch 232--------\n",
      "Val loss = -0.7577\n",
      "\n",
      "--------Epoch 233--------\n",
      "Val loss = -0.7576\n",
      "\n",
      "--------Epoch 234--------\n",
      "Val loss = -0.7611\n",
      "\n",
      "--------Epoch 235--------\n",
      "Val loss = -0.7623\n",
      "\n",
      "--------Epoch 236--------\n",
      "Val loss = -0.7595\n",
      "\n",
      "--------Epoch 237--------\n",
      "Val loss = -0.7576\n",
      "\n",
      "--------Epoch 238--------\n",
      "Val loss = -0.7573\n",
      "\n",
      "--------Epoch 239--------\n",
      "Val loss = -0.7533\n",
      "\n",
      "--------Epoch 240--------\n",
      "Val loss = -0.7593\n",
      "\n",
      "--------Epoch 241--------\n",
      "Val loss = -0.7569\n",
      "\n",
      "--------Epoch 242--------\n",
      "Val loss = -0.7515\n",
      "\n",
      "--------Epoch 243--------\n",
      "Val loss = -0.7569\n",
      "\n",
      "--------Epoch 244--------\n",
      "Val loss = -0.7563\n",
      "\n",
      "--------Epoch 245--------\n",
      "Val loss = -0.7564\n",
      "\n",
      "--------Epoch 246--------\n",
      "Val loss = -0.7576\n",
      "\n",
      "--------Epoch 247--------\n",
      "Val loss = -0.7576\n",
      "\n",
      "--------Epoch 248--------\n",
      "Val loss = -0.7483\n",
      "\n",
      "--------Epoch 249--------\n",
      "Val loss = -0.7524\n",
      "\n",
      "--------Epoch 250--------\n",
      "Val loss = -0.7518\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up model\n",
    "global_model = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_out))\n",
    "selector = ConcreteMask(d_in, max_features, append=True)\n",
    "globalfs = GlobalSelector(global_model, selector).to(device)\n",
    "\n",
    "# Train\n",
    "globalfs.fit(train_dataset,\n",
    "             val_dataset,\n",
    "             mbsize=128,\n",
    "             lr=1e-3,\n",
    "             nepochs=250,\n",
    "             loss_fn=nn.CrossEntropyLoss(),\n",
    "             val_loss_fn=NegAccuracy(),\n",
    "             start_temp=1.0,\n",
    "             end_temp=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "entertaining-charleston",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test acc = 77.86\n"
     ]
    }
   ],
   "source": [
    "test_acc = globalfs.evaluate(test_dataset, Accuracy(), 1024)\n",
    "print(f'Test acc = {100*test_acc:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "beneficial-shadow",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAK3klEQVR4nO3dT6hc93mH8edbV1ZASUCqa6M6bpMGL2oKVcpFKbgUF9PU8cbOIiVeBBUCyiKGBLKoSRfx0pQmoYsSUGoRtaQOgcS1FqaJEAGTjfG1UW25amvXqIkiITVoYadQWXbeLu5xuZHvP8+c+SO/zwcuM3Nm7p2XQY/O3Dkz95eqQtK7368segBJ82HsUhPGLjVh7FITxi418avzvLMbs7vew5553qXUyv/yP7xeV7LRdVPFnuQe4G+AG4C/q6pHtrr9e9jDR3P3NHcpaQtP18lNr5v4aXySG4C/BT4O3AE8kOSOSX+epNma5nf2g8DLVfVKVb0OfBu4b5yxJI1tmthvBX6y7vK5YdsvSXI4yWqS1atcmeLuJE1jmtg3ehHgbe+9raojVbVSVSu72D3F3UmaxjSxnwNuW3f5A8D56caRNCvTxP4McHuSDyW5EfgUcHycsSSNbeJDb1X1RpIHge+zdujtaFW9ONpkkkY11XH2qnoSeHKkWSTNkG+XlZowdqkJY5eaMHapCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmjB2qYmplmxOchZ4DXgTeKOqVsYYStL4pop98MdV9bMRfo6kGfJpvNTEtLEX8IMkzyY5vNENkhxOsppk9SpXprw7SZOa9mn8nVV1PsnNwIkk/1ZVT62/QVUdAY4AvD/7asr7kzShqfbsVXV+OL0EPA4cHGMoSeObOPYke5K8763zwMeA02MNJmlc0zyNvwV4PMlbP+cfq+qfR5lK7xrfP39q0+v+9DcOzG0OTRF7Vb0C/N6Is0iaIQ+9SU0Yu9SEsUtNGLvUhLFLTYzxQRhpUx5eWx7u2aUmjF1qwtilJoxdasLYpSaMXWrC2KUmjF1qwtilJoxdasLYpSaMXWrC2KUmjF1qwtilJvw8u7a01Z+CBj+vfj1xzy41YexSE8YuNWHsUhPGLjVh7FITxi414XF2bWna4+izXLLZ9wC8M9vu2ZMcTXIpyel12/YlOZHkpeF072zHlDStnTyN/yZwzzXbHgJOVtXtwMnhsqQltm3sVfUUcPmazfcBx4bzx4D7xx1L0tgmfYHulqq6ADCc3rzZDZMcTrKaZPUqVya8O0nTmvmr8VV1pKpWqmplF7tnfXeSNjFp7BeT7AcYTi+NN5KkWZg09uPAoeH8IeCJccaRNCupqq1vkDwG3AXcBFwEvgz8E/Ad4DeBHwOfrKprX8R7m/dnX300d083cUMeT9ZOPV0nebUuZ6Prtn1TTVU9sMlVVitdR3y7rNSEsUtNGLvUhLFLTRi71IQfcb0OeGhNY3DPLjVh7FITxi41YexSE8YuNWHsUhPGLjVh7FITxi41YexSE8YuNWHsUhPGLjVh7FITxi414efZtaVZ/hlr/0T2fLlnl5owdqkJY5eaMHapCWOXmjB2qQljl5rwOLu2NMtj3R5Hn69t9+xJjia5lOT0um0PJ/lpklPD172zHVPStHbyNP6bwD0bbP9aVR0Yvp4cdyxJY9s29qp6Crg8h1kkzdA0L9A9mOT54Wn+3s1ulORwktUkq1e5MsXdSZrGpLF/HfgwcAC4AHxlsxtW1ZGqWqmqlV3snvDuJE1rotir6mJVvVlVvwC+ARwcdyxJY5so9iT71138BHB6s9tKWg7bHmdP8hhwF3BTknPAl4G7khwACjgLfHZ2I0oaw7axV9UDG2x+dAazSJoh3y4rNWHsUhPGLjVh7FITxi414Udcm/PPOffhnl1qwtilJoxdasLYpSaMXWrC2KUmjF1qwuPszXkcvQ/37FITxi41YexSE8YuNWHsUhPGLjVh7FITHme/DviZc43BPbvUhLFLTRi71ISxS00Yu9SEsUtNGLvUhMfZrwMeR9cYtt2zJ7ktyQ+TnEnyYpLPD9v3JTmR5KXhdO/sx5U0qZ08jX8D+GJV/Q7wB8DnktwBPAScrKrbgZPDZUlLatvYq+pCVT03nH8NOAPcCtwHHBtudgy4f0YzShrBO3qBLskHgY8ATwO3VNUFWPsPAbh5k+85nGQ1yepVrkw5rqRJ7Tj2JO8Fvgt8oape3en3VdWRqlqpqpVd7J5kRkkj2FHsSXaxFvq3qup7w+aLSfYP1+8HLs1mRElj2PbQW5IAjwJnquqr6646DhwCHhlOn5jJhO8CfkRVy2Anx9nvBD4NvJDk1LDtS6xF/p0knwF+DHxyJhNKGsW2sVfVj4BscvXd444jaVZ8u6zUhLFLTRi71ISxS00Yu9SEH3GdA4+jaxm4Z5eaMHapCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmjB2qQljl5rYNvYktyX5YZIzSV5M8vlh+8NJfprk1PB17+zHlTSpnSwS8Qbwxap6Lsn7gGeTnBiu+1pV/fXsxpM0lp2sz34BuDCcfy3JGeDWWQ8maVzv6Hf2JB8EPgI8PWx6MMnzSY4m2bvJ9xxOsppk9SpXpptW0sR2HHuS9wLfBb5QVa8CXwc+DBxgbc//lY2+r6qOVNVKVa3sYvf0E0uayI5iT7KLtdC/VVXfA6iqi1X1ZlX9AvgGcHB2Y0qa1k5ejQ/wKHCmqr66bvv+dTf7BHB6/PEkjWUnr8bfCXwaeCHJqWHbl4AHkhwACjgLfHYG80kayU5ejf8RkA2uenL8cSTNiu+gk5owdqkJY5eaMHapCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmjB2qQljl5owdqmJVNX87iz5b+C/1m26CfjZ3AZ4Z5Z1tmWdC5xtUmPO9ltV9esbXTHX2N9258lqVa0sbIAtLOtsyzoXONuk5jWbT+OlJoxdamLRsR9Z8P1vZVlnW9a5wNkmNZfZFvo7u6T5WfSeXdKcGLvUxEJiT3JPkn9P8nKShxYxw2aSnE3ywrAM9eqCZzma5FKS0+u27UtyIslLw+mGa+wtaLalWMZ7i2XGF/rYLXr587n/zp7kBuA/gD8BzgHPAA9U1b/OdZBNJDkLrFTVwt+AkeSPgJ8Df19Vvzts+yvgclU9MvxHubeq/mJJZnsY+Pmil/EeVivav36ZceB+4M9Z4GO3xVx/xhwet0Xs2Q8CL1fVK1X1OvBt4L4FzLH0quop4PI1m+8Djg3nj7H2j2XuNpltKVTVhap6bjj/GvDWMuMLfey2mGsuFhH7rcBP1l0+x3Kt917AD5I8m+TwoofZwC1VdQHW/vEANy94nmttu4z3PF2zzPjSPHaTLH8+rUXEvtFSUst0/O/Oqvp94OPA54anq9qZHS3jPS8bLDO+FCZd/nxai4j9HHDbussfAM4vYI4NVdX54fQS8DjLtxT1xbdW0B1OLy14nv+3TMt4b7TMOEvw2C1y+fNFxP4McHuSDyW5EfgUcHwBc7xNkj3DCyck2QN8jOVbivo4cGg4fwh4YoGz/JJlWcZ7s2XGWfBjt/Dlz6tq7l/Avay9Iv+fwF8uYoZN5vpt4F+GrxcXPRvwGGtP666y9ozoM8CvASeBl4bTfUs02z8ALwDPsxbW/gXN9oes/Wr4PHBq+Lp30Y/dFnPN5XHz7bJSE76DTmrC2KUmjF1qwtilJoxdasLYpSaMXWri/wCKFG0rCtAxvgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "inds = selector.logits.argmax(dim=1).cpu().data.numpy()\n",
    "zeros = np.zeros(784)\n",
    "zeros[inds] = 1\n",
    "plt.figure()\n",
    "plt.imshow(zeros.reshape(28, 28))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "speaking-radiation",
   "metadata": {},
   "source": [
    "# Pretrain model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "indonesian-reader",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.2443\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.2668\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.2833\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.2905\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.2966\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.2956\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.2972\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.3093\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.2995\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.3049\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.3048\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.3075\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.3029\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.3082\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.3174\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.3150\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.3068\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.3067\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.3230\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.3173\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = -0.3139\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = -0.3099\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = -0.3161\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = -0.3165\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = -0.3138\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = -0.3140\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = -0.3205\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = -0.3147\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = -0.3177\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = -0.3092\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = -0.3181\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = -0.3161\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = -0.3163\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = -0.3213\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = -0.3273\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = -0.3242\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = -0.3268\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = -0.3231\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = -0.3210\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = -0.3259\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = -0.3255\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = -0.3280\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = -0.3279\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = -0.3316\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = -0.3174\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = -0.3300\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = -0.3278\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = -0.3262\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = -0.3370\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = -0.3289\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = -0.3289\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = -0.3176\n",
      "\n",
      "--------Epoch 53--------\n",
      "Val loss = -0.3322\n",
      "\n",
      "--------Epoch 54--------\n",
      "Val loss = -0.3316\n",
      "\n",
      "--------Epoch 55--------\n",
      "Val loss = -0.3421\n",
      "\n",
      "--------Epoch 56--------\n",
      "Val loss = -0.3321\n",
      "\n",
      "--------Epoch 57--------\n",
      "Val loss = -0.3270\n",
      "\n",
      "--------Epoch 58--------\n",
      "Val loss = -0.3277\n",
      "\n",
      "--------Epoch 59--------\n",
      "Val loss = -0.3324\n",
      "\n",
      "--------Epoch 60--------\n",
      "Val loss = -0.3309\n",
      "\n",
      "--------Epoch 61--------\n",
      "Val loss = -0.3308\n",
      "\n",
      "--------Epoch 62--------\n",
      "Val loss = -0.3279\n",
      "\n",
      "--------Epoch 63--------\n",
      "Val loss = -0.3363\n",
      "\n",
      "--------Epoch 64--------\n",
      "Val loss = -0.3340\n",
      "\n",
      "--------Epoch 65--------\n",
      "Val loss = -0.3219\n",
      "\n",
      "--------Epoch 66--------\n",
      "Val loss = -0.3255\n",
      "\n",
      "--------Epoch 67--------\n",
      "Val loss = -0.3375\n",
      "\n",
      "--------Epoch 68--------\n",
      "Val loss = -0.3314\n",
      "\n",
      "--------Epoch 69--------\n",
      "Val loss = -0.3336\n",
      "\n",
      "--------Epoch 70--------\n",
      "Val loss = -0.3351\n",
      "\n",
      "--------Epoch 71--------\n",
      "Val loss = -0.3354\n",
      "\n",
      "--------Epoch 72--------\n",
      "Val loss = -0.3317\n",
      "\n",
      "--------Epoch 73--------\n",
      "Val loss = -0.3383\n",
      "\n",
      "--------Epoch 74--------\n",
      "Val loss = -0.3346\n",
      "\n",
      "--------Epoch 75--------\n",
      "Val loss = -0.3341\n",
      "\n",
      "--------Epoch 76--------\n",
      "Val loss = -0.3363\n",
      "\n",
      "--------Epoch 77--------\n",
      "Val loss = -0.3389\n",
      "\n",
      "--------Epoch 78--------\n",
      "Val loss = -0.3387\n",
      "\n",
      "--------Epoch 79--------\n",
      "Val loss = -0.3338\n",
      "\n",
      "--------Epoch 80--------\n",
      "Val loss = -0.3230\n",
      "\n",
      "--------Epoch 81--------\n",
      "Val loss = -0.3316\n",
      "\n",
      "--------Epoch 82--------\n",
      "Val loss = -0.3427\n",
      "\n",
      "--------Epoch 83--------\n",
      "Val loss = -0.3375\n",
      "\n",
      "--------Epoch 84--------\n",
      "Val loss = -0.3371\n",
      "\n",
      "--------Epoch 85--------\n",
      "Val loss = -0.3354\n",
      "\n",
      "--------Epoch 86--------\n",
      "Val loss = -0.3350\n",
      "\n",
      "--------Epoch 87--------\n",
      "Val loss = -0.3405\n",
      "\n",
      "--------Epoch 88--------\n",
      "Val loss = -0.3402\n",
      "\n",
      "--------Epoch 89--------\n",
      "Val loss = -0.3340\n",
      "\n",
      "--------Epoch 90--------\n",
      "Val loss = -0.3430\n",
      "\n",
      "--------Epoch 91--------\n",
      "Val loss = -0.3338\n",
      "\n",
      "--------Epoch 92--------\n",
      "Val loss = -0.3367\n",
      "\n",
      "--------Epoch 93--------\n",
      "Val loss = -0.3386\n",
      "\n",
      "--------Epoch 94--------\n",
      "Val loss = -0.3451\n",
      "\n",
      "--------Epoch 95--------\n",
      "Val loss = -0.3431\n",
      "\n",
      "--------Epoch 96--------\n",
      "Val loss = -0.3328\n",
      "\n",
      "--------Epoch 97--------\n",
      "Val loss = -0.3436\n",
      "\n",
      "--------Epoch 98--------\n",
      "Val loss = -0.3378\n",
      "\n",
      "--------Epoch 99--------\n",
      "Val loss = -0.3364\n",
      "\n",
      "--------Epoch 100--------\n",
      "Val loss = -0.3459\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up model\n",
    "model = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_out))\n",
    "mask_layer = MaskLayer(append=True)\n",
    "pretrain = Pretrainer(model, mask_layer).to(device)\n",
    "\n",
    "# Pretrain\n",
    "pretrain.fit(train_dataset,\n",
    "             val_dataset,\n",
    "             mbsize=128,\n",
    "             lr=1e-3,\n",
    "             nepochs=100,\n",
    "             max_features=max_features,\n",
    "             loss_fn=nn.CrossEntropyLoss(),\n",
    "             val_loss_fn=NegAccuracy())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "civilian-tuition",
   "metadata": {},
   "source": [
    "# Q-learning training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "driven-guitar",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.distributions import Categorical\n",
    "from torch.utils.data import DataLoader\n",
    "from utils import restore_parameters\n",
    "from copy import deepcopy\n",
    "\n",
    "\n",
    "def make_onehot(x):\n",
    "    '''Make an approximately one-hot vector one-hot.'''\n",
    "    argmax = torch.argmax(x, dim=1)\n",
    "    onehot = torch.zeros(x.shape, dtype=x.dtype, device=x.device)\n",
    "    onehot[torch.arange(len(x)), argmax] = 1\n",
    "    return onehot\n",
    "\n",
    "\n",
    "def ind_to_onehot(inds, n):\n",
    "    onehot = torch.zeros(len(inds), n, dtype=torch.float32, device=inds.device)\n",
    "    onehot[torch.arange(len(inds)), inds] = 1\n",
    "    return onehot\n",
    "\n",
    "\n",
    "def logsoftmax(x, dim):\n",
    "    m = torch.max(x, dim=dim, keepdim=True).values\n",
    "    x = x - m\n",
    "    return x - torch.log(torch.sum(torch.exp(x), dim=dim, keepdim=True))\n",
    "\n",
    "\n",
    "class DiscreteSampler(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        \n",
    "    def forward(self, logits):\n",
    "        dist = Categorical(logits=logits)\n",
    "        inds = dist.sample()\n",
    "        onehot = torch.zeros(logits.shape, dtype=logits.dtype, device=logits.device)\n",
    "        onehot[torch.arange(len(logits)), inds] = 1\n",
    "        return inds, onehot\n",
    "\n",
    "\n",
    "class GreedyAdaptiveFS(nn.Module):\n",
    "    '''\n",
    "    Greedy adaptive feature selection.\n",
    "    '''\n",
    "\n",
    "    def __init__(self, selector, model, mask_layer):\n",
    "        super().__init__()\n",
    "        self.selector = selector\n",
    "        self.model = model\n",
    "        self.mask_layer = mask_layer\n",
    "        self.sampler = DiscreteSampler()\n",
    "    \n",
    "    def fit(self,\n",
    "            train,\n",
    "            val,\n",
    "            mbsize,\n",
    "            lr,\n",
    "            nepochs,\n",
    "            max_features,\n",
    "            eps,\n",
    "            loss_fn,\n",
    "            val_loss_fn=None,\n",
    "            train_model=True,\n",
    "            train_selector=True,\n",
    "            no_repeats=True,\n",
    "            validation_mode='final',\n",
    "            verbose=True):\n",
    "        '''\n",
    "        Train models.\n",
    "        '''\n",
    "        # Set up data loaders.\n",
    "        train_loader = DataLoader(\n",
    "            train, batch_size=mbsize, shuffle=True, pin_memory=True,\n",
    "            drop_last=True, num_workers=4)\n",
    "        val_loader = DataLoader(\n",
    "            val, batch_size=mbsize, shuffle=False, pin_memory=True,\n",
    "            drop_last=False, num_workers=4)\n",
    "        \n",
    "        # More setup.\n",
    "        selector = self.selector\n",
    "        model = self.model\n",
    "        mask_layer = self.mask_layer\n",
    "        sampler = self.sampler\n",
    "        device = next(model.parameters()).device\n",
    "        mse_loss_fn = nn.MSELoss(reduction='none')\n",
    "        assert validation_mode in ('final', 'mean')\n",
    "        assert train_model or train_selector\n",
    "        if train_model:\n",
    "            model_opt = optim.Adam(model.parameters(), lr=lr)\n",
    "        if train_selector:\n",
    "            selector_opt = optim.Adam(selector.parameters(), lr=lr)\n",
    "        if val_loss_fn is None:\n",
    "            val_loss_fn = deepcopy(loss_fn)\n",
    "            \n",
    "        # Fix loss function reduction.\n",
    "        if loss_fn.reduction == 'mean':\n",
    "            loss_fn.reduction = 'none'\n",
    "        else:\n",
    "            val_loss_fn.reduction = 'mean'\n",
    "\n",
    "        # For tracking best model.\n",
    "        best_model = None\n",
    "        best_selector = None\n",
    "        best_loss = float('inf')\n",
    "        \n",
    "        for epoch in range(nepochs):\n",
    "            for x, y in train_loader:\n",
    "                # Move to device.\n",
    "                x = x.to(device)\n",
    "                y = y.to(device)\n",
    "                \n",
    "                # Setup.\n",
    "                m_hard = torch.zeros(x.shape, dtype=x.dtype, device=device)\n",
    "                total_loss = 0\n",
    "                \n",
    "                for i in range(max_features):\n",
    "                    # Evaluate selector model.\n",
    "                    x_masked = mask_layer(x, m_hard)\n",
    "                    pred_loss = selector(x_masked)\n",
    "                    \n",
    "                    # Take actions.\n",
    "                    exploit = (torch.rand(len(x), device=device) > eps).int()\n",
    "                    best = torch.argmin(pred_loss, dim=1)\n",
    "                    random = torch.tensor(np.random.choice(x.shape[1], size=len(x)), device=device)\n",
    "                    actions = exploit * best + (1 - exploit) * random\n",
    "                    hard = ind_to_onehot(actions, x.shape[1])\n",
    "                    m_hard = torch.max(m_hard, hard)\n",
    "                    \n",
    "                    # Evaluate model.\n",
    "                    x_masked = mask_layer(x, m_hard)\n",
    "                    pred = model(x_masked)\n",
    "                    \n",
    "                    # Calculate loss.\n",
    "                    model_loss = loss_fn(pred, y)\n",
    "                    selector_loss = mse_loss_fn(pred_loss[torch.arange(len(x)), actions], model_loss.detach())\n",
    "                    loss = torch.mean(model_loss) + torch.mean(selector_loss)\n",
    "                    total_loss = total_loss + loss\n",
    "                    \n",
    "                # Take gradient step.\n",
    "                total_loss = total_loss / max_features\n",
    "                total_loss.backward()\n",
    "                if train_model:\n",
    "                    model_opt.step()\n",
    "                if train_selector:\n",
    "                    selector_opt.step()\n",
    "                model.zero_grad()\n",
    "                selector.zero_grad()\n",
    "                \n",
    "            # Calculate validation loss.\n",
    "            with torch.no_grad():\n",
    "                # For mean loss.\n",
    "                val_loss = 0\n",
    "                n = 0\n",
    "\n",
    "                for x, y in val_loader:\n",
    "                    # Move to device.\n",
    "                    x = x.to(device)\n",
    "                    y = y.to(device)\n",
    "\n",
    "                    # Setup.\n",
    "                    m_hard = torch.zeros(x.shape, dtype=x.dtype, device=device)\n",
    "                    total_loss = 0\n",
    "\n",
    "                    for i in range(max_features):\n",
    "                        # Evaluate selector model.\n",
    "                        x_masked = mask_layer(x, m_hard)\n",
    "                        pred_loss = selector(x_masked)\n",
    "\n",
    "                        # Get selections.\n",
    "                        if no_repeats:\n",
    "                            pred_loss = pred_loss + 1e6 * m_hard\n",
    "                        hard = make_onehot(-pred_loss)\n",
    "\n",
    "                        # Evaluate model.\n",
    "                        m_hard = torch.max(m_hard, hard)\n",
    "                        x_masked = mask_layer(x, m_hard)\n",
    "                        pred = model(x_masked)\n",
    "\n",
    "                        # Calculate loss.\n",
    "                        loss = val_loss_fn(pred, y).item()\n",
    "                        total_loss = total_loss + loss\n",
    "                        \n",
    "                    # Update mean loss.\n",
    "                    if validation_mode == 'final':\n",
    "                        batch_loss = loss\n",
    "                    elif validation_mode == 'mean':\n",
    "                        batch_loss = total_loss / max_features\n",
    "                    val_loss = (\n",
    "                        (batch_loss * len(x) + val_loss * n) / (len(x) + n))\n",
    "                    n += len(x)\n",
    "            \n",
    "            # Print progress.\n",
    "            print(f'{\"-\"*8}Epoch {epoch+1}{\"-\"*8}')\n",
    "            print(f'Val loss = {val_loss:.4f}\\n')\n",
    "\n",
    "            # See if best model.\n",
    "            if val_loss < best_loss:\n",
    "                best_loss = val_loss\n",
    "                best_model = deepcopy(model)\n",
    "                best_selector = deepcopy(selector)\n",
    "\n",
    "        # Restore best parameters.\n",
    "        if best_model:\n",
    "            restore_parameters(model, best_model)\n",
    "        if best_selector:\n",
    "            restore_parameters(selector, best_selector)\n",
    "\n",
    "    def forward(self, x, max_features, no_repeats=True):\n",
    "        '''\n",
    "        Make predictions using selected features.\n",
    "\n",
    "        Args:\n",
    "          x: input data (torch.Tensor).\n",
    "          max_features: max features to observe.\n",
    "          no_repeats: whether to ensure that no repeated selections occur.\n",
    "        '''\n",
    "        # Setup.\n",
    "        selector = self.selector\n",
    "        model = self.model\n",
    "        mask_layer = self.mask_layer\n",
    "        sampler = self.sampler\n",
    "        device = next(model.parameters()).device\n",
    "        m_hard = torch.zeros(x.shape, dtype=x.dtype, device=device)\n",
    "\n",
    "        for i in range(max_features):\n",
    "            # Evaluate selector model.\n",
    "            x_masked = mask_layer(x, m_hard)\n",
    "            pred_loss = selector(x_masked)\n",
    "\n",
    "            # Update selections.\n",
    "            if no_repeats:\n",
    "                pred_loss = pred_loss + 1e6 * m_hard\n",
    "            hard = make_onehot(-pred_loss)\n",
    "            m_hard = torch.max(m_hard, hard)\n",
    "\n",
    "        # Make predictions.\n",
    "        x_masked = mask_layer(x, m_hard)\n",
    "        pred = model(x_masked)\n",
    "        return pred\n",
    "\n",
    "    def evaluate(self, dataset, max_features, loss_fn, batch_size, no_repeats=True):\n",
    "        '''\n",
    "        Evaluate mean performance across a dataset.\n",
    "        '''\n",
    "        # Setup.\n",
    "        device = next(self.model.parameters()).device\n",
    "        loader = DataLoader(\n",
    "            dataset, batch_size=batch_size, shuffle=False, pin_memory=True,\n",
    "            drop_last=False, num_workers=4)\n",
    "\n",
    "        # For calculating mean loss.\n",
    "        mean_loss = 0\n",
    "        n = 0\n",
    "\n",
    "        with torch.no_grad():\n",
    "            for x, y in loader:\n",
    "                # Move to GPU.\n",
    "                x = x.to(device)\n",
    "                y = y.to(device)\n",
    "\n",
    "                # Calculate loss.\n",
    "                pred = self.forward(x, max_features, no_repeats)\n",
    "                loss = loss_fn(pred, y).item()\n",
    "\n",
    "                # Update average.\n",
    "                mean_loss = (mean_loss * n + loss * len(x)) / (n + len(x))\n",
    "                n += len(x)\n",
    "\n",
    "        return mean_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "according-rogers",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.3991\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.3819\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.4453\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.4273\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.3158\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.4161\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.2640\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.4190\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.4417\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.3166\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.4031\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.2833\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.4075\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.3961\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.4533\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.3652\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.2525\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.4169\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.3887\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.4839\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = -0.2912\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = -0.4384\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = -0.3653\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = -0.4411\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = -0.3940\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = -0.4037\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = -0.4929\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = -0.5707\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = -0.5737\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = -0.5839\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = -0.4879\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = -0.5307\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = -0.5140\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = -0.5688\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = -0.6306\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = -0.5639\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = -0.5847\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = -0.5146\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = -0.4024\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = -0.4967\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = -0.4069\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = -0.4985\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = -0.5551\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = -0.5335\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = -0.5459\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = -0.5276\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = -0.4743\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = -0.5471\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = -0.4981\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = -0.5914\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = -0.5690\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = -0.5380\n",
      "\n",
      "--------Epoch 53--------\n",
      "Val loss = -0.5545\n",
      "\n",
      "--------Epoch 54--------\n",
      "Val loss = -0.5852\n",
      "\n",
      "--------Epoch 55--------\n",
      "Val loss = -0.5894\n",
      "\n",
      "--------Epoch 56--------\n",
      "Val loss = -0.6194\n",
      "\n",
      "--------Epoch 57--------\n",
      "Val loss = -0.5714\n",
      "\n",
      "--------Epoch 58--------\n",
      "Val loss = -0.5858\n",
      "\n",
      "--------Epoch 59--------\n",
      "Val loss = -0.5791\n",
      "\n",
      "--------Epoch 60--------\n",
      "Val loss = -0.5566\n",
      "\n",
      "--------Epoch 61--------\n",
      "Val loss = -0.6241\n",
      "\n",
      "--------Epoch 62--------\n",
      "Val loss = -0.5741\n",
      "\n",
      "--------Epoch 63--------\n",
      "Val loss = -0.5385\n",
      "\n",
      "--------Epoch 64--------\n",
      "Val loss = -0.6455\n",
      "\n",
      "--------Epoch 65--------\n",
      "Val loss = -0.5796\n",
      "\n",
      "--------Epoch 66--------\n",
      "Val loss = -0.5591\n",
      "\n",
      "--------Epoch 67--------\n",
      "Val loss = -0.5897\n",
      "\n",
      "--------Epoch 68--------\n",
      "Val loss = -0.5646\n",
      "\n",
      "--------Epoch 69--------\n",
      "Val loss = -0.5904\n",
      "\n",
      "--------Epoch 70--------\n",
      "Val loss = -0.5794\n",
      "\n",
      "--------Epoch 71--------\n",
      "Val loss = -0.6172\n",
      "\n",
      "--------Epoch 72--------\n",
      "Val loss = -0.6133\n",
      "\n",
      "--------Epoch 73--------\n",
      "Val loss = -0.6028\n",
      "\n",
      "--------Epoch 74--------\n",
      "Val loss = -0.3521\n",
      "\n",
      "--------Epoch 75--------\n",
      "Val loss = -0.4077\n",
      "\n",
      "--------Epoch 76--------\n",
      "Val loss = -0.6340\n",
      "\n",
      "--------Epoch 77--------\n",
      "Val loss = -0.6234\n",
      "\n",
      "--------Epoch 78--------\n",
      "Val loss = -0.6330\n",
      "\n",
      "--------Epoch 79--------\n",
      "Val loss = -0.6286\n",
      "\n",
      "--------Epoch 80--------\n",
      "Val loss = -0.5854\n",
      "\n",
      "--------Epoch 81--------\n",
      "Val loss = -0.6113\n",
      "\n",
      "--------Epoch 82--------\n",
      "Val loss = -0.6081\n",
      "\n",
      "--------Epoch 83--------\n",
      "Val loss = -0.6385\n",
      "\n",
      "--------Epoch 84--------\n",
      "Val loss = -0.6409\n",
      "\n",
      "--------Epoch 85--------\n",
      "Val loss = -0.6471\n",
      "\n",
      "--------Epoch 86--------\n",
      "Val loss = -0.6469\n",
      "\n",
      "--------Epoch 87--------\n",
      "Val loss = -0.6336\n",
      "\n",
      "--------Epoch 88--------\n",
      "Val loss = -0.6523\n",
      "\n",
      "--------Epoch 89--------\n",
      "Val loss = -0.4727\n",
      "\n",
      "--------Epoch 90--------\n",
      "Val loss = -0.6803\n",
      "\n",
      "--------Epoch 91--------\n",
      "Val loss = -0.6509\n",
      "\n",
      "--------Epoch 92--------\n",
      "Val loss = -0.6567\n",
      "\n",
      "--------Epoch 93--------\n",
      "Val loss = -0.6832\n",
      "\n",
      "--------Epoch 94--------\n",
      "Val loss = -0.6392\n",
      "\n",
      "--------Epoch 95--------\n",
      "Val loss = -0.6828\n",
      "\n",
      "--------Epoch 96--------\n",
      "Val loss = -0.6860\n",
      "\n",
      "--------Epoch 97--------\n",
      "Val loss = -0.6822\n",
      "\n",
      "--------Epoch 98--------\n",
      "Val loss = -0.6684\n",
      "\n",
      "--------Epoch 99--------\n",
      "Val loss = -0.6332\n",
      "\n",
      "--------Epoch 100--------\n",
      "Val loss = -0.6777\n",
      "\n",
      "--------Epoch 101--------\n",
      "Val loss = -0.6826\n",
      "\n",
      "--------Epoch 102--------\n",
      "Val loss = -0.6573\n",
      "\n",
      "--------Epoch 103--------\n",
      "Val loss = -0.6976\n",
      "\n",
      "--------Epoch 104--------\n",
      "Val loss = -0.6654\n",
      "\n",
      "--------Epoch 105--------\n",
      "Val loss = -0.6609\n",
      "\n",
      "--------Epoch 106--------\n",
      "Val loss = -0.6938\n",
      "\n",
      "--------Epoch 107--------\n",
      "Val loss = -0.6902\n",
      "\n",
      "--------Epoch 108--------\n",
      "Val loss = -0.6759\n",
      "\n",
      "--------Epoch 109--------\n",
      "Val loss = -0.7018\n",
      "\n",
      "--------Epoch 110--------\n",
      "Val loss = -0.7112\n",
      "\n",
      "--------Epoch 111--------\n",
      "Val loss = -0.7040\n",
      "\n",
      "--------Epoch 112--------\n",
      "Val loss = -0.7042\n",
      "\n",
      "--------Epoch 113--------\n",
      "Val loss = -0.7374\n",
      "\n",
      "--------Epoch 114--------\n",
      "Val loss = -0.7284\n",
      "\n",
      "--------Epoch 115--------\n",
      "Val loss = -0.7299\n",
      "\n",
      "--------Epoch 116--------\n",
      "Val loss = -0.7247\n",
      "\n",
      "--------Epoch 117--------\n",
      "Val loss = -0.7059\n",
      "\n",
      "--------Epoch 118--------\n",
      "Val loss = -0.6877\n",
      "\n",
      "--------Epoch 119--------\n",
      "Val loss = -0.7229\n",
      "\n",
      "--------Epoch 120--------\n",
      "Val loss = -0.7115\n",
      "\n",
      "--------Epoch 121--------\n",
      "Val loss = -0.7075\n",
      "\n",
      "--------Epoch 122--------\n",
      "Val loss = -0.7213\n",
      "\n",
      "--------Epoch 123--------\n",
      "Val loss = -0.6938\n",
      "\n",
      "--------Epoch 124--------\n",
      "Val loss = -0.7337\n",
      "\n",
      "--------Epoch 125--------\n",
      "Val loss = -0.7254\n",
      "\n",
      "--------Epoch 126--------\n",
      "Val loss = -0.6841\n",
      "\n",
      "--------Epoch 127--------\n",
      "Val loss = -0.7315\n",
      "\n",
      "--------Epoch 128--------\n",
      "Val loss = -0.7285\n",
      "\n",
      "--------Epoch 129--------\n",
      "Val loss = -0.7354\n",
      "\n",
      "--------Epoch 130--------\n",
      "Val loss = -0.7187\n",
      "\n",
      "--------Epoch 131--------\n",
      "Val loss = -0.7394\n",
      "\n",
      "--------Epoch 132--------\n",
      "Val loss = -0.7280\n",
      "\n",
      "--------Epoch 133--------\n",
      "Val loss = -0.7353\n",
      "\n",
      "--------Epoch 134--------\n",
      "Val loss = -0.7456\n",
      "\n",
      "--------Epoch 135--------\n",
      "Val loss = -0.7367\n",
      "\n",
      "--------Epoch 136--------\n",
      "Val loss = -0.7540\n",
      "\n",
      "--------Epoch 137--------\n",
      "Val loss = -0.7338\n",
      "\n",
      "--------Epoch 138--------\n",
      "Val loss = -0.7015\n",
      "\n",
      "--------Epoch 139--------\n",
      "Val loss = -0.7346\n",
      "\n",
      "--------Epoch 140--------\n",
      "Val loss = -0.7343\n",
      "\n",
      "--------Epoch 141--------\n",
      "Val loss = -0.7672\n",
      "\n",
      "--------Epoch 142--------\n",
      "Val loss = -0.7567\n",
      "\n",
      "--------Epoch 143--------\n",
      "Val loss = -0.7451\n",
      "\n",
      "--------Epoch 144--------\n",
      "Val loss = -0.7441\n",
      "\n",
      "--------Epoch 145--------\n",
      "Val loss = -0.7523\n",
      "\n",
      "--------Epoch 146--------\n",
      "Val loss = -0.7454\n",
      "\n",
      "--------Epoch 147--------\n",
      "Val loss = -0.6990\n",
      "\n",
      "--------Epoch 148--------\n",
      "Val loss = -0.7431\n",
      "\n",
      "--------Epoch 149--------\n",
      "Val loss = -0.7241\n",
      "\n",
      "--------Epoch 150--------\n",
      "Val loss = -0.7241\n",
      "\n",
      "--------Epoch 151--------\n",
      "Val loss = -0.7295\n",
      "\n",
      "--------Epoch 152--------\n",
      "Val loss = -0.7349\n",
      "\n",
      "--------Epoch 153--------\n",
      "Val loss = -0.7496\n",
      "\n",
      "--------Epoch 154--------\n",
      "Val loss = -0.7470\n",
      "\n",
      "--------Epoch 155--------\n",
      "Val loss = -0.7497\n",
      "\n",
      "--------Epoch 156--------\n",
      "Val loss = -0.7462\n",
      "\n",
      "--------Epoch 157--------\n",
      "Val loss = -0.7504\n",
      "\n",
      "--------Epoch 158--------\n",
      "Val loss = -0.7220\n",
      "\n",
      "--------Epoch 159--------\n",
      "Val loss = -0.7499\n",
      "\n",
      "--------Epoch 160--------\n",
      "Val loss = -0.7363\n",
      "\n",
      "--------Epoch 161--------\n",
      "Val loss = -0.7344\n",
      "\n",
      "--------Epoch 162--------\n",
      "Val loss = -0.7531\n",
      "\n",
      "--------Epoch 163--------\n",
      "Val loss = -0.7487\n",
      "\n",
      "--------Epoch 164--------\n",
      "Val loss = -0.7506\n",
      "\n",
      "--------Epoch 165--------\n",
      "Val loss = -0.7482\n",
      "\n",
      "--------Epoch 166--------\n",
      "Val loss = -0.7552\n",
      "\n",
      "--------Epoch 167--------\n",
      "Val loss = -0.7670\n",
      "\n",
      "--------Epoch 168--------\n",
      "Val loss = -0.7458\n",
      "\n",
      "--------Epoch 169--------\n",
      "Val loss = -0.7486\n",
      "\n",
      "--------Epoch 170--------\n",
      "Val loss = -0.7419\n",
      "\n",
      "--------Epoch 171--------\n",
      "Val loss = -0.7593\n",
      "\n",
      "--------Epoch 172--------\n",
      "Val loss = -0.7648\n",
      "\n",
      "--------Epoch 173--------\n",
      "Val loss = -0.7332\n",
      "\n",
      "--------Epoch 174--------\n",
      "Val loss = -0.7581\n",
      "\n",
      "--------Epoch 175--------\n",
      "Val loss = -0.7579\n",
      "\n",
      "--------Epoch 176--------\n",
      "Val loss = -0.7470\n",
      "\n",
      "--------Epoch 177--------\n",
      "Val loss = -0.7488\n",
      "\n",
      "--------Epoch 178--------\n",
      "Val loss = -0.7532\n",
      "\n",
      "--------Epoch 179--------\n",
      "Val loss = -0.7614\n",
      "\n",
      "--------Epoch 180--------\n",
      "Val loss = -0.7617\n",
      "\n",
      "--------Epoch 181--------\n",
      "Val loss = -0.7738\n",
      "\n",
      "--------Epoch 182--------\n",
      "Val loss = -0.7363\n",
      "\n",
      "--------Epoch 183--------\n",
      "Val loss = -0.7609\n",
      "\n",
      "--------Epoch 184--------\n",
      "Val loss = -0.7667\n",
      "\n",
      "--------Epoch 185--------\n",
      "Val loss = -0.7558\n",
      "\n",
      "--------Epoch 186--------\n",
      "Val loss = -0.7514\n",
      "\n",
      "--------Epoch 187--------\n",
      "Val loss = -0.7767\n",
      "\n",
      "--------Epoch 188--------\n",
      "Val loss = -0.7302\n",
      "\n",
      "--------Epoch 189--------\n",
      "Val loss = -0.7637\n",
      "\n",
      "--------Epoch 190--------\n",
      "Val loss = -0.7602\n",
      "\n",
      "--------Epoch 191--------\n",
      "Val loss = -0.7663\n",
      "\n",
      "--------Epoch 192--------\n",
      "Val loss = -0.7479\n",
      "\n",
      "--------Epoch 193--------\n",
      "Val loss = -0.7676\n",
      "\n",
      "--------Epoch 194--------\n",
      "Val loss = -0.7683\n",
      "\n",
      "--------Epoch 195--------\n",
      "Val loss = -0.7573\n",
      "\n",
      "--------Epoch 196--------\n",
      "Val loss = -0.7681\n",
      "\n",
      "--------Epoch 197--------\n",
      "Val loss = -0.7632\n",
      "\n",
      "--------Epoch 198--------\n",
      "Val loss = -0.7726\n",
      "\n",
      "--------Epoch 199--------\n",
      "Val loss = -0.7747\n",
      "\n",
      "--------Epoch 200--------\n",
      "Val loss = -0.7642\n",
      "\n",
      "--------Epoch 201--------\n",
      "Val loss = -0.7653\n",
      "\n",
      "--------Epoch 202--------\n",
      "Val loss = -0.7586\n",
      "\n",
      "--------Epoch 203--------\n",
      "Val loss = -0.7789\n",
      "\n",
      "--------Epoch 204--------\n",
      "Val loss = -0.7480\n",
      "\n",
      "--------Epoch 205--------\n",
      "Val loss = -0.7584\n",
      "\n",
      "--------Epoch 206--------\n",
      "Val loss = -0.7578\n",
      "\n",
      "--------Epoch 207--------\n",
      "Val loss = -0.7483\n",
      "\n",
      "--------Epoch 208--------\n",
      "Val loss = -0.7663\n",
      "\n",
      "--------Epoch 209--------\n",
      "Val loss = -0.7694\n",
      "\n",
      "--------Epoch 210--------\n",
      "Val loss = -0.7757\n",
      "\n",
      "--------Epoch 211--------\n",
      "Val loss = -0.7746\n",
      "\n",
      "--------Epoch 212--------\n",
      "Val loss = -0.7588\n",
      "\n",
      "--------Epoch 213--------\n",
      "Val loss = -0.7604\n",
      "\n",
      "--------Epoch 214--------\n",
      "Val loss = -0.7672\n",
      "\n",
      "--------Epoch 215--------\n",
      "Val loss = -0.7595\n",
      "\n",
      "--------Epoch 216--------\n",
      "Val loss = -0.7403\n",
      "\n",
      "--------Epoch 217--------\n",
      "Val loss = -0.7712\n",
      "\n",
      "--------Epoch 218--------\n",
      "Val loss = -0.7844\n",
      "\n",
      "--------Epoch 219--------\n",
      "Val loss = -0.7559\n",
      "\n",
      "--------Epoch 220--------\n",
      "Val loss = -0.7794\n",
      "\n",
      "--------Epoch 221--------\n",
      "Val loss = -0.7763\n",
      "\n",
      "--------Epoch 222--------\n",
      "Val loss = -0.7949\n",
      "\n",
      "--------Epoch 223--------\n",
      "Val loss = -0.7444\n",
      "\n",
      "--------Epoch 224--------\n",
      "Val loss = -0.7741\n",
      "\n",
      "--------Epoch 225--------\n",
      "Val loss = -0.7763\n",
      "\n",
      "--------Epoch 226--------\n",
      "Val loss = -0.7726\n",
      "\n",
      "--------Epoch 227--------\n",
      "Val loss = -0.7688\n",
      "\n",
      "--------Epoch 228--------\n",
      "Val loss = -0.7638\n",
      "\n",
      "--------Epoch 229--------\n",
      "Val loss = -0.7875\n",
      "\n",
      "--------Epoch 230--------\n",
      "Val loss = -0.7849\n",
      "\n",
      "--------Epoch 231--------\n",
      "Val loss = -0.7816\n",
      "\n",
      "--------Epoch 232--------\n",
      "Val loss = -0.7813\n",
      "\n",
      "--------Epoch 233--------\n",
      "Val loss = -0.7654\n",
      "\n",
      "--------Epoch 234--------\n",
      "Val loss = -0.7836\n",
      "\n",
      "--------Epoch 235--------\n",
      "Val loss = -0.7927\n",
      "\n",
      "--------Epoch 236--------\n",
      "Val loss = -0.7722\n",
      "\n",
      "--------Epoch 237--------\n",
      "Val loss = -0.7858\n",
      "\n",
      "--------Epoch 238--------\n",
      "Val loss = -0.7860\n",
      "\n",
      "--------Epoch 239--------\n",
      "Val loss = -0.7701\n",
      "\n",
      "--------Epoch 240--------\n",
      "Val loss = -0.7720\n",
      "\n",
      "--------Epoch 241--------\n",
      "Val loss = -0.7750\n",
      "\n",
      "--------Epoch 242--------\n",
      "Val loss = -0.7368\n",
      "\n",
      "--------Epoch 243--------\n",
      "Val loss = -0.7879\n",
      "\n",
      "--------Epoch 244--------\n",
      "Val loss = -0.7827\n",
      "\n",
      "--------Epoch 245--------\n",
      "Val loss = -0.7734\n",
      "\n",
      "--------Epoch 246--------\n",
      "Val loss = -0.7791\n",
      "\n",
      "--------Epoch 247--------\n",
      "Val loss = -0.7785\n",
      "\n",
      "--------Epoch 248--------\n",
      "Val loss = -0.7902\n",
      "\n",
      "--------Epoch 249--------\n",
      "Val loss = -0.7720\n",
      "\n",
      "--------Epoch 250--------\n",
      "Val loss = -0.7839\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up selector\n",
    "selector = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_in))\n",
    "gafs = GreedyAdaptiveFS(selector, deepcopy(model), mask_layer).to(device)\n",
    "\n",
    "# Train\n",
    "gafs.fit(train_dataset,\n",
    "         val_dataset,\n",
    "         mbsize=128,\n",
    "         lr=2e-4,\n",
    "         nepochs=250,\n",
    "         max_features=max_features,\n",
    "         eps=0.1,\n",
    "         loss_fn=nn.CrossEntropyLoss(),\n",
    "         val_loss_fn=NegAccuracy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "indian-session",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test acc = 79.98\n"
     ]
    }
   ],
   "source": [
    "test_acc = gafs.evaluate(test_dataset, max_features, Accuracy(), 1024)\n",
    "print(f'Test acc = {100*test_acc:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "lovely-melbourne",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.3361\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.4623\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.5317\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.5007\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.5292\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.5525\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.5775\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.6118\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.6296\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.6437\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.6527\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.6661\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.6198\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.6750\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.6384\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.6651\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.6747\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.6876\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.6699\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.6909\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = -0.6610\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = -0.6870\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = -0.6873\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = -0.6884\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = -0.7114\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = -0.6771\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = -0.7245\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = -0.6851\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = -0.7156\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = -0.7047\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = -0.7340\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = -0.7128\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = -0.6968\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = -0.7313\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = -0.7362\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = -0.7361\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = -0.7387\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = -0.7376\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = -0.7135\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = -0.6782\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = -0.7442\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = -0.7272\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = -0.7255\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = -0.7399\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = -0.7311\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = -0.7325\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = -0.7432\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = -0.7456\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = -0.7352\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = -0.7419\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = -0.7413\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = -0.7502\n",
      "\n",
      "--------Epoch 53--------\n",
      "Val loss = -0.7590\n",
      "\n",
      "--------Epoch 54--------\n",
      "Val loss = -0.7282\n",
      "\n",
      "--------Epoch 55--------\n",
      "Val loss = -0.7564\n",
      "\n",
      "--------Epoch 56--------\n",
      "Val loss = -0.7392\n",
      "\n",
      "--------Epoch 57--------\n",
      "Val loss = -0.7568\n",
      "\n",
      "--------Epoch 58--------\n",
      "Val loss = -0.7216\n",
      "\n",
      "--------Epoch 59--------\n",
      "Val loss = -0.7417\n",
      "\n",
      "--------Epoch 60--------\n",
      "Val loss = -0.7586\n",
      "\n",
      "--------Epoch 61--------\n",
      "Val loss = -0.7594\n",
      "\n",
      "--------Epoch 62--------\n",
      "Val loss = -0.7549\n",
      "\n",
      "--------Epoch 63--------\n",
      "Val loss = -0.7553\n",
      "\n",
      "--------Epoch 64--------\n",
      "Val loss = -0.7711\n",
      "\n",
      "--------Epoch 65--------\n",
      "Val loss = -0.7564\n",
      "\n",
      "--------Epoch 66--------\n",
      "Val loss = -0.7642\n",
      "\n",
      "--------Epoch 67--------\n",
      "Val loss = -0.7153\n",
      "\n",
      "--------Epoch 68--------\n",
      "Val loss = -0.7189\n",
      "\n",
      "--------Epoch 69--------\n",
      "Val loss = -0.7541\n",
      "\n",
      "--------Epoch 70--------\n",
      "Val loss = -0.7643\n",
      "\n",
      "--------Epoch 71--------\n",
      "Val loss = -0.7463\n",
      "\n",
      "--------Epoch 72--------\n",
      "Val loss = -0.7558\n",
      "\n",
      "--------Epoch 73--------\n",
      "Val loss = -0.7663\n",
      "\n",
      "--------Epoch 74--------\n",
      "Val loss = -0.7651\n",
      "\n",
      "--------Epoch 75--------\n",
      "Val loss = -0.7646\n",
      "\n",
      "--------Epoch 76--------\n",
      "Val loss = -0.6949\n",
      "\n",
      "--------Epoch 77--------\n",
      "Val loss = -0.7590\n",
      "\n",
      "--------Epoch 78--------\n",
      "Val loss = -0.7488\n",
      "\n",
      "--------Epoch 79--------\n",
      "Val loss = -0.7570\n",
      "\n",
      "--------Epoch 80--------\n",
      "Val loss = -0.7632\n",
      "\n",
      "--------Epoch 81--------\n",
      "Val loss = -0.7711\n",
      "\n",
      "--------Epoch 82--------\n",
      "Val loss = -0.7623\n",
      "\n",
      "--------Epoch 83--------\n",
      "Val loss = -0.7698\n",
      "\n",
      "--------Epoch 84--------\n",
      "Val loss = -0.7664\n",
      "\n",
      "--------Epoch 85--------\n",
      "Val loss = -0.7441\n",
      "\n",
      "--------Epoch 86--------\n",
      "Val loss = -0.7355\n",
      "\n",
      "--------Epoch 87--------\n",
      "Val loss = -0.7402\n",
      "\n",
      "--------Epoch 88--------\n",
      "Val loss = -0.7599\n",
      "\n",
      "--------Epoch 89--------\n",
      "Val loss = -0.7671\n",
      "\n",
      "--------Epoch 90--------\n",
      "Val loss = -0.7799\n",
      "\n",
      "--------Epoch 91--------\n",
      "Val loss = -0.7689\n",
      "\n",
      "--------Epoch 92--------\n",
      "Val loss = -0.7650\n",
      "\n",
      "--------Epoch 93--------\n",
      "Val loss = -0.7444\n",
      "\n",
      "--------Epoch 94--------\n",
      "Val loss = -0.7306\n",
      "\n",
      "--------Epoch 95--------\n",
      "Val loss = -0.7632\n",
      "\n",
      "--------Epoch 96--------\n",
      "Val loss = -0.7527\n",
      "\n",
      "--------Epoch 97--------\n",
      "Val loss = -0.7691\n",
      "\n",
      "--------Epoch 98--------\n",
      "Val loss = -0.7728\n",
      "\n",
      "--------Epoch 99--------\n",
      "Val loss = -0.7348\n",
      "\n",
      "--------Epoch 100--------\n",
      "Val loss = -0.7605\n",
      "\n",
      "--------Epoch 101--------\n",
      "Val loss = -0.7578\n",
      "\n",
      "--------Epoch 102--------\n",
      "Val loss = -0.7801\n",
      "\n",
      "--------Epoch 103--------\n",
      "Val loss = -0.7498\n",
      "\n",
      "--------Epoch 104--------\n",
      "Val loss = -0.7665\n",
      "\n",
      "--------Epoch 105--------\n",
      "Val loss = -0.7791\n",
      "\n",
      "--------Epoch 106--------\n",
      "Val loss = -0.7792\n",
      "\n",
      "--------Epoch 107--------\n",
      "Val loss = -0.7648\n",
      "\n",
      "--------Epoch 108--------\n",
      "Val loss = -0.7635\n",
      "\n",
      "--------Epoch 109--------\n",
      "Val loss = -0.7641\n",
      "\n",
      "--------Epoch 110--------\n",
      "Val loss = -0.7728\n",
      "\n",
      "--------Epoch 111--------\n",
      "Val loss = -0.7791\n",
      "\n",
      "--------Epoch 112--------\n",
      "Val loss = -0.7727\n",
      "\n",
      "--------Epoch 113--------\n",
      "Val loss = -0.7713\n",
      "\n",
      "--------Epoch 114--------\n",
      "Val loss = -0.7871\n",
      "\n",
      "--------Epoch 115--------\n",
      "Val loss = -0.7757\n",
      "\n",
      "--------Epoch 116--------\n",
      "Val loss = -0.7618\n",
      "\n",
      "--------Epoch 117--------\n",
      "Val loss = -0.7885\n",
      "\n",
      "--------Epoch 118--------\n",
      "Val loss = -0.7630\n",
      "\n",
      "--------Epoch 119--------\n",
      "Val loss = -0.7716\n",
      "\n",
      "--------Epoch 120--------\n",
      "Val loss = -0.7590\n",
      "\n",
      "--------Epoch 121--------\n",
      "Val loss = -0.7619\n",
      "\n",
      "--------Epoch 122--------\n",
      "Val loss = -0.7621\n",
      "\n",
      "--------Epoch 123--------\n",
      "Val loss = -0.7866\n",
      "\n",
      "--------Epoch 124--------\n",
      "Val loss = -0.7434\n",
      "\n",
      "--------Epoch 125--------\n",
      "Val loss = -0.7863\n",
      "\n",
      "--------Epoch 126--------\n",
      "Val loss = -0.7838\n",
      "\n",
      "--------Epoch 127--------\n",
      "Val loss = -0.7882\n",
      "\n",
      "--------Epoch 128--------\n",
      "Val loss = -0.7854\n",
      "\n",
      "--------Epoch 129--------\n",
      "Val loss = -0.7946\n",
      "\n",
      "--------Epoch 130--------\n",
      "Val loss = -0.7611\n",
      "\n",
      "--------Epoch 131--------\n",
      "Val loss = -0.7850\n",
      "\n",
      "--------Epoch 132--------\n",
      "Val loss = -0.7808\n",
      "\n",
      "--------Epoch 133--------\n",
      "Val loss = -0.7568\n",
      "\n",
      "--------Epoch 134--------\n",
      "Val loss = -0.7686\n",
      "\n",
      "--------Epoch 135--------\n",
      "Val loss = -0.7487\n",
      "\n",
      "--------Epoch 136--------\n",
      "Val loss = -0.7825\n",
      "\n",
      "--------Epoch 137--------\n",
      "Val loss = -0.7805\n",
      "\n",
      "--------Epoch 138--------\n",
      "Val loss = -0.7803\n",
      "\n",
      "--------Epoch 139--------\n",
      "Val loss = -0.7852\n",
      "\n",
      "--------Epoch 140--------\n",
      "Val loss = -0.7786\n",
      "\n",
      "--------Epoch 141--------\n",
      "Val loss = -0.7838\n",
      "\n",
      "--------Epoch 142--------\n",
      "Val loss = -0.7995\n",
      "\n",
      "--------Epoch 143--------\n",
      "Val loss = -0.7679\n",
      "\n",
      "--------Epoch 144--------\n",
      "Val loss = -0.7933\n",
      "\n",
      "--------Epoch 145--------\n",
      "Val loss = -0.7888\n",
      "\n",
      "--------Epoch 146--------\n",
      "Val loss = -0.7608\n",
      "\n",
      "--------Epoch 147--------\n",
      "Val loss = -0.7787\n",
      "\n",
      "--------Epoch 148--------\n",
      "Val loss = -0.7479\n",
      "\n",
      "--------Epoch 149--------\n",
      "Val loss = -0.7891\n",
      "\n",
      "--------Epoch 150--------\n",
      "Val loss = -0.7863\n",
      "\n",
      "--------Epoch 151--------\n",
      "Val loss = -0.7887\n",
      "\n",
      "--------Epoch 152--------\n",
      "Val loss = -0.7938\n",
      "\n",
      "--------Epoch 153--------\n",
      "Val loss = -0.7826\n",
      "\n",
      "--------Epoch 154--------\n",
      "Val loss = -0.7557\n",
      "\n",
      "--------Epoch 155--------\n",
      "Val loss = -0.7838\n",
      "\n",
      "--------Epoch 156--------\n",
      "Val loss = -0.7995\n",
      "\n",
      "--------Epoch 157--------\n",
      "Val loss = -0.7837\n",
      "\n",
      "--------Epoch 158--------\n",
      "Val loss = -0.7885\n",
      "\n",
      "--------Epoch 159--------\n",
      "Val loss = -0.7930\n",
      "\n",
      "--------Epoch 160--------\n",
      "Val loss = -0.7909\n",
      "\n",
      "--------Epoch 161--------\n",
      "Val loss = -0.7766\n",
      "\n",
      "--------Epoch 162--------\n",
      "Val loss = -0.7835\n",
      "\n",
      "--------Epoch 163--------\n",
      "Val loss = -0.7916\n",
      "\n",
      "--------Epoch 164--------\n",
      "Val loss = -0.7785\n",
      "\n",
      "--------Epoch 165--------\n",
      "Val loss = -0.7962\n",
      "\n",
      "--------Epoch 166--------\n",
      "Val loss = -0.7817\n",
      "\n",
      "--------Epoch 167--------\n",
      "Val loss = -0.7833\n",
      "\n",
      "--------Epoch 168--------\n",
      "Val loss = -0.7779\n",
      "\n",
      "--------Epoch 169--------\n",
      "Val loss = -0.7932\n",
      "\n",
      "--------Epoch 170--------\n",
      "Val loss = -0.7805\n",
      "\n",
      "--------Epoch 171--------\n",
      "Val loss = -0.7761\n",
      "\n",
      "--------Epoch 172--------\n",
      "Val loss = -0.7725\n",
      "\n",
      "--------Epoch 173--------\n",
      "Val loss = -0.7858\n",
      "\n",
      "--------Epoch 174--------\n",
      "Val loss = -0.7787\n",
      "\n",
      "--------Epoch 175--------\n",
      "Val loss = -0.7972\n",
      "\n",
      "--------Epoch 176--------\n",
      "Val loss = -0.7851\n",
      "\n",
      "--------Epoch 177--------\n",
      "Val loss = -0.7692\n",
      "\n",
      "--------Epoch 178--------\n",
      "Val loss = -0.7834\n",
      "\n",
      "--------Epoch 179--------\n",
      "Val loss = -0.7616\n",
      "\n",
      "--------Epoch 180--------\n",
      "Val loss = -0.7942\n",
      "\n",
      "--------Epoch 181--------\n",
      "Val loss = -0.7802\n",
      "\n",
      "--------Epoch 182--------\n",
      "Val loss = -0.7871\n",
      "\n",
      "--------Epoch 183--------\n",
      "Val loss = -0.7947\n",
      "\n",
      "--------Epoch 184--------\n",
      "Val loss = -0.7978\n",
      "\n",
      "--------Epoch 185--------\n",
      "Val loss = -0.7550\n",
      "\n",
      "--------Epoch 186--------\n",
      "Val loss = -0.7895\n",
      "\n",
      "--------Epoch 187--------\n",
      "Val loss = -0.7825\n",
      "\n",
      "--------Epoch 188--------\n",
      "Val loss = -0.7936\n",
      "\n",
      "--------Epoch 189--------\n",
      "Val loss = -0.7597\n",
      "\n",
      "--------Epoch 190--------\n",
      "Val loss = -0.7897\n",
      "\n",
      "--------Epoch 191--------\n",
      "Val loss = -0.7945\n",
      "\n",
      "--------Epoch 192--------\n",
      "Val loss = -0.7930\n",
      "\n",
      "--------Epoch 193--------\n",
      "Val loss = -0.7915\n",
      "\n",
      "--------Epoch 194--------\n",
      "Val loss = -0.7708\n",
      "\n",
      "--------Epoch 195--------\n",
      "Val loss = -0.7846\n",
      "\n",
      "--------Epoch 196--------\n",
      "Val loss = -0.7777\n",
      "\n",
      "--------Epoch 197--------\n",
      "Val loss = -0.8058\n",
      "\n",
      "--------Epoch 198--------\n",
      "Val loss = -0.7923\n",
      "\n",
      "--------Epoch 199--------\n",
      "Val loss = -0.7907\n",
      "\n",
      "--------Epoch 200--------\n",
      "Val loss = -0.7735\n",
      "\n",
      "--------Epoch 201--------\n",
      "Val loss = -0.7832\n",
      "\n",
      "--------Epoch 202--------\n",
      "Val loss = -0.7964\n",
      "\n",
      "--------Epoch 203--------\n",
      "Val loss = -0.7906\n",
      "\n",
      "--------Epoch 204--------\n",
      "Val loss = -0.7966\n",
      "\n",
      "--------Epoch 205--------\n",
      "Val loss = -0.7926\n",
      "\n",
      "--------Epoch 206--------\n",
      "Val loss = -0.7963\n",
      "\n",
      "--------Epoch 207--------\n",
      "Val loss = -0.7990\n",
      "\n",
      "--------Epoch 208--------\n",
      "Val loss = -0.7979\n",
      "\n",
      "--------Epoch 209--------\n",
      "Val loss = -0.8010\n",
      "\n",
      "--------Epoch 210--------\n",
      "Val loss = -0.7755\n",
      "\n",
      "--------Epoch 211--------\n",
      "Val loss = -0.7593\n",
      "\n",
      "--------Epoch 212--------\n",
      "Val loss = -0.7561\n",
      "\n",
      "--------Epoch 213--------\n",
      "Val loss = -0.8019\n",
      "\n",
      "--------Epoch 214--------\n",
      "Val loss = -0.8066\n",
      "\n",
      "--------Epoch 215--------\n",
      "Val loss = -0.7917\n",
      "\n",
      "--------Epoch 216--------\n",
      "Val loss = -0.7968\n",
      "\n",
      "--------Epoch 217--------\n",
      "Val loss = -0.8054\n",
      "\n",
      "--------Epoch 218--------\n",
      "Val loss = -0.7964\n",
      "\n",
      "--------Epoch 219--------\n",
      "Val loss = -0.7336\n",
      "\n",
      "--------Epoch 220--------\n",
      "Val loss = -0.7925\n",
      "\n",
      "--------Epoch 221--------\n",
      "Val loss = -0.7948\n",
      "\n",
      "--------Epoch 222--------\n",
      "Val loss = -0.7964\n",
      "\n",
      "--------Epoch 223--------\n",
      "Val loss = -0.7892\n",
      "\n",
      "--------Epoch 224--------\n",
      "Val loss = -0.7966\n",
      "\n",
      "--------Epoch 225--------\n",
      "Val loss = -0.7965\n",
      "\n",
      "--------Epoch 226--------\n",
      "Val loss = -0.7893\n",
      "\n",
      "--------Epoch 227--------\n",
      "Val loss = -0.8049\n",
      "\n",
      "--------Epoch 228--------\n",
      "Val loss = -0.7886\n",
      "\n",
      "--------Epoch 229--------\n",
      "Val loss = -0.7965\n",
      "\n",
      "--------Epoch 230--------\n",
      "Val loss = -0.7977\n",
      "\n",
      "--------Epoch 231--------\n",
      "Val loss = -0.7666\n",
      "\n",
      "--------Epoch 232--------\n",
      "Val loss = -0.8020\n",
      "\n",
      "--------Epoch 233--------\n",
      "Val loss = -0.7901\n",
      "\n",
      "--------Epoch 234--------\n",
      "Val loss = -0.7957\n",
      "\n",
      "--------Epoch 235--------\n",
      "Val loss = -0.7939\n",
      "\n",
      "--------Epoch 236--------\n",
      "Val loss = -0.7899\n",
      "\n",
      "--------Epoch 237--------\n",
      "Val loss = -0.7977\n",
      "\n",
      "--------Epoch 238--------\n",
      "Val loss = -0.7819\n",
      "\n",
      "--------Epoch 239--------\n",
      "Val loss = -0.7832\n",
      "\n",
      "--------Epoch 240--------\n",
      "Val loss = -0.7648\n",
      "\n",
      "--------Epoch 241--------\n",
      "Val loss = -0.8006\n",
      "\n",
      "--------Epoch 242--------\n",
      "Val loss = -0.8056\n",
      "\n",
      "--------Epoch 243--------\n",
      "Val loss = -0.8012\n",
      "\n",
      "--------Epoch 244--------\n",
      "Val loss = -0.7878\n",
      "\n",
      "--------Epoch 245--------\n",
      "Val loss = -0.7764\n",
      "\n",
      "--------Epoch 246--------\n",
      "Val loss = -0.7777\n",
      "\n",
      "--------Epoch 247--------\n",
      "Val loss = -0.8026\n",
      "\n",
      "--------Epoch 248--------\n",
      "Val loss = -0.8024\n",
      "\n",
      "--------Epoch 249--------\n",
      "Val loss = -0.8021\n",
      "\n",
      "--------Epoch 250--------\n",
      "Val loss = -0.7940\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up selector\n",
    "selector = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_in))\n",
    "gafs = GreedyAdaptiveFS(selector, deepcopy(model), mask_layer).to(device)\n",
    "\n",
    "# Tie weights\n",
    "selector[0].weight = nn.Parameter(gafs.model[0].weight)\n",
    "selector[2].weight = nn.Parameter(gafs.model[2].weight)\n",
    "\n",
    "# Train\n",
    "gafs.fit(train_dataset,\n",
    "         val_dataset,\n",
    "         mbsize=128,\n",
    "         lr=2e-4,\n",
    "         nepochs=250,\n",
    "         max_features=max_features,\n",
    "         eps=0.1,\n",
    "         loss_fn=nn.CrossEntropyLoss(),\n",
    "         val_loss_fn=NegAccuracy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "exempt-california",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test acc = 80.76\n"
     ]
    }
   ],
   "source": [
    "test_acc = gafs.evaluate(test_dataset, max_features, Accuracy(), 1024)\n",
    "print(f'Test acc = {100*test_acc:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "declared-oliver",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
