{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "isolated-encoding",
   "metadata": {},
   "source": [
    "# RL version\n",
    "\n",
    "- Unlike in the other versions, this implementation does not cancel gradients after each selection step. As a result, the learned selection procedure is not a sequence of greedy steps, but a potentially better sequence that could be expected from performing RL.\n",
    "- However, the results are worse, suggesting that this version is harder to optimize."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "narrative-reflection",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n",
    "from torchvision import transforms\n",
    "from torchvision.datasets import MNIST\n",
    "from copy import deepcopy\n",
    "from utils import *\n",
    "from models import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "rental-missile",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda', 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "matched-outdoors",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load data\n",
    "class Flatten(object):\n",
    "    def __call__(self, pic):\n",
    "        return torch.flatten(pic)\n",
    "    \n",
    "mnist_dataset = MNIST('/tmp/mnist/', download=True, train=True,\n",
    "                      transform=transforms.Compose([transforms.ToTensor(), Flatten()]))\n",
    "images = mnist_dataset.data\n",
    "targets = mnist_dataset.targets\n",
    "np.random.seed(0)\n",
    "val_inds = np.sort(np.random.choice(len(images), size=10000, replace=False))\n",
    "train_inds = np.setdiff1d(np.arange(len(images)), val_inds)\n",
    "\n",
    "# Training dataset\n",
    "train_dataset = torch.utils.data.Subset(mnist_dataset, train_inds)\n",
    "\n",
    "# Validation dataset\n",
    "val_dataset = torch.utils.data.Subset(mnist_dataset, val_inds)\n",
    "\n",
    "# Test dataset\n",
    "test_dataset = MNIST('/tmp/mnist/', download=True, train=False,\n",
    "                     transform=transforms.Compose([transforms.ToTensor(), Flatten()]))\n",
    "\n",
    "# Set input/output dimensions\n",
    "d_in = 784\n",
    "d_out = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "pharmaceutical-brooks",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Number of features to select\n",
    "max_features = 10"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "opposed-complexity",
   "metadata": {},
   "source": [
    "# Global FS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "classical-capability",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.2832\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.2997\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.3102\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.3283\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.3179\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.3346\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.3476\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.3450\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.3408\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.3372\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.3626\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.3634\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.3619\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.3714\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.3721\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.3791\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.3865\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.3751\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.3830\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.3878\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = -0.3944\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = -0.3931\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = -0.3905\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = -0.4102\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = -0.4160\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = -0.4175\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = -0.4234\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = -0.4321\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = -0.4269\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = -0.4343\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = -0.4471\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = -0.4517\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = -0.4559\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = -0.4543\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = -0.4648\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = -0.4641\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = -0.4660\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = -0.4664\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = -0.4676\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = -0.4810\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = -0.4760\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = -0.4895\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = -0.4846\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = -0.4920\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = -0.5153\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = -0.5011\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = -0.5141\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = -0.5181\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = -0.5170\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = -0.5145\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = -0.5249\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = -0.5248\n",
      "\n",
      "--------Epoch 53--------\n",
      "Val loss = -0.5313\n",
      "\n",
      "--------Epoch 54--------\n",
      "Val loss = -0.5410\n",
      "\n",
      "--------Epoch 55--------\n",
      "Val loss = -0.5381\n",
      "\n",
      "--------Epoch 56--------\n",
      "Val loss = -0.5350\n",
      "\n",
      "--------Epoch 57--------\n",
      "Val loss = -0.5409\n",
      "\n",
      "--------Epoch 58--------\n",
      "Val loss = -0.5559\n",
      "\n",
      "--------Epoch 59--------\n",
      "Val loss = -0.5662\n",
      "\n",
      "--------Epoch 60--------\n",
      "Val loss = -0.5496\n",
      "\n",
      "--------Epoch 61--------\n",
      "Val loss = -0.5762\n",
      "\n",
      "--------Epoch 62--------\n",
      "Val loss = -0.5756\n",
      "\n",
      "--------Epoch 63--------\n",
      "Val loss = -0.5768\n",
      "\n",
      "--------Epoch 64--------\n",
      "Val loss = -0.5674\n",
      "\n",
      "--------Epoch 65--------\n",
      "Val loss = -0.5767\n",
      "\n",
      "--------Epoch 66--------\n",
      "Val loss = -0.5832\n",
      "\n",
      "--------Epoch 67--------\n",
      "Val loss = -0.5892\n",
      "\n",
      "--------Epoch 68--------\n",
      "Val loss = -0.5960\n",
      "\n",
      "--------Epoch 69--------\n",
      "Val loss = -0.5962\n",
      "\n",
      "--------Epoch 70--------\n",
      "Val loss = -0.5943\n",
      "\n",
      "--------Epoch 71--------\n",
      "Val loss = -0.5910\n",
      "\n",
      "--------Epoch 72--------\n",
      "Val loss = -0.6005\n",
      "\n",
      "--------Epoch 73--------\n",
      "Val loss = -0.5969\n",
      "\n",
      "--------Epoch 74--------\n",
      "Val loss = -0.6122\n",
      "\n",
      "--------Epoch 75--------\n",
      "Val loss = -0.6044\n",
      "\n",
      "--------Epoch 76--------\n",
      "Val loss = -0.6111\n",
      "\n",
      "--------Epoch 77--------\n",
      "Val loss = -0.6235\n",
      "\n",
      "--------Epoch 78--------\n",
      "Val loss = -0.6312\n",
      "\n",
      "--------Epoch 79--------\n",
      "Val loss = -0.6231\n",
      "\n",
      "--------Epoch 80--------\n",
      "Val loss = -0.6341\n",
      "\n",
      "--------Epoch 81--------\n",
      "Val loss = -0.6334\n",
      "\n",
      "--------Epoch 82--------\n",
      "Val loss = -0.6392\n",
      "\n",
      "--------Epoch 83--------\n",
      "Val loss = -0.6370\n",
      "\n",
      "--------Epoch 84--------\n",
      "Val loss = -0.6477\n",
      "\n",
      "--------Epoch 85--------\n",
      "Val loss = -0.6551\n",
      "\n",
      "--------Epoch 86--------\n",
      "Val loss = -0.6600\n",
      "\n",
      "--------Epoch 87--------\n",
      "Val loss = -0.6594\n",
      "\n",
      "--------Epoch 88--------\n",
      "Val loss = -0.6641\n",
      "\n",
      "--------Epoch 89--------\n",
      "Val loss = -0.6719\n",
      "\n",
      "--------Epoch 90--------\n",
      "Val loss = -0.6691\n",
      "\n",
      "--------Epoch 91--------\n",
      "Val loss = -0.6713\n",
      "\n",
      "--------Epoch 92--------\n",
      "Val loss = -0.6748\n",
      "\n",
      "--------Epoch 93--------\n",
      "Val loss = -0.6744\n",
      "\n",
      "--------Epoch 94--------\n",
      "Val loss = -0.6854\n",
      "\n",
      "--------Epoch 95--------\n",
      "Val loss = -0.6831\n",
      "\n",
      "--------Epoch 96--------\n",
      "Val loss = -0.6913\n",
      "\n",
      "--------Epoch 97--------\n",
      "Val loss = -0.6910\n",
      "\n",
      "--------Epoch 98--------\n",
      "Val loss = -0.6987\n",
      "\n",
      "--------Epoch 99--------\n",
      "Val loss = -0.7004\n",
      "\n",
      "--------Epoch 100--------\n",
      "Val loss = -0.7017\n",
      "\n",
      "--------Epoch 101--------\n",
      "Val loss = -0.6994\n",
      "\n",
      "--------Epoch 102--------\n",
      "Val loss = -0.7020\n",
      "\n",
      "--------Epoch 103--------\n",
      "Val loss = -0.7056\n",
      "\n",
      "--------Epoch 104--------\n",
      "Val loss = -0.7177\n",
      "\n",
      "--------Epoch 105--------\n",
      "Val loss = -0.7068\n",
      "\n",
      "--------Epoch 106--------\n",
      "Val loss = -0.7199\n",
      "\n",
      "--------Epoch 107--------\n",
      "Val loss = -0.7265\n",
      "\n",
      "--------Epoch 108--------\n",
      "Val loss = -0.7264\n",
      "\n",
      "--------Epoch 109--------\n",
      "Val loss = -0.7250\n",
      "\n",
      "--------Epoch 110--------\n",
      "Val loss = -0.7233\n",
      "\n",
      "--------Epoch 111--------\n",
      "Val loss = -0.7249\n",
      "\n",
      "--------Epoch 112--------\n",
      "Val loss = -0.7318\n",
      "\n",
      "--------Epoch 113--------\n",
      "Val loss = -0.7328\n",
      "\n",
      "--------Epoch 114--------\n",
      "Val loss = -0.7356\n",
      "\n",
      "--------Epoch 115--------\n",
      "Val loss = -0.7377\n",
      "\n",
      "--------Epoch 116--------\n",
      "Val loss = -0.7374\n",
      "\n",
      "--------Epoch 117--------\n",
      "Val loss = -0.7390\n",
      "\n",
      "--------Epoch 118--------\n",
      "Val loss = -0.7437\n",
      "\n",
      "--------Epoch 119--------\n",
      "Val loss = -0.7524\n",
      "\n",
      "--------Epoch 120--------\n",
      "Val loss = -0.7485\n",
      "\n",
      "--------Epoch 121--------\n",
      "Val loss = -0.7496\n",
      "\n",
      "--------Epoch 122--------\n",
      "Val loss = -0.7543\n",
      "\n",
      "--------Epoch 123--------\n",
      "Val loss = -0.7489\n",
      "\n",
      "--------Epoch 124--------\n",
      "Val loss = -0.7583\n",
      "\n",
      "--------Epoch 125--------\n",
      "Val loss = -0.7564\n",
      "\n",
      "--------Epoch 126--------\n",
      "Val loss = -0.7545\n",
      "\n",
      "--------Epoch 127--------\n",
      "Val loss = -0.7641\n",
      "\n",
      "--------Epoch 128--------\n",
      "Val loss = -0.7579\n",
      "\n",
      "--------Epoch 129--------\n",
      "Val loss = -0.7629\n",
      "\n",
      "--------Epoch 130--------\n",
      "Val loss = -0.7680\n",
      "\n",
      "--------Epoch 131--------\n",
      "Val loss = -0.7654\n",
      "\n",
      "--------Epoch 132--------\n",
      "Val loss = -0.7724\n",
      "\n",
      "--------Epoch 133--------\n",
      "Val loss = -0.7684\n",
      "\n",
      "--------Epoch 134--------\n",
      "Val loss = -0.7691\n",
      "\n",
      "--------Epoch 135--------\n",
      "Val loss = -0.7713\n",
      "\n",
      "--------Epoch 136--------\n",
      "Val loss = -0.7714\n",
      "\n",
      "--------Epoch 137--------\n",
      "Val loss = -0.7723\n",
      "\n",
      "--------Epoch 138--------\n",
      "Val loss = -0.7740\n",
      "\n",
      "--------Epoch 139--------\n",
      "Val loss = -0.7756\n",
      "\n",
      "--------Epoch 140--------\n",
      "Val loss = -0.7767\n",
      "\n",
      "--------Epoch 141--------\n",
      "Val loss = -0.7777\n",
      "\n",
      "--------Epoch 142--------\n",
      "Val loss = -0.7744\n",
      "\n",
      "--------Epoch 143--------\n",
      "Val loss = -0.7859\n",
      "\n",
      "--------Epoch 144--------\n",
      "Val loss = -0.7788\n",
      "\n",
      "--------Epoch 145--------\n",
      "Val loss = -0.7804\n",
      "\n",
      "--------Epoch 146--------\n",
      "Val loss = -0.7812\n",
      "\n",
      "--------Epoch 147--------\n",
      "Val loss = -0.7813\n",
      "\n",
      "--------Epoch 148--------\n",
      "Val loss = -0.7792\n",
      "\n",
      "--------Epoch 149--------\n",
      "Val loss = -0.7832\n",
      "\n",
      "--------Epoch 150--------\n",
      "Val loss = -0.7840\n",
      "\n",
      "--------Epoch 151--------\n",
      "Val loss = -0.7861\n",
      "\n",
      "--------Epoch 152--------\n",
      "Val loss = -0.7837\n",
      "\n",
      "--------Epoch 153--------\n",
      "Val loss = -0.7795\n",
      "\n",
      "--------Epoch 154--------\n",
      "Val loss = -0.7808\n",
      "\n",
      "--------Epoch 155--------\n",
      "Val loss = -0.7828\n",
      "\n",
      "--------Epoch 156--------\n",
      "Val loss = -0.7845\n",
      "\n",
      "--------Epoch 157--------\n",
      "Val loss = -0.7844\n",
      "\n",
      "--------Epoch 158--------\n",
      "Val loss = -0.7868\n",
      "\n",
      "--------Epoch 159--------\n",
      "Val loss = -0.7890\n",
      "\n",
      "--------Epoch 160--------\n",
      "Val loss = -0.7838\n",
      "\n",
      "--------Epoch 161--------\n",
      "Val loss = -0.7821\n",
      "\n",
      "--------Epoch 162--------\n",
      "Val loss = -0.7882\n",
      "\n",
      "--------Epoch 163--------\n",
      "Val loss = -0.7868\n",
      "\n",
      "--------Epoch 164--------\n",
      "Val loss = -0.7842\n",
      "\n",
      "--------Epoch 165--------\n",
      "Val loss = -0.7891\n",
      "\n",
      "--------Epoch 166--------\n",
      "Val loss = -0.7832\n",
      "\n",
      "--------Epoch 167--------\n",
      "Val loss = -0.7902\n",
      "\n",
      "--------Epoch 168--------\n",
      "Val loss = -0.7877\n",
      "\n",
      "--------Epoch 169--------\n",
      "Val loss = -0.7860\n",
      "\n",
      "--------Epoch 170--------\n",
      "Val loss = -0.7824\n",
      "\n",
      "--------Epoch 171--------\n",
      "Val loss = -0.7894\n",
      "\n",
      "--------Epoch 172--------\n",
      "Val loss = -0.7852\n",
      "\n",
      "--------Epoch 173--------\n",
      "Val loss = -0.7882\n",
      "\n",
      "--------Epoch 174--------\n",
      "Val loss = -0.7875\n",
      "\n",
      "--------Epoch 175--------\n",
      "Val loss = -0.7896\n",
      "\n",
      "--------Epoch 176--------\n",
      "Val loss = -0.7838\n",
      "\n",
      "--------Epoch 177--------\n",
      "Val loss = -0.7840\n",
      "\n",
      "--------Epoch 178--------\n",
      "Val loss = -0.7837\n",
      "\n",
      "--------Epoch 179--------\n",
      "Val loss = -0.7831\n",
      "\n",
      "--------Epoch 180--------\n",
      "Val loss = -0.7834\n",
      "\n",
      "--------Epoch 181--------\n",
      "Val loss = -0.7811\n",
      "\n",
      "--------Epoch 182--------\n",
      "Val loss = -0.7826\n",
      "\n",
      "--------Epoch 183--------\n",
      "Val loss = -0.7834\n",
      "\n",
      "--------Epoch 184--------\n",
      "Val loss = -0.7838\n",
      "\n",
      "--------Epoch 185--------\n",
      "Val loss = -0.7812\n",
      "\n",
      "--------Epoch 186--------\n",
      "Val loss = -0.7800\n",
      "\n",
      "--------Epoch 187--------\n",
      "Val loss = -0.7836\n",
      "\n",
      "--------Epoch 188--------\n",
      "Val loss = -0.7789\n",
      "\n",
      "--------Epoch 189--------\n",
      "Val loss = -0.7794\n",
      "\n",
      "--------Epoch 190--------\n",
      "Val loss = -0.7792\n",
      "\n",
      "--------Epoch 191--------\n",
      "Val loss = -0.7808\n",
      "\n",
      "--------Epoch 192--------\n",
      "Val loss = -0.7803\n",
      "\n",
      "--------Epoch 193--------\n",
      "Val loss = -0.7824\n",
      "\n",
      "--------Epoch 194--------\n",
      "Val loss = -0.7729\n",
      "\n",
      "--------Epoch 195--------\n",
      "Val loss = -0.7783\n",
      "\n",
      "--------Epoch 196--------\n",
      "Val loss = -0.7776\n",
      "\n",
      "--------Epoch 197--------\n",
      "Val loss = -0.7825\n",
      "\n",
      "--------Epoch 198--------\n",
      "Val loss = -0.7767\n",
      "\n",
      "--------Epoch 199--------\n",
      "Val loss = -0.7761\n",
      "\n",
      "--------Epoch 200--------\n",
      "Val loss = -0.7757\n",
      "\n",
      "--------Epoch 201--------\n",
      "Val loss = -0.7788\n",
      "\n",
      "--------Epoch 202--------\n",
      "Val loss = -0.7762\n",
      "\n",
      "--------Epoch 203--------\n",
      "Val loss = -0.7759\n",
      "\n",
      "--------Epoch 204--------\n",
      "Val loss = -0.7796\n",
      "\n",
      "--------Epoch 205--------\n",
      "Val loss = -0.7727\n",
      "\n",
      "--------Epoch 206--------\n",
      "Val loss = -0.7730\n",
      "\n",
      "--------Epoch 207--------\n",
      "Val loss = -0.7722\n",
      "\n",
      "--------Epoch 208--------\n",
      "Val loss = -0.7725\n",
      "\n",
      "--------Epoch 209--------\n",
      "Val loss = -0.7736\n",
      "\n",
      "--------Epoch 210--------\n",
      "Val loss = -0.7744\n",
      "\n",
      "--------Epoch 211--------\n",
      "Val loss = -0.7749\n",
      "\n",
      "--------Epoch 212--------\n",
      "Val loss = -0.7705\n",
      "\n",
      "--------Epoch 213--------\n",
      "Val loss = -0.7715\n",
      "\n",
      "--------Epoch 214--------\n",
      "Val loss = -0.7729\n",
      "\n",
      "--------Epoch 215--------\n",
      "Val loss = -0.7699\n",
      "\n",
      "--------Epoch 216--------\n",
      "Val loss = -0.7717\n",
      "\n",
      "--------Epoch 217--------\n",
      "Val loss = -0.7724\n",
      "\n",
      "--------Epoch 218--------\n",
      "Val loss = -0.7699\n",
      "\n",
      "--------Epoch 219--------\n",
      "Val loss = -0.7702\n",
      "\n",
      "--------Epoch 220--------\n",
      "Val loss = -0.7702\n",
      "\n",
      "--------Epoch 221--------\n",
      "Val loss = -0.7703\n",
      "\n",
      "--------Epoch 222--------\n",
      "Val loss = -0.7749\n",
      "\n",
      "--------Epoch 223--------\n",
      "Val loss = -0.7688\n",
      "\n",
      "--------Epoch 224--------\n",
      "Val loss = -0.7693\n",
      "\n",
      "--------Epoch 225--------\n",
      "Val loss = -0.7728\n",
      "\n",
      "--------Epoch 226--------\n",
      "Val loss = -0.7698\n",
      "\n",
      "--------Epoch 227--------\n",
      "Val loss = -0.7682\n",
      "\n",
      "--------Epoch 228--------\n",
      "Val loss = -0.7657\n",
      "\n",
      "--------Epoch 229--------\n",
      "Val loss = -0.7694\n",
      "\n",
      "--------Epoch 230--------\n",
      "Val loss = -0.7639\n",
      "\n",
      "--------Epoch 231--------\n",
      "Val loss = -0.7636\n",
      "\n",
      "--------Epoch 232--------\n",
      "Val loss = -0.7667\n",
      "\n",
      "--------Epoch 233--------\n",
      "Val loss = -0.7698\n",
      "\n",
      "--------Epoch 234--------\n",
      "Val loss = -0.7687\n",
      "\n",
      "--------Epoch 235--------\n",
      "Val loss = -0.7676\n",
      "\n",
      "--------Epoch 236--------\n",
      "Val loss = -0.7656\n",
      "\n",
      "--------Epoch 237--------\n",
      "Val loss = -0.7595\n",
      "\n",
      "--------Epoch 238--------\n",
      "Val loss = -0.7659\n",
      "\n",
      "--------Epoch 239--------\n",
      "Val loss = -0.7659\n",
      "\n",
      "--------Epoch 240--------\n",
      "Val loss = -0.7647\n",
      "\n",
      "--------Epoch 241--------\n",
      "Val loss = -0.7654\n",
      "\n",
      "--------Epoch 242--------\n",
      "Val loss = -0.7618\n",
      "\n",
      "--------Epoch 243--------\n",
      "Val loss = -0.7666\n",
      "\n",
      "--------Epoch 244--------\n",
      "Val loss = -0.7656\n",
      "\n",
      "--------Epoch 245--------\n",
      "Val loss = -0.7660\n",
      "\n",
      "--------Epoch 246--------\n",
      "Val loss = -0.7644\n",
      "\n",
      "--------Epoch 247--------\n",
      "Val loss = -0.7602\n",
      "\n",
      "--------Epoch 248--------\n",
      "Val loss = -0.7610\n",
      "\n",
      "--------Epoch 249--------\n",
      "Val loss = -0.7657\n",
      "\n",
      "--------Epoch 250--------\n",
      "Val loss = -0.7598\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up model\n",
    "global_model = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_out))\n",
    "selector = ConcreteMask(d_in, max_features, append=True)\n",
    "globalfs = GlobalSelector(global_model, selector).to(device)\n",
    "\n",
    "# Train\n",
    "globalfs.fit(train_dataset,\n",
    "             val_dataset,\n",
    "             mbsize=128,\n",
    "             lr=1e-3,\n",
    "             nepochs=250,\n",
    "             loss_fn=nn.CrossEntropyLoss(),\n",
    "             val_loss_fn=NegAccuracy(),\n",
    "             start_temp=1.0,\n",
    "             end_temp=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "senior-creation",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test acc = 79.29\n"
     ]
    }
   ],
   "source": [
    "test_acc = globalfs.evaluate(test_dataset, Accuracy(), 1024)\n",
    "print(f'Test acc = {100*test_acc:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "attempted-count",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAKz0lEQVR4nO3dT4ic933H8fenrqyAkoKd1EZ1TJMGU2oKVcqiFlxKinHq+CLn0BIfggoG5RBDAjnUpIf6aEqT0EMJKLWIWlKHQmKsg2kiRMAEivHaqLYctZVj1EaRkBp8iFOoLDvfHvZx2cj7zzPP/PF+3y9YZuaZ2Z0vg956ZuaZ3V+qCkm73y8tegBJ82HsUhPGLjVh7FITxi418cvzvLMbs7few7553qXUyv/yP7xeV7PRdVPFnuRe4G+AG4C/q6pHt7r9e9jH7+Xuae5S0haeqVObXjfx0/gkNwB/C3wCuBN4IMmdk/48SbM1zWv2g8DLVfVKVb0OfBM4NM5YksY2Tey3AT9ad/nCsO0XJDmSZDXJ6jWuTnF3kqYxTewbvQnwts/eVtXRqlqpqpU97J3i7iRNY5rYLwC3r7v8QeDidONImpVpYn8WuCPJh5PcCHwKODHOWJLGNvGht6p6I8lDwHdYO/R2rKpeGm0ySaOa6jh7VT0FPDXSLJJmyI/LSk0Yu9SEsUtNGLvUhLFLTRi71ISxS00Yu9SEsUtNGLvUhLFLTRi71ISxS00Yu9SEsUtNGLvUhLFLTRi71ISxS00Yu9SEsUtNGLvUhLFLTRi71ISxS00Yu9SEsUtNGLvUhLFLTUy1iqv0nYunt7z+j3/twFzm0Pamij3JeeA14E3gjapaGWMoSeMbY8/+R1X1kxF+jqQZ8jW71MS0sRfw3STPJTmy0Q2SHEmymmT1GlenvDtJk5r2afxdVXUxyS3AyST/VlVPr79BVR0FjgL8Sm6uKe9P0oSm2rNX1cXh9ArwBHBwjKEkjW/i2JPsS/K+t84DHwfOjDWYpHFN8zT+VuCJJG/9nH+sqn8eZSq9a3gc/d1j4tir6hXgd0acRdIMeehNasLYpSaMXWrC2KUmjF1qwl9x1VT8Fdd3D/fsUhPGLjVh7FITxi41YexSE8YuNWHsUhMeZ9dUPI7+7uGeXWrC2KUmjF1qwtilJoxdasLYpSaMXWrC4+yayix/n93flR+Xe3apCWOXmjB2qQljl5owdqkJY5eaMHapCY+zayqzPNbtcfRxbbtnT3IsyZUkZ9ZtuznJySTnhtObZjumpGnt5Gn814F7r9v2MHCqqu4ATg2XJS2xbWOvqqeBV6/bfAg4Ppw/Dtw/7liSxjbpG3S3VtUlgOH0ls1umORIktUkq9e4OuHdSZrWzN+Nr6qjVbVSVSt72Dvru5O0iUljv5xkP8BwemW8kSTNwqSxnwAOD+cPA0+OM46kWdnJobfHgX8BfjPJhSQPAo8C9yQ5B9wzXJa0xLb9UE1VPbDJVXePPIukGfLjslITxi41YexSE8YuNWHsUhP+iqu25J9z3j3cs0tNGLvUhLFLTRi71ISxS00Yu9SEsUtNeJx9F9jqWPi0x8E9jr57uGeXmjB2qQljl5owdqkJY5eaMHapCWOXmvA4+y7gsXDthHt2qQljl5owdqkJY5eaMHapCWOXmjB2qQljl5rYyfrsx5JcSXJm3bZHkvw4yenh677ZjilpWjvZs38duHeD7V+pqgPD11PjjiVpbNvGXlVPA6/OYRZJMzTNa/aHkrwwPM2/abMbJTmSZDXJ6jWuTnF3kqYxaexfBT4CHAAuAV/a7IZVdbSqVqpqZQ97J7w7SdOaKPaqulxVb1bVz4GvAQfHHUvS2CaKPcn+dRc/CZzZ7LaSlsO2v8+e5HHgY8AHklwA/hL4WJIDQAHngc/MbkRtZ5Z/N9712XePbWOvqgc22PzYDGaRNEN+gk5qwtilJoxdasLYpSaMXWrCPyW9C8zy8JeH1nYP9+xSE8YuNWHsUhPGLjVh7FITxi41YexSE8YuNWHsUhPGLjVh7FITxi41YexSE8YuNWHsUhP+PvsuMMs/Ja3dwz271ISxS00Yu9SEsUtNGLvUhLFLTRi71ITH2edg1sseeyxdO7Htnj3J7Um+l+RskpeSfG7YfnOSk0nODac3zX5cSZPaydP4N4AvVNVvAb8PfDbJncDDwKmqugM4NVyWtKS2jb2qLlXV88P514CzwG3AIeD4cLPjwP0zmlHSCN7RG3RJPgR8FHgGuLWqLsHafwjALZt8z5Ekq0lWr3F1ynElTWrHsSd5L/At4PNV9dOdfl9VHa2qlapa2cPeSWaUNIIdxZ5kD2uhf6Oqvj1svpxk/3D9fuDKbEaUNIZtD70lCfAYcLaqvrzuqhPAYeDR4fTJmUy4C3hoTMtgJ8fZ7wI+DbyY5PSw7YusRf5PSR4E/gv4k5lMKGkU28ZeVd8HssnVd487jqRZ8eOyUhPGLjVh7FITxi41YexSE8YuNWHsUhPGLjVh7FITxi41YexSE8YuNWHsUhPGLjVh7FITxi41YexSE8YuNWHsUhPGLjVh7FITxi41YexSE8YuNWHsUhPGLjVh7FITxi41YexSE9vGnuT2JN9LcjbJS0k+N2x/JMmPk5wevu6b/biSJrWT9dnfAL5QVc8neR/wXJKTw3Vfqaq/nt14ksayk/XZLwGXhvOvJTkL3DbrwSSN6x29Zk/yIeCjwDPDpoeSvJDkWJKbNvmeI0lWk6xe4+p000qa2I5jT/Je4FvA56vqp8BXgY8AB1jb839po++rqqNVtVJVK3vYO/3Ekiayo9iT7GEt9G9U1bcBqupyVb1ZVT8HvgYcnN2Ykqa1k3fjAzwGnK2qL6/bvn/dzT4JnBl/PElj2cm78XcBnwZeTHJ62PZF4IEkB4ACzgOfmcF8kkayk3fjvw9kg6ueGn8cSbPiJ+ikJoxdasLYpSaMXWrC2KUmjF1qwtilJoxdasLYpSaMXWrC2KUmjF1qwtilJoxdaiJVNb87S/4b+M91mz4A/GRuA7wzyzrbss4FzjapMWf79ar61Y2umGvsb7vzZLWqVhY2wBaWdbZlnQucbVLzms2n8VITxi41sejYjy74/reyrLMt61zgbJOay2wLfc0uaX4WvWeXNCfGLjWxkNiT3Jvk35O8nOThRcywmSTnk7w4LEO9uuBZjiW5kuTMum03JzmZ5NxwuuEaewuabSmW8d5imfGFPnaLXv587q/Zk9wA/AdwD3ABeBZ4oKp+MNdBNpHkPLBSVQv/AEaSPwR+Bvx9Vf32sO2vgFer6tHhP8qbqurPl2S2R4CfLXoZ72G1ov3rlxkH7gf+jAU+dlvM9afM4XFbxJ79IPByVb1SVa8D3wQOLWCOpVdVTwOvXrf5EHB8OH+ctX8sc7fJbEuhqi5V1fPD+deAt5YZX+hjt8Vcc7GI2G8DfrTu8gWWa733Ar6b5LkkRxY9zAZurapLsPaPB7hlwfNcb9tlvOfpumXGl+axm2T582ktIvaNlpJapuN/d1XV7wKfAD47PF3VzuxoGe952WCZ8aUw6fLn01pE7BeA29dd/iBwcQFzbKiqLg6nV4AnWL6lqC+/tYLucHplwfP8v2VaxnujZcZZgsdukcufLyL2Z4E7knw4yY3Ap4ATC5jjbZLsG944Ick+4OMs31LUJ4DDw/nDwJMLnOUXLMsy3pstM86CH7uFL39eVXP/Au5j7R35HwJ/sYgZNpnrN4B/Hb5eWvRswOOsPa27xtozogeB9wOngHPD6c1LNNs/AC8CL7AW1v4FzfYHrL00fAE4PXzdt+jHbou55vK4+XFZqQk/QSc1YexSE8YuNWHsUhPGLjVh7FITxi418X+75WCu6v7X1wAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "inds = selector.logits.argmax(dim=1).cpu().data.numpy()\n",
    "zeros = np.zeros(784)\n",
    "zeros[inds] = 1\n",
    "plt.figure()\n",
    "plt.imshow(zeros.reshape(28, 28))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "colored-patrol",
   "metadata": {},
   "source": [
    "# Pretrain model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "delayed-liechtenstein",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.2447\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.2737\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.2883\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.2858\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.2869\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.2952\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.2998\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.2978\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.3003\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.3074\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.3081\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.3122\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.3001\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.3065\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.3125\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.3221\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.3168\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.3131\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.3134\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.3151\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = -0.3159\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = -0.3139\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = -0.3163\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = -0.3174\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = -0.3150\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = -0.3154\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = -0.3282\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = -0.3211\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = -0.3235\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = -0.3189\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = -0.3223\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = -0.3213\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = -0.3183\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = -0.3170\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = -0.3283\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = -0.3241\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = -0.3205\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = -0.3301\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = -0.3284\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = -0.3246\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = -0.3274\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = -0.3264\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = -0.3270\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = -0.3208\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = -0.3237\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = -0.3254\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = -0.3264\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = -0.3287\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = -0.3361\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = -0.3239\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = -0.3313\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = -0.3235\n",
      "\n",
      "--------Epoch 53--------\n",
      "Val loss = -0.3280\n",
      "\n",
      "--------Epoch 54--------\n",
      "Val loss = -0.3338\n",
      "\n",
      "--------Epoch 55--------\n",
      "Val loss = -0.3281\n",
      "\n",
      "--------Epoch 56--------\n",
      "Val loss = -0.3251\n",
      "\n",
      "--------Epoch 57--------\n",
      "Val loss = -0.3296\n",
      "\n",
      "--------Epoch 58--------\n",
      "Val loss = -0.3265\n",
      "\n",
      "--------Epoch 59--------\n",
      "Val loss = -0.3256\n",
      "\n",
      "--------Epoch 60--------\n",
      "Val loss = -0.3291\n",
      "\n",
      "--------Epoch 61--------\n",
      "Val loss = -0.3266\n",
      "\n",
      "--------Epoch 62--------\n",
      "Val loss = -0.3299\n",
      "\n",
      "--------Epoch 63--------\n",
      "Val loss = -0.3300\n",
      "\n",
      "--------Epoch 64--------\n",
      "Val loss = -0.3288\n",
      "\n",
      "--------Epoch 65--------\n",
      "Val loss = -0.3289\n",
      "\n",
      "--------Epoch 66--------\n",
      "Val loss = -0.3324\n",
      "\n",
      "--------Epoch 67--------\n",
      "Val loss = -0.3250\n",
      "\n",
      "--------Epoch 68--------\n",
      "Val loss = -0.3373\n",
      "\n",
      "--------Epoch 69--------\n",
      "Val loss = -0.3384\n",
      "\n",
      "--------Epoch 70--------\n",
      "Val loss = -0.3358\n",
      "\n",
      "--------Epoch 71--------\n",
      "Val loss = -0.3350\n",
      "\n",
      "--------Epoch 72--------\n",
      "Val loss = -0.3385\n",
      "\n",
      "--------Epoch 73--------\n",
      "Val loss = -0.3309\n",
      "\n",
      "--------Epoch 74--------\n",
      "Val loss = -0.3347\n",
      "\n",
      "--------Epoch 75--------\n",
      "Val loss = -0.3246\n",
      "\n",
      "--------Epoch 76--------\n",
      "Val loss = -0.3239\n",
      "\n",
      "--------Epoch 77--------\n",
      "Val loss = -0.3371\n",
      "\n",
      "--------Epoch 78--------\n",
      "Val loss = -0.3369\n",
      "\n",
      "--------Epoch 79--------\n",
      "Val loss = -0.3287\n",
      "\n",
      "--------Epoch 80--------\n",
      "Val loss = -0.3367\n",
      "\n",
      "--------Epoch 81--------\n",
      "Val loss = -0.3414\n",
      "\n",
      "--------Epoch 82--------\n",
      "Val loss = -0.3328\n",
      "\n",
      "--------Epoch 83--------\n",
      "Val loss = -0.3335\n",
      "\n",
      "--------Epoch 84--------\n",
      "Val loss = -0.3399\n",
      "\n",
      "--------Epoch 85--------\n",
      "Val loss = -0.3410\n",
      "\n",
      "--------Epoch 86--------\n",
      "Val loss = -0.3368\n",
      "\n",
      "--------Epoch 87--------\n",
      "Val loss = -0.3266\n",
      "\n",
      "--------Epoch 88--------\n",
      "Val loss = -0.3399\n",
      "\n",
      "--------Epoch 89--------\n",
      "Val loss = -0.3380\n",
      "\n",
      "--------Epoch 90--------\n",
      "Val loss = -0.3363\n",
      "\n",
      "--------Epoch 91--------\n",
      "Val loss = -0.3361\n",
      "\n",
      "--------Epoch 92--------\n",
      "Val loss = -0.3377\n",
      "\n",
      "--------Epoch 93--------\n",
      "Val loss = -0.3380\n",
      "\n",
      "--------Epoch 94--------\n",
      "Val loss = -0.3427\n",
      "\n",
      "--------Epoch 95--------\n",
      "Val loss = -0.3449\n",
      "\n",
      "--------Epoch 96--------\n",
      "Val loss = -0.3338\n",
      "\n",
      "--------Epoch 97--------\n",
      "Val loss = -0.3401\n",
      "\n",
      "--------Epoch 98--------\n",
      "Val loss = -0.3369\n",
      "\n",
      "--------Epoch 99--------\n",
      "Val loss = -0.3355\n",
      "\n",
      "--------Epoch 100--------\n",
      "Val loss = -0.3323\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up model\n",
    "model = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_out))\n",
    "mask_layer = MaskLayer(append=True)\n",
    "pretrain = Pretrainer(model, mask_layer).to(device)\n",
    "\n",
    "# Pretrain\n",
    "pretrain.fit(train_dataset,\n",
    "             val_dataset,\n",
    "             mbsize=128,\n",
    "             lr=1e-3,\n",
    "             nepochs=100,\n",
    "             max_features=max_features,\n",
    "             loss_fn=nn.CrossEntropyLoss(),\n",
    "             val_loss_fn=NegAccuracy())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "vanilla-balloon",
   "metadata": {},
   "source": [
    "# RL version (no discretization)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "broke-assistant",
   "metadata": {},
   "outputs": [],
   "source": [
    "class AdaptiveFS(nn.Module):\n",
    "    '''\n",
    "    Greedy adaptive feature selection.\n",
    "    '''\n",
    "\n",
    "    def __init__(self, selector, model, mask_layer, selector_layer):\n",
    "        super().__init__()\n",
    "        self.selector = selector\n",
    "        self.model = model\n",
    "        self.mask_layer = mask_layer\n",
    "        self.selector_layer = selector_layer\n",
    "    \n",
    "    def fit(self,\n",
    "            train,\n",
    "            val,\n",
    "            mbsize,\n",
    "            lr,\n",
    "            nepochs,\n",
    "            max_features,\n",
    "            loss_fn,\n",
    "            val_loss_fn=None,\n",
    "            weighting='linear',\n",
    "            beta=0.9,\n",
    "            train_model=True,\n",
    "            train_selector=True,\n",
    "            start_temp=10.0,\n",
    "            end_temp=0.01,\n",
    "            argmax=True,\n",
    "            no_repeats=True,\n",
    "            validation_mode='final',\n",
    "            verbose=True):\n",
    "        '''\n",
    "        Train models.\n",
    "        '''\n",
    "        # Set up data loaders.\n",
    "        train_loader = DataLoader(\n",
    "            train, batch_size=mbsize, shuffle=True, pin_memory=True,\n",
    "            drop_last=True, num_workers=4)\n",
    "        val_loader = DataLoader(\n",
    "            val, batch_size=mbsize, shuffle=False, pin_memory=True,\n",
    "            drop_last=False, num_workers=4)\n",
    "        \n",
    "        # More setup.\n",
    "        selector = self.selector\n",
    "        model = self.model\n",
    "        mask_layer = self.mask_layer\n",
    "        selector_layer = self.selector_layer\n",
    "        device = next(model.parameters()).device\n",
    "        assert validation_mode in ('final', 'weighted')\n",
    "        assert weighting in ('linear', 'multiplicative')\n",
    "        assert train_model or train_selector\n",
    "        if train_model:\n",
    "            model_opt = optim.Adam(model.parameters(), lr=lr)\n",
    "        if train_selector:\n",
    "            selector_opt = optim.Adam(selector.parameters(), lr=lr)\n",
    "        if val_loss_fn is None:\n",
    "            val_loss_fn = loss_fn\n",
    "        \n",
    "        # Temperature setup.\n",
    "        r = (end_temp / start_temp) ** (1 / (nepochs * len(train_loader)))\n",
    "        temp = start_temp\n",
    "\n",
    "        # For tracking best model.\n",
    "        best_model = None\n",
    "        best_selector = None\n",
    "        best_loss = float('inf')\n",
    "        \n",
    "        for epoch in range(nepochs):\n",
    "            for x, y in train_loader:\n",
    "                # Move to device.\n",
    "                x = x.to(device)\n",
    "                y = y.to(device)\n",
    "                \n",
    "                # Setup.\n",
    "                m = torch.zeros(x.shape, dtype=x.dtype, device=device)\n",
    "                total_loss = 0\n",
    "                total_weight = 0\n",
    "                \n",
    "                for i in range(max_features):\n",
    "                    # Evaluate selector model.\n",
    "                    x_masked = mask_layer(x, m)\n",
    "                    logits = selector(x_masked)\n",
    "                    \n",
    "                    # Get selections.\n",
    "                    soft = selector_layer(logits, temp)\n",
    "                    \n",
    "                    # Evaluate model.\n",
    "                    m = torch.max(m, soft)\n",
    "                    x_masked = mask_layer(x, m)\n",
    "                    pred = model(x_masked)\n",
    "                    \n",
    "                    # Calculate loss.\n",
    "                    loss = loss_fn(pred, y)\n",
    "                    if weighting == 'linear':\n",
    "                        weight = i / max_features\n",
    "                    elif weighting == 'multiplicative':\n",
    "                        weight = beta ** (max_features - 1 - i)\n",
    "                    total_loss = total_loss + weight * loss\n",
    "                    total_weight += weight\n",
    "                    \n",
    "                # Take gradient step.\n",
    "                total_loss = total_loss / total_weight\n",
    "                total_loss.backward()\n",
    "                if train_model:\n",
    "                    model_opt.step()\n",
    "                if train_selector:\n",
    "                    selector_opt.step()\n",
    "                model.zero_grad()\n",
    "                selector.zero_grad()\n",
    "                temp *= r\n",
    "                \n",
    "            # Calculate validation loss.\n",
    "            with torch.no_grad():\n",
    "                # For mean loss.\n",
    "                val_loss = 0\n",
    "                n = 0\n",
    "\n",
    "                for x, y in val_loader:\n",
    "                    # Move to device.\n",
    "                    x = x.to(device)\n",
    "                    y = y.to(device)\n",
    "\n",
    "                    # Setup.\n",
    "                    m_hard = torch.zeros(x.shape, dtype=x.dtype, device=device)\n",
    "                    total_loss = 0\n",
    "\n",
    "                    for i in range(max_features):\n",
    "                        # Evaluate selector model.\n",
    "                        x_masked = mask_layer(x, m_hard)\n",
    "                        logits = selector(x_masked)\n",
    "\n",
    "                        # Get selections.\n",
    "                        if no_repeats:\n",
    "                            logits = logits - 1e6 * m_hard\n",
    "                        if argmax:\n",
    "                            hard = make_onehot(logits)\n",
    "                        else:\n",
    "                            hard = make_onehot(selector_layer(logits, 1e-6))\n",
    "\n",
    "                        # Evaluate model.\n",
    "                        m_hard = torch.max(m_hard, hard)\n",
    "                        x_masked = mask_layer(x, m_hard)\n",
    "                        pred = model(x_masked)\n",
    "\n",
    "                        # Calculate loss.\n",
    "                        loss = val_loss_fn(pred, y).item()\n",
    "                        total_loss = total_loss + loss\n",
    "                        \n",
    "                    # Update mean loss.\n",
    "                    if validation_mode == 'final':\n",
    "                        batch_loss = loss\n",
    "                    elif validation_mode == 'weighted':\n",
    "                        batch_loss = total_loss / total_weight\n",
    "                    val_loss = (\n",
    "                        (batch_loss * len(x) + val_loss * n) / (len(x) + n))\n",
    "                    n += len(x)\n",
    "            \n",
    "            # Print progress.\n",
    "            print(f'{\"-\"*8}Epoch {epoch+1}{\"-\"*8}')\n",
    "            print(f'Val loss = {val_loss:.4f}\\n')\n",
    "\n",
    "            # See if best model.\n",
    "            if val_loss < best_loss:\n",
    "                best_loss = val_loss\n",
    "                best_model = deepcopy(model)\n",
    "                best_selector = deepcopy(selector)\n",
    "\n",
    "        # Restore best parameters.\n",
    "        if best_model:\n",
    "            restore_parameters(model, best_model)\n",
    "        if best_selector:\n",
    "            restore_parameters(selector, best_selector)\n",
    "\n",
    "    def forward(self, x, max_features, argmax=True, no_repeats=True):\n",
    "        '''\n",
    "        Make predictions using selected features.\n",
    "\n",
    "        Args:\n",
    "          x: input data (torch.Tensor).\n",
    "          max_features: max features to observe.\n",
    "          argmax: whether to select the next feature using the max probability.\n",
    "          no_repeats: whether to ensure that no repeated selections occur.\n",
    "        '''\n",
    "        # Setup.\n",
    "        selector = self.selector\n",
    "        model = self.model\n",
    "        mask_layer = self.mask_layer\n",
    "        selector_layer = self.selector_layer\n",
    "        device = next(model.parameters()).device\n",
    "        m_hard = torch.zeros(x.shape, dtype=x.dtype, device=device)\n",
    "\n",
    "        for i in range(max_features):\n",
    "            # Evaluate selector model.\n",
    "            x_masked = mask_layer(x, m_hard)\n",
    "            logits = selector(x_masked)\n",
    "\n",
    "            # Update selections.\n",
    "            if no_repeats:\n",
    "                logits = logits - 1e6 * m_hard\n",
    "            if argmax:\n",
    "                hard = make_onehot(logits)\n",
    "            else:\n",
    "                hard = make_onehot(selector_layer(logits, 1e-6))\n",
    "            m_hard = torch.max(m_hard, hard)\n",
    "\n",
    "        # Make predictions.\n",
    "        x_masked = mask_layer(x, m_hard)\n",
    "        pred = model(x_masked)\n",
    "        return pred\n",
    "\n",
    "    def evaluate(self, dataset, max_features, loss_fn, batch_size, argmax=True,\n",
    "                 no_repeats=True):\n",
    "        '''\n",
    "        Evaluate mean performance across a dataset.\n",
    "        '''\n",
    "        # Setup.\n",
    "        device = next(self.model.parameters()).device\n",
    "        loader = DataLoader(\n",
    "            dataset, batch_size=batch_size, shuffle=False, pin_memory=True,\n",
    "            drop_last=False, num_workers=4)\n",
    "\n",
    "        # For calculating mean loss.\n",
    "        mean_loss = 0\n",
    "        n = 0\n",
    "\n",
    "        with torch.no_grad():\n",
    "            for x, y in loader:\n",
    "                # Move to GPU.\n",
    "                x = x.to(device)\n",
    "                y = y.to(device)\n",
    "\n",
    "                # Calculate loss.\n",
    "                pred = self.forward(x, max_features, argmax, no_repeats)\n",
    "                loss = loss_fn(pred, y).item()\n",
    "\n",
    "                # Update average.\n",
    "                mean_loss = (mean_loss * n + loss * len(x)) / (n + len(x))\n",
    "                n += len(x)\n",
    "\n",
    "        return mean_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ongoing-romance",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.7081\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.7377\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.7284\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.7434\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.7414\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.7479\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.7546\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.7514\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.7591\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.7464\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.7592\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up selector\n",
    "selector = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_in))\n",
    "selector_layer = ConcreteSelector()\n",
    "afs = AdaptiveFS(selector, deepcopy(model), mask_layer, selector_layer).to(device)\n",
    "\n",
    "# Tie weights\n",
    "selector[0].weight = nn.Parameter(afs.model[0].weight)\n",
    "selector[2].weight = nn.Parameter(afs.model[2].weight)\n",
    "\n",
    "# Train\n",
    "afs.fit(train_dataset,\n",
    "        val_dataset,\n",
    "        mbsize=128,\n",
    "        lr=2e-4,\n",
    "        nepochs=250,\n",
    "        max_features=max_features,\n",
    "        loss_fn=nn.CrossEntropyLoss(),\n",
    "        val_loss_fn=NegAccuracy(),\n",
    "        weighting='linear',\n",
    "        start_temp=1.0,\n",
    "        end_temp=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "equipped-audit",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test acc = 83.86\n"
     ]
    }
   ],
   "source": [
    "# Linear version\n",
    "linear_acc = afs.evaluate(test_dataset, max_features, Accuracy(), 1024)\n",
    "print(f'Test acc = {100*linear_acc:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "treated-parameter",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Set up selector\n",
    "selector = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_in))\n",
    "selector_layer = ConcreteSelector()\n",
    "afs = AdaptiveFS(selector, deepcopy(model), mask_layer, selector_layer).to(device)\n",
    "\n",
    "# Tie weights\n",
    "selector[0].weight = nn.Parameter(afs.model[0].weight)\n",
    "selector[2].weight = nn.Parameter(afs.model[2].weight)\n",
    "\n",
    "# Train\n",
    "afs.fit(train_dataset,\n",
    "        val_dataset,\n",
    "        mbsize=128,\n",
    "        lr=2e-4,\n",
    "        nepochs=250,\n",
    "        max_features=max_features,\n",
    "        loss_fn=nn.CrossEntropyLoss(),\n",
    "        val_loss_fn=NegAccuracy(),\n",
    "        weighting='multiplicative',\n",
    "        beta=0.9,\n",
    "        start_temp=1.0,\n",
    "        end_temp=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "progressive-steel",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test acc = 85.13\n"
     ]
    }
   ],
   "source": [
    "# Multiplicative version\n",
    "mult_acc = afs.evaluate(test_dataset, max_features, Accuracy(), 1024)\n",
    "print(f'Test acc = {100*mult_acc:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "enhanced-greek",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.7023\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.7291\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.7152\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.7257\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.7371\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.7293\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.7552\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.7494\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.7294\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.7334\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.7544\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.7688\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.7526\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.7690\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.7748\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.7829\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.7712\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.7818\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.7786\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.7914\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = -0.7633\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = -0.7906\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = -0.7942\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = -0.7842\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = -0.7796\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = -0.7955\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = -0.7739\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = -0.8117\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = -0.7945\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = -0.7937\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = -0.8138\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = -0.7899\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = -0.8071\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = -0.8024\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = -0.7962\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = -0.8059\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = -0.7958\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = -0.8032\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = -0.8068\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = -0.8051\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = -0.8145\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = -0.8029\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = -0.8090\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = -0.8034\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = -0.7999\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = -0.8087\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = -0.8120\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = -0.8166\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = -0.8140\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = -0.8039\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = -0.8179\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = -0.8224\n",
      "\n",
      "--------Epoch 53--------\n",
      "Val loss = -0.8283\n",
      "\n",
      "--------Epoch 54--------\n",
      "Val loss = -0.8113\n",
      "\n",
      "--------Epoch 55--------\n",
      "Val loss = -0.8255\n",
      "\n",
      "--------Epoch 56--------\n",
      "Val loss = -0.8238\n",
      "\n",
      "--------Epoch 57--------\n",
      "Val loss = -0.8247\n",
      "\n",
      "--------Epoch 58--------\n",
      "Val loss = -0.8276\n",
      "\n",
      "--------Epoch 59--------\n",
      "Val loss = -0.8216\n",
      "\n",
      "--------Epoch 60--------\n",
      "Val loss = -0.8342\n",
      "\n",
      "--------Epoch 61--------\n",
      "Val loss = -0.8310\n",
      "\n",
      "--------Epoch 62--------\n",
      "Val loss = -0.8242\n",
      "\n",
      "--------Epoch 63--------\n",
      "Val loss = -0.8328\n",
      "\n",
      "--------Epoch 64--------\n",
      "Val loss = -0.8279\n",
      "\n",
      "--------Epoch 65--------\n",
      "Val loss = -0.8201\n",
      "\n",
      "--------Epoch 66--------\n",
      "Val loss = -0.8189\n",
      "\n",
      "--------Epoch 67--------\n",
      "Val loss = -0.8250\n",
      "\n",
      "--------Epoch 68--------\n",
      "Val loss = -0.8286\n",
      "\n",
      "--------Epoch 69--------\n",
      "Val loss = -0.8322\n",
      "\n",
      "--------Epoch 70--------\n",
      "Val loss = -0.8309\n",
      "\n",
      "--------Epoch 71--------\n",
      "Val loss = -0.8312\n",
      "\n",
      "--------Epoch 72--------\n",
      "Val loss = -0.8202\n",
      "\n",
      "--------Epoch 73--------\n",
      "Val loss = -0.8295\n",
      "\n",
      "--------Epoch 74--------\n",
      "Val loss = -0.8257\n",
      "\n",
      "--------Epoch 75--------\n",
      "Val loss = -0.8309\n",
      "\n",
      "--------Epoch 76--------\n",
      "Val loss = -0.8303\n",
      "\n",
      "--------Epoch 77--------\n",
      "Val loss = -0.8318\n",
      "\n",
      "--------Epoch 78--------\n",
      "Val loss = -0.8276\n",
      "\n",
      "--------Epoch 79--------\n",
      "Val loss = -0.8332\n",
      "\n",
      "--------Epoch 80--------\n",
      "Val loss = -0.8324\n",
      "\n",
      "--------Epoch 81--------\n",
      "Val loss = -0.8266\n",
      "\n",
      "--------Epoch 82--------\n",
      "Val loss = -0.8346\n",
      "\n",
      "--------Epoch 83--------\n",
      "Val loss = -0.8365\n",
      "\n",
      "--------Epoch 84--------\n",
      "Val loss = -0.8201\n",
      "\n",
      "--------Epoch 85--------\n",
      "Val loss = -0.8223\n",
      "\n",
      "--------Epoch 86--------\n",
      "Val loss = -0.8228\n",
      "\n",
      "--------Epoch 87--------\n",
      "Val loss = -0.8279\n",
      "\n",
      "--------Epoch 88--------\n",
      "Val loss = -0.8265\n",
      "\n",
      "--------Epoch 89--------\n",
      "Val loss = -0.8310\n",
      "\n",
      "--------Epoch 90--------\n",
      "Val loss = -0.8264\n",
      "\n",
      "--------Epoch 91--------\n",
      "Val loss = -0.8255\n",
      "\n",
      "--------Epoch 92--------\n",
      "Val loss = -0.8283\n",
      "\n",
      "--------Epoch 93--------\n",
      "Val loss = -0.8312\n",
      "\n",
      "--------Epoch 94--------\n",
      "Val loss = -0.8277\n",
      "\n",
      "--------Epoch 95--------\n",
      "Val loss = -0.8350\n",
      "\n",
      "--------Epoch 96--------\n",
      "Val loss = -0.8235\n",
      "\n",
      "--------Epoch 97--------\n",
      "Val loss = -0.8260\n",
      "\n",
      "--------Epoch 98--------\n",
      "Val loss = -0.8291\n",
      "\n",
      "--------Epoch 99--------\n",
      "Val loss = -0.8301\n",
      "\n",
      "--------Epoch 100--------\n",
      "Val loss = -0.8302\n",
      "\n",
      "--------Epoch 101--------\n",
      "Val loss = -0.8189\n",
      "\n",
      "--------Epoch 102--------\n",
      "Val loss = -0.8181\n",
      "\n",
      "--------Epoch 103--------\n",
      "Val loss = -0.8299\n",
      "\n",
      "--------Epoch 104--------\n",
      "Val loss = -0.8360\n",
      "\n",
      "--------Epoch 105--------\n",
      "Val loss = -0.8332\n",
      "\n",
      "--------Epoch 106--------\n",
      "Val loss = -0.8367\n",
      "\n",
      "--------Epoch 107--------\n",
      "Val loss = -0.8390\n",
      "\n",
      "--------Epoch 108--------\n",
      "Val loss = -0.8350\n",
      "\n",
      "--------Epoch 109--------\n",
      "Val loss = -0.8403\n",
      "\n",
      "--------Epoch 110--------\n",
      "Val loss = -0.8429\n",
      "\n",
      "--------Epoch 111--------\n",
      "Val loss = -0.8371\n",
      "\n",
      "--------Epoch 112--------\n",
      "Val loss = -0.8429\n",
      "\n",
      "--------Epoch 113--------\n",
      "Val loss = -0.8278\n",
      "\n",
      "--------Epoch 114--------\n",
      "Val loss = -0.8363\n",
      "\n",
      "--------Epoch 115--------\n",
      "Val loss = -0.8238\n",
      "\n",
      "--------Epoch 116--------\n",
      "Val loss = -0.8327\n",
      "\n",
      "--------Epoch 117--------\n",
      "Val loss = -0.8256\n",
      "\n",
      "--------Epoch 118--------\n",
      "Val loss = -0.8204\n",
      "\n",
      "--------Epoch 119--------\n",
      "Val loss = -0.8233\n",
      "\n",
      "--------Epoch 120--------\n",
      "Val loss = -0.8175\n",
      "\n",
      "--------Epoch 121--------\n",
      "Val loss = -0.8235\n",
      "\n",
      "--------Epoch 122--------\n",
      "Val loss = -0.8281\n",
      "\n",
      "--------Epoch 123--------\n",
      "Val loss = -0.8328\n",
      "\n",
      "--------Epoch 124--------\n",
      "Val loss = -0.8315\n",
      "\n",
      "--------Epoch 125--------\n",
      "Val loss = -0.8268\n",
      "\n",
      "--------Epoch 126--------\n",
      "Val loss = -0.8267\n",
      "\n",
      "--------Epoch 127--------\n",
      "Val loss = -0.8232\n",
      "\n",
      "--------Epoch 128--------\n",
      "Val loss = -0.8257\n",
      "\n",
      "--------Epoch 129--------\n",
      "Val loss = -0.8257\n",
      "\n",
      "--------Epoch 130--------\n",
      "Val loss = -0.8298\n",
      "\n",
      "--------Epoch 131--------\n",
      "Val loss = -0.8285\n",
      "\n",
      "--------Epoch 132--------\n",
      "Val loss = -0.8273\n",
      "\n",
      "--------Epoch 133--------\n",
      "Val loss = -0.8211\n",
      "\n",
      "--------Epoch 134--------\n",
      "Val loss = -0.8208\n",
      "\n",
      "--------Epoch 135--------\n",
      "Val loss = -0.8232\n",
      "\n",
      "--------Epoch 136--------\n",
      "Val loss = -0.8224\n",
      "\n",
      "--------Epoch 137--------\n",
      "Val loss = -0.8274\n",
      "\n",
      "--------Epoch 138--------\n",
      "Val loss = -0.8199\n",
      "\n",
      "--------Epoch 139--------\n",
      "Val loss = -0.8181\n",
      "\n",
      "--------Epoch 140--------\n",
      "Val loss = -0.8154\n",
      "\n",
      "--------Epoch 141--------\n",
      "Val loss = -0.8074\n",
      "\n",
      "--------Epoch 142--------\n",
      "Val loss = -0.8198\n",
      "\n",
      "--------Epoch 143--------\n",
      "Val loss = -0.8190\n",
      "\n",
      "--------Epoch 144--------\n",
      "Val loss = -0.8127\n",
      "\n",
      "--------Epoch 145--------\n",
      "Val loss = -0.8175\n",
      "\n",
      "--------Epoch 146--------\n",
      "Val loss = -0.8199\n",
      "\n",
      "--------Epoch 147--------\n",
      "Val loss = -0.8142\n",
      "\n",
      "--------Epoch 148--------\n",
      "Val loss = -0.8146\n",
      "\n",
      "--------Epoch 149--------\n",
      "Val loss = -0.8192\n",
      "\n",
      "--------Epoch 150--------\n",
      "Val loss = -0.8118\n",
      "\n",
      "--------Epoch 151--------\n",
      "Val loss = -0.8190\n",
      "\n",
      "--------Epoch 152--------\n",
      "Val loss = -0.8207\n",
      "\n",
      "--------Epoch 153--------\n",
      "Val loss = -0.8233\n",
      "\n",
      "--------Epoch 154--------\n",
      "Val loss = -0.8180\n",
      "\n",
      "--------Epoch 155--------\n",
      "Val loss = -0.8217\n",
      "\n",
      "--------Epoch 156--------\n",
      "Val loss = -0.8278\n",
      "\n",
      "--------Epoch 157--------\n",
      "Val loss = -0.8219\n",
      "\n",
      "--------Epoch 158--------\n",
      "Val loss = -0.8245\n",
      "\n",
      "--------Epoch 159--------\n",
      "Val loss = -0.8211\n",
      "\n",
      "--------Epoch 160--------\n",
      "Val loss = -0.8326\n",
      "\n",
      "--------Epoch 161--------\n",
      "Val loss = -0.8292\n",
      "\n",
      "--------Epoch 162--------\n",
      "Val loss = -0.8307\n",
      "\n",
      "--------Epoch 163--------\n",
      "Val loss = -0.8255\n",
      "\n",
      "--------Epoch 164--------\n",
      "Val loss = -0.8310\n",
      "\n",
      "--------Epoch 165--------\n",
      "Val loss = -0.8240\n",
      "\n",
      "--------Epoch 166--------\n",
      "Val loss = -0.8224\n",
      "\n",
      "--------Epoch 167--------\n",
      "Val loss = -0.8035\n",
      "\n",
      "--------Epoch 168--------\n",
      "Val loss = -0.7974\n",
      "\n",
      "--------Epoch 169--------\n",
      "Val loss = -0.8076\n",
      "\n",
      "--------Epoch 170--------\n",
      "Val loss = -0.8138\n",
      "\n",
      "--------Epoch 171--------\n",
      "Val loss = -0.8199\n",
      "\n",
      "--------Epoch 172--------\n",
      "Val loss = -0.8261\n",
      "\n",
      "--------Epoch 173--------\n",
      "Val loss = -0.8069\n",
      "\n",
      "--------Epoch 174--------\n",
      "Val loss = -0.8154\n",
      "\n",
      "--------Epoch 175--------\n",
      "Val loss = -0.8157\n",
      "\n",
      "--------Epoch 176--------\n",
      "Val loss = -0.8146\n",
      "\n",
      "--------Epoch 177--------\n",
      "Val loss = -0.8106\n",
      "\n",
      "--------Epoch 178--------\n",
      "Val loss = -0.8129\n",
      "\n",
      "--------Epoch 179--------\n",
      "Val loss = -0.8142\n",
      "\n",
      "--------Epoch 180--------\n",
      "Val loss = -0.8146\n",
      "\n",
      "--------Epoch 181--------\n",
      "Val loss = -0.8124\n",
      "\n",
      "--------Epoch 182--------\n",
      "Val loss = -0.8206\n",
      "\n",
      "--------Epoch 183--------\n",
      "Val loss = -0.8238\n",
      "\n",
      "--------Epoch 184--------\n",
      "Val loss = -0.8228\n",
      "\n",
      "--------Epoch 185--------\n",
      "Val loss = -0.8226\n",
      "\n",
      "--------Epoch 186--------\n",
      "Val loss = -0.8205\n",
      "\n",
      "--------Epoch 187--------\n",
      "Val loss = -0.8210\n",
      "\n",
      "--------Epoch 188--------\n",
      "Val loss = -0.8232\n",
      "\n",
      "--------Epoch 189--------\n",
      "Val loss = -0.8253\n",
      "\n",
      "--------Epoch 190--------\n",
      "Val loss = -0.8220\n",
      "\n",
      "--------Epoch 191--------\n",
      "Val loss = -0.8326\n",
      "\n",
      "--------Epoch 192--------\n",
      "Val loss = -0.8317\n",
      "\n",
      "--------Epoch 193--------\n",
      "Val loss = -0.8342\n",
      "\n",
      "--------Epoch 194--------\n",
      "Val loss = -0.8182\n",
      "\n",
      "--------Epoch 195--------\n",
      "Val loss = -0.8171\n",
      "\n",
      "--------Epoch 196--------\n",
      "Val loss = -0.8178\n",
      "\n",
      "--------Epoch 197--------\n",
      "Val loss = -0.8186\n",
      "\n",
      "--------Epoch 198--------\n",
      "Val loss = -0.8178\n",
      "\n",
      "--------Epoch 199--------\n",
      "Val loss = -0.8120\n",
      "\n",
      "--------Epoch 200--------\n",
      "Val loss = -0.8168\n",
      "\n",
      "--------Epoch 201--------\n",
      "Val loss = -0.8119\n",
      "\n",
      "--------Epoch 202--------\n",
      "Val loss = -0.8173\n",
      "\n",
      "--------Epoch 203--------\n",
      "Val loss = -0.8194\n",
      "\n",
      "--------Epoch 204--------\n",
      "Val loss = -0.8169\n",
      "\n",
      "--------Epoch 205--------\n",
      "Val loss = -0.8147\n",
      "\n",
      "--------Epoch 206--------\n",
      "Val loss = -0.8156\n",
      "\n",
      "--------Epoch 207--------\n",
      "Val loss = -0.8156\n",
      "\n",
      "--------Epoch 208--------\n",
      "Val loss = -0.8177\n",
      "\n",
      "--------Epoch 209--------\n",
      "Val loss = -0.8231\n",
      "\n",
      "--------Epoch 210--------\n",
      "Val loss = -0.8242\n",
      "\n",
      "--------Epoch 211--------\n",
      "Val loss = -0.8225\n",
      "\n",
      "--------Epoch 212--------\n",
      "Val loss = -0.8200\n",
      "\n",
      "--------Epoch 213--------\n",
      "Val loss = -0.8061\n",
      "\n",
      "--------Epoch 214--------\n",
      "Val loss = -0.8160\n",
      "\n",
      "--------Epoch 215--------\n",
      "Val loss = -0.8237\n",
      "\n",
      "--------Epoch 216--------\n",
      "Val loss = -0.8128\n",
      "\n",
      "--------Epoch 217--------\n",
      "Val loss = -0.8076\n",
      "\n",
      "--------Epoch 218--------\n",
      "Val loss = -0.8096\n",
      "\n",
      "--------Epoch 219--------\n",
      "Val loss = -0.8095\n",
      "\n",
      "--------Epoch 220--------\n",
      "Val loss = -0.8202\n",
      "\n",
      "--------Epoch 221--------\n",
      "Val loss = -0.8201\n",
      "\n",
      "--------Epoch 222--------\n",
      "Val loss = -0.8202\n",
      "\n",
      "--------Epoch 223--------\n",
      "Val loss = -0.8121\n",
      "\n",
      "--------Epoch 224--------\n",
      "Val loss = -0.8103\n",
      "\n",
      "--------Epoch 225--------\n",
      "Val loss = -0.8060\n",
      "\n",
      "--------Epoch 226--------\n",
      "Val loss = -0.8103\n",
      "\n",
      "--------Epoch 227--------\n",
      "Val loss = -0.8151\n",
      "\n",
      "--------Epoch 228--------\n",
      "Val loss = -0.8178\n",
      "\n",
      "--------Epoch 229--------\n",
      "Val loss = -0.8159\n",
      "\n",
      "--------Epoch 230--------\n",
      "Val loss = -0.8102\n",
      "\n",
      "--------Epoch 231--------\n",
      "Val loss = -0.8018\n",
      "\n",
      "--------Epoch 232--------\n",
      "Val loss = -0.8014\n",
      "\n",
      "--------Epoch 233--------\n",
      "Val loss = -0.8043\n",
      "\n",
      "--------Epoch 234--------\n",
      "Val loss = -0.8017\n",
      "\n",
      "--------Epoch 235--------\n",
      "Val loss = -0.8042\n",
      "\n",
      "--------Epoch 236--------\n",
      "Val loss = -0.8058\n",
      "\n",
      "--------Epoch 237--------\n",
      "Val loss = -0.8062\n",
      "\n",
      "--------Epoch 238--------\n",
      "Val loss = -0.8121\n",
      "\n",
      "--------Epoch 239--------\n",
      "Val loss = -0.8147\n",
      "\n",
      "--------Epoch 240--------\n",
      "Val loss = -0.8139\n",
      "\n",
      "--------Epoch 241--------\n",
      "Val loss = -0.8184\n",
      "\n",
      "--------Epoch 242--------\n",
      "Val loss = -0.8162\n",
      "\n",
      "--------Epoch 243--------\n",
      "Val loss = -0.8101\n",
      "\n",
      "--------Epoch 244--------\n",
      "Val loss = -0.8172\n",
      "\n",
      "--------Epoch 245--------\n",
      "Val loss = -0.8141\n",
      "\n",
      "--------Epoch 246--------\n",
      "Val loss = -0.8098\n",
      "\n",
      "--------Epoch 247--------\n",
      "Val loss = -0.8139\n",
      "\n",
      "--------Epoch 248--------\n",
      "Val loss = -0.8201\n",
      "\n",
      "--------Epoch 249--------\n",
      "Val loss = -0.8199\n",
      "\n",
      "--------Epoch 250--------\n",
      "Val loss = -0.8198\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up selector\n",
    "selector = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_in))\n",
    "selector_layer = ConcreteSelector()\n",
    "afs = AdaptiveFS(selector, deepcopy(model), mask_layer, selector_layer).to(device)\n",
    "\n",
    "# Tie weights\n",
    "selector[0].weight = nn.Parameter(afs.model[0].weight)\n",
    "selector[2].weight = nn.Parameter(afs.model[2].weight)\n",
    "\n",
    "# Train\n",
    "afs.fit(train_dataset,\n",
    "        val_dataset,\n",
    "        mbsize=128,\n",
    "        lr=2e-4,\n",
    "        nepochs=250,\n",
    "        max_features=max_features,\n",
    "        loss_fn=nn.CrossEntropyLoss(),\n",
    "        val_loss_fn=NegAccuracy(),\n",
    "        weighting='multiplicative',\n",
    "        beta=1.0,\n",
    "        start_temp=1.0,\n",
    "        end_temp=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "million-belly",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test acc = 84.53\n"
     ]
    }
   ],
   "source": [
    "# Equal weight version\n",
    "equal_acc = afs.evaluate(test_dataset, max_features, Accuracy(), 1024)\n",
    "print(f'Test acc = {100*equal_acc:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "expired-technique",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.7116\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.7099\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.7221\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.7212\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.7541\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.7660\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.7605\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.7589\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.7485\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.7666\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.7724\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.7529\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.7444\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.7644\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.7708\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.7830\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.7972\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.8057\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.7839\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.7862\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = -0.7714\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = -0.7724\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = -0.7918\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = -0.7973\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = -0.7934\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = -0.8040\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = -0.7756\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = -0.8089\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = -0.8071\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = -0.7902\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = -0.8039\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = -0.7956\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = -0.7882\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = -0.8078\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = -0.8035\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = -0.7839\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = -0.8147\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = -0.8141\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = -0.7988\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = -0.8064\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = -0.8160\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = -0.8123\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = -0.8148\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = -0.8122\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = -0.8167\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = -0.7992\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = -0.8080\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = -0.8099\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = -0.8156\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = -0.8029\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = -0.8036\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = -0.8235\n",
      "\n",
      "--------Epoch 53--------\n",
      "Val loss = -0.8114\n",
      "\n",
      "--------Epoch 54--------\n",
      "Val loss = -0.8197\n",
      "\n",
      "--------Epoch 55--------\n",
      "Val loss = -0.8220\n",
      "\n",
      "--------Epoch 56--------\n",
      "Val loss = -0.8214\n",
      "\n",
      "--------Epoch 57--------\n",
      "Val loss = -0.8170\n",
      "\n",
      "--------Epoch 58--------\n",
      "Val loss = -0.8324\n",
      "\n",
      "--------Epoch 59--------\n",
      "Val loss = -0.8274\n",
      "\n",
      "--------Epoch 60--------\n",
      "Val loss = -0.8199\n",
      "\n",
      "--------Epoch 61--------\n",
      "Val loss = -0.8022\n",
      "\n",
      "--------Epoch 62--------\n",
      "Val loss = -0.8022\n",
      "\n",
      "--------Epoch 63--------\n",
      "Val loss = -0.8143\n",
      "\n",
      "--------Epoch 64--------\n",
      "Val loss = -0.8004\n",
      "\n",
      "--------Epoch 65--------\n",
      "Val loss = -0.7976\n",
      "\n",
      "--------Epoch 66--------\n",
      "Val loss = -0.8062\n",
      "\n",
      "--------Epoch 67--------\n",
      "Val loss = -0.7932\n",
      "\n",
      "--------Epoch 68--------\n",
      "Val loss = -0.8073\n",
      "\n",
      "--------Epoch 69--------\n",
      "Val loss = -0.8001\n",
      "\n",
      "--------Epoch 70--------\n",
      "Val loss = -0.7880\n",
      "\n",
      "--------Epoch 71--------\n",
      "Val loss = -0.8051\n",
      "\n",
      "--------Epoch 72--------\n",
      "Val loss = -0.8016\n",
      "\n",
      "--------Epoch 73--------\n",
      "Val loss = -0.8051\n",
      "\n",
      "--------Epoch 74--------\n",
      "Val loss = -0.8129\n",
      "\n",
      "--------Epoch 75--------\n",
      "Val loss = -0.8022\n",
      "\n",
      "--------Epoch 76--------\n",
      "Val loss = -0.7953\n",
      "\n",
      "--------Epoch 77--------\n",
      "Val loss = -0.7987\n",
      "\n",
      "--------Epoch 78--------\n",
      "Val loss = -0.8066\n",
      "\n",
      "--------Epoch 79--------\n",
      "Val loss = -0.8059\n",
      "\n",
      "--------Epoch 80--------\n",
      "Val loss = -0.8024\n",
      "\n",
      "--------Epoch 81--------\n",
      "Val loss = -0.8137\n",
      "\n",
      "--------Epoch 82--------\n",
      "Val loss = -0.7934\n",
      "\n",
      "--------Epoch 83--------\n",
      "Val loss = -0.7865\n",
      "\n",
      "--------Epoch 84--------\n",
      "Val loss = -0.8205\n",
      "\n",
      "--------Epoch 85--------\n",
      "Val loss = -0.7959\n",
      "\n",
      "--------Epoch 86--------\n",
      "Val loss = -0.8008\n",
      "\n",
      "--------Epoch 87--------\n",
      "Val loss = -0.8014\n",
      "\n",
      "--------Epoch 88--------\n",
      "Val loss = -0.8027\n",
      "\n",
      "--------Epoch 89--------\n",
      "Val loss = -0.8168\n",
      "\n",
      "--------Epoch 90--------\n",
      "Val loss = -0.8127\n",
      "\n",
      "--------Epoch 91--------\n",
      "Val loss = -0.8174\n",
      "\n",
      "--------Epoch 92--------\n",
      "Val loss = -0.8147\n",
      "\n",
      "--------Epoch 93--------\n",
      "Val loss = -0.8093\n",
      "\n",
      "--------Epoch 94--------\n",
      "Val loss = -0.8088\n",
      "\n",
      "--------Epoch 95--------\n",
      "Val loss = -0.8089\n",
      "\n",
      "--------Epoch 96--------\n",
      "Val loss = -0.8177\n",
      "\n",
      "--------Epoch 97--------\n",
      "Val loss = -0.8071\n",
      "\n",
      "--------Epoch 98--------\n",
      "Val loss = -0.8170\n",
      "\n",
      "--------Epoch 99--------\n",
      "Val loss = -0.8092\n",
      "\n",
      "--------Epoch 100--------\n",
      "Val loss = -0.8173\n",
      "\n",
      "--------Epoch 101--------\n",
      "Val loss = -0.8189\n",
      "\n",
      "--------Epoch 102--------\n",
      "Val loss = -0.8290\n",
      "\n",
      "--------Epoch 103--------\n",
      "Val loss = -0.8221\n",
      "\n",
      "--------Epoch 104--------\n",
      "Val loss = -0.8172\n",
      "\n",
      "--------Epoch 105--------\n",
      "Val loss = -0.8018\n",
      "\n",
      "--------Epoch 106--------\n",
      "Val loss = -0.8056\n",
      "\n",
      "--------Epoch 107--------\n",
      "Val loss = -0.7962\n",
      "\n",
      "--------Epoch 108--------\n",
      "Val loss = -0.8059\n",
      "\n",
      "--------Epoch 109--------\n",
      "Val loss = -0.8147\n",
      "\n",
      "--------Epoch 110--------\n",
      "Val loss = -0.8132\n",
      "\n",
      "--------Epoch 111--------\n",
      "Val loss = -0.8018\n",
      "\n",
      "--------Epoch 112--------\n",
      "Val loss = -0.8135\n",
      "\n",
      "--------Epoch 113--------\n",
      "Val loss = -0.8128\n",
      "\n",
      "--------Epoch 114--------\n",
      "Val loss = -0.8134\n",
      "\n",
      "--------Epoch 115--------\n",
      "Val loss = -0.8158\n",
      "\n",
      "--------Epoch 116--------\n",
      "Val loss = -0.8125\n",
      "\n",
      "--------Epoch 117--------\n",
      "Val loss = -0.8091\n",
      "\n",
      "--------Epoch 118--------\n",
      "Val loss = -0.8053\n",
      "\n",
      "--------Epoch 119--------\n",
      "Val loss = -0.8012\n",
      "\n",
      "--------Epoch 120--------\n",
      "Val loss = -0.7987\n",
      "\n",
      "--------Epoch 121--------\n",
      "Val loss = -0.7961\n",
      "\n",
      "--------Epoch 122--------\n",
      "Val loss = -0.8085\n",
      "\n",
      "--------Epoch 123--------\n",
      "Val loss = -0.8086\n",
      "\n",
      "--------Epoch 124--------\n",
      "Val loss = -0.8066\n",
      "\n",
      "--------Epoch 125--------\n",
      "Val loss = -0.8052\n",
      "\n",
      "--------Epoch 126--------\n",
      "Val loss = -0.8037\n",
      "\n",
      "--------Epoch 127--------\n",
      "Val loss = -0.8149\n",
      "\n",
      "--------Epoch 128--------\n",
      "Val loss = -0.8088\n",
      "\n",
      "--------Epoch 129--------\n",
      "Val loss = -0.8015\n",
      "\n",
      "--------Epoch 130--------\n",
      "Val loss = -0.8004\n",
      "\n",
      "--------Epoch 131--------\n",
      "Val loss = -0.7994\n",
      "\n",
      "--------Epoch 132--------\n",
      "Val loss = -0.8039\n",
      "\n",
      "--------Epoch 133--------\n",
      "Val loss = -0.7988\n",
      "\n",
      "--------Epoch 134--------\n",
      "Val loss = -0.7904\n",
      "\n",
      "--------Epoch 135--------\n",
      "Val loss = -0.7977\n",
      "\n",
      "--------Epoch 136--------\n",
      "Val loss = -0.8025\n",
      "\n",
      "--------Epoch 137--------\n",
      "Val loss = -0.8061\n",
      "\n",
      "--------Epoch 138--------\n",
      "Val loss = -0.8095\n",
      "\n",
      "--------Epoch 139--------\n",
      "Val loss = -0.8121\n",
      "\n",
      "--------Epoch 140--------\n",
      "Val loss = -0.8086\n",
      "\n",
      "--------Epoch 141--------\n",
      "Val loss = -0.8043\n",
      "\n",
      "--------Epoch 142--------\n",
      "Val loss = -0.8034\n",
      "\n",
      "--------Epoch 143--------\n",
      "Val loss = -0.8024\n",
      "\n",
      "--------Epoch 144--------\n",
      "Val loss = -0.8017\n",
      "\n",
      "--------Epoch 145--------\n",
      "Val loss = -0.8090\n",
      "\n",
      "--------Epoch 146--------\n",
      "Val loss = -0.8172\n",
      "\n",
      "--------Epoch 147--------\n",
      "Val loss = -0.8167\n",
      "\n",
      "--------Epoch 148--------\n",
      "Val loss = -0.8161\n",
      "\n",
      "--------Epoch 149--------\n",
      "Val loss = -0.8146\n",
      "\n",
      "--------Epoch 150--------\n",
      "Val loss = -0.8107\n",
      "\n",
      "--------Epoch 151--------\n",
      "Val loss = -0.8129\n",
      "\n",
      "--------Epoch 152--------\n",
      "Val loss = -0.8075\n",
      "\n",
      "--------Epoch 153--------\n",
      "Val loss = -0.8121\n",
      "\n",
      "--------Epoch 154--------\n",
      "Val loss = -0.8119\n",
      "\n",
      "--------Epoch 155--------\n",
      "Val loss = -0.8118\n",
      "\n",
      "--------Epoch 156--------\n",
      "Val loss = -0.8129\n",
      "\n",
      "--------Epoch 157--------\n",
      "Val loss = -0.8056\n",
      "\n",
      "--------Epoch 158--------\n",
      "Val loss = -0.8062\n",
      "\n",
      "--------Epoch 159--------\n",
      "Val loss = -0.8096\n",
      "\n",
      "--------Epoch 160--------\n",
      "Val loss = -0.8087\n",
      "\n",
      "--------Epoch 161--------\n",
      "Val loss = -0.8099\n",
      "\n",
      "--------Epoch 162--------\n",
      "Val loss = -0.8207\n",
      "\n",
      "--------Epoch 163--------\n",
      "Val loss = -0.8191\n",
      "\n",
      "--------Epoch 164--------\n",
      "Val loss = -0.8260\n",
      "\n",
      "--------Epoch 165--------\n",
      "Val loss = -0.8147\n",
      "\n",
      "--------Epoch 166--------\n",
      "Val loss = -0.8241\n",
      "\n",
      "--------Epoch 167--------\n",
      "Val loss = -0.8162\n",
      "\n",
      "--------Epoch 168--------\n",
      "Val loss = -0.8184\n",
      "\n",
      "--------Epoch 169--------\n",
      "Val loss = -0.8154\n",
      "\n",
      "--------Epoch 170--------\n",
      "Val loss = -0.8170\n",
      "\n",
      "--------Epoch 171--------\n",
      "Val loss = -0.8163\n",
      "\n",
      "--------Epoch 172--------\n",
      "Val loss = -0.8156\n",
      "\n",
      "--------Epoch 173--------\n",
      "Val loss = -0.8171\n",
      "\n",
      "--------Epoch 174--------\n",
      "Val loss = -0.8114\n",
      "\n",
      "--------Epoch 175--------\n",
      "Val loss = -0.8124\n",
      "\n",
      "--------Epoch 176--------\n",
      "Val loss = -0.8037\n",
      "\n",
      "--------Epoch 177--------\n",
      "Val loss = -0.8039\n",
      "\n",
      "--------Epoch 178--------\n",
      "Val loss = -0.8059\n",
      "\n",
      "--------Epoch 179--------\n",
      "Val loss = -0.8042\n",
      "\n",
      "--------Epoch 180--------\n",
      "Val loss = -0.8066\n",
      "\n",
      "--------Epoch 181--------\n",
      "Val loss = -0.8033\n",
      "\n",
      "--------Epoch 182--------\n",
      "Val loss = -0.8061\n",
      "\n",
      "--------Epoch 183--------\n",
      "Val loss = -0.8038\n",
      "\n",
      "--------Epoch 184--------\n",
      "Val loss = -0.8053\n",
      "\n",
      "--------Epoch 185--------\n",
      "Val loss = -0.8018\n",
      "\n",
      "--------Epoch 186--------\n",
      "Val loss = -0.8089\n",
      "\n",
      "--------Epoch 187--------\n",
      "Val loss = -0.8099\n",
      "\n",
      "--------Epoch 188--------\n",
      "Val loss = -0.8140\n",
      "\n",
      "--------Epoch 189--------\n",
      "Val loss = -0.8157\n",
      "\n",
      "--------Epoch 190--------\n",
      "Val loss = -0.8149\n",
      "\n",
      "--------Epoch 191--------\n",
      "Val loss = -0.8183\n",
      "\n",
      "--------Epoch 192--------\n",
      "Val loss = -0.8144\n",
      "\n",
      "--------Epoch 193--------\n",
      "Val loss = -0.8167\n",
      "\n",
      "--------Epoch 194--------\n",
      "Val loss = -0.8117\n",
      "\n",
      "--------Epoch 195--------\n",
      "Val loss = -0.8156\n",
      "\n",
      "--------Epoch 196--------\n",
      "Val loss = -0.8149\n",
      "\n",
      "--------Epoch 197--------\n",
      "Val loss = -0.8161\n",
      "\n",
      "--------Epoch 198--------\n",
      "Val loss = -0.8188\n",
      "\n",
      "--------Epoch 199--------\n",
      "Val loss = -0.8242\n",
      "\n",
      "--------Epoch 200--------\n",
      "Val loss = -0.8264\n",
      "\n",
      "--------Epoch 201--------\n",
      "Val loss = -0.8189\n",
      "\n",
      "--------Epoch 202--------\n",
      "Val loss = -0.8194\n",
      "\n",
      "--------Epoch 203--------\n",
      "Val loss = -0.8213\n",
      "\n",
      "--------Epoch 204--------\n",
      "Val loss = -0.8147\n",
      "\n",
      "--------Epoch 205--------\n",
      "Val loss = -0.8179\n",
      "\n",
      "--------Epoch 206--------\n",
      "Val loss = -0.8179\n",
      "\n",
      "--------Epoch 207--------\n",
      "Val loss = -0.8174\n",
      "\n",
      "--------Epoch 208--------\n",
      "Val loss = -0.8169\n",
      "\n",
      "--------Epoch 209--------\n",
      "Val loss = -0.8175\n",
      "\n",
      "--------Epoch 210--------\n",
      "Val loss = -0.8187\n",
      "\n",
      "--------Epoch 211--------\n",
      "Val loss = -0.8209\n",
      "\n",
      "--------Epoch 212--------\n",
      "Val loss = -0.8184\n",
      "\n",
      "--------Epoch 213--------\n",
      "Val loss = -0.8287\n",
      "\n",
      "--------Epoch 214--------\n",
      "Val loss = -0.8268\n",
      "\n",
      "--------Epoch 215--------\n",
      "Val loss = -0.8262\n",
      "\n",
      "--------Epoch 216--------\n",
      "Val loss = -0.8252\n",
      "\n",
      "--------Epoch 217--------\n",
      "Val loss = -0.8247\n",
      "\n",
      "--------Epoch 218--------\n",
      "Val loss = -0.8259\n",
      "\n",
      "--------Epoch 219--------\n",
      "Val loss = -0.8229\n",
      "\n",
      "--------Epoch 220--------\n",
      "Val loss = -0.8228\n",
      "\n",
      "--------Epoch 221--------\n",
      "Val loss = -0.8255\n",
      "\n",
      "--------Epoch 222--------\n",
      "Val loss = -0.8272\n",
      "\n",
      "--------Epoch 223--------\n",
      "Val loss = -0.8213\n",
      "\n",
      "--------Epoch 224--------\n",
      "Val loss = -0.7965\n",
      "\n",
      "--------Epoch 225--------\n",
      "Val loss = -0.8170\n",
      "\n",
      "--------Epoch 226--------\n",
      "Val loss = -0.8175\n",
      "\n",
      "--------Epoch 227--------\n",
      "Val loss = -0.8010\n",
      "\n",
      "--------Epoch 228--------\n",
      "Val loss = -0.7963\n",
      "\n",
      "--------Epoch 229--------\n",
      "Val loss = -0.7987\n",
      "\n",
      "--------Epoch 230--------\n",
      "Val loss = -0.8001\n",
      "\n",
      "--------Epoch 231--------\n",
      "Val loss = -0.8004\n",
      "\n",
      "--------Epoch 232--------\n",
      "Val loss = -0.7918\n",
      "\n",
      "--------Epoch 233--------\n",
      "Val loss = -0.7914\n",
      "\n",
      "--------Epoch 234--------\n",
      "Val loss = -0.8000\n",
      "\n",
      "--------Epoch 235--------\n",
      "Val loss = -0.7856\n",
      "\n",
      "--------Epoch 236--------\n",
      "Val loss = -0.7916\n",
      "\n",
      "--------Epoch 237--------\n",
      "Val loss = -0.7929\n",
      "\n",
      "--------Epoch 238--------\n",
      "Val loss = -0.7956\n",
      "\n",
      "--------Epoch 239--------\n",
      "Val loss = -0.7927\n",
      "\n",
      "--------Epoch 240--------\n",
      "Val loss = -0.7905\n",
      "\n",
      "--------Epoch 241--------\n",
      "Val loss = -0.7872\n",
      "\n",
      "--------Epoch 242--------\n",
      "Val loss = -0.7824\n",
      "\n",
      "--------Epoch 243--------\n",
      "Val loss = -0.7889\n",
      "\n",
      "--------Epoch 244--------\n",
      "Val loss = -0.7941\n",
      "\n",
      "--------Epoch 245--------\n",
      "Val loss = -0.7900\n",
      "\n",
      "--------Epoch 246--------\n",
      "Val loss = -0.7867\n",
      "\n",
      "--------Epoch 247--------\n",
      "Val loss = -0.7853\n",
      "\n",
      "--------Epoch 248--------\n",
      "Val loss = -0.7869\n",
      "\n",
      "--------Epoch 249--------\n",
      "Val loss = -0.7901\n",
      "\n",
      "--------Epoch 250--------\n",
      "Val loss = -0.7880\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up selector\n",
    "selector = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_in))\n",
    "selector_layer = ConcreteSelector()\n",
    "afs = AdaptiveFS(selector, deepcopy(model), mask_layer, selector_layer).to(device)\n",
    "\n",
    "# Tie weights\n",
    "selector[0].weight = nn.Parameter(afs.model[0].weight)\n",
    "selector[2].weight = nn.Parameter(afs.model[2].weight)\n",
    "\n",
    "# Train\n",
    "afs.fit(train_dataset,\n",
    "        val_dataset,\n",
    "        mbsize=128,\n",
    "        lr=2e-4,\n",
    "        nepochs=250,\n",
    "        max_features=max_features,\n",
    "        loss_fn=nn.CrossEntropyLoss(),\n",
    "        val_loss_fn=NegAccuracy(),\n",
    "        weighting='multiplicative',\n",
    "        beta=0.5,\n",
    "        start_temp=1.0,\n",
    "        end_temp=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "parallel-africa",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test acc = 83.35\n"
     ]
    }
   ],
   "source": [
    "# Strong decay version\n",
    "decay_acc = afs.evaluate(test_dataset, max_features, Accuracy(), 1024)\n",
    "print(f'Test acc = {100*decay_acc:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "searching-navigation",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.7505\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.7626\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.7646\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.7698\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.7740\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.7839\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.7870\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.7885\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.7851\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.7904\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.7986\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.7907\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.8018\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.8048\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.8014\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.8010\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.8085\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.8088\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.8091\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.8151\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = -0.8142\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = -0.8145\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = -0.8242\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = -0.8177\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = -0.8223\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = -0.8277\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = -0.8193\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = -0.8174\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = -0.8228\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = -0.8246\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = -0.8259\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = -0.8252\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = -0.8297\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = -0.8250\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = -0.8288\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = -0.8305\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = -0.8236\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = -0.8351\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = -0.8268\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = -0.8315\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = -0.8349\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = -0.8351\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = -0.8265\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = -0.8320\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = -0.8278\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = -0.8290\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = -0.8294\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = -0.8312\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = -0.8298\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = -0.8341\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = -0.8394\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = -0.8302\n",
      "\n",
      "--------Epoch 53--------\n",
      "Val loss = -0.8384\n",
      "\n",
      "--------Epoch 54--------\n",
      "Val loss = -0.8399\n",
      "\n",
      "--------Epoch 55--------\n",
      "Val loss = -0.8386\n",
      "\n",
      "--------Epoch 56--------\n",
      "Val loss = -0.8336\n",
      "\n",
      "--------Epoch 57--------\n",
      "Val loss = -0.8368\n",
      "\n",
      "--------Epoch 58--------\n",
      "Val loss = -0.8398\n",
      "\n",
      "--------Epoch 59--------\n",
      "Val loss = -0.8385\n",
      "\n",
      "--------Epoch 60--------\n",
      "Val loss = -0.8426\n",
      "\n",
      "--------Epoch 61--------\n",
      "Val loss = -0.8411\n",
      "\n",
      "--------Epoch 62--------\n",
      "Val loss = -0.8429\n",
      "\n",
      "--------Epoch 63--------\n",
      "Val loss = -0.8430\n",
      "\n",
      "--------Epoch 64--------\n",
      "Val loss = -0.8413\n",
      "\n",
      "--------Epoch 65--------\n",
      "Val loss = -0.8441\n",
      "\n",
      "--------Epoch 66--------\n",
      "Val loss = -0.8387\n",
      "\n",
      "--------Epoch 67--------\n",
      "Val loss = -0.8439\n",
      "\n",
      "--------Epoch 68--------\n",
      "Val loss = -0.8461\n",
      "\n",
      "--------Epoch 69--------\n",
      "Val loss = -0.8451\n",
      "\n",
      "--------Epoch 70--------\n",
      "Val loss = -0.8456\n",
      "\n",
      "--------Epoch 71--------\n",
      "Val loss = -0.8451\n",
      "\n",
      "--------Epoch 72--------\n",
      "Val loss = -0.8410\n",
      "\n",
      "--------Epoch 73--------\n",
      "Val loss = -0.8428\n",
      "\n",
      "--------Epoch 74--------\n",
      "Val loss = -0.8418\n",
      "\n",
      "--------Epoch 75--------\n",
      "Val loss = -0.8463\n",
      "\n",
      "--------Epoch 76--------\n",
      "Val loss = -0.8411\n",
      "\n",
      "--------Epoch 77--------\n",
      "Val loss = -0.8497\n",
      "\n",
      "--------Epoch 78--------\n",
      "Val loss = -0.8434\n",
      "\n",
      "--------Epoch 79--------\n",
      "Val loss = -0.8482\n",
      "\n",
      "--------Epoch 80--------\n",
      "Val loss = -0.8448\n",
      "\n",
      "--------Epoch 81--------\n",
      "Val loss = -0.8470\n",
      "\n",
      "--------Epoch 82--------\n",
      "Val loss = -0.8446\n",
      "\n",
      "--------Epoch 83--------\n",
      "Val loss = -0.8450\n",
      "\n",
      "--------Epoch 84--------\n",
      "Val loss = -0.8522\n",
      "\n",
      "--------Epoch 85--------\n",
      "Val loss = -0.8514\n",
      "\n",
      "--------Epoch 86--------\n",
      "Val loss = -0.8460\n",
      "\n",
      "--------Epoch 87--------\n",
      "Val loss = -0.8501\n",
      "\n",
      "--------Epoch 88--------\n",
      "Val loss = -0.8498\n",
      "\n",
      "--------Epoch 89--------\n",
      "Val loss = -0.8501\n",
      "\n",
      "--------Epoch 90--------\n",
      "Val loss = -0.8453\n",
      "\n",
      "--------Epoch 91--------\n",
      "Val loss = -0.8487\n",
      "\n",
      "--------Epoch 92--------\n",
      "Val loss = -0.8516\n",
      "\n",
      "--------Epoch 93--------\n",
      "Val loss = -0.8534\n",
      "\n",
      "--------Epoch 94--------\n",
      "Val loss = -0.8507\n",
      "\n",
      "--------Epoch 95--------\n",
      "Val loss = -0.8510\n",
      "\n",
      "--------Epoch 96--------\n",
      "Val loss = -0.8542\n",
      "\n",
      "--------Epoch 97--------\n",
      "Val loss = -0.8501\n",
      "\n",
      "--------Epoch 98--------\n",
      "Val loss = -0.8487\n",
      "\n",
      "--------Epoch 99--------\n",
      "Val loss = -0.8472\n",
      "\n",
      "--------Epoch 100--------\n",
      "Val loss = -0.8483\n",
      "\n",
      "--------Epoch 101--------\n",
      "Val loss = -0.8523\n",
      "\n",
      "--------Epoch 102--------\n",
      "Val loss = -0.8523\n",
      "\n",
      "--------Epoch 103--------\n",
      "Val loss = -0.8424\n",
      "\n",
      "--------Epoch 104--------\n",
      "Val loss = -0.8491\n",
      "\n",
      "--------Epoch 105--------\n",
      "Val loss = -0.8533\n",
      "\n",
      "--------Epoch 106--------\n",
      "Val loss = -0.8537\n",
      "\n",
      "--------Epoch 107--------\n",
      "Val loss = -0.8559\n",
      "\n",
      "--------Epoch 108--------\n",
      "Val loss = -0.8533\n",
      "\n",
      "--------Epoch 109--------\n",
      "Val loss = -0.8529\n",
      "\n",
      "--------Epoch 110--------\n",
      "Val loss = -0.8493\n",
      "\n",
      "--------Epoch 111--------\n",
      "Val loss = -0.8509\n",
      "\n",
      "--------Epoch 112--------\n",
      "Val loss = -0.8517\n",
      "\n",
      "--------Epoch 113--------\n",
      "Val loss = -0.8544\n",
      "\n",
      "--------Epoch 114--------\n",
      "Val loss = -0.8497\n",
      "\n",
      "--------Epoch 115--------\n",
      "Val loss = -0.8554\n",
      "\n",
      "--------Epoch 116--------\n",
      "Val loss = -0.8476\n",
      "\n",
      "--------Epoch 117--------\n",
      "Val loss = -0.8515\n",
      "\n",
      "--------Epoch 118--------\n",
      "Val loss = -0.8531\n",
      "\n",
      "--------Epoch 119--------\n",
      "Val loss = -0.8499\n",
      "\n",
      "--------Epoch 120--------\n",
      "Val loss = -0.8541\n",
      "\n",
      "--------Epoch 121--------\n",
      "Val loss = -0.8562\n",
      "\n",
      "--------Epoch 122--------\n",
      "Val loss = -0.8531\n",
      "\n",
      "--------Epoch 123--------\n",
      "Val loss = -0.8520\n",
      "\n",
      "--------Epoch 124--------\n",
      "Val loss = -0.8517\n",
      "\n",
      "--------Epoch 125--------\n",
      "Val loss = -0.8543\n",
      "\n",
      "--------Epoch 126--------\n",
      "Val loss = -0.8576\n",
      "\n",
      "--------Epoch 127--------\n",
      "Val loss = -0.8548\n",
      "\n",
      "--------Epoch 128--------\n",
      "Val loss = -0.8565\n",
      "\n",
      "--------Epoch 129--------\n",
      "Val loss = -0.8538\n",
      "\n",
      "--------Epoch 130--------\n",
      "Val loss = -0.8525\n",
      "\n",
      "--------Epoch 131--------\n",
      "Val loss = -0.8530\n",
      "\n",
      "--------Epoch 132--------\n",
      "Val loss = -0.8594\n",
      "\n",
      "--------Epoch 133--------\n",
      "Val loss = -0.8597\n",
      "\n",
      "--------Epoch 134--------\n",
      "Val loss = -0.8543\n",
      "\n",
      "--------Epoch 135--------\n",
      "Val loss = -0.8568\n",
      "\n",
      "--------Epoch 136--------\n",
      "Val loss = -0.8505\n",
      "\n",
      "--------Epoch 137--------\n",
      "Val loss = -0.8535\n",
      "\n",
      "--------Epoch 138--------\n",
      "Val loss = -0.8529\n",
      "\n",
      "--------Epoch 139--------\n",
      "Val loss = -0.8545\n",
      "\n",
      "--------Epoch 140--------\n",
      "Val loss = -0.8586\n",
      "\n",
      "--------Epoch 141--------\n",
      "Val loss = -0.8549\n",
      "\n",
      "--------Epoch 142--------\n",
      "Val loss = -0.8540\n",
      "\n",
      "--------Epoch 143--------\n",
      "Val loss = -0.8578\n",
      "\n",
      "--------Epoch 144--------\n",
      "Val loss = -0.8542\n",
      "\n",
      "--------Epoch 145--------\n",
      "Val loss = -0.8574\n",
      "\n",
      "--------Epoch 146--------\n",
      "Val loss = -0.8550\n",
      "\n",
      "--------Epoch 147--------\n",
      "Val loss = -0.8553\n",
      "\n",
      "--------Epoch 148--------\n",
      "Val loss = -0.8593\n",
      "\n",
      "--------Epoch 149--------\n",
      "Val loss = -0.8531\n",
      "\n",
      "--------Epoch 150--------\n",
      "Val loss = -0.8512\n",
      "\n",
      "--------Epoch 151--------\n",
      "Val loss = -0.8607\n",
      "\n",
      "--------Epoch 152--------\n",
      "Val loss = -0.8530\n",
      "\n",
      "--------Epoch 153--------\n",
      "Val loss = -0.8525\n",
      "\n",
      "--------Epoch 154--------\n",
      "Val loss = -0.8566\n",
      "\n",
      "--------Epoch 155--------\n",
      "Val loss = -0.8516\n",
      "\n",
      "--------Epoch 156--------\n",
      "Val loss = -0.8539\n",
      "\n",
      "--------Epoch 157--------\n",
      "Val loss = -0.8550\n",
      "\n",
      "--------Epoch 158--------\n",
      "Val loss = -0.8591\n",
      "\n",
      "--------Epoch 159--------\n",
      "Val loss = -0.8513\n",
      "\n",
      "--------Epoch 160--------\n",
      "Val loss = -0.8538\n",
      "\n",
      "--------Epoch 161--------\n",
      "Val loss = -0.8533\n",
      "\n",
      "--------Epoch 162--------\n",
      "Val loss = -0.8532\n",
      "\n",
      "--------Epoch 163--------\n",
      "Val loss = -0.8608\n",
      "\n",
      "--------Epoch 164--------\n",
      "Val loss = -0.8621\n",
      "\n",
      "--------Epoch 165--------\n",
      "Val loss = -0.8541\n",
      "\n",
      "--------Epoch 166--------\n",
      "Val loss = -0.8509\n",
      "\n",
      "--------Epoch 167--------\n",
      "Val loss = -0.8553\n",
      "\n",
      "--------Epoch 168--------\n",
      "Val loss = -0.8543\n",
      "\n",
      "--------Epoch 169--------\n",
      "Val loss = -0.8552\n",
      "\n",
      "--------Epoch 170--------\n",
      "Val loss = -0.8537\n",
      "\n",
      "--------Epoch 171--------\n",
      "Val loss = -0.8572\n",
      "\n",
      "--------Epoch 172--------\n",
      "Val loss = -0.8546\n",
      "\n",
      "--------Epoch 173--------\n",
      "Val loss = -0.8563\n",
      "\n",
      "--------Epoch 174--------\n",
      "Val loss = -0.8518\n",
      "\n",
      "--------Epoch 175--------\n",
      "Val loss = -0.8542\n",
      "\n",
      "--------Epoch 176--------\n",
      "Val loss = -0.8527\n",
      "\n",
      "--------Epoch 177--------\n",
      "Val loss = -0.8499\n",
      "\n",
      "--------Epoch 178--------\n",
      "Val loss = -0.8556\n",
      "\n",
      "--------Epoch 179--------\n",
      "Val loss = -0.8596\n",
      "\n",
      "--------Epoch 180--------\n",
      "Val loss = -0.8532\n",
      "\n",
      "--------Epoch 181--------\n",
      "Val loss = -0.8580\n",
      "\n",
      "--------Epoch 182--------\n",
      "Val loss = -0.8546\n",
      "\n",
      "--------Epoch 183--------\n",
      "Val loss = -0.8552\n",
      "\n",
      "--------Epoch 184--------\n",
      "Val loss = -0.8532\n",
      "\n",
      "--------Epoch 185--------\n",
      "Val loss = -0.8534\n",
      "\n",
      "--------Epoch 186--------\n",
      "Val loss = -0.8553\n",
      "\n",
      "--------Epoch 187--------\n",
      "Val loss = -0.8570\n",
      "\n",
      "--------Epoch 188--------\n",
      "Val loss = -0.8566\n",
      "\n",
      "--------Epoch 189--------\n",
      "Val loss = -0.8585\n",
      "\n",
      "--------Epoch 190--------\n",
      "Val loss = -0.8529\n",
      "\n",
      "--------Epoch 191--------\n",
      "Val loss = -0.8512\n",
      "\n",
      "--------Epoch 192--------\n",
      "Val loss = -0.8537\n",
      "\n",
      "--------Epoch 193--------\n",
      "Val loss = -0.8560\n",
      "\n",
      "--------Epoch 194--------\n",
      "Val loss = -0.8564\n",
      "\n",
      "--------Epoch 195--------\n",
      "Val loss = -0.8538\n",
      "\n",
      "--------Epoch 196--------\n",
      "Val loss = -0.8545\n",
      "\n",
      "--------Epoch 197--------\n",
      "Val loss = -0.8555\n",
      "\n",
      "--------Epoch 198--------\n",
      "Val loss = -0.8565\n",
      "\n",
      "--------Epoch 199--------\n",
      "Val loss = -0.8585\n",
      "\n",
      "--------Epoch 200--------\n",
      "Val loss = -0.8583\n",
      "\n",
      "--------Epoch 201--------\n",
      "Val loss = -0.8589\n",
      "\n",
      "--------Epoch 202--------\n",
      "Val loss = -0.8542\n",
      "\n",
      "--------Epoch 203--------\n",
      "Val loss = -0.8562\n",
      "\n",
      "--------Epoch 204--------\n",
      "Val loss = -0.8550\n",
      "\n",
      "--------Epoch 205--------\n",
      "Val loss = -0.8546\n",
      "\n",
      "--------Epoch 206--------\n",
      "Val loss = -0.8578\n",
      "\n",
      "--------Epoch 207--------\n",
      "Val loss = -0.8528\n",
      "\n",
      "--------Epoch 208--------\n",
      "Val loss = -0.8518\n",
      "\n",
      "--------Epoch 209--------\n",
      "Val loss = -0.8573\n",
      "\n",
      "--------Epoch 210--------\n",
      "Val loss = -0.8512\n",
      "\n",
      "--------Epoch 211--------\n",
      "Val loss = -0.8549\n",
      "\n",
      "--------Epoch 212--------\n",
      "Val loss = -0.8566\n",
      "\n",
      "--------Epoch 213--------\n",
      "Val loss = -0.8556\n",
      "\n",
      "--------Epoch 214--------\n",
      "Val loss = -0.8498\n",
      "\n",
      "--------Epoch 215--------\n",
      "Val loss = -0.8518\n",
      "\n",
      "--------Epoch 216--------\n",
      "Val loss = -0.8528\n",
      "\n",
      "--------Epoch 217--------\n",
      "Val loss = -0.8520\n",
      "\n",
      "--------Epoch 218--------\n",
      "Val loss = -0.8553\n",
      "\n",
      "--------Epoch 219--------\n",
      "Val loss = -0.8576\n",
      "\n",
      "--------Epoch 220--------\n",
      "Val loss = -0.8542\n",
      "\n",
      "--------Epoch 221--------\n",
      "Val loss = -0.8508\n",
      "\n",
      "--------Epoch 222--------\n",
      "Val loss = -0.8537\n",
      "\n",
      "--------Epoch 223--------\n",
      "Val loss = -0.8504\n",
      "\n",
      "--------Epoch 224--------\n",
      "Val loss = -0.8528\n",
      "\n",
      "--------Epoch 225--------\n",
      "Val loss = -0.8507\n",
      "\n",
      "--------Epoch 226--------\n",
      "Val loss = -0.8545\n",
      "\n",
      "--------Epoch 227--------\n",
      "Val loss = -0.8559\n",
      "\n",
      "--------Epoch 228--------\n",
      "Val loss = -0.8525\n",
      "\n",
      "--------Epoch 229--------\n",
      "Val loss = -0.8515\n",
      "\n",
      "--------Epoch 230--------\n",
      "Val loss = -0.8546\n",
      "\n",
      "--------Epoch 231--------\n",
      "Val loss = -0.8538\n",
      "\n",
      "--------Epoch 232--------\n",
      "Val loss = -0.8543\n",
      "\n",
      "--------Epoch 233--------\n",
      "Val loss = -0.8564\n",
      "\n",
      "--------Epoch 234--------\n",
      "Val loss = -0.8558\n",
      "\n",
      "--------Epoch 235--------\n",
      "Val loss = -0.8574\n",
      "\n",
      "--------Epoch 236--------\n",
      "Val loss = -0.8559\n",
      "\n",
      "--------Epoch 237--------\n",
      "Val loss = -0.8569\n",
      "\n",
      "--------Epoch 238--------\n",
      "Val loss = -0.8606\n",
      "\n",
      "--------Epoch 239--------\n",
      "Val loss = -0.8544\n",
      "\n",
      "--------Epoch 240--------\n",
      "Val loss = -0.8567\n",
      "\n",
      "--------Epoch 241--------\n",
      "Val loss = -0.8576\n",
      "\n",
      "--------Epoch 242--------\n",
      "Val loss = -0.8549\n",
      "\n",
      "--------Epoch 243--------\n",
      "Val loss = -0.8525\n",
      "\n",
      "--------Epoch 244--------\n",
      "Val loss = -0.8536\n",
      "\n",
      "--------Epoch 245--------\n",
      "Val loss = -0.8540\n",
      "\n",
      "--------Epoch 246--------\n",
      "Val loss = -0.8526\n",
      "\n",
      "--------Epoch 247--------\n",
      "Val loss = -0.8486\n",
      "\n",
      "--------Epoch 248--------\n",
      "Val loss = -0.8517\n",
      "\n",
      "--------Epoch 249--------\n",
      "Val loss = -0.8530\n",
      "\n",
      "--------Epoch 250--------\n",
      "Val loss = -0.8523\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up selector\n",
    "selector = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_in))\n",
    "selector_layer = ConcreteSelector()\n",
    "gafs = GreedyAdaptiveFS(selector, deepcopy(model), mask_layer, selector_layer).to(device)\n",
    "\n",
    "# Tie weights\n",
    "selector[0].weight = nn.Parameter(gafs.model[0].weight)\n",
    "selector[2].weight = nn.Parameter(gafs.model[2].weight)\n",
    "\n",
    "# Train\n",
    "gafs.fit(train_dataset,\n",
    "         val_dataset,\n",
    "         mbsize=128,\n",
    "         lr=2e-4,\n",
    "         nepochs=250,\n",
    "         max_features=max_features,\n",
    "         loss_fn=nn.CrossEntropyLoss(),\n",
    "         val_loss_fn=NegAccuracy(),\n",
    "         start_temp=1.0,\n",
    "         end_temp=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "pleased-ordinance",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test acc = 86.51\n"
     ]
    }
   ],
   "source": [
    "# Greedy version\n",
    "test_acc = gafs.evaluate(test_dataset, max_features, Accuracy(), 1024)\n",
    "print(f'Test acc = {100*test_acc:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "elegant-pledge",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test acc = 87.02\n"
     ]
    }
   ],
   "source": [
    "# Greedy version\n",
    "# test_acc = gafs.evaluate(test_dataset, max_features, Accuracy(), 1024)\n",
    "# print(f'Test acc = {100*test_acc:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "numeric-industry",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
