{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "reverse-light",
   "metadata": {},
   "source": [
    "# Warm starting and weight tying\n",
    "\n",
    "- This notebook tests two separate tricks to possibly improve performance.\n",
    "- The first idea, \"weight tying,\" is to share parameters for the initial layers of the selector and predictor model. The hope would be that it trains faster and ultimately achieves better results. In the end, it didn't make much of a difference. I could see it making a bigger difference on other datasets that are harder to learn, because the selector is useless immediately after initialization.\n",
    "- The second ideas, \"warm starting,\" refers to giving the selector model a stronger initialization prior to joint training. This is done by fixing the predictor model and training the selector only; after this converges, the two are trained jointly. Again, this didn't make a big difference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "olympic-maximum",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n",
    "from torchvision import transforms\n",
    "from torchvision.datasets import MNIST\n",
    "from copy import deepcopy\n",
    "from utils import *\n",
    "from models import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fixed-decline",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda', 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "refined-store",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load data\n",
    "class Flatten(object):\n",
    "    def __call__(self, pic):\n",
    "        return torch.flatten(pic)\n",
    "    \n",
    "mnist_dataset = MNIST('/tmp/mnist/', download=True, train=True,\n",
    "                      transform=transforms.Compose([transforms.ToTensor(), Flatten()]))\n",
    "images = mnist_dataset.data\n",
    "targets = mnist_dataset.targets\n",
    "np.random.seed(0)\n",
    "val_inds = np.sort(np.random.choice(len(images), size=10000, replace=False))\n",
    "train_inds = np.setdiff1d(np.arange(len(images)), val_inds)\n",
    "\n",
    "# Training dataset\n",
    "train_dataset = torch.utils.data.Subset(mnist_dataset, train_inds)\n",
    "\n",
    "# Validation dataset\n",
    "val_dataset = torch.utils.data.Subset(mnist_dataset, val_inds)\n",
    "\n",
    "# Test dataset\n",
    "test_dataset = MNIST('/tmp/mnist/', download=True, train=False,\n",
    "                     transform=transforms.Compose([transforms.ToTensor(), Flatten()]))\n",
    "\n",
    "# Set input/output dimensions\n",
    "d_in = 784\n",
    "d_out = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "earlier-coffee",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Number of features to select\n",
    "max_features = 10"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "natural-hollow",
   "metadata": {},
   "source": [
    "# Global FS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "excessive-paradise",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.2946\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.3245\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.3048\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.3243\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.3410\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.3410\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.3505\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.3491\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.3482\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.3586\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.3588\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.3709\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.3702\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.3684\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.3778\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.3829\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.3896\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.3897\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.4051\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.3979\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = -0.3957\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = -0.4163\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = -0.4121\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = -0.4116\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = -0.4149\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = -0.4196\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = -0.4276\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = -0.4266\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = -0.4300\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = -0.4374\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = -0.4337\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = -0.4396\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = -0.4560\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = -0.4518\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = -0.4559\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = -0.4616\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = -0.4642\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = -0.4673\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = -0.4805\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = -0.4668\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = -0.4836\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = -0.4743\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = -0.4947\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = -0.4810\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = -0.4960\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = -0.4988\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = -0.5057\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = -0.5028\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = -0.5049\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = -0.5146\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = -0.5028\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = -0.5148\n",
      "\n",
      "--------Epoch 53--------\n",
      "Val loss = -0.5191\n",
      "\n",
      "--------Epoch 54--------\n",
      "Val loss = -0.5216\n",
      "\n",
      "--------Epoch 55--------\n",
      "Val loss = -0.5249\n",
      "\n",
      "--------Epoch 56--------\n",
      "Val loss = -0.5252\n",
      "\n",
      "--------Epoch 57--------\n",
      "Val loss = -0.5373\n",
      "\n",
      "--------Epoch 58--------\n",
      "Val loss = -0.5338\n",
      "\n",
      "--------Epoch 59--------\n",
      "Val loss = -0.5370\n",
      "\n",
      "--------Epoch 60--------\n",
      "Val loss = -0.5419\n",
      "\n",
      "--------Epoch 61--------\n",
      "Val loss = -0.5482\n",
      "\n",
      "--------Epoch 62--------\n",
      "Val loss = -0.5454\n",
      "\n",
      "--------Epoch 63--------\n",
      "Val loss = -0.5485\n",
      "\n",
      "--------Epoch 64--------\n",
      "Val loss = -0.5507\n",
      "\n",
      "--------Epoch 65--------\n",
      "Val loss = -0.5581\n",
      "\n",
      "--------Epoch 66--------\n",
      "Val loss = -0.5489\n",
      "\n",
      "--------Epoch 67--------\n",
      "Val loss = -0.5632\n",
      "\n",
      "--------Epoch 68--------\n",
      "Val loss = -0.5617\n",
      "\n",
      "--------Epoch 69--------\n",
      "Val loss = -0.5676\n",
      "\n",
      "--------Epoch 70--------\n",
      "Val loss = -0.5721\n",
      "\n",
      "--------Epoch 71--------\n",
      "Val loss = -0.5684\n",
      "\n",
      "--------Epoch 72--------\n",
      "Val loss = -0.5800\n",
      "\n",
      "--------Epoch 73--------\n",
      "Val loss = -0.5853\n",
      "\n",
      "--------Epoch 74--------\n",
      "Val loss = -0.5765\n",
      "\n",
      "--------Epoch 75--------\n",
      "Val loss = -0.5835\n",
      "\n",
      "--------Epoch 76--------\n",
      "Val loss = -0.5795\n",
      "\n",
      "--------Epoch 77--------\n",
      "Val loss = -0.5834\n",
      "\n",
      "--------Epoch 78--------\n",
      "Val loss = -0.5960\n",
      "\n",
      "--------Epoch 79--------\n",
      "Val loss = -0.5862\n",
      "\n",
      "--------Epoch 80--------\n",
      "Val loss = -0.6051\n",
      "\n",
      "--------Epoch 81--------\n",
      "Val loss = -0.5994\n",
      "\n",
      "--------Epoch 82--------\n",
      "Val loss = -0.6000\n",
      "\n",
      "--------Epoch 83--------\n",
      "Val loss = -0.6035\n",
      "\n",
      "--------Epoch 84--------\n",
      "Val loss = -0.6142\n",
      "\n",
      "--------Epoch 85--------\n",
      "Val loss = -0.6009\n",
      "\n",
      "--------Epoch 86--------\n",
      "Val loss = -0.6086\n",
      "\n",
      "--------Epoch 87--------\n",
      "Val loss = -0.6071\n",
      "\n",
      "--------Epoch 88--------\n",
      "Val loss = -0.6051\n",
      "\n",
      "--------Epoch 89--------\n",
      "Val loss = -0.6108\n",
      "\n",
      "--------Epoch 90--------\n",
      "Val loss = -0.6171\n",
      "\n",
      "--------Epoch 91--------\n",
      "Val loss = -0.6117\n",
      "\n",
      "--------Epoch 92--------\n",
      "Val loss = -0.6144\n",
      "\n",
      "--------Epoch 93--------\n",
      "Val loss = -0.6264\n",
      "\n",
      "--------Epoch 94--------\n",
      "Val loss = -0.6214\n",
      "\n",
      "--------Epoch 95--------\n",
      "Val loss = -0.6283\n",
      "\n",
      "--------Epoch 96--------\n",
      "Val loss = -0.6281\n",
      "\n",
      "--------Epoch 97--------\n",
      "Val loss = -0.6321\n",
      "\n",
      "--------Epoch 98--------\n",
      "Val loss = -0.6296\n",
      "\n",
      "--------Epoch 99--------\n",
      "Val loss = -0.6410\n",
      "\n",
      "--------Epoch 100--------\n",
      "Val loss = -0.6322\n",
      "\n",
      "--------Epoch 101--------\n",
      "Val loss = -0.6353\n",
      "\n",
      "--------Epoch 102--------\n",
      "Val loss = -0.6464\n",
      "\n",
      "--------Epoch 103--------\n",
      "Val loss = -0.6423\n",
      "\n",
      "--------Epoch 104--------\n",
      "Val loss = -0.6469\n",
      "\n",
      "--------Epoch 105--------\n",
      "Val loss = -0.6487\n",
      "\n",
      "--------Epoch 106--------\n",
      "Val loss = -0.6457\n",
      "\n",
      "--------Epoch 107--------\n",
      "Val loss = -0.6483\n",
      "\n",
      "--------Epoch 108--------\n",
      "Val loss = -0.6495\n",
      "\n",
      "--------Epoch 109--------\n",
      "Val loss = -0.6518\n",
      "\n",
      "--------Epoch 110--------\n",
      "Val loss = -0.6498\n",
      "\n",
      "--------Epoch 111--------\n",
      "Val loss = -0.6482\n",
      "\n",
      "--------Epoch 112--------\n",
      "Val loss = -0.6574\n",
      "\n",
      "--------Epoch 113--------\n",
      "Val loss = -0.6599\n",
      "\n",
      "--------Epoch 114--------\n",
      "Val loss = -0.6638\n",
      "\n",
      "--------Epoch 115--------\n",
      "Val loss = -0.6577\n",
      "\n",
      "--------Epoch 116--------\n",
      "Val loss = -0.6612\n",
      "\n",
      "--------Epoch 117--------\n",
      "Val loss = -0.6657\n",
      "\n",
      "--------Epoch 118--------\n",
      "Val loss = -0.6705\n",
      "\n",
      "--------Epoch 119--------\n",
      "Val loss = -0.6712\n",
      "\n",
      "--------Epoch 120--------\n",
      "Val loss = -0.6701\n",
      "\n",
      "--------Epoch 121--------\n",
      "Val loss = -0.6732\n",
      "\n",
      "--------Epoch 122--------\n",
      "Val loss = -0.6807\n",
      "\n",
      "--------Epoch 123--------\n",
      "Val loss = -0.6752\n",
      "\n",
      "--------Epoch 124--------\n",
      "Val loss = -0.6696\n",
      "\n",
      "--------Epoch 125--------\n",
      "Val loss = -0.6854\n",
      "\n",
      "--------Epoch 126--------\n",
      "Val loss = -0.6875\n",
      "\n",
      "--------Epoch 127--------\n",
      "Val loss = -0.6843\n",
      "\n",
      "--------Epoch 128--------\n",
      "Val loss = -0.6846\n",
      "\n",
      "--------Epoch 129--------\n",
      "Val loss = -0.6949\n",
      "\n",
      "--------Epoch 130--------\n",
      "Val loss = -0.6903\n",
      "\n",
      "--------Epoch 131--------\n",
      "Val loss = -0.6992\n",
      "\n",
      "--------Epoch 132--------\n",
      "Val loss = -0.6902\n",
      "\n",
      "--------Epoch 133--------\n",
      "Val loss = -0.6933\n",
      "\n",
      "--------Epoch 134--------\n",
      "Val loss = -0.6934\n",
      "\n",
      "--------Epoch 135--------\n",
      "Val loss = -0.6932\n",
      "\n",
      "--------Epoch 136--------\n",
      "Val loss = -0.6979\n",
      "\n",
      "--------Epoch 137--------\n",
      "Val loss = -0.7029\n",
      "\n",
      "--------Epoch 138--------\n",
      "Val loss = -0.6971\n",
      "\n",
      "--------Epoch 139--------\n",
      "Val loss = -0.6981\n",
      "\n",
      "--------Epoch 140--------\n",
      "Val loss = -0.7049\n",
      "\n",
      "--------Epoch 141--------\n",
      "Val loss = -0.6998\n",
      "\n",
      "--------Epoch 142--------\n",
      "Val loss = -0.7130\n",
      "\n",
      "--------Epoch 143--------\n",
      "Val loss = -0.7065\n",
      "\n",
      "--------Epoch 144--------\n",
      "Val loss = -0.7087\n",
      "\n",
      "--------Epoch 145--------\n",
      "Val loss = -0.7078\n",
      "\n",
      "--------Epoch 146--------\n",
      "Val loss = -0.7091\n",
      "\n",
      "--------Epoch 147--------\n",
      "Val loss = -0.7146\n",
      "\n",
      "--------Epoch 148--------\n",
      "Val loss = -0.7122\n",
      "\n",
      "--------Epoch 149--------\n",
      "Val loss = -0.7068\n",
      "\n",
      "--------Epoch 150--------\n",
      "Val loss = -0.7108\n",
      "\n",
      "--------Epoch 151--------\n",
      "Val loss = -0.7212\n",
      "\n",
      "--------Epoch 152--------\n",
      "Val loss = -0.7218\n",
      "\n",
      "--------Epoch 153--------\n",
      "Val loss = -0.7184\n",
      "\n",
      "--------Epoch 154--------\n",
      "Val loss = -0.7221\n",
      "\n",
      "--------Epoch 155--------\n",
      "Val loss = -0.7243\n",
      "\n",
      "--------Epoch 156--------\n",
      "Val loss = -0.7226\n",
      "\n",
      "--------Epoch 157--------\n",
      "Val loss = -0.7260\n",
      "\n",
      "--------Epoch 158--------\n",
      "Val loss = -0.7296\n",
      "\n",
      "--------Epoch 159--------\n",
      "Val loss = -0.7244\n",
      "\n",
      "--------Epoch 160--------\n",
      "Val loss = -0.7282\n",
      "\n",
      "--------Epoch 161--------\n",
      "Val loss = -0.7251\n",
      "\n",
      "--------Epoch 162--------\n",
      "Val loss = -0.7330\n",
      "\n",
      "--------Epoch 163--------\n",
      "Val loss = -0.7365\n",
      "\n",
      "--------Epoch 164--------\n",
      "Val loss = -0.7342\n",
      "\n",
      "--------Epoch 165--------\n",
      "Val loss = -0.7363\n",
      "\n",
      "--------Epoch 166--------\n",
      "Val loss = -0.7327\n",
      "\n",
      "--------Epoch 167--------\n",
      "Val loss = -0.7349\n",
      "\n",
      "--------Epoch 168--------\n",
      "Val loss = -0.7389\n",
      "\n",
      "--------Epoch 169--------\n",
      "Val loss = -0.7463\n",
      "\n",
      "--------Epoch 170--------\n",
      "Val loss = -0.7479\n",
      "\n",
      "--------Epoch 171--------\n",
      "Val loss = -0.7420\n",
      "\n",
      "--------Epoch 172--------\n",
      "Val loss = -0.7426\n",
      "\n",
      "--------Epoch 173--------\n",
      "Val loss = -0.7502\n",
      "\n",
      "--------Epoch 174--------\n",
      "Val loss = -0.7524\n",
      "\n",
      "--------Epoch 175--------\n",
      "Val loss = -0.7522\n",
      "\n",
      "--------Epoch 176--------\n",
      "Val loss = -0.7482\n",
      "\n",
      "--------Epoch 177--------\n",
      "Val loss = -0.7469\n",
      "\n",
      "--------Epoch 178--------\n",
      "Val loss = -0.7484\n",
      "\n",
      "--------Epoch 179--------\n",
      "Val loss = -0.7494\n",
      "\n",
      "--------Epoch 180--------\n",
      "Val loss = -0.7531\n",
      "\n",
      "--------Epoch 181--------\n",
      "Val loss = -0.7546\n",
      "\n",
      "--------Epoch 182--------\n",
      "Val loss = -0.7508\n",
      "\n",
      "--------Epoch 183--------\n",
      "Val loss = -0.7568\n",
      "\n",
      "--------Epoch 184--------\n",
      "Val loss = -0.7553\n",
      "\n",
      "--------Epoch 185--------\n",
      "Val loss = -0.7586\n",
      "\n",
      "--------Epoch 186--------\n",
      "Val loss = -0.7622\n",
      "\n",
      "--------Epoch 187--------\n",
      "Val loss = -0.7602\n",
      "\n",
      "--------Epoch 188--------\n",
      "Val loss = -0.7587\n",
      "\n",
      "--------Epoch 189--------\n",
      "Val loss = -0.7606\n",
      "\n",
      "--------Epoch 190--------\n",
      "Val loss = -0.7639\n",
      "\n",
      "--------Epoch 191--------\n",
      "Val loss = -0.7625\n",
      "\n",
      "--------Epoch 192--------\n",
      "Val loss = -0.7590\n",
      "\n",
      "--------Epoch 193--------\n",
      "Val loss = -0.7640\n",
      "\n",
      "--------Epoch 194--------\n",
      "Val loss = -0.7628\n",
      "\n",
      "--------Epoch 195--------\n",
      "Val loss = -0.7655\n",
      "\n",
      "--------Epoch 196--------\n",
      "Val loss = -0.7620\n",
      "\n",
      "--------Epoch 197--------\n",
      "Val loss = -0.7677\n",
      "\n",
      "--------Epoch 198--------\n",
      "Val loss = -0.7624\n",
      "\n",
      "--------Epoch 199--------\n",
      "Val loss = -0.7683\n",
      "\n",
      "--------Epoch 200--------\n",
      "Val loss = -0.7609\n",
      "\n",
      "--------Epoch 201--------\n",
      "Val loss = -0.7683\n",
      "\n",
      "--------Epoch 202--------\n",
      "Val loss = -0.7646\n",
      "\n",
      "--------Epoch 203--------\n",
      "Val loss = -0.7682\n",
      "\n",
      "--------Epoch 204--------\n",
      "Val loss = -0.7648\n",
      "\n",
      "--------Epoch 205--------\n",
      "Val loss = -0.7635\n",
      "\n",
      "--------Epoch 206--------\n",
      "Val loss = -0.7616\n",
      "\n",
      "--------Epoch 207--------\n",
      "Val loss = -0.7669\n",
      "\n",
      "--------Epoch 208--------\n",
      "Val loss = -0.7644\n",
      "\n",
      "--------Epoch 209--------\n",
      "Val loss = -0.7659\n",
      "\n",
      "--------Epoch 210--------\n",
      "Val loss = -0.7668\n",
      "\n",
      "--------Epoch 211--------\n",
      "Val loss = -0.7646\n",
      "\n",
      "--------Epoch 212--------\n",
      "Val loss = -0.7618\n",
      "\n",
      "--------Epoch 213--------\n",
      "Val loss = -0.7651\n",
      "\n",
      "--------Epoch 214--------\n",
      "Val loss = -0.7638\n",
      "\n",
      "--------Epoch 215--------\n",
      "Val loss = -0.7618\n",
      "\n",
      "--------Epoch 216--------\n",
      "Val loss = -0.7691\n",
      "\n",
      "--------Epoch 217--------\n",
      "Val loss = -0.7690\n",
      "\n",
      "--------Epoch 218--------\n",
      "Val loss = -0.7665\n",
      "\n",
      "--------Epoch 219--------\n",
      "Val loss = -0.7673\n",
      "\n",
      "--------Epoch 220--------\n",
      "Val loss = -0.7622\n",
      "\n",
      "--------Epoch 221--------\n",
      "Val loss = -0.7667\n",
      "\n",
      "--------Epoch 222--------\n",
      "Val loss = -0.7665\n",
      "\n",
      "--------Epoch 223--------\n",
      "Val loss = -0.7662\n",
      "\n",
      "--------Epoch 224--------\n",
      "Val loss = -0.7636\n",
      "\n",
      "--------Epoch 225--------\n",
      "Val loss = -0.7677\n",
      "\n",
      "--------Epoch 226--------\n",
      "Val loss = -0.7646\n",
      "\n",
      "--------Epoch 227--------\n",
      "Val loss = -0.7649\n",
      "\n",
      "--------Epoch 228--------\n",
      "Val loss = -0.7637\n",
      "\n",
      "--------Epoch 229--------\n",
      "Val loss = -0.7628\n",
      "\n",
      "--------Epoch 230--------\n",
      "Val loss = -0.7634\n",
      "\n",
      "--------Epoch 231--------\n",
      "Val loss = -0.7611\n",
      "\n",
      "--------Epoch 232--------\n",
      "Val loss = -0.7588\n",
      "\n",
      "--------Epoch 233--------\n",
      "Val loss = -0.7677\n",
      "\n",
      "--------Epoch 234--------\n",
      "Val loss = -0.7659\n",
      "\n",
      "--------Epoch 235--------\n",
      "Val loss = -0.7690\n",
      "\n",
      "--------Epoch 236--------\n",
      "Val loss = -0.7666\n",
      "\n",
      "--------Epoch 237--------\n",
      "Val loss = -0.7647\n",
      "\n",
      "--------Epoch 238--------\n",
      "Val loss = -0.7612\n",
      "\n",
      "--------Epoch 239--------\n",
      "Val loss = -0.7595\n",
      "\n",
      "--------Epoch 240--------\n",
      "Val loss = -0.7621\n",
      "\n",
      "--------Epoch 241--------\n",
      "Val loss = -0.7616\n",
      "\n",
      "--------Epoch 242--------\n",
      "Val loss = -0.7598\n",
      "\n",
      "--------Epoch 243--------\n",
      "Val loss = -0.7622\n",
      "\n",
      "--------Epoch 244--------\n",
      "Val loss = -0.7680\n",
      "\n",
      "--------Epoch 245--------\n",
      "Val loss = -0.7605\n",
      "\n",
      "--------Epoch 246--------\n",
      "Val loss = -0.7596\n",
      "\n",
      "--------Epoch 247--------\n",
      "Val loss = -0.7578\n",
      "\n",
      "--------Epoch 248--------\n",
      "Val loss = -0.7572\n",
      "\n",
      "--------Epoch 249--------\n",
      "Val loss = -0.7588\n",
      "\n",
      "--------Epoch 250--------\n",
      "Val loss = -0.7617\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up model\n",
    "global_model = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_out))\n",
    "selector = ConcreteMask(d_in, max_features, append=True)\n",
    "globalfs = GlobalSelector(global_model, selector).to(device)\n",
    "\n",
    "# Train\n",
    "globalfs.fit(train_dataset,\n",
    "             val_dataset,\n",
    "             mbsize=128,\n",
    "             lr=1e-3,\n",
    "             nepochs=250,\n",
    "             loss_fn=nn.CrossEntropyLoss(),\n",
    "             val_loss_fn=NegAccuracy(),\n",
    "             start_temp=1.0,\n",
    "             end_temp=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "herbal-country",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test acc = 76.63\n"
     ]
    }
   ],
   "source": [
    "test_acc = globalfs.evaluate(test_dataset, Accuracy(), 1024)\n",
    "print(f'Test acc = {100*test_acc:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "square-video",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAK1klEQVR4nO3dT6hc93mH8edbV1ZASUGKa6M6pkmDFzWFKuWiFlyKi2nqeCNnkRItggoGZRFDAlnUpIt4aUqT0EUJKLWIWlKHQGKshWkiRMBkY3xtVFuu2to1SqJISA1exClUlp23i3tcbuT7zzNn/kTv84HLzJyZe+dl0KMzM2fu/aWqkHTj+7VFDyBpPoxdasLYpSaMXWrC2KUmfn2ed3Zzdtd72DPPu5Ra+V/+hzfqaja6bqrYk9wH/B1wE/APVfXoVrd/D3v4w9w7zV1K2sIzdXrT6yZ+Gp/kJuDvgY8BdwGHk9w16c+TNFvTvGY/CLxSVa9W1RvAN4FD44wlaWzTxH478ON1ly8M235JkqNJVpOsXuPqFHcnaRrTxL7RmwDv+OxtVR2rqpWqWtnF7inuTtI0pon9AnDHussfAC5ON46kWZkm9meBO5N8KMnNwCeBk+OMJWlsEx96q6o3kzwEfJe1Q2/Hq+ql0SaTNKqpjrNX1VPAUyPNImmG/Lis1ISxS00Yu9SEsUtNGLvUhLFLTRi71ISxS00Yu9SEsUtNGLvUhLFLTRi71ISxS00Yu9SEsUtNGLvUhLFLTRi71ISxS00Yu9SEsUtNGLvUhLFLTRi71ISxS00Yu9SEsUtNGLvUxFSruErb+e7FM5te9+e/dWBuc2jK2JOcB14H3gLerKqVMYaSNL4x9ux/WlU/HeHnSJohX7NLTUwbewHfS/JckqMb3SDJ0SSrSVavcXXKu5M0qWmfxt9dVReT3AqcSvLvVfX0+htU1THgGMBvZF9NeX+SJjTVnr2qLg6nV4AngINjDCVpfBPHnmRPkve9fR74KHB2rMEkjWuap/G3AU8kefvn/HNV/csoU+mG4bH05TFx7FX1KvD7I84iaYY89CY1YexSE8YuNWHsUhPGLjXhr7hqKlv9Cit46G2ZuGeXmjB2qQljl5owdqkJY5eaMHapCWOXmvA4u6bicfRfHe7ZpSaMXWrC2KUmjF1qwtilJoxdasLYpSY8zq6F8Xfh58s9u9SEsUtNGLvUhLFLTRi71ISxS00Yu9SEx9m1MLM+jr7VcfyOx/C33bMnOZ7kSpKz67btS3IqycvD6d7ZjilpWjt5Gv914L7rtj0MnK6qO4HTw2VJS2zb2KvqaeC16zYfAk4M508AD4w7lqSxTfoG3W1VdQlgOL11sxsmOZpkNcnqNa5OeHeSpjXzd+Or6lhVrVTVyi52z/ruJG1i0tgvJ9kPMJxeGW8kSbMwaewngSPD+SPAk+OMI2lWtj3OnuRx4B7gliQXgC8CjwLfSvIg8CPgE7McUppEx2PpW9k29qo6vMlV9448i6QZ8uOyUhPGLjVh7FITxi41YexSE/6Kq6bin4P+1eGeXWrC2KUmjF1qwtilJoxdasLYpSaMXWrC4+yayiyPo3sMf1zu2aUmjF1qwtilJoxdasLYpSaMXWrC2KUmPM5+A5jl0sSLPNbtcfRxuWeXmjB2qQljl5owdqkJY5eaMHapCWOXmvA4+w3AY93aiW337EmOJ7mS5Oy6bY8k+UmSM8PX/bMdU9K0dvI0/uvAfRts/0pVHRi+nhp3LElj2zb2qnoaeG0Os0iaoWneoHsoyQvD0/y9m90oydEkq0lWr3F1iruTNI1JY/8q8GHgAHAJ+NJmN6yqY1W1UlUru9g94d1JmtZEsVfV5ap6q6p+AXwNODjuWJLGNlHsSfavu/hx4Oxmt5W0HLY9zp7kceAe4JYkF4AvAvckOQAUcB749OxGlDSGbWOvqsMbbH5sBrNImiE/Lis1YexSE8YuNWHsUhPGLjVh7FITxi41YexSE8YuNWHsUhPGLjVh7FITxi414Z+SvgHMcslm3Tjcs0tNGLvUhLFLTRi71ISxS00Yu9SEsUtNeJz9BuCxdO2Ee3apCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmtg29iR3JPl+knNJXkry2WH7viSnkrw8nO6d/biSJrWTPfubwOer6neBPwI+k+Qu4GHgdFXdCZweLktaUtvGXlWXqur54fzrwDngduAQcGK42QnggRnNKGkE7+o1e5IPAh8BngFuq6pLsPYfAnDrJt9zNMlqktVrXJ1yXEmT2nHsSd4LfBv4XFX9bKffV1XHqmqlqlZ2sXuSGSWNYEexJ9nFWujfqKrvDJsvJ9k/XL8fuDKbESWNYdtfcU0S4DHgXFV9ed1VJ4EjwKPD6ZMzmfAGsNWfegZ/RVXzsZPfZ78b+BTwYpIzw7YvsBb5t5I8CPwI+MRMJpQ0im1jr6ofANnk6nvHHUfSrPgJOqkJY5eaMHapCWOXmjB2qQn/lPQceBxdy8A9u9SEsUtNGLvUhLFLTRi71ISxS00Yu9SEsUtNGLvUhLFLTRi71ISxS00Yu9SEsUtNGLvUhLFLTRi71ISxS00Yu9SEsUtNGLvUhLFLTRi71MS2sSe5I8n3k5xL8lKSzw7bH0nykyRnhq/7Zz+upEntZJGIN4HPV9XzSd4HPJfk1HDdV6rqb2c3nqSx7GR99kvApeH860nOAbfPejBJ43pXr9mTfBD4CPDMsOmhJC8kOZ5k7ybfczTJapLVa1ydblpJE9tx7EneC3wb+FxV/Qz4KvBh4ABre/4vbfR9VXWsqlaqamUXu6efWNJEdhR7kl2shf6NqvoOQFVdrqq3quoXwNeAg7MbU9K0dvJufIDHgHNV9eV12/evu9nHgbPjjydpLDt5N/5u4FPAi0nODNu+ABxOcgAo4Dzw6RnMJ2kkO3k3/gdANrjqqfHHkTQrfoJOasLYpSaMXWrC2KUmjF1qwtilJoxdasLYpSaMXWrC2KUmjF1qwtilJoxdasLYpSZSVfO7s+S/gR+u23QL8NO5DfDuLOtsyzoXONukxpztt6vqNze6Yq6xv+POk9WqWlnYAFtY1tmWdS5wtknNazafxktNGLvUxKJjP7bg+9/Kss62rHOBs01qLrMt9DW7pPlZ9J5d0pwYu9TEQmJPcl+S/0jySpKHFzHDZpKcT/LisAz16oJnOZ7kSpKz67btS3IqycvD6YZr7C1otqVYxnuLZcYX+tgtevnzub9mT3IT8J/AnwEXgGeBw1X1b3MdZBNJzgMrVbXwD2Ak+RPg58A/VtXvDdv+Bnitqh4d/qPcW1V/tSSzPQL8fNHLeA+rFe1fv8w48ADwlyzwsdtirr9gDo/bIvbsB4FXqurVqnoD+CZwaAFzLL2qehp47brNh4ATw/kTrP1jmbtNZlsKVXWpqp4fzr8OvL3M+EIfuy3mmotFxH478ON1ly+wXOu9F/C9JM8lObroYTZwW1VdgrV/PMCtC57netsu4z1P1y0zvjSP3STLn09rEbFvtJTUMh3/u7uq/gD4GPCZ4emqdmZHy3jPywbLjC+FSZc/n9YiYr8A3LHu8geAiwuYY0NVdXE4vQI8wfItRX357RV0h9MrC57n/y3TMt4bLTPOEjx2i1z+fBGxPwvcmeRDSW4GPgmcXMAc75Bkz/DGCUn2AB9l+ZaiPgkcGc4fAZ5c4Cy/ZFmW8d5smXEW/NgtfPnzqpr7F3A/a+/I/xfw14uYYZO5fgf41+HrpUXPBjzO2tO6a6w9I3oQeD9wGnh5ON23RLP9E/Ai8AJrYe1f0Gx/zNpLwxeAM8PX/Yt+7LaYay6Pmx+XlZrwE3RSE8YuNWHsUhPGLjVh7FITxi41YexSE/8H9CRgiL18UF8AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "inds = selector.logits.argmax(dim=1).cpu().data.numpy()\n",
    "zeros = np.zeros(784)\n",
    "zeros[inds] = 1\n",
    "plt.figure()\n",
    "plt.imshow(zeros.reshape(28, 28))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "perceived-blind",
   "metadata": {},
   "source": [
    "# Pretrain model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "separate-going",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.2429\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.2722\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.2814\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.2903\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.3019\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.3013\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.3003\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.3001\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.3069\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.3032\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.3032\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.3076\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.3149\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.3082\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.3147\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.3142\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.3229\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.3037\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.3084\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.3177\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = -0.3126\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = -0.3109\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = -0.3119\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = -0.3233\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = -0.3218\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = -0.3208\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = -0.3126\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = -0.3180\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = -0.3228\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = -0.3201\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = -0.3238\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = -0.3193\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = -0.3164\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = -0.3164\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = -0.3272\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = -0.3186\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = -0.3186\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = -0.3292\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = -0.3300\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = -0.3237\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = -0.3262\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = -0.3297\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = -0.3309\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = -0.3235\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = -0.3238\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = -0.3227\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = -0.3255\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = -0.3271\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = -0.3315\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = -0.3262\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = -0.3307\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = -0.3284\n",
      "\n",
      "--------Epoch 53--------\n",
      "Val loss = -0.3305\n",
      "\n",
      "--------Epoch 54--------\n",
      "Val loss = -0.3268\n",
      "\n",
      "--------Epoch 55--------\n",
      "Val loss = -0.3303\n",
      "\n",
      "--------Epoch 56--------\n",
      "Val loss = -0.3236\n",
      "\n",
      "--------Epoch 57--------\n",
      "Val loss = -0.3264\n",
      "\n",
      "--------Epoch 58--------\n",
      "Val loss = -0.3334\n",
      "\n",
      "--------Epoch 59--------\n",
      "Val loss = -0.3348\n",
      "\n",
      "--------Epoch 60--------\n",
      "Val loss = -0.3279\n",
      "\n",
      "--------Epoch 61--------\n",
      "Val loss = -0.3311\n",
      "\n",
      "--------Epoch 62--------\n",
      "Val loss = -0.3331\n",
      "\n",
      "--------Epoch 63--------\n",
      "Val loss = -0.3315\n",
      "\n",
      "--------Epoch 64--------\n",
      "Val loss = -0.3344\n",
      "\n",
      "--------Epoch 65--------\n",
      "Val loss = -0.3384\n",
      "\n",
      "--------Epoch 66--------\n",
      "Val loss = -0.3275\n",
      "\n",
      "--------Epoch 67--------\n",
      "Val loss = -0.3392\n",
      "\n",
      "--------Epoch 68--------\n",
      "Val loss = -0.3246\n",
      "\n",
      "--------Epoch 69--------\n",
      "Val loss = -0.3286\n",
      "\n",
      "--------Epoch 70--------\n",
      "Val loss = -0.3281\n",
      "\n",
      "--------Epoch 71--------\n",
      "Val loss = -0.3347\n",
      "\n",
      "--------Epoch 72--------\n",
      "Val loss = -0.3306\n",
      "\n",
      "--------Epoch 73--------\n",
      "Val loss = -0.3419\n",
      "\n",
      "--------Epoch 74--------\n",
      "Val loss = -0.3319\n",
      "\n",
      "--------Epoch 75--------\n",
      "Val loss = -0.3397\n",
      "\n",
      "--------Epoch 76--------\n",
      "Val loss = -0.3343\n",
      "\n",
      "--------Epoch 77--------\n",
      "Val loss = -0.3413\n",
      "\n",
      "--------Epoch 78--------\n",
      "Val loss = -0.3303\n",
      "\n",
      "--------Epoch 79--------\n",
      "Val loss = -0.3359\n",
      "\n",
      "--------Epoch 80--------\n",
      "Val loss = -0.3306\n",
      "\n",
      "--------Epoch 81--------\n",
      "Val loss = -0.3332\n",
      "\n",
      "--------Epoch 82--------\n",
      "Val loss = -0.3365\n",
      "\n",
      "--------Epoch 83--------\n",
      "Val loss = -0.3353\n",
      "\n",
      "--------Epoch 84--------\n",
      "Val loss = -0.3359\n",
      "\n",
      "--------Epoch 85--------\n",
      "Val loss = -0.3466\n",
      "\n",
      "--------Epoch 86--------\n",
      "Val loss = -0.3319\n",
      "\n",
      "--------Epoch 87--------\n",
      "Val loss = -0.3361\n",
      "\n",
      "--------Epoch 88--------\n",
      "Val loss = -0.3405\n",
      "\n",
      "--------Epoch 89--------\n",
      "Val loss = -0.3344\n",
      "\n",
      "--------Epoch 90--------\n",
      "Val loss = -0.3397\n",
      "\n",
      "--------Epoch 91--------\n",
      "Val loss = -0.3370\n",
      "\n",
      "--------Epoch 92--------\n",
      "Val loss = -0.3399\n",
      "\n",
      "--------Epoch 93--------\n",
      "Val loss = -0.3352\n",
      "\n",
      "--------Epoch 94--------\n",
      "Val loss = -0.3307\n",
      "\n",
      "--------Epoch 95--------\n",
      "Val loss = -0.3460\n",
      "\n",
      "--------Epoch 96--------\n",
      "Val loss = -0.3442\n",
      "\n",
      "--------Epoch 97--------\n",
      "Val loss = -0.3310\n",
      "\n",
      "--------Epoch 98--------\n",
      "Val loss = -0.3310\n",
      "\n",
      "--------Epoch 99--------\n",
      "Val loss = -0.3315\n",
      "\n",
      "--------Epoch 100--------\n",
      "Val loss = -0.3330\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up model\n",
    "model = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_out))\n",
    "mask_layer = MaskLayer(append=True)\n",
    "pretrain = Pretrainer(model, mask_layer).to(device)\n",
    "\n",
    "# Pretrain\n",
    "pretrain.fit(train_dataset,\n",
    "             val_dataset,\n",
    "             mbsize=128,\n",
    "             lr=1e-3,\n",
    "             nepochs=100,\n",
    "             max_features=max_features,\n",
    "             loss_fn=nn.CrossEntropyLoss(),\n",
    "             val_loss_fn=NegAccuracy())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dated-television",
   "metadata": {},
   "source": [
    "# Normal training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "cognitive-blind",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.6435\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.6944\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.7070\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.7170\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.7076\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.7123\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.7092\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.7031\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.6897\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.6983\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.6886\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.6835\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.6842\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.6785\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.6668\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.6688\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.6456\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.6321\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.6414\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.6539\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = -0.6453\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = -0.6362\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = -0.6533\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = -0.6330\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = -0.6656\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = -0.6719\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = -0.6671\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = -0.6437\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = -0.6564\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = -0.6514\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = -0.6542\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = -0.6388\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = -0.6508\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = -0.6614\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = -0.6565\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = -0.6484\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = -0.6533\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = -0.6606\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = -0.6573\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = -0.6475\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = -0.6541\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = -0.6460\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = -0.6559\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = -0.6585\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = -0.6665\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = -0.6499\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = -0.6533\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = -0.6469\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = -0.6422\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = -0.6397\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = -0.6486\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = -0.6605\n",
      "\n",
      "--------Epoch 53--------\n",
      "Val loss = -0.6466\n",
      "\n",
      "--------Epoch 54--------\n",
      "Val loss = -0.6596\n",
      "\n",
      "--------Epoch 55--------\n",
      "Val loss = -0.6490\n",
      "\n",
      "--------Epoch 56--------\n",
      "Val loss = -0.6650\n",
      "\n",
      "--------Epoch 57--------\n",
      "Val loss = -0.6424\n",
      "\n",
      "--------Epoch 58--------\n",
      "Val loss = -0.6531\n",
      "\n",
      "--------Epoch 59--------\n",
      "Val loss = -0.6594\n",
      "\n",
      "--------Epoch 60--------\n",
      "Val loss = -0.6617\n",
      "\n",
      "--------Epoch 61--------\n",
      "Val loss = -0.6560\n",
      "\n",
      "--------Epoch 62--------\n",
      "Val loss = -0.6582\n",
      "\n",
      "--------Epoch 63--------\n",
      "Val loss = -0.6570\n",
      "\n",
      "--------Epoch 64--------\n",
      "Val loss = -0.6634\n",
      "\n",
      "--------Epoch 65--------\n",
      "Val loss = -0.6695\n",
      "\n",
      "--------Epoch 66--------\n",
      "Val loss = -0.6589\n",
      "\n",
      "--------Epoch 67--------\n",
      "Val loss = -0.6675\n",
      "\n",
      "--------Epoch 68--------\n",
      "Val loss = -0.6713\n",
      "\n",
      "--------Epoch 69--------\n",
      "Val loss = -0.6611\n",
      "\n",
      "--------Epoch 70--------\n",
      "Val loss = -0.6664\n",
      "\n",
      "--------Epoch 71--------\n",
      "Val loss = -0.6757\n",
      "\n",
      "--------Epoch 72--------\n",
      "Val loss = -0.6835\n",
      "\n",
      "--------Epoch 73--------\n",
      "Val loss = -0.6716\n",
      "\n",
      "--------Epoch 74--------\n",
      "Val loss = -0.6783\n",
      "\n",
      "--------Epoch 75--------\n",
      "Val loss = -0.6878\n",
      "\n",
      "--------Epoch 76--------\n",
      "Val loss = -0.6809\n",
      "\n",
      "--------Epoch 77--------\n",
      "Val loss = -0.6857\n",
      "\n",
      "--------Epoch 78--------\n",
      "Val loss = -0.6841\n",
      "\n",
      "--------Epoch 79--------\n",
      "Val loss = -0.6896\n",
      "\n",
      "--------Epoch 80--------\n",
      "Val loss = -0.7039\n",
      "\n",
      "--------Epoch 81--------\n",
      "Val loss = -0.6972\n",
      "\n",
      "--------Epoch 82--------\n",
      "Val loss = -0.7015\n",
      "\n",
      "--------Epoch 83--------\n",
      "Val loss = -0.7070\n",
      "\n",
      "--------Epoch 84--------\n",
      "Val loss = -0.7097\n",
      "\n",
      "--------Epoch 85--------\n",
      "Val loss = -0.7211\n",
      "\n",
      "--------Epoch 86--------\n",
      "Val loss = -0.7249\n",
      "\n",
      "--------Epoch 87--------\n",
      "Val loss = -0.7234\n",
      "\n",
      "--------Epoch 88--------\n",
      "Val loss = -0.7257\n",
      "\n",
      "--------Epoch 89--------\n",
      "Val loss = -0.7269\n",
      "\n",
      "--------Epoch 90--------\n",
      "Val loss = -0.7388\n",
      "\n",
      "--------Epoch 91--------\n",
      "Val loss = -0.7364\n",
      "\n",
      "--------Epoch 92--------\n",
      "Val loss = -0.7353\n",
      "\n",
      "--------Epoch 93--------\n",
      "Val loss = -0.7488\n",
      "\n",
      "--------Epoch 94--------\n",
      "Val loss = -0.7464\n",
      "\n",
      "--------Epoch 95--------\n",
      "Val loss = -0.7507\n",
      "\n",
      "--------Epoch 96--------\n",
      "Val loss = -0.7600\n",
      "\n",
      "--------Epoch 97--------\n",
      "Val loss = -0.7591\n",
      "\n",
      "--------Epoch 98--------\n",
      "Val loss = -0.7666\n",
      "\n",
      "--------Epoch 99--------\n",
      "Val loss = -0.7676\n",
      "\n",
      "--------Epoch 100--------\n",
      "Val loss = -0.7683\n",
      "\n",
      "--------Epoch 101--------\n",
      "Val loss = -0.7666\n",
      "\n",
      "--------Epoch 102--------\n",
      "Val loss = -0.7773\n",
      "\n",
      "--------Epoch 103--------\n",
      "Val loss = -0.7870\n",
      "\n",
      "--------Epoch 104--------\n",
      "Val loss = -0.7790\n",
      "\n",
      "--------Epoch 105--------\n",
      "Val loss = -0.7849\n",
      "\n",
      "--------Epoch 106--------\n",
      "Val loss = -0.7841\n",
      "\n",
      "--------Epoch 107--------\n",
      "Val loss = -0.7885\n",
      "\n",
      "--------Epoch 108--------\n",
      "Val loss = -0.7868\n",
      "\n",
      "--------Epoch 109--------\n",
      "Val loss = -0.7941\n",
      "\n",
      "--------Epoch 110--------\n",
      "Val loss = -0.7949\n",
      "\n",
      "--------Epoch 111--------\n",
      "Val loss = -0.7984\n",
      "\n",
      "--------Epoch 112--------\n",
      "Val loss = -0.7969\n",
      "\n",
      "--------Epoch 113--------\n",
      "Val loss = -0.8042\n",
      "\n",
      "--------Epoch 114--------\n",
      "Val loss = -0.8075\n",
      "\n",
      "--------Epoch 115--------\n",
      "Val loss = -0.8058\n",
      "\n",
      "--------Epoch 116--------\n",
      "Val loss = -0.8122\n",
      "\n",
      "--------Epoch 117--------\n",
      "Val loss = -0.8106\n",
      "\n",
      "--------Epoch 118--------\n",
      "Val loss = -0.8159\n",
      "\n",
      "--------Epoch 119--------\n",
      "Val loss = -0.8156\n",
      "\n",
      "--------Epoch 120--------\n",
      "Val loss = -0.8117\n",
      "\n",
      "--------Epoch 121--------\n",
      "Val loss = -0.8137\n",
      "\n",
      "--------Epoch 122--------\n",
      "Val loss = -0.8139\n",
      "\n",
      "--------Epoch 123--------\n",
      "Val loss = -0.8143\n",
      "\n",
      "--------Epoch 124--------\n",
      "Val loss = -0.8116\n",
      "\n",
      "--------Epoch 125--------\n",
      "Val loss = -0.8213\n",
      "\n",
      "--------Epoch 126--------\n",
      "Val loss = -0.8228\n",
      "\n",
      "--------Epoch 127--------\n",
      "Val loss = -0.8201\n",
      "\n",
      "--------Epoch 128--------\n",
      "Val loss = -0.8273\n",
      "\n",
      "--------Epoch 129--------\n",
      "Val loss = -0.8223\n",
      "\n",
      "--------Epoch 130--------\n",
      "Val loss = -0.8292\n",
      "\n",
      "--------Epoch 131--------\n",
      "Val loss = -0.8295\n",
      "\n",
      "--------Epoch 132--------\n",
      "Val loss = -0.8256\n",
      "\n",
      "--------Epoch 133--------\n",
      "Val loss = -0.8262\n",
      "\n",
      "--------Epoch 134--------\n",
      "Val loss = -0.8280\n",
      "\n",
      "--------Epoch 135--------\n",
      "Val loss = -0.8303\n",
      "\n",
      "--------Epoch 136--------\n",
      "Val loss = -0.8260\n",
      "\n",
      "--------Epoch 137--------\n",
      "Val loss = -0.8323\n",
      "\n",
      "--------Epoch 138--------\n",
      "Val loss = -0.8272\n",
      "\n",
      "--------Epoch 139--------\n",
      "Val loss = -0.8310\n",
      "\n",
      "--------Epoch 140--------\n",
      "Val loss = -0.8330\n",
      "\n",
      "--------Epoch 141--------\n",
      "Val loss = -0.8305\n",
      "\n",
      "--------Epoch 142--------\n",
      "Val loss = -0.8301\n",
      "\n",
      "--------Epoch 143--------\n",
      "Val loss = -0.8361\n",
      "\n",
      "--------Epoch 144--------\n",
      "Val loss = -0.8335\n",
      "\n",
      "--------Epoch 145--------\n",
      "Val loss = -0.8341\n",
      "\n",
      "--------Epoch 146--------\n",
      "Val loss = -0.8365\n",
      "\n",
      "--------Epoch 147--------\n",
      "Val loss = -0.8353\n",
      "\n",
      "--------Epoch 148--------\n",
      "Val loss = -0.8360\n",
      "\n",
      "--------Epoch 149--------\n",
      "Val loss = -0.8374\n",
      "\n",
      "--------Epoch 150--------\n",
      "Val loss = -0.8377\n",
      "\n",
      "--------Epoch 151--------\n",
      "Val loss = -0.8347\n",
      "\n",
      "--------Epoch 152--------\n",
      "Val loss = -0.8405\n",
      "\n",
      "--------Epoch 153--------\n",
      "Val loss = -0.8373\n",
      "\n",
      "--------Epoch 154--------\n",
      "Val loss = -0.8405\n",
      "\n",
      "--------Epoch 155--------\n",
      "Val loss = -0.8358\n",
      "\n",
      "--------Epoch 156--------\n",
      "Val loss = -0.8380\n",
      "\n",
      "--------Epoch 157--------\n",
      "Val loss = -0.8433\n",
      "\n",
      "--------Epoch 158--------\n",
      "Val loss = -0.8412\n",
      "\n",
      "--------Epoch 159--------\n",
      "Val loss = -0.8439\n",
      "\n",
      "--------Epoch 160--------\n",
      "Val loss = -0.8378\n",
      "\n",
      "--------Epoch 161--------\n",
      "Val loss = -0.8369\n",
      "\n",
      "--------Epoch 162--------\n",
      "Val loss = -0.8422\n",
      "\n",
      "--------Epoch 163--------\n",
      "Val loss = -0.8382\n",
      "\n",
      "--------Epoch 164--------\n",
      "Val loss = -0.8441\n",
      "\n",
      "--------Epoch 165--------\n",
      "Val loss = -0.8384\n",
      "\n",
      "--------Epoch 166--------\n",
      "Val loss = -0.8440\n",
      "\n",
      "--------Epoch 167--------\n",
      "Val loss = -0.8404\n",
      "\n",
      "--------Epoch 168--------\n",
      "Val loss = -0.8465\n",
      "\n",
      "--------Epoch 169--------\n",
      "Val loss = -0.8414\n",
      "\n",
      "--------Epoch 170--------\n",
      "Val loss = -0.8418\n",
      "\n",
      "--------Epoch 171--------\n",
      "Val loss = -0.8406\n",
      "\n",
      "--------Epoch 172--------\n",
      "Val loss = -0.8388\n",
      "\n",
      "--------Epoch 173--------\n",
      "Val loss = -0.8390\n",
      "\n",
      "--------Epoch 174--------\n",
      "Val loss = -0.8359\n",
      "\n",
      "--------Epoch 175--------\n",
      "Val loss = -0.8438\n",
      "\n",
      "--------Epoch 176--------\n",
      "Val loss = -0.8424\n",
      "\n",
      "--------Epoch 177--------\n",
      "Val loss = -0.8419\n",
      "\n",
      "--------Epoch 178--------\n",
      "Val loss = -0.8434\n",
      "\n",
      "--------Epoch 179--------\n",
      "Val loss = -0.8425\n",
      "\n",
      "--------Epoch 180--------\n",
      "Val loss = -0.8449\n",
      "\n",
      "--------Epoch 181--------\n",
      "Val loss = -0.8424\n",
      "\n",
      "--------Epoch 182--------\n",
      "Val loss = -0.8506\n",
      "\n",
      "--------Epoch 183--------\n",
      "Val loss = -0.8458\n",
      "\n",
      "--------Epoch 184--------\n",
      "Val loss = -0.8451\n",
      "\n",
      "--------Epoch 185--------\n",
      "Val loss = -0.8438\n",
      "\n",
      "--------Epoch 186--------\n",
      "Val loss = -0.8497\n",
      "\n",
      "--------Epoch 187--------\n",
      "Val loss = -0.8422\n",
      "\n",
      "--------Epoch 188--------\n",
      "Val loss = -0.8424\n",
      "\n",
      "--------Epoch 189--------\n",
      "Val loss = -0.8405\n",
      "\n",
      "--------Epoch 190--------\n",
      "Val loss = -0.8476\n",
      "\n",
      "--------Epoch 191--------\n",
      "Val loss = -0.8438\n",
      "\n",
      "--------Epoch 192--------\n",
      "Val loss = -0.8416\n",
      "\n",
      "--------Epoch 193--------\n",
      "Val loss = -0.8425\n",
      "\n",
      "--------Epoch 194--------\n",
      "Val loss = -0.8394\n",
      "\n",
      "--------Epoch 195--------\n",
      "Val loss = -0.8477\n",
      "\n",
      "--------Epoch 196--------\n",
      "Val loss = -0.8467\n",
      "\n",
      "--------Epoch 197--------\n",
      "Val loss = -0.8394\n",
      "\n",
      "--------Epoch 198--------\n",
      "Val loss = -0.8498\n",
      "\n",
      "--------Epoch 199--------\n",
      "Val loss = -0.8463\n",
      "\n",
      "--------Epoch 200--------\n",
      "Val loss = -0.8418\n",
      "\n",
      "--------Epoch 201--------\n",
      "Val loss = -0.8456\n",
      "\n",
      "--------Epoch 202--------\n",
      "Val loss = -0.8487\n",
      "\n",
      "--------Epoch 203--------\n",
      "Val loss = -0.8460\n",
      "\n",
      "--------Epoch 204--------\n",
      "Val loss = -0.8418\n",
      "\n",
      "--------Epoch 205--------\n",
      "Val loss = -0.8466\n",
      "\n",
      "--------Epoch 206--------\n",
      "Val loss = -0.8483\n",
      "\n",
      "--------Epoch 207--------\n",
      "Val loss = -0.8275\n",
      "\n",
      "--------Epoch 208--------\n",
      "Val loss = -0.8442\n",
      "\n",
      "--------Epoch 209--------\n",
      "Val loss = -0.8410\n",
      "\n",
      "--------Epoch 210--------\n",
      "Val loss = -0.8419\n",
      "\n",
      "--------Epoch 211--------\n",
      "Val loss = -0.8431\n",
      "\n",
      "--------Epoch 212--------\n",
      "Val loss = -0.8455\n",
      "\n",
      "--------Epoch 213--------\n",
      "Val loss = -0.8409\n",
      "\n",
      "--------Epoch 214--------\n",
      "Val loss = -0.8425\n",
      "\n",
      "--------Epoch 215--------\n",
      "Val loss = -0.8487\n",
      "\n",
      "--------Epoch 216--------\n",
      "Val loss = -0.8446\n",
      "\n",
      "--------Epoch 217--------\n",
      "Val loss = -0.8428\n",
      "\n",
      "--------Epoch 218--------\n",
      "Val loss = -0.8384\n",
      "\n",
      "--------Epoch 219--------\n",
      "Val loss = -0.8306\n",
      "\n",
      "--------Epoch 220--------\n",
      "Val loss = -0.8443\n",
      "\n",
      "--------Epoch 221--------\n",
      "Val loss = -0.8356\n",
      "\n",
      "--------Epoch 222--------\n",
      "Val loss = -0.8430\n",
      "\n",
      "--------Epoch 223--------\n",
      "Val loss = -0.8410\n",
      "\n",
      "--------Epoch 224--------\n",
      "Val loss = -0.8474\n",
      "\n",
      "--------Epoch 225--------\n",
      "Val loss = -0.8475\n",
      "\n",
      "--------Epoch 226--------\n",
      "Val loss = -0.8356\n",
      "\n",
      "--------Epoch 227--------\n",
      "Val loss = -0.8426\n",
      "\n",
      "--------Epoch 228--------\n",
      "Val loss = -0.8397\n",
      "\n",
      "--------Epoch 229--------\n",
      "Val loss = -0.8402\n",
      "\n",
      "--------Epoch 230--------\n",
      "Val loss = -0.8420\n",
      "\n",
      "--------Epoch 231--------\n",
      "Val loss = -0.8403\n",
      "\n",
      "--------Epoch 232--------\n",
      "Val loss = -0.8426\n",
      "\n",
      "--------Epoch 233--------\n",
      "Val loss = -0.8435\n",
      "\n",
      "--------Epoch 234--------\n",
      "Val loss = -0.8431\n",
      "\n",
      "--------Epoch 235--------\n",
      "Val loss = -0.8400\n",
      "\n",
      "--------Epoch 236--------\n",
      "Val loss = -0.8412\n",
      "\n",
      "--------Epoch 237--------\n",
      "Val loss = -0.8367\n",
      "\n",
      "--------Epoch 238--------\n",
      "Val loss = -0.8384\n",
      "\n",
      "--------Epoch 239--------\n",
      "Val loss = -0.8351\n",
      "\n",
      "--------Epoch 240--------\n",
      "Val loss = -0.8333\n",
      "\n",
      "--------Epoch 241--------\n",
      "Val loss = -0.8298\n",
      "\n",
      "--------Epoch 242--------\n",
      "Val loss = -0.8371\n",
      "\n",
      "--------Epoch 243--------\n",
      "Val loss = -0.8292\n",
      "\n",
      "--------Epoch 244--------\n",
      "Val loss = -0.8344\n",
      "\n",
      "--------Epoch 245--------\n",
      "Val loss = -0.8356\n",
      "\n",
      "--------Epoch 246--------\n",
      "Val loss = -0.8324\n",
      "\n",
      "--------Epoch 247--------\n",
      "Val loss = -0.8197\n",
      "\n",
      "--------Epoch 248--------\n",
      "Val loss = -0.8307\n",
      "\n",
      "--------Epoch 249--------\n",
      "Val loss = -0.8255\n",
      "\n",
      "--------Epoch 250--------\n",
      "Val loss = -0.8331\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up selector\n",
    "selector = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_in))\n",
    "selector_layer = ConcreteSelector()\n",
    "gafs = GreedyAdaptiveFS(selector, deepcopy(model), mask_layer, selector_layer).to(device)\n",
    "\n",
    "# Train\n",
    "gafs.fit(train_dataset,\n",
    "         val_dataset,\n",
    "         mbsize=128,\n",
    "         lr=2e-4,\n",
    "         nepochs=250,\n",
    "         max_features=max_features,\n",
    "         loss_fn=nn.CrossEntropyLoss(),\n",
    "         val_loss_fn=NegAccuracy(),\n",
    "         argmax=False,\n",
    "         no_repeats=False,\n",
    "         start_temp=10.0,\n",
    "         end_temp=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "fallen-april",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test acc = 86.26\n"
     ]
    }
   ],
   "source": [
    "test_acc = gafs.evaluate(test_dataset, max_features, Accuracy(), 1024)\n",
    "print(f'Test acc = {100*test_acc:.2f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "strong-wiring",
   "metadata": {},
   "source": [
    "# Warm starting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "detailed-portfolio",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.6720\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.7117\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.7234\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.7304\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.7235\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.7337\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.7357\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.7235\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.7230\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.6971\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.7187\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.6973\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.6807\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.6831\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.6815\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.7028\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.7090\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.6946\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.6692\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.6698\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = -0.6573\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = -0.6856\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = -0.6838\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = -0.6829\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = -0.6444\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = -0.6740\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = -0.6792\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = -0.6705\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = -0.6815\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = -0.6772\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = -0.6642\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = -0.6729\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = -0.6565\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = -0.6607\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = -0.6563\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = -0.6584\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = -0.6676\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = -0.6571\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = -0.6605\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = -0.6742\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = -0.6574\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = -0.6623\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = -0.6632\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = -0.6609\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = -0.6582\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = -0.6570\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = -0.6533\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = -0.6517\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = -0.6653\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = -0.6586\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = -0.6519\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = -0.6538\n",
      "\n",
      "--------Epoch 53--------\n",
      "Val loss = -0.6552\n",
      "\n",
      "--------Epoch 54--------\n",
      "Val loss = -0.6605\n",
      "\n",
      "--------Epoch 55--------\n",
      "Val loss = -0.6691\n",
      "\n",
      "--------Epoch 56--------\n",
      "Val loss = -0.6666\n",
      "\n",
      "--------Epoch 57--------\n",
      "Val loss = -0.6601\n",
      "\n",
      "--------Epoch 58--------\n",
      "Val loss = -0.6629\n",
      "\n",
      "--------Epoch 59--------\n",
      "Val loss = -0.6700\n",
      "\n",
      "--------Epoch 60--------\n",
      "Val loss = -0.6651\n",
      "\n",
      "--------Epoch 61--------\n",
      "Val loss = -0.6576\n",
      "\n",
      "--------Epoch 62--------\n",
      "Val loss = -0.6634\n",
      "\n",
      "--------Epoch 63--------\n",
      "Val loss = -0.6484\n",
      "\n",
      "--------Epoch 64--------\n",
      "Val loss = -0.6619\n",
      "\n",
      "--------Epoch 65--------\n",
      "Val loss = -0.6672\n",
      "\n",
      "--------Epoch 66--------\n",
      "Val loss = -0.6681\n",
      "\n",
      "--------Epoch 67--------\n",
      "Val loss = -0.6726\n",
      "\n",
      "--------Epoch 68--------\n",
      "Val loss = -0.6623\n",
      "\n",
      "--------Epoch 69--------\n",
      "Val loss = -0.6672\n",
      "\n",
      "--------Epoch 70--------\n",
      "Val loss = -0.6711\n",
      "\n",
      "--------Epoch 71--------\n",
      "Val loss = -0.6801\n",
      "\n",
      "--------Epoch 72--------\n",
      "Val loss = -0.6780\n",
      "\n",
      "--------Epoch 73--------\n",
      "Val loss = -0.6731\n",
      "\n",
      "--------Epoch 74--------\n",
      "Val loss = -0.6813\n",
      "\n",
      "--------Epoch 75--------\n",
      "Val loss = -0.6853\n",
      "\n",
      "--------Epoch 76--------\n",
      "Val loss = -0.6795\n",
      "\n",
      "--------Epoch 77--------\n",
      "Val loss = -0.6897\n",
      "\n",
      "--------Epoch 78--------\n",
      "Val loss = -0.6920\n",
      "\n",
      "--------Epoch 79--------\n",
      "Val loss = -0.7061\n",
      "\n",
      "--------Epoch 80--------\n",
      "Val loss = -0.7032\n",
      "\n",
      "--------Epoch 81--------\n",
      "Val loss = -0.7052\n",
      "\n",
      "--------Epoch 82--------\n",
      "Val loss = -0.7055\n",
      "\n",
      "--------Epoch 83--------\n",
      "Val loss = -0.7042\n",
      "\n",
      "--------Epoch 84--------\n",
      "Val loss = -0.7171\n",
      "\n",
      "--------Epoch 85--------\n",
      "Val loss = -0.7093\n",
      "\n",
      "--------Epoch 86--------\n",
      "Val loss = -0.7179\n",
      "\n",
      "--------Epoch 87--------\n",
      "Val loss = -0.7262\n",
      "\n",
      "--------Epoch 88--------\n",
      "Val loss = -0.7227\n",
      "\n",
      "--------Epoch 89--------\n",
      "Val loss = -0.7342\n",
      "\n",
      "--------Epoch 90--------\n",
      "Val loss = -0.7322\n",
      "\n",
      "--------Epoch 91--------\n",
      "Val loss = -0.7391\n",
      "\n",
      "--------Epoch 92--------\n",
      "Val loss = -0.7369\n",
      "\n",
      "--------Epoch 93--------\n",
      "Val loss = -0.7506\n",
      "\n",
      "--------Epoch 94--------\n",
      "Val loss = -0.7432\n",
      "\n",
      "--------Epoch 95--------\n",
      "Val loss = -0.7518\n",
      "\n",
      "--------Epoch 96--------\n",
      "Val loss = -0.7620\n",
      "\n",
      "--------Epoch 97--------\n",
      "Val loss = -0.7603\n",
      "\n",
      "--------Epoch 98--------\n",
      "Val loss = -0.7649\n",
      "\n",
      "--------Epoch 99--------\n",
      "Val loss = -0.7636\n",
      "\n",
      "--------Epoch 100--------\n",
      "Val loss = -0.7711\n",
      "\n",
      "--------Epoch 101--------\n",
      "Val loss = -0.7685\n",
      "\n",
      "--------Epoch 102--------\n",
      "Val loss = -0.7760\n",
      "\n",
      "--------Epoch 103--------\n",
      "Val loss = -0.7755\n",
      "\n",
      "--------Epoch 104--------\n",
      "Val loss = -0.7854\n",
      "\n",
      "--------Epoch 105--------\n",
      "Val loss = -0.7891\n",
      "\n",
      "--------Epoch 106--------\n",
      "Val loss = -0.7842\n",
      "\n",
      "--------Epoch 107--------\n",
      "Val loss = -0.7864\n",
      "\n",
      "--------Epoch 108--------\n",
      "Val loss = -0.7937\n",
      "\n",
      "--------Epoch 109--------\n",
      "Val loss = -0.7886\n",
      "\n",
      "--------Epoch 110--------\n",
      "Val loss = -0.7981\n",
      "\n",
      "--------Epoch 111--------\n",
      "Val loss = -0.8005\n",
      "\n",
      "--------Epoch 112--------\n",
      "Val loss = -0.8034\n",
      "\n",
      "--------Epoch 113--------\n",
      "Val loss = -0.8009\n",
      "\n",
      "--------Epoch 114--------\n",
      "Val loss = -0.8018\n",
      "\n",
      "--------Epoch 115--------\n",
      "Val loss = -0.8104\n",
      "\n",
      "--------Epoch 116--------\n",
      "Val loss = -0.8082\n",
      "\n",
      "--------Epoch 117--------\n",
      "Val loss = -0.8086\n",
      "\n",
      "--------Epoch 118--------\n",
      "Val loss = -0.8153\n",
      "\n",
      "--------Epoch 119--------\n",
      "Val loss = -0.8160\n",
      "\n",
      "--------Epoch 120--------\n",
      "Val loss = -0.8165\n",
      "\n",
      "--------Epoch 121--------\n",
      "Val loss = -0.8116\n",
      "\n",
      "--------Epoch 122--------\n",
      "Val loss = -0.8076\n",
      "\n",
      "--------Epoch 123--------\n",
      "Val loss = -0.8172\n",
      "\n",
      "--------Epoch 124--------\n",
      "Val loss = -0.8212\n",
      "\n",
      "--------Epoch 125--------\n",
      "Val loss = -0.8195\n",
      "\n",
      "--------Epoch 126--------\n",
      "Val loss = -0.8205\n",
      "\n",
      "--------Epoch 127--------\n",
      "Val loss = -0.8264\n",
      "\n",
      "--------Epoch 128--------\n",
      "Val loss = -0.8185\n",
      "\n",
      "--------Epoch 129--------\n",
      "Val loss = -0.8314\n",
      "\n",
      "--------Epoch 130--------\n",
      "Val loss = -0.8294\n",
      "\n",
      "--------Epoch 131--------\n",
      "Val loss = -0.8317\n",
      "\n",
      "--------Epoch 132--------\n",
      "Val loss = -0.8323\n",
      "\n",
      "--------Epoch 133--------\n",
      "Val loss = -0.8312\n",
      "\n",
      "--------Epoch 134--------\n",
      "Val loss = -0.8337\n",
      "\n",
      "--------Epoch 135--------\n",
      "Val loss = -0.8346\n",
      "\n",
      "--------Epoch 136--------\n",
      "Val loss = -0.8348\n",
      "\n",
      "--------Epoch 137--------\n",
      "Val loss = -0.8302\n",
      "\n",
      "--------Epoch 138--------\n",
      "Val loss = -0.8365\n",
      "\n",
      "--------Epoch 139--------\n",
      "Val loss = -0.8344\n",
      "\n",
      "--------Epoch 140--------\n",
      "Val loss = -0.8336\n",
      "\n",
      "--------Epoch 141--------\n",
      "Val loss = -0.8365\n",
      "\n",
      "--------Epoch 142--------\n",
      "Val loss = -0.8334\n",
      "\n",
      "--------Epoch 143--------\n",
      "Val loss = -0.8348\n",
      "\n",
      "--------Epoch 144--------\n",
      "Val loss = -0.8349\n",
      "\n",
      "--------Epoch 145--------\n",
      "Val loss = -0.8309\n",
      "\n",
      "--------Epoch 146--------\n",
      "Val loss = -0.8410\n",
      "\n",
      "--------Epoch 147--------\n",
      "Val loss = -0.8345\n",
      "\n",
      "--------Epoch 148--------\n",
      "Val loss = -0.8387\n",
      "\n",
      "--------Epoch 149--------\n",
      "Val loss = -0.8395\n",
      "\n",
      "--------Epoch 150--------\n",
      "Val loss = -0.8386\n",
      "\n",
      "--------Epoch 151--------\n",
      "Val loss = -0.8408\n",
      "\n",
      "--------Epoch 152--------\n",
      "Val loss = -0.8391\n",
      "\n",
      "--------Epoch 153--------\n",
      "Val loss = -0.8380\n",
      "\n",
      "--------Epoch 154--------\n",
      "Val loss = -0.8476\n",
      "\n",
      "--------Epoch 155--------\n",
      "Val loss = -0.8423\n",
      "\n",
      "--------Epoch 156--------\n",
      "Val loss = -0.8429\n",
      "\n",
      "--------Epoch 157--------\n",
      "Val loss = -0.8422\n",
      "\n",
      "--------Epoch 158--------\n",
      "Val loss = -0.8449\n",
      "\n",
      "--------Epoch 159--------\n",
      "Val loss = -0.8449\n",
      "\n",
      "--------Epoch 160--------\n",
      "Val loss = -0.8458\n",
      "\n",
      "--------Epoch 161--------\n",
      "Val loss = -0.8445\n",
      "\n",
      "--------Epoch 162--------\n",
      "Val loss = -0.8469\n",
      "\n",
      "--------Epoch 163--------\n",
      "Val loss = -0.8480\n",
      "\n",
      "--------Epoch 164--------\n",
      "Val loss = -0.8436\n",
      "\n",
      "--------Epoch 165--------\n",
      "Val loss = -0.8474\n",
      "\n",
      "--------Epoch 166--------\n",
      "Val loss = -0.8447\n",
      "\n",
      "--------Epoch 167--------\n",
      "Val loss = -0.8474\n",
      "\n",
      "--------Epoch 168--------\n",
      "Val loss = -0.8468\n",
      "\n",
      "--------Epoch 169--------\n",
      "Val loss = -0.8487\n",
      "\n",
      "--------Epoch 170--------\n",
      "Val loss = -0.8505\n",
      "\n",
      "--------Epoch 171--------\n",
      "Val loss = -0.8444\n",
      "\n",
      "--------Epoch 172--------\n",
      "Val loss = -0.8477\n",
      "\n",
      "--------Epoch 173--------\n",
      "Val loss = -0.8486\n",
      "\n",
      "--------Epoch 174--------\n",
      "Val loss = -0.8475\n",
      "\n",
      "--------Epoch 175--------\n",
      "Val loss = -0.8434\n",
      "\n",
      "--------Epoch 176--------\n",
      "Val loss = -0.8472\n",
      "\n",
      "--------Epoch 177--------\n",
      "Val loss = -0.8458\n",
      "\n",
      "--------Epoch 178--------\n",
      "Val loss = -0.8484\n",
      "\n",
      "--------Epoch 179--------\n",
      "Val loss = -0.8505\n",
      "\n",
      "--------Epoch 180--------\n",
      "Val loss = -0.8534\n",
      "\n",
      "--------Epoch 181--------\n",
      "Val loss = -0.8499\n",
      "\n",
      "--------Epoch 182--------\n",
      "Val loss = -0.8510\n",
      "\n",
      "--------Epoch 183--------\n",
      "Val loss = -0.8471\n",
      "\n",
      "--------Epoch 184--------\n",
      "Val loss = -0.8481\n",
      "\n",
      "--------Epoch 185--------\n",
      "Val loss = -0.8482\n",
      "\n",
      "--------Epoch 186--------\n",
      "Val loss = -0.8516\n",
      "\n",
      "--------Epoch 187--------\n",
      "Val loss = -0.8493\n",
      "\n",
      "--------Epoch 188--------\n",
      "Val loss = -0.8535\n",
      "\n",
      "--------Epoch 189--------\n",
      "Val loss = -0.8520\n",
      "\n",
      "--------Epoch 190--------\n",
      "Val loss = -0.8491\n",
      "\n",
      "--------Epoch 191--------\n",
      "Val loss = -0.8558\n",
      "\n",
      "--------Epoch 192--------\n",
      "Val loss = -0.8531\n",
      "\n",
      "--------Epoch 193--------\n",
      "Val loss = -0.8472\n",
      "\n",
      "--------Epoch 194--------\n",
      "Val loss = -0.8559\n",
      "\n",
      "--------Epoch 195--------\n",
      "Val loss = -0.8546\n",
      "\n",
      "--------Epoch 196--------\n",
      "Val loss = -0.8529\n",
      "\n",
      "--------Epoch 197--------\n",
      "Val loss = -0.8542\n",
      "\n",
      "--------Epoch 198--------\n",
      "Val loss = -0.8473\n",
      "\n",
      "--------Epoch 199--------\n",
      "Val loss = -0.8479\n",
      "\n",
      "--------Epoch 200--------\n",
      "Val loss = -0.8543\n",
      "\n",
      "--------Epoch 201--------\n",
      "Val loss = -0.8547\n",
      "\n",
      "--------Epoch 202--------\n",
      "Val loss = -0.8493\n",
      "\n",
      "--------Epoch 203--------\n",
      "Val loss = -0.8368\n",
      "\n",
      "--------Epoch 204--------\n",
      "Val loss = -0.8521\n",
      "\n",
      "--------Epoch 205--------\n",
      "Val loss = -0.8550\n",
      "\n",
      "--------Epoch 206--------\n",
      "Val loss = -0.8459\n",
      "\n",
      "--------Epoch 207--------\n",
      "Val loss = -0.8499\n",
      "\n",
      "--------Epoch 208--------\n",
      "Val loss = -0.8523\n",
      "\n",
      "--------Epoch 209--------\n",
      "Val loss = -0.8501\n",
      "\n",
      "--------Epoch 210--------\n",
      "Val loss = -0.8559\n",
      "\n",
      "--------Epoch 211--------\n",
      "Val loss = -0.8515\n",
      "\n",
      "--------Epoch 212--------\n",
      "Val loss = -0.8547\n",
      "\n",
      "--------Epoch 213--------\n",
      "Val loss = -0.8516\n",
      "\n",
      "--------Epoch 214--------\n",
      "Val loss = -0.8519\n",
      "\n",
      "--------Epoch 215--------\n",
      "Val loss = -0.8558\n",
      "\n",
      "--------Epoch 216--------\n",
      "Val loss = -0.8527\n",
      "\n",
      "--------Epoch 217--------\n",
      "Val loss = -0.8464\n",
      "\n",
      "--------Epoch 218--------\n",
      "Val loss = -0.8505\n",
      "\n",
      "--------Epoch 219--------\n",
      "Val loss = -0.8514\n",
      "\n",
      "--------Epoch 220--------\n",
      "Val loss = -0.8550\n",
      "\n",
      "--------Epoch 221--------\n",
      "Val loss = -0.8517\n",
      "\n",
      "--------Epoch 222--------\n",
      "Val loss = -0.8456\n",
      "\n",
      "--------Epoch 223--------\n",
      "Val loss = -0.8520\n",
      "\n",
      "--------Epoch 224--------\n",
      "Val loss = -0.8492\n",
      "\n",
      "--------Epoch 225--------\n",
      "Val loss = -0.8522\n",
      "\n",
      "--------Epoch 226--------\n",
      "Val loss = -0.8530\n",
      "\n",
      "--------Epoch 227--------\n",
      "Val loss = -0.8494\n",
      "\n",
      "--------Epoch 228--------\n",
      "Val loss = -0.8503\n",
      "\n",
      "--------Epoch 229--------\n",
      "Val loss = -0.8545\n",
      "\n",
      "--------Epoch 230--------\n",
      "Val loss = -0.8555\n",
      "\n",
      "--------Epoch 231--------\n",
      "Val loss = -0.8544\n",
      "\n",
      "--------Epoch 232--------\n",
      "Val loss = -0.8501\n",
      "\n",
      "--------Epoch 233--------\n",
      "Val loss = -0.8407\n",
      "\n",
      "--------Epoch 234--------\n",
      "Val loss = -0.8478\n",
      "\n",
      "--------Epoch 235--------\n",
      "Val loss = -0.8454\n",
      "\n",
      "--------Epoch 236--------\n",
      "Val loss = -0.8515\n",
      "\n",
      "--------Epoch 237--------\n",
      "Val loss = -0.8553\n",
      "\n",
      "--------Epoch 238--------\n",
      "Val loss = -0.8544\n",
      "\n",
      "--------Epoch 239--------\n",
      "Val loss = -0.8454\n",
      "\n",
      "--------Epoch 240--------\n",
      "Val loss = -0.8508\n",
      "\n",
      "--------Epoch 241--------\n",
      "Val loss = -0.8439\n",
      "\n",
      "--------Epoch 242--------\n",
      "Val loss = -0.8492\n",
      "\n",
      "--------Epoch 243--------\n",
      "Val loss = -0.8458\n",
      "\n",
      "--------Epoch 244--------\n",
      "Val loss = -0.8493\n",
      "\n",
      "--------Epoch 245--------\n",
      "Val loss = -0.8416\n",
      "\n",
      "--------Epoch 246--------\n",
      "Val loss = -0.8374\n",
      "\n",
      "--------Epoch 247--------\n",
      "Val loss = -0.8490\n",
      "\n",
      "--------Epoch 248--------\n",
      "Val loss = -0.8491\n",
      "\n",
      "--------Epoch 249--------\n",
      "Val loss = -0.8538\n",
      "\n",
      "--------Epoch 250--------\n",
      "Val loss = -0.8516\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up selector\n",
    "selector = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_in))\n",
    "selector_layer = ConcreteSelector()\n",
    "gafs = GreedyAdaptiveFS(selector, deepcopy(model), mask_layer, selector_layer).to(device)\n",
    "\n",
    "# Copy parameter values\n",
    "selector[0].weight.data[:] = gafs.model[0].weight\n",
    "selector[2].weight.data[:] = gafs.model[2].weight\n",
    "\n",
    "# Train\n",
    "gafs.fit(train_dataset,\n",
    "         val_dataset,\n",
    "         mbsize=128,\n",
    "         lr=2e-4,\n",
    "         nepochs=250,\n",
    "         max_features=max_features,\n",
    "         loss_fn=nn.CrossEntropyLoss(),\n",
    "         val_loss_fn=NegAccuracy(),\n",
    "         argmax=False,\n",
    "         no_repeats=False,\n",
    "         start_temp=10.0,\n",
    "         end_temp=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "neutral-grass",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test acc = 86.24\n"
     ]
    }
   ],
   "source": [
    "test_acc = gafs.evaluate(test_dataset, max_features, Accuracy(), 1024)\n",
    "print(f'Test acc = {100*test_acc:.2f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ethical-terry",
   "metadata": {},
   "source": [
    "# Weight tying"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "identical-single",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.6547\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.6881\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.6917\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.7068\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.7306\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.7141\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.7199\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.7319\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.7081\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.7211\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.7150\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.6900\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.7097\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.6998\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.6934\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.6871\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.6783\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.6943\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.6921\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.6914\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = -0.6796\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = -0.6955\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = -0.6881\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = -0.6697\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = -0.6525\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = -0.6739\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = -0.6691\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = -0.6544\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = -0.6675\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = -0.6627\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = -0.6641\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = -0.6562\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = -0.6496\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = -0.6513\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = -0.6505\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = -0.6579\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = -0.6549\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = -0.6406\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = -0.6578\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = -0.6546\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = -0.6409\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = -0.6438\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = -0.6480\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = -0.6581\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = -0.6600\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = -0.6531\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = -0.6505\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = -0.6485\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = -0.6341\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = -0.6474\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = -0.6503\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = -0.6510\n",
      "\n",
      "--------Epoch 53--------\n",
      "Val loss = -0.6468\n",
      "\n",
      "--------Epoch 54--------\n",
      "Val loss = -0.6517\n",
      "\n",
      "--------Epoch 55--------\n",
      "Val loss = -0.6487\n",
      "\n",
      "--------Epoch 56--------\n",
      "Val loss = -0.6409\n",
      "\n",
      "--------Epoch 57--------\n",
      "Val loss = -0.6442\n",
      "\n",
      "--------Epoch 58--------\n",
      "Val loss = -0.6443\n",
      "\n",
      "--------Epoch 59--------\n",
      "Val loss = -0.6457\n",
      "\n",
      "--------Epoch 60--------\n",
      "Val loss = -0.6591\n",
      "\n",
      "--------Epoch 61--------\n",
      "Val loss = -0.6560\n",
      "\n",
      "--------Epoch 62--------\n",
      "Val loss = -0.6536\n",
      "\n",
      "--------Epoch 63--------\n",
      "Val loss = -0.6544\n",
      "\n",
      "--------Epoch 64--------\n",
      "Val loss = -0.6492\n",
      "\n",
      "--------Epoch 65--------\n",
      "Val loss = -0.6577\n",
      "\n",
      "--------Epoch 66--------\n",
      "Val loss = -0.6674\n",
      "\n",
      "--------Epoch 67--------\n",
      "Val loss = -0.6543\n",
      "\n",
      "--------Epoch 68--------\n",
      "Val loss = -0.6717\n",
      "\n",
      "--------Epoch 69--------\n",
      "Val loss = -0.6720\n",
      "\n",
      "--------Epoch 70--------\n",
      "Val loss = -0.6678\n",
      "\n",
      "--------Epoch 71--------\n",
      "Val loss = -0.6561\n",
      "\n",
      "--------Epoch 72--------\n",
      "Val loss = -0.6724\n",
      "\n",
      "--------Epoch 73--------\n",
      "Val loss = -0.6794\n",
      "\n",
      "--------Epoch 74--------\n",
      "Val loss = -0.6726\n",
      "\n",
      "--------Epoch 75--------\n",
      "Val loss = -0.6755\n",
      "\n",
      "--------Epoch 76--------\n",
      "Val loss = -0.6821\n",
      "\n",
      "--------Epoch 77--------\n",
      "Val loss = -0.6797\n",
      "\n",
      "--------Epoch 78--------\n",
      "Val loss = -0.6906\n",
      "\n",
      "--------Epoch 79--------\n",
      "Val loss = -0.6915\n",
      "\n",
      "--------Epoch 80--------\n",
      "Val loss = -0.6942\n",
      "\n",
      "--------Epoch 81--------\n",
      "Val loss = -0.6971\n",
      "\n",
      "--------Epoch 82--------\n",
      "Val loss = -0.6958\n",
      "\n",
      "--------Epoch 83--------\n",
      "Val loss = -0.7090\n",
      "\n",
      "--------Epoch 84--------\n",
      "Val loss = -0.7052\n",
      "\n",
      "--------Epoch 85--------\n",
      "Val loss = -0.7070\n",
      "\n",
      "--------Epoch 86--------\n",
      "Val loss = -0.7183\n",
      "\n",
      "--------Epoch 87--------\n",
      "Val loss = -0.7074\n",
      "\n",
      "--------Epoch 88--------\n",
      "Val loss = -0.7277\n",
      "\n",
      "--------Epoch 89--------\n",
      "Val loss = -0.7157\n",
      "\n",
      "--------Epoch 90--------\n",
      "Val loss = -0.7275\n",
      "\n",
      "--------Epoch 91--------\n",
      "Val loss = -0.7294\n",
      "\n",
      "--------Epoch 92--------\n",
      "Val loss = -0.7369\n",
      "\n",
      "--------Epoch 93--------\n",
      "Val loss = -0.7281\n",
      "\n",
      "--------Epoch 94--------\n",
      "Val loss = -0.7435\n",
      "\n",
      "--------Epoch 95--------\n",
      "Val loss = -0.7430\n",
      "\n",
      "--------Epoch 96--------\n",
      "Val loss = -0.7569\n",
      "\n",
      "--------Epoch 97--------\n",
      "Val loss = -0.7565\n",
      "\n",
      "--------Epoch 98--------\n",
      "Val loss = -0.7536\n",
      "\n",
      "--------Epoch 99--------\n",
      "Val loss = -0.7597\n",
      "\n",
      "--------Epoch 100--------\n",
      "Val loss = -0.7672\n",
      "\n",
      "--------Epoch 101--------\n",
      "Val loss = -0.7643\n",
      "\n",
      "--------Epoch 102--------\n",
      "Val loss = -0.7703\n",
      "\n",
      "--------Epoch 103--------\n",
      "Val loss = -0.7697\n",
      "\n",
      "--------Epoch 104--------\n",
      "Val loss = -0.7742\n",
      "\n",
      "--------Epoch 105--------\n",
      "Val loss = -0.7771\n",
      "\n",
      "--------Epoch 106--------\n",
      "Val loss = -0.7816\n",
      "\n",
      "--------Epoch 107--------\n",
      "Val loss = -0.7845\n",
      "\n",
      "--------Epoch 108--------\n",
      "Val loss = -0.7893\n",
      "\n",
      "--------Epoch 109--------\n",
      "Val loss = -0.7914\n",
      "\n",
      "--------Epoch 110--------\n",
      "Val loss = -0.7906\n",
      "\n",
      "--------Epoch 111--------\n",
      "Val loss = -0.7947\n",
      "\n",
      "--------Epoch 112--------\n",
      "Val loss = -0.7961\n",
      "\n",
      "--------Epoch 113--------\n",
      "Val loss = -0.7988\n",
      "\n",
      "--------Epoch 114--------\n",
      "Val loss = -0.7997\n",
      "\n",
      "--------Epoch 115--------\n",
      "Val loss = -0.8039\n",
      "\n",
      "--------Epoch 116--------\n",
      "Val loss = -0.8079\n",
      "\n",
      "--------Epoch 117--------\n",
      "Val loss = -0.8055\n",
      "\n",
      "--------Epoch 118--------\n",
      "Val loss = -0.8101\n",
      "\n",
      "--------Epoch 119--------\n",
      "Val loss = -0.8116\n",
      "\n",
      "--------Epoch 120--------\n",
      "Val loss = -0.8132\n",
      "\n",
      "--------Epoch 121--------\n",
      "Val loss = -0.8204\n",
      "\n",
      "--------Epoch 122--------\n",
      "Val loss = -0.8223\n",
      "\n",
      "--------Epoch 123--------\n",
      "Val loss = -0.8183\n",
      "\n",
      "--------Epoch 124--------\n",
      "Val loss = -0.8129\n",
      "\n",
      "--------Epoch 125--------\n",
      "Val loss = -0.8205\n",
      "\n",
      "--------Epoch 126--------\n",
      "Val loss = -0.8217\n",
      "\n",
      "--------Epoch 127--------\n",
      "Val loss = -0.8228\n",
      "\n",
      "--------Epoch 128--------\n",
      "Val loss = -0.8243\n",
      "\n",
      "--------Epoch 129--------\n",
      "Val loss = -0.8267\n",
      "\n",
      "--------Epoch 130--------\n",
      "Val loss = -0.8318\n",
      "\n",
      "--------Epoch 131--------\n",
      "Val loss = -0.8302\n",
      "\n",
      "--------Epoch 132--------\n",
      "Val loss = -0.8335\n",
      "\n",
      "--------Epoch 133--------\n",
      "Val loss = -0.8308\n",
      "\n",
      "--------Epoch 134--------\n",
      "Val loss = -0.8303\n",
      "\n",
      "--------Epoch 135--------\n",
      "Val loss = -0.8353\n",
      "\n",
      "--------Epoch 136--------\n",
      "Val loss = -0.8357\n",
      "\n",
      "--------Epoch 137--------\n",
      "Val loss = -0.8368\n",
      "\n",
      "--------Epoch 138--------\n",
      "Val loss = -0.8323\n",
      "\n",
      "--------Epoch 139--------\n",
      "Val loss = -0.8328\n",
      "\n",
      "--------Epoch 140--------\n",
      "Val loss = -0.8359\n",
      "\n",
      "--------Epoch 141--------\n",
      "Val loss = -0.8303\n",
      "\n",
      "--------Epoch 142--------\n",
      "Val loss = -0.8372\n",
      "\n",
      "--------Epoch 143--------\n",
      "Val loss = -0.8411\n",
      "\n",
      "--------Epoch 144--------\n",
      "Val loss = -0.8330\n",
      "\n",
      "--------Epoch 145--------\n",
      "Val loss = -0.8389\n",
      "\n",
      "--------Epoch 146--------\n",
      "Val loss = -0.8418\n",
      "\n",
      "--------Epoch 147--------\n",
      "Val loss = -0.8422\n",
      "\n",
      "--------Epoch 148--------\n",
      "Val loss = -0.8326\n",
      "\n",
      "--------Epoch 149--------\n",
      "Val loss = -0.8419\n",
      "\n",
      "--------Epoch 150--------\n",
      "Val loss = -0.8396\n",
      "\n",
      "--------Epoch 151--------\n",
      "Val loss = -0.8395\n",
      "\n",
      "--------Epoch 152--------\n",
      "Val loss = -0.8425\n",
      "\n",
      "--------Epoch 153--------\n",
      "Val loss = -0.8441\n",
      "\n",
      "--------Epoch 154--------\n",
      "Val loss = -0.8396\n",
      "\n",
      "--------Epoch 155--------\n",
      "Val loss = -0.8445\n",
      "\n",
      "--------Epoch 156--------\n",
      "Val loss = -0.8475\n",
      "\n",
      "--------Epoch 157--------\n",
      "Val loss = -0.8455\n",
      "\n",
      "--------Epoch 158--------\n",
      "Val loss = -0.8470\n",
      "\n",
      "--------Epoch 159--------\n",
      "Val loss = -0.8444\n",
      "\n",
      "--------Epoch 160--------\n",
      "Val loss = -0.8445\n",
      "\n",
      "--------Epoch 161--------\n",
      "Val loss = -0.8466\n",
      "\n",
      "--------Epoch 162--------\n",
      "Val loss = -0.8465\n",
      "\n",
      "--------Epoch 163--------\n",
      "Val loss = -0.8482\n",
      "\n",
      "--------Epoch 164--------\n",
      "Val loss = -0.8442\n",
      "\n",
      "--------Epoch 165--------\n",
      "Val loss = -0.8434\n",
      "\n",
      "--------Epoch 166--------\n",
      "Val loss = -0.8426\n",
      "\n",
      "--------Epoch 167--------\n",
      "Val loss = -0.8530\n",
      "\n",
      "--------Epoch 168--------\n",
      "Val loss = -0.8432\n",
      "\n",
      "--------Epoch 169--------\n",
      "Val loss = -0.8496\n",
      "\n",
      "--------Epoch 170--------\n",
      "Val loss = -0.8474\n",
      "\n",
      "--------Epoch 171--------\n",
      "Val loss = -0.8513\n",
      "\n",
      "--------Epoch 172--------\n",
      "Val loss = -0.8540\n",
      "\n",
      "--------Epoch 173--------\n",
      "Val loss = -0.8503\n",
      "\n",
      "--------Epoch 174--------\n",
      "Val loss = -0.8471\n",
      "\n",
      "--------Epoch 175--------\n",
      "Val loss = -0.8485\n",
      "\n",
      "--------Epoch 176--------\n",
      "Val loss = -0.8487\n",
      "\n",
      "--------Epoch 177--------\n",
      "Val loss = -0.8484\n",
      "\n",
      "--------Epoch 178--------\n",
      "Val loss = -0.8491\n",
      "\n",
      "--------Epoch 179--------\n",
      "Val loss = -0.8506\n",
      "\n",
      "--------Epoch 180--------\n",
      "Val loss = -0.8506\n",
      "\n",
      "--------Epoch 181--------\n",
      "Val loss = -0.8478\n",
      "\n",
      "--------Epoch 182--------\n",
      "Val loss = -0.8482\n",
      "\n",
      "--------Epoch 183--------\n",
      "Val loss = -0.8505\n",
      "\n",
      "--------Epoch 184--------\n",
      "Val loss = -0.8480\n",
      "\n",
      "--------Epoch 185--------\n",
      "Val loss = -0.8545\n",
      "\n",
      "--------Epoch 186--------\n",
      "Val loss = -0.8473\n",
      "\n",
      "--------Epoch 187--------\n",
      "Val loss = -0.8459\n",
      "\n",
      "--------Epoch 188--------\n",
      "Val loss = -0.8465\n",
      "\n",
      "--------Epoch 189--------\n",
      "Val loss = -0.8506\n",
      "\n",
      "--------Epoch 190--------\n",
      "Val loss = -0.8577\n",
      "\n",
      "--------Epoch 191--------\n",
      "Val loss = -0.8519\n",
      "\n",
      "--------Epoch 192--------\n",
      "Val loss = -0.8458\n",
      "\n",
      "--------Epoch 193--------\n",
      "Val loss = -0.8515\n",
      "\n",
      "--------Epoch 194--------\n",
      "Val loss = -0.8505\n",
      "\n",
      "--------Epoch 195--------\n",
      "Val loss = -0.8518\n",
      "\n",
      "--------Epoch 196--------\n",
      "Val loss = -0.8504\n",
      "\n",
      "--------Epoch 197--------\n",
      "Val loss = -0.8540\n",
      "\n",
      "--------Epoch 198--------\n",
      "Val loss = -0.8496\n",
      "\n",
      "--------Epoch 199--------\n",
      "Val loss = -0.8539\n",
      "\n",
      "--------Epoch 200--------\n",
      "Val loss = -0.8505\n",
      "\n",
      "--------Epoch 201--------\n",
      "Val loss = -0.8528\n",
      "\n",
      "--------Epoch 202--------\n",
      "Val loss = -0.8491\n",
      "\n",
      "--------Epoch 203--------\n",
      "Val loss = -0.8537\n",
      "\n",
      "--------Epoch 204--------\n",
      "Val loss = -0.8522\n",
      "\n",
      "--------Epoch 205--------\n",
      "Val loss = -0.8462\n",
      "\n",
      "--------Epoch 206--------\n",
      "Val loss = -0.8479\n",
      "\n",
      "--------Epoch 207--------\n",
      "Val loss = -0.8521\n",
      "\n",
      "--------Epoch 208--------\n",
      "Val loss = -0.8501\n",
      "\n",
      "--------Epoch 209--------\n",
      "Val loss = -0.8487\n",
      "\n",
      "--------Epoch 210--------\n",
      "Val loss = -0.8459\n",
      "\n",
      "--------Epoch 211--------\n",
      "Val loss = -0.8430\n",
      "\n",
      "--------Epoch 212--------\n",
      "Val loss = -0.8502\n",
      "\n",
      "--------Epoch 213--------\n",
      "Val loss = -0.8504\n",
      "\n",
      "--------Epoch 214--------\n",
      "Val loss = -0.8476\n",
      "\n",
      "--------Epoch 215--------\n",
      "Val loss = -0.8481\n",
      "\n",
      "--------Epoch 216--------\n",
      "Val loss = -0.8516\n",
      "\n",
      "--------Epoch 217--------\n",
      "Val loss = -0.8488\n",
      "\n",
      "--------Epoch 218--------\n",
      "Val loss = -0.8493\n",
      "\n",
      "--------Epoch 219--------\n",
      "Val loss = -0.8490\n",
      "\n",
      "--------Epoch 220--------\n",
      "Val loss = -0.8535\n",
      "\n",
      "--------Epoch 221--------\n",
      "Val loss = -0.8480\n",
      "\n",
      "--------Epoch 222--------\n",
      "Val loss = -0.8474\n",
      "\n",
      "--------Epoch 223--------\n",
      "Val loss = -0.8478\n",
      "\n",
      "--------Epoch 224--------\n",
      "Val loss = -0.8429\n",
      "\n",
      "--------Epoch 225--------\n",
      "Val loss = -0.8522\n",
      "\n",
      "--------Epoch 226--------\n",
      "Val loss = -0.8465\n",
      "\n",
      "--------Epoch 227--------\n",
      "Val loss = -0.8446\n",
      "\n",
      "--------Epoch 228--------\n",
      "Val loss = -0.8507\n",
      "\n",
      "--------Epoch 229--------\n",
      "Val loss = -0.8392\n",
      "\n",
      "--------Epoch 230--------\n",
      "Val loss = -0.8470\n",
      "\n",
      "--------Epoch 231--------\n",
      "Val loss = -0.8459\n",
      "\n",
      "--------Epoch 232--------\n",
      "Val loss = -0.8517\n",
      "\n",
      "--------Epoch 233--------\n",
      "Val loss = -0.8493\n",
      "\n",
      "--------Epoch 234--------\n",
      "Val loss = -0.8510\n",
      "\n",
      "--------Epoch 235--------\n",
      "Val loss = -0.8500\n",
      "\n",
      "--------Epoch 236--------\n",
      "Val loss = -0.8457\n",
      "\n",
      "--------Epoch 237--------\n",
      "Val loss = -0.8488\n",
      "\n",
      "--------Epoch 238--------\n",
      "Val loss = -0.8545\n",
      "\n",
      "--------Epoch 239--------\n",
      "Val loss = -0.8550\n",
      "\n",
      "--------Epoch 240--------\n",
      "Val loss = -0.8499\n",
      "\n",
      "--------Epoch 241--------\n",
      "Val loss = -0.8474\n",
      "\n",
      "--------Epoch 242--------\n",
      "Val loss = -0.8471\n",
      "\n",
      "--------Epoch 243--------\n",
      "Val loss = -0.8499\n",
      "\n",
      "--------Epoch 244--------\n",
      "Val loss = -0.8492\n",
      "\n",
      "--------Epoch 245--------\n",
      "Val loss = -0.8511\n",
      "\n",
      "--------Epoch 246--------\n",
      "Val loss = -0.8492\n",
      "\n",
      "--------Epoch 247--------\n",
      "Val loss = -0.8516\n",
      "\n",
      "--------Epoch 248--------\n",
      "Val loss = -0.8538\n",
      "\n",
      "--------Epoch 249--------\n",
      "Val loss = -0.8477\n",
      "\n",
      "--------Epoch 250--------\n",
      "Val loss = -0.8467\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up selector\n",
    "selector = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_in))\n",
    "selector_layer = ConcreteSelector()\n",
    "gafs = GreedyAdaptiveFS(selector, deepcopy(model), mask_layer, selector_layer).to(device)\n",
    "\n",
    "# Tie weights\n",
    "selector[0].weight = nn.Parameter(gafs.model[0].weight)\n",
    "selector[2].weight = nn.Parameter(gafs.model[2].weight)\n",
    "\n",
    "# Train\n",
    "gafs.fit(train_dataset,\n",
    "         val_dataset,\n",
    "         mbsize=128,\n",
    "         lr=2e-4,\n",
    "         nepochs=250,\n",
    "         max_features=max_features,\n",
    "         loss_fn=nn.CrossEntropyLoss(),\n",
    "         val_loss_fn=NegAccuracy(),\n",
    "         argmax=False,\n",
    "         no_repeats=False,\n",
    "         start_temp=10.0,\n",
    "         end_temp=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "known-chapter",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test acc = 87.02\n"
     ]
    }
   ],
   "source": [
    "test_acc = gafs.evaluate(test_dataset, max_features, Accuracy(), 1024)\n",
    "print(f'Test acc = {100*test_acc:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "sharp-layout",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
