{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "expanded-cardiff",
   "metadata": {},
   "source": [
    "# Concrete with Gaussian noise\n",
    "\n",
    "- Someone in the lab suggested that the Concrete distribution might not be the only distribution that could yield good results for this algorithm. Rather than passing Gumbel noise through a softmax with the learned logits added, why not use Gaussian noise? I had no strong prior belief about whether this would work, but I guessed that it would.\n",
    "- It did not work. Not sure why."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "preliminary-passenger",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n",
    "from torchvision import transforms\n",
    "from torchvision.datasets import MNIST\n",
    "from copy import deepcopy\n",
    "from utils import *\n",
    "from models import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "impressive-humanity",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda', 7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "furnished-offset",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load data\n",
    "class Flatten(object):\n",
    "    def __call__(self, pic):\n",
    "        return torch.flatten(pic)\n",
    "    \n",
    "mnist_dataset = MNIST('/tmp/mnist/', download=True, train=True,\n",
    "                      transform=transforms.Compose([transforms.ToTensor(), Flatten()]))\n",
    "images = mnist_dataset.data\n",
    "targets = mnist_dataset.targets\n",
    "np.random.seed(0)\n",
    "val_inds = np.sort(np.random.choice(len(images), size=10000, replace=False))\n",
    "train_inds = np.setdiff1d(np.arange(len(images)), val_inds)\n",
    "\n",
    "# Training dataset\n",
    "train_dataset = torch.utils.data.Subset(mnist_dataset, train_inds)\n",
    "\n",
    "# Validation dataset\n",
    "val_dataset = torch.utils.data.Subset(mnist_dataset, val_inds)\n",
    "\n",
    "# Test dataset\n",
    "test_dataset = MNIST('/tmp/mnist/', download=True, train=False,\n",
    "                     transform=transforms.Compose([transforms.ToTensor(), Flatten()]))\n",
    "\n",
    "# Set input/output dimensions\n",
    "d_in = 784\n",
    "d_out = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "russian-representation",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Number of features to select\n",
    "max_features = 10"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "periodic-variety",
   "metadata": {},
   "source": [
    "# Global FS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "extended-courtesy",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.2946\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.3245\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.3048\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.3243\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.3410\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.3410\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.3505\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.3491\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.3482\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.3586\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.3588\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.3709\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.3702\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.3684\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.3778\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.3829\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.3896\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.3897\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.4051\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.3979\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = -0.3957\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = -0.4163\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = -0.4121\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = -0.4116\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = -0.4149\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = -0.4196\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = -0.4276\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = -0.4266\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = -0.4300\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = -0.4374\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = -0.4337\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = -0.4396\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = -0.4560\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = -0.4518\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = -0.4559\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = -0.4616\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = -0.4642\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = -0.4673\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = -0.4805\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = -0.4668\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = -0.4836\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = -0.4743\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = -0.4947\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = -0.4810\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = -0.4960\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = -0.4988\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = -0.5057\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = -0.5028\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = -0.5049\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = -0.5146\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = -0.5028\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = -0.5148\n",
      "\n",
      "--------Epoch 53--------\n",
      "Val loss = -0.5191\n",
      "\n",
      "--------Epoch 54--------\n",
      "Val loss = -0.5216\n",
      "\n",
      "--------Epoch 55--------\n",
      "Val loss = -0.5249\n",
      "\n",
      "--------Epoch 56--------\n",
      "Val loss = -0.5252\n",
      "\n",
      "--------Epoch 57--------\n",
      "Val loss = -0.5373\n",
      "\n",
      "--------Epoch 58--------\n",
      "Val loss = -0.5338\n",
      "\n",
      "--------Epoch 59--------\n",
      "Val loss = -0.5370\n",
      "\n",
      "--------Epoch 60--------\n",
      "Val loss = -0.5419\n",
      "\n",
      "--------Epoch 61--------\n",
      "Val loss = -0.5482\n",
      "\n",
      "--------Epoch 62--------\n",
      "Val loss = -0.5454\n",
      "\n",
      "--------Epoch 63--------\n",
      "Val loss = -0.5485\n",
      "\n",
      "--------Epoch 64--------\n",
      "Val loss = -0.5507\n",
      "\n",
      "--------Epoch 65--------\n",
      "Val loss = -0.5581\n",
      "\n",
      "--------Epoch 66--------\n",
      "Val loss = -0.5489\n",
      "\n",
      "--------Epoch 67--------\n",
      "Val loss = -0.5632\n",
      "\n",
      "--------Epoch 68--------\n",
      "Val loss = -0.5617\n",
      "\n",
      "--------Epoch 69--------\n",
      "Val loss = -0.5676\n",
      "\n",
      "--------Epoch 70--------\n",
      "Val loss = -0.5721\n",
      "\n",
      "--------Epoch 71--------\n",
      "Val loss = -0.5684\n",
      "\n",
      "--------Epoch 72--------\n",
      "Val loss = -0.5800\n",
      "\n",
      "--------Epoch 73--------\n",
      "Val loss = -0.5853\n",
      "\n",
      "--------Epoch 74--------\n",
      "Val loss = -0.5765\n",
      "\n",
      "--------Epoch 75--------\n",
      "Val loss = -0.5835\n",
      "\n",
      "--------Epoch 76--------\n",
      "Val loss = -0.5795\n",
      "\n",
      "--------Epoch 77--------\n",
      "Val loss = -0.5834\n",
      "\n",
      "--------Epoch 78--------\n",
      "Val loss = -0.5960\n",
      "\n",
      "--------Epoch 79--------\n",
      "Val loss = -0.5862\n",
      "\n",
      "--------Epoch 80--------\n",
      "Val loss = -0.6051\n",
      "\n",
      "--------Epoch 81--------\n",
      "Val loss = -0.5994\n",
      "\n",
      "--------Epoch 82--------\n",
      "Val loss = -0.6000\n",
      "\n",
      "--------Epoch 83--------\n",
      "Val loss = -0.6035\n",
      "\n",
      "--------Epoch 84--------\n",
      "Val loss = -0.6142\n",
      "\n",
      "--------Epoch 85--------\n",
      "Val loss = -0.6009\n",
      "\n",
      "--------Epoch 86--------\n",
      "Val loss = -0.6086\n",
      "\n",
      "--------Epoch 87--------\n",
      "Val loss = -0.6071\n",
      "\n",
      "--------Epoch 88--------\n",
      "Val loss = -0.6051\n",
      "\n",
      "--------Epoch 89--------\n",
      "Val loss = -0.6108\n",
      "\n",
      "--------Epoch 90--------\n",
      "Val loss = -0.6171\n",
      "\n",
      "--------Epoch 91--------\n",
      "Val loss = -0.6117\n",
      "\n",
      "--------Epoch 92--------\n",
      "Val loss = -0.6144\n",
      "\n",
      "--------Epoch 93--------\n",
      "Val loss = -0.6264\n",
      "\n",
      "--------Epoch 94--------\n",
      "Val loss = -0.6214\n",
      "\n",
      "--------Epoch 95--------\n",
      "Val loss = -0.6283\n",
      "\n",
      "--------Epoch 96--------\n",
      "Val loss = -0.6281\n",
      "\n",
      "--------Epoch 97--------\n",
      "Val loss = -0.6321\n",
      "\n",
      "--------Epoch 98--------\n",
      "Val loss = -0.6296\n",
      "\n",
      "--------Epoch 99--------\n",
      "Val loss = -0.6410\n",
      "\n",
      "--------Epoch 100--------\n",
      "Val loss = -0.6322\n",
      "\n",
      "--------Epoch 101--------\n",
      "Val loss = -0.6353\n",
      "\n",
      "--------Epoch 102--------\n",
      "Val loss = -0.6464\n",
      "\n",
      "--------Epoch 103--------\n",
      "Val loss = -0.6423\n",
      "\n",
      "--------Epoch 104--------\n",
      "Val loss = -0.6469\n",
      "\n",
      "--------Epoch 105--------\n",
      "Val loss = -0.6487\n",
      "\n",
      "--------Epoch 106--------\n",
      "Val loss = -0.6457\n",
      "\n",
      "--------Epoch 107--------\n",
      "Val loss = -0.6483\n",
      "\n",
      "--------Epoch 108--------\n",
      "Val loss = -0.6495\n",
      "\n",
      "--------Epoch 109--------\n",
      "Val loss = -0.6518\n",
      "\n",
      "--------Epoch 110--------\n",
      "Val loss = -0.6498\n",
      "\n",
      "--------Epoch 111--------\n",
      "Val loss = -0.6482\n",
      "\n",
      "--------Epoch 112--------\n",
      "Val loss = -0.6574\n",
      "\n",
      "--------Epoch 113--------\n",
      "Val loss = -0.6599\n",
      "\n",
      "--------Epoch 114--------\n",
      "Val loss = -0.6638\n",
      "\n",
      "--------Epoch 115--------\n",
      "Val loss = -0.6577\n",
      "\n",
      "--------Epoch 116--------\n",
      "Val loss = -0.6612\n",
      "\n",
      "--------Epoch 117--------\n",
      "Val loss = -0.6657\n",
      "\n",
      "--------Epoch 118--------\n",
      "Val loss = -0.6705\n",
      "\n",
      "--------Epoch 119--------\n",
      "Val loss = -0.6712\n",
      "\n",
      "--------Epoch 120--------\n",
      "Val loss = -0.6701\n",
      "\n",
      "--------Epoch 121--------\n",
      "Val loss = -0.6732\n",
      "\n",
      "--------Epoch 122--------\n",
      "Val loss = -0.6807\n",
      "\n",
      "--------Epoch 123--------\n",
      "Val loss = -0.6752\n",
      "\n",
      "--------Epoch 124--------\n",
      "Val loss = -0.6696\n",
      "\n",
      "--------Epoch 125--------\n",
      "Val loss = -0.6854\n",
      "\n",
      "--------Epoch 126--------\n",
      "Val loss = -0.6875\n",
      "\n",
      "--------Epoch 127--------\n",
      "Val loss = -0.6843\n",
      "\n",
      "--------Epoch 128--------\n",
      "Val loss = -0.6846\n",
      "\n",
      "--------Epoch 129--------\n",
      "Val loss = -0.6949\n",
      "\n",
      "--------Epoch 130--------\n",
      "Val loss = -0.6903\n",
      "\n",
      "--------Epoch 131--------\n",
      "Val loss = -0.6992\n",
      "\n",
      "--------Epoch 132--------\n",
      "Val loss = -0.6902\n",
      "\n",
      "--------Epoch 133--------\n",
      "Val loss = -0.6933\n",
      "\n",
      "--------Epoch 134--------\n",
      "Val loss = -0.6934\n",
      "\n",
      "--------Epoch 135--------\n",
      "Val loss = -0.6932\n",
      "\n",
      "--------Epoch 136--------\n",
      "Val loss = -0.6979\n",
      "\n",
      "--------Epoch 137--------\n",
      "Val loss = -0.7029\n",
      "\n",
      "--------Epoch 138--------\n",
      "Val loss = -0.6971\n",
      "\n",
      "--------Epoch 139--------\n",
      "Val loss = -0.6981\n",
      "\n",
      "--------Epoch 140--------\n",
      "Val loss = -0.7049\n",
      "\n",
      "--------Epoch 141--------\n",
      "Val loss = -0.6998\n",
      "\n",
      "--------Epoch 142--------\n",
      "Val loss = -0.7130\n",
      "\n",
      "--------Epoch 143--------\n",
      "Val loss = -0.7065\n",
      "\n",
      "--------Epoch 144--------\n",
      "Val loss = -0.7087\n",
      "\n",
      "--------Epoch 145--------\n",
      "Val loss = -0.7078\n",
      "\n",
      "--------Epoch 146--------\n",
      "Val loss = -0.7091\n",
      "\n",
      "--------Epoch 147--------\n",
      "Val loss = -0.7146\n",
      "\n",
      "--------Epoch 148--------\n",
      "Val loss = -0.7122\n",
      "\n",
      "--------Epoch 149--------\n",
      "Val loss = -0.7068\n",
      "\n",
      "--------Epoch 150--------\n",
      "Val loss = -0.7108\n",
      "\n",
      "--------Epoch 151--------\n",
      "Val loss = -0.7212\n",
      "\n",
      "--------Epoch 152--------\n",
      "Val loss = -0.7218\n",
      "\n",
      "--------Epoch 153--------\n",
      "Val loss = -0.7184\n",
      "\n",
      "--------Epoch 154--------\n",
      "Val loss = -0.7221\n",
      "\n",
      "--------Epoch 155--------\n",
      "Val loss = -0.7243\n",
      "\n",
      "--------Epoch 156--------\n",
      "Val loss = -0.7226\n",
      "\n",
      "--------Epoch 157--------\n",
      "Val loss = -0.7260\n",
      "\n",
      "--------Epoch 158--------\n",
      "Val loss = -0.7296\n",
      "\n",
      "--------Epoch 159--------\n",
      "Val loss = -0.7244\n",
      "\n",
      "--------Epoch 160--------\n",
      "Val loss = -0.7282\n",
      "\n",
      "--------Epoch 161--------\n",
      "Val loss = -0.7251\n",
      "\n",
      "--------Epoch 162--------\n",
      "Val loss = -0.7330\n",
      "\n",
      "--------Epoch 163--------\n",
      "Val loss = -0.7365\n",
      "\n",
      "--------Epoch 164--------\n",
      "Val loss = -0.7342\n",
      "\n",
      "--------Epoch 165--------\n",
      "Val loss = -0.7363\n",
      "\n",
      "--------Epoch 166--------\n",
      "Val loss = -0.7327\n",
      "\n",
      "--------Epoch 167--------\n",
      "Val loss = -0.7349\n",
      "\n",
      "--------Epoch 168--------\n",
      "Val loss = -0.7389\n",
      "\n",
      "--------Epoch 169--------\n",
      "Val loss = -0.7463\n",
      "\n",
      "--------Epoch 170--------\n",
      "Val loss = -0.7479\n",
      "\n",
      "--------Epoch 171--------\n",
      "Val loss = -0.7420\n",
      "\n",
      "--------Epoch 172--------\n",
      "Val loss = -0.7426\n",
      "\n",
      "--------Epoch 173--------\n",
      "Val loss = -0.7502\n",
      "\n",
      "--------Epoch 174--------\n",
      "Val loss = -0.7524\n",
      "\n",
      "--------Epoch 175--------\n",
      "Val loss = -0.7522\n",
      "\n",
      "--------Epoch 176--------\n",
      "Val loss = -0.7482\n",
      "\n",
      "--------Epoch 177--------\n",
      "Val loss = -0.7469\n",
      "\n",
      "--------Epoch 178--------\n",
      "Val loss = -0.7484\n",
      "\n",
      "--------Epoch 179--------\n",
      "Val loss = -0.7494\n",
      "\n",
      "--------Epoch 180--------\n",
      "Val loss = -0.7531\n",
      "\n",
      "--------Epoch 181--------\n",
      "Val loss = -0.7546\n",
      "\n",
      "--------Epoch 182--------\n",
      "Val loss = -0.7508\n",
      "\n",
      "--------Epoch 183--------\n",
      "Val loss = -0.7568\n",
      "\n",
      "--------Epoch 184--------\n",
      "Val loss = -0.7553\n",
      "\n",
      "--------Epoch 185--------\n",
      "Val loss = -0.7586\n",
      "\n",
      "--------Epoch 186--------\n",
      "Val loss = -0.7622\n",
      "\n",
      "--------Epoch 187--------\n",
      "Val loss = -0.7602\n",
      "\n",
      "--------Epoch 188--------\n",
      "Val loss = -0.7587\n",
      "\n",
      "--------Epoch 189--------\n",
      "Val loss = -0.7606\n",
      "\n",
      "--------Epoch 190--------\n",
      "Val loss = -0.7639\n",
      "\n",
      "--------Epoch 191--------\n",
      "Val loss = -0.7625\n",
      "\n",
      "--------Epoch 192--------\n",
      "Val loss = -0.7590\n",
      "\n",
      "--------Epoch 193--------\n",
      "Val loss = -0.7640\n",
      "\n",
      "--------Epoch 194--------\n",
      "Val loss = -0.7628\n",
      "\n",
      "--------Epoch 195--------\n",
      "Val loss = -0.7655\n",
      "\n",
      "--------Epoch 196--------\n",
      "Val loss = -0.7620\n",
      "\n",
      "--------Epoch 197--------\n",
      "Val loss = -0.7677\n",
      "\n",
      "--------Epoch 198--------\n",
      "Val loss = -0.7624\n",
      "\n",
      "--------Epoch 199--------\n",
      "Val loss = -0.7683\n",
      "\n",
      "--------Epoch 200--------\n",
      "Val loss = -0.7609\n",
      "\n",
      "--------Epoch 201--------\n",
      "Val loss = -0.7683\n",
      "\n",
      "--------Epoch 202--------\n",
      "Val loss = -0.7646\n",
      "\n",
      "--------Epoch 203--------\n",
      "Val loss = -0.7682\n",
      "\n",
      "--------Epoch 204--------\n",
      "Val loss = -0.7648\n",
      "\n",
      "--------Epoch 205--------\n",
      "Val loss = -0.7635\n",
      "\n",
      "--------Epoch 206--------\n",
      "Val loss = -0.7616\n",
      "\n",
      "--------Epoch 207--------\n",
      "Val loss = -0.7669\n",
      "\n",
      "--------Epoch 208--------\n",
      "Val loss = -0.7644\n",
      "\n",
      "--------Epoch 209--------\n",
      "Val loss = -0.7659\n",
      "\n",
      "--------Epoch 210--------\n",
      "Val loss = -0.7668\n",
      "\n",
      "--------Epoch 211--------\n",
      "Val loss = -0.7646\n",
      "\n",
      "--------Epoch 212--------\n",
      "Val loss = -0.7618\n",
      "\n",
      "--------Epoch 213--------\n",
      "Val loss = -0.7651\n",
      "\n",
      "--------Epoch 214--------\n",
      "Val loss = -0.7638\n",
      "\n",
      "--------Epoch 215--------\n",
      "Val loss = -0.7618\n",
      "\n",
      "--------Epoch 216--------\n",
      "Val loss = -0.7691\n",
      "\n",
      "--------Epoch 217--------\n",
      "Val loss = -0.7690\n",
      "\n",
      "--------Epoch 218--------\n",
      "Val loss = -0.7665\n",
      "\n",
      "--------Epoch 219--------\n",
      "Val loss = -0.7673\n",
      "\n",
      "--------Epoch 220--------\n",
      "Val loss = -0.7622\n",
      "\n",
      "--------Epoch 221--------\n",
      "Val loss = -0.7667\n",
      "\n",
      "--------Epoch 222--------\n",
      "Val loss = -0.7665\n",
      "\n",
      "--------Epoch 223--------\n",
      "Val loss = -0.7662\n",
      "\n",
      "--------Epoch 224--------\n",
      "Val loss = -0.7636\n",
      "\n",
      "--------Epoch 225--------\n",
      "Val loss = -0.7677\n",
      "\n",
      "--------Epoch 226--------\n",
      "Val loss = -0.7646\n",
      "\n",
      "--------Epoch 227--------\n",
      "Val loss = -0.7649\n",
      "\n",
      "--------Epoch 228--------\n",
      "Val loss = -0.7637\n",
      "\n",
      "--------Epoch 229--------\n",
      "Val loss = -0.7628\n",
      "\n",
      "--------Epoch 230--------\n",
      "Val loss = -0.7634\n",
      "\n",
      "--------Epoch 231--------\n",
      "Val loss = -0.7611\n",
      "\n",
      "--------Epoch 232--------\n",
      "Val loss = -0.7588\n",
      "\n",
      "--------Epoch 233--------\n",
      "Val loss = -0.7677\n",
      "\n",
      "--------Epoch 234--------\n",
      "Val loss = -0.7659\n",
      "\n",
      "--------Epoch 235--------\n",
      "Val loss = -0.7690\n",
      "\n",
      "--------Epoch 236--------\n",
      "Val loss = -0.7666\n",
      "\n",
      "--------Epoch 237--------\n",
      "Val loss = -0.7647\n",
      "\n",
      "--------Epoch 238--------\n",
      "Val loss = -0.7612\n",
      "\n",
      "--------Epoch 239--------\n",
      "Val loss = -0.7595\n",
      "\n",
      "--------Epoch 240--------\n",
      "Val loss = -0.7621\n",
      "\n",
      "--------Epoch 241--------\n",
      "Val loss = -0.7616\n",
      "\n",
      "--------Epoch 242--------\n",
      "Val loss = -0.7598\n",
      "\n",
      "--------Epoch 243--------\n",
      "Val loss = -0.7622\n",
      "\n",
      "--------Epoch 244--------\n",
      "Val loss = -0.7680\n",
      "\n",
      "--------Epoch 245--------\n",
      "Val loss = -0.7605\n",
      "\n",
      "--------Epoch 246--------\n",
      "Val loss = -0.7596\n",
      "\n",
      "--------Epoch 247--------\n",
      "Val loss = -0.7578\n",
      "\n",
      "--------Epoch 248--------\n",
      "Val loss = -0.7572\n",
      "\n",
      "--------Epoch 249--------\n",
      "Val loss = -0.7588\n",
      "\n",
      "--------Epoch 250--------\n",
      "Val loss = -0.7617\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up model\n",
    "global_model = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_out))\n",
    "selector = ConcreteMask(d_in, max_features, append=True)\n",
    "globalfs = GlobalSelector(global_model, selector).to(device)\n",
    "\n",
    "# Train\n",
    "globalfs.fit(train_dataset,\n",
    "             val_dataset,\n",
    "             mbsize=128,\n",
    "             lr=1e-3,\n",
    "             nepochs=250,\n",
    "             loss_fn=nn.CrossEntropyLoss(),\n",
    "             val_loss_fn=NegAccuracy(),\n",
    "             start_temp=1.0,\n",
    "             end_temp=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "boxed-uncle",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test acc = 76.63\n"
     ]
    }
   ],
   "source": [
    "test_acc = globalfs.evaluate(test_dataset, Accuracy(), 1024)\n",
    "print(f'Test acc = {100*test_acc:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "dramatic-delight",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAK1klEQVR4nO3dT6hc93mH8edbV1ZASUGKa6M6pkmDFzWFKuWiFlyKi2nqeCNnkRItggoGZRFDAlnUpIt4aUqT0EUJKLWIWlKHQGKshWkiRMBkY3xtVFuu2to1SqJISA1exClUlp23i3tcbuT7zzNn/kTv84HLzJyZe+dl0KMzM2fu/aWqkHTj+7VFDyBpPoxdasLYpSaMXWrC2KUmfn2ed3Zzdtd72DPPu5Ra+V/+hzfqaja6bqrYk9wH/B1wE/APVfXoVrd/D3v4w9w7zV1K2sIzdXrT6yZ+Gp/kJuDvgY8BdwGHk9w16c+TNFvTvGY/CLxSVa9W1RvAN4FD44wlaWzTxH478ON1ly8M235JkqNJVpOsXuPqFHcnaRrTxL7RmwDv+OxtVR2rqpWqWtnF7inuTtI0pon9AnDHussfAC5ON46kWZkm9meBO5N8KMnNwCeBk+OMJWlsEx96q6o3kzwEfJe1Q2/Hq+ql0SaTNKqpjrNX1VPAUyPNImmG/Lis1ISxS00Yu9SEsUtNGLvUhLFLTRi71ISxS00Yu9SEsUtNGLvUhLFLTRi71ISxS00Yu9SEsUtNGLvUhLFLTRi71ISxS00Yu9SEsUtNGLvUhLFLTRi71ISxS00Yu9SEsUtNGLvUxFSruErb+e7FM5te9+e/dWBuc2jK2JOcB14H3gLerKqVMYaSNL4x9ux/WlU/HeHnSJohX7NLTUwbewHfS/JckqMb3SDJ0SSrSVavcXXKu5M0qWmfxt9dVReT3AqcSvLvVfX0+htU1THgGMBvZF9NeX+SJjTVnr2qLg6nV4AngINjDCVpfBPHnmRPkve9fR74KHB2rMEkjWuap/G3AU8kefvn/HNV/csoU+mG4bH05TFx7FX1KvD7I84iaYY89CY1YexSE8YuNWHsUhPGLjXhr7hqKlv9Cit46G2ZuGeXmjB2qQljl5owdqkJY5eaMHapCWOXmvA4u6bicfRfHe7ZpSaMXWrC2KUmjF1qwtilJoxdasLYpSY8zq6F8Xfh58s9u9SEsUtNGLvUhLFLTRi71ISxS00Yu9SEx9m1MLM+jr7VcfyOx/C33bMnOZ7kSpKz67btS3IqycvD6d7ZjilpWjt5Gv914L7rtj0MnK6qO4HTw2VJS2zb2KvqaeC16zYfAk4M508AD4w7lqSxTfoG3W1VdQlgOL11sxsmOZpkNcnqNa5OeHeSpjXzd+Or6lhVrVTVyi52z/ruJG1i0tgvJ9kPMJxeGW8kSbMwaewngSPD+SPAk+OMI2lWtj3OnuRx4B7gliQXgC8CjwLfSvIg8CPgE7McUppEx2PpW9k29qo6vMlV9448i6QZ8uOyUhPGLjVh7FITxi41YexSE/6Kq6bin4P+1eGeXWrC2KUmjF1qwtilJoxdasLYpSaMXWrC4+yayiyPo3sMf1zu2aUmjF1qwtilJoxdasLYpSaMXWrC2KUmPM5+A5jl0sSLPNbtcfRxuWeXmjB2qQljl5owdqkJY5eaMHapCWOXmvA4+w3AY93aiW337EmOJ7mS5Oy6bY8k+UmSM8PX/bMdU9K0dvI0/uvAfRts/0pVHRi+nhp3LElj2zb2qnoaeG0Os0iaoWneoHsoyQvD0/y9m90oydEkq0lWr3F1iruTNI1JY/8q8GHgAHAJ+NJmN6yqY1W1UlUru9g94d1JmtZEsVfV5ap6q6p+AXwNODjuWJLGNlHsSfavu/hx4Oxmt5W0HLY9zp7kceAe4JYkF4AvAvckOQAUcB749OxGlDSGbWOvqsMbbH5sBrNImiE/Lis1YexSE8YuNWHsUhPGLjVh7FITxi41YexSE8YuNWHsUhPGLjVh7FITxi414Z+SvgHMcslm3Tjcs0tNGLvUhLFLTRi71ISxS00Yu9SEsUtNeJz9BuCxdO2Ee3apCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmtg29iR3JPl+knNJXkry2WH7viSnkrw8nO6d/biSJrWTPfubwOer6neBPwI+k+Qu4GHgdFXdCZweLktaUtvGXlWXqur54fzrwDngduAQcGK42QnggRnNKGkE7+o1e5IPAh8BngFuq6pLsPYfAnDrJt9zNMlqktVrXJ1yXEmT2nHsSd4LfBv4XFX9bKffV1XHqmqlqlZ2sXuSGSWNYEexJ9nFWujfqKrvDJsvJ9k/XL8fuDKbESWNYdtfcU0S4DHgXFV9ed1VJ4EjwKPD6ZMzmfAGsNWfegZ/RVXzsZPfZ78b+BTwYpIzw7YvsBb5t5I8CPwI+MRMJpQ0im1jr6ofANnk6nvHHUfSrPgJOqkJY5eaMHapCWOXmjB2qQn/lPQceBxdy8A9u9SEsUtNGLvUhLFLTRi71ISxS00Yu9SEsUtNGLvUhLFLTRi71ISxS00Yu9SEsUtNGLvUhLFLTRi71ISxS00Yu9SEsUtNGLvUhLFLTRi71MS2sSe5I8n3k5xL8lKSzw7bH0nykyRnhq/7Zz+upEntZJGIN4HPV9XzSd4HPJfk1HDdV6rqb2c3nqSx7GR99kvApeH860nOAbfPejBJ43pXr9mTfBD4CPDMsOmhJC8kOZ5k7ybfczTJapLVa1ydblpJE9tx7EneC3wb+FxV/Qz4KvBh4ABre/4vbfR9VXWsqlaqamUXu6efWNJEdhR7kl2shf6NqvoOQFVdrqq3quoXwNeAg7MbU9K0dvJufIDHgHNV9eV12/evu9nHgbPjjydpLDt5N/5u4FPAi0nODNu+ABxOcgAo4Dzw6RnMJ2kkO3k3/gdANrjqqfHHkTQrfoJOasLYpSaMXWrC2KUmjF1qwtilJoxdasLYpSaMXWrC2KUmjF1qwtilJoxdasLYpSZSVfO7s+S/gR+u23QL8NO5DfDuLOtsyzoXONukxpztt6vqNze6Yq6xv+POk9WqWlnYAFtY1tmWdS5wtknNazafxktNGLvUxKJjP7bg+9/Kss62rHOBs01qLrMt9DW7pPlZ9J5d0pwYu9TEQmJPcl+S/0jySpKHFzHDZpKcT/LisAz16oJnOZ7kSpKz67btS3IqycvD6YZr7C1otqVYxnuLZcYX+tgtevnzub9mT3IT8J/AnwEXgGeBw1X1b3MdZBNJzgMrVbXwD2Ak+RPg58A/VtXvDdv+Bnitqh4d/qPcW1V/tSSzPQL8fNHLeA+rFe1fv8w48ADwlyzwsdtirr9gDo/bIvbsB4FXqurVqnoD+CZwaAFzLL2qehp47brNh4ATw/kTrP1jmbtNZlsKVXWpqp4fzr8OvL3M+EIfuy3mmotFxH478ON1ly+wXOu9F/C9JM8lObroYTZwW1VdgrV/PMCtC57netsu4z1P1y0zvjSP3STLn09rEbFvtJTUMh3/u7uq/gD4GPCZ4emqdmZHy3jPywbLjC+FSZc/n9YiYr8A3LHu8geAiwuYY0NVdXE4vQI8wfItRX357RV0h9MrC57n/y3TMt4bLTPOEjx2i1z+fBGxPwvcmeRDSW4GPgmcXMAc75Bkz/DGCUn2AB9l+ZaiPgkcGc4fAZ5c4Cy/ZFmW8d5smXEW/NgtfPnzqpr7F3A/a+/I/xfw14uYYZO5fgf41+HrpUXPBjzO2tO6a6w9I3oQeD9wGnh5ON23RLP9E/Ai8AJrYe1f0Gx/zNpLwxeAM8PX/Yt+7LaYay6Pmx+XlZrwE3RSE8YuNWHsUhPGLjVh7FITxi41YexSE/8H9CRgiL18UF8AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "inds = selector.logits.argmax(dim=1).cpu().data.numpy()\n",
    "zeros = np.zeros(784)\n",
    "zeros[inds] = 1\n",
    "plt.figure()\n",
    "plt.imshow(zeros.reshape(28, 28))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "female-macintosh",
   "metadata": {},
   "source": [
    "# Pretrain model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "closed-dimension",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.2434\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.2746\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.2877\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.2843\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.2951\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.3083\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.3009\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.2999\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.3125\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.3036\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.3053\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.3032\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.3072\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.3095\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.3169\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.3122\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.3155\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.3228\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.3183\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.3096\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = -0.3181\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = -0.3210\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = -0.3226\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = -0.3298\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = -0.3206\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = -0.3223\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = -0.3250\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = -0.3129\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = -0.3218\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = -0.3230\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = -0.3204\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = -0.3222\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = -0.3237\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = -0.3257\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = -0.3182\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = -0.3245\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = -0.3233\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = -0.3331\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = -0.3219\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = -0.3264\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = -0.3238\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = -0.3280\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = -0.3283\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = -0.3255\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = -0.3315\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = -0.3279\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = -0.3341\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = -0.3216\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = -0.3191\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = -0.3244\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = -0.3350\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = -0.3279\n",
      "\n",
      "--------Epoch 53--------\n",
      "Val loss = -0.3385\n",
      "\n",
      "--------Epoch 54--------\n",
      "Val loss = -0.3317\n",
      "\n",
      "--------Epoch 55--------\n",
      "Val loss = -0.3283\n",
      "\n",
      "--------Epoch 56--------\n",
      "Val loss = -0.3319\n",
      "\n",
      "--------Epoch 57--------\n",
      "Val loss = -0.3352\n",
      "\n",
      "--------Epoch 58--------\n",
      "Val loss = -0.3301\n",
      "\n",
      "--------Epoch 59--------\n",
      "Val loss = -0.3395\n",
      "\n",
      "--------Epoch 60--------\n",
      "Val loss = -0.3292\n",
      "\n",
      "--------Epoch 61--------\n",
      "Val loss = -0.3432\n",
      "\n",
      "--------Epoch 62--------\n",
      "Val loss = -0.3402\n",
      "\n",
      "--------Epoch 63--------\n",
      "Val loss = -0.3287\n",
      "\n",
      "--------Epoch 64--------\n",
      "Val loss = -0.3390\n",
      "\n",
      "--------Epoch 65--------\n",
      "Val loss = -0.3328\n",
      "\n",
      "--------Epoch 66--------\n",
      "Val loss = -0.3289\n",
      "\n",
      "--------Epoch 67--------\n",
      "Val loss = -0.3423\n",
      "\n",
      "--------Epoch 68--------\n",
      "Val loss = -0.3384\n",
      "\n",
      "--------Epoch 69--------\n",
      "Val loss = -0.3381\n",
      "\n",
      "--------Epoch 70--------\n",
      "Val loss = -0.3298\n",
      "\n",
      "--------Epoch 71--------\n",
      "Val loss = -0.3364\n",
      "\n",
      "--------Epoch 72--------\n",
      "Val loss = -0.3426\n",
      "\n",
      "--------Epoch 73--------\n",
      "Val loss = -0.3345\n",
      "\n",
      "--------Epoch 74--------\n",
      "Val loss = -0.3313\n",
      "\n",
      "--------Epoch 75--------\n",
      "Val loss = -0.3430\n",
      "\n",
      "--------Epoch 76--------\n",
      "Val loss = -0.3402\n",
      "\n",
      "--------Epoch 77--------\n",
      "Val loss = -0.3317\n",
      "\n",
      "--------Epoch 78--------\n",
      "Val loss = -0.3362\n",
      "\n",
      "--------Epoch 79--------\n",
      "Val loss = -0.3372\n",
      "\n",
      "--------Epoch 80--------\n",
      "Val loss = -0.3358\n",
      "\n",
      "--------Epoch 81--------\n",
      "Val loss = -0.3381\n",
      "\n",
      "--------Epoch 82--------\n",
      "Val loss = -0.3425\n",
      "\n",
      "--------Epoch 83--------\n",
      "Val loss = -0.3367\n",
      "\n",
      "--------Epoch 84--------\n",
      "Val loss = -0.3339\n",
      "\n",
      "--------Epoch 85--------\n",
      "Val loss = -0.3387\n",
      "\n",
      "--------Epoch 86--------\n",
      "Val loss = -0.3398\n",
      "\n",
      "--------Epoch 87--------\n",
      "Val loss = -0.3330\n",
      "\n",
      "--------Epoch 88--------\n",
      "Val loss = -0.3387\n",
      "\n",
      "--------Epoch 89--------\n",
      "Val loss = -0.3395\n",
      "\n",
      "--------Epoch 90--------\n",
      "Val loss = -0.3345\n",
      "\n",
      "--------Epoch 91--------\n",
      "Val loss = -0.3377\n",
      "\n",
      "--------Epoch 92--------\n",
      "Val loss = -0.3411\n",
      "\n",
      "--------Epoch 93--------\n",
      "Val loss = -0.3420\n",
      "\n",
      "--------Epoch 94--------\n",
      "Val loss = -0.3314\n",
      "\n",
      "--------Epoch 95--------\n",
      "Val loss = -0.3337\n",
      "\n",
      "--------Epoch 96--------\n",
      "Val loss = -0.3423\n",
      "\n",
      "--------Epoch 97--------\n",
      "Val loss = -0.3442\n",
      "\n",
      "--------Epoch 98--------\n",
      "Val loss = -0.3408\n",
      "\n",
      "--------Epoch 99--------\n",
      "Val loss = -0.3366\n",
      "\n",
      "--------Epoch 100--------\n",
      "Val loss = -0.3382\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up model\n",
    "model = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_out))\n",
    "mask_layer = MaskLayer(append=True)\n",
    "pretrain = Pretrainer(model, mask_layer).to(device)\n",
    "\n",
    "# Pretrain\n",
    "pretrain.fit(train_dataset,\n",
    "             val_dataset,\n",
    "             mbsize=128,\n",
    "             lr=1e-3,\n",
    "             nepochs=100,\n",
    "             max_features=max_features,\n",
    "             loss_fn=nn.CrossEntropyLoss(),\n",
    "             val_loss_fn=NegAccuracy())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "quick-thread",
   "metadata": {},
   "source": [
    "# Normal training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "passive-wallace",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.6435\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.6944\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.7070\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.7170\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.7076\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.7123\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.7092\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.7031\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.6897\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.6983\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.6886\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.6835\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.6842\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.6785\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.6668\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.6688\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.6456\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.6321\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.6414\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.6539\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = -0.6453\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = -0.6362\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = -0.6533\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = -0.6330\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = -0.6656\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = -0.6719\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = -0.6671\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = -0.6437\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = -0.6564\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = -0.6514\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = -0.6542\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = -0.6388\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = -0.6508\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = -0.6614\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = -0.6565\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = -0.6484\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = -0.6533\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = -0.6606\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = -0.6573\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = -0.6475\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = -0.6541\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = -0.6460\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = -0.6559\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = -0.6585\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = -0.6665\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = -0.6499\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = -0.6533\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = -0.6469\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = -0.6422\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = -0.6397\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = -0.6486\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = -0.6605\n",
      "\n",
      "--------Epoch 53--------\n",
      "Val loss = -0.6466\n",
      "\n",
      "--------Epoch 54--------\n",
      "Val loss = -0.6596\n",
      "\n",
      "--------Epoch 55--------\n",
      "Val loss = -0.6490\n",
      "\n",
      "--------Epoch 56--------\n",
      "Val loss = -0.6650\n",
      "\n",
      "--------Epoch 57--------\n",
      "Val loss = -0.6424\n",
      "\n",
      "--------Epoch 58--------\n",
      "Val loss = -0.6531\n",
      "\n",
      "--------Epoch 59--------\n",
      "Val loss = -0.6594\n",
      "\n",
      "--------Epoch 60--------\n",
      "Val loss = -0.6617\n",
      "\n",
      "--------Epoch 61--------\n",
      "Val loss = -0.6560\n",
      "\n",
      "--------Epoch 62--------\n",
      "Val loss = -0.6582\n",
      "\n",
      "--------Epoch 63--------\n",
      "Val loss = -0.6570\n",
      "\n",
      "--------Epoch 64--------\n",
      "Val loss = -0.6634\n",
      "\n",
      "--------Epoch 65--------\n",
      "Val loss = -0.6695\n",
      "\n",
      "--------Epoch 66--------\n",
      "Val loss = -0.6589\n",
      "\n",
      "--------Epoch 67--------\n",
      "Val loss = -0.6675\n",
      "\n",
      "--------Epoch 68--------\n",
      "Val loss = -0.6713\n",
      "\n",
      "--------Epoch 69--------\n",
      "Val loss = -0.6611\n",
      "\n",
      "--------Epoch 70--------\n",
      "Val loss = -0.6664\n",
      "\n",
      "--------Epoch 71--------\n",
      "Val loss = -0.6757\n",
      "\n",
      "--------Epoch 72--------\n",
      "Val loss = -0.6835\n",
      "\n",
      "--------Epoch 73--------\n",
      "Val loss = -0.6716\n",
      "\n",
      "--------Epoch 74--------\n",
      "Val loss = -0.6783\n",
      "\n",
      "--------Epoch 75--------\n",
      "Val loss = -0.6878\n",
      "\n",
      "--------Epoch 76--------\n",
      "Val loss = -0.6809\n",
      "\n",
      "--------Epoch 77--------\n",
      "Val loss = -0.6857\n",
      "\n",
      "--------Epoch 78--------\n",
      "Val loss = -0.6841\n",
      "\n",
      "--------Epoch 79--------\n",
      "Val loss = -0.6896\n",
      "\n",
      "--------Epoch 80--------\n",
      "Val loss = -0.7039\n",
      "\n",
      "--------Epoch 81--------\n",
      "Val loss = -0.6972\n",
      "\n",
      "--------Epoch 82--------\n",
      "Val loss = -0.7015\n",
      "\n",
      "--------Epoch 83--------\n",
      "Val loss = -0.7070\n",
      "\n",
      "--------Epoch 84--------\n",
      "Val loss = -0.7097\n",
      "\n",
      "--------Epoch 85--------\n",
      "Val loss = -0.7211\n",
      "\n",
      "--------Epoch 86--------\n",
      "Val loss = -0.7249\n",
      "\n",
      "--------Epoch 87--------\n",
      "Val loss = -0.7234\n",
      "\n",
      "--------Epoch 88--------\n",
      "Val loss = -0.7257\n",
      "\n",
      "--------Epoch 89--------\n",
      "Val loss = -0.7269\n",
      "\n",
      "--------Epoch 90--------\n",
      "Val loss = -0.7388\n",
      "\n",
      "--------Epoch 91--------\n",
      "Val loss = -0.7364\n",
      "\n",
      "--------Epoch 92--------\n",
      "Val loss = -0.7353\n",
      "\n",
      "--------Epoch 93--------\n",
      "Val loss = -0.7488\n",
      "\n",
      "--------Epoch 94--------\n",
      "Val loss = -0.7464\n",
      "\n",
      "--------Epoch 95--------\n",
      "Val loss = -0.7507\n",
      "\n",
      "--------Epoch 96--------\n",
      "Val loss = -0.7600\n",
      "\n",
      "--------Epoch 97--------\n",
      "Val loss = -0.7591\n",
      "\n",
      "--------Epoch 98--------\n",
      "Val loss = -0.7666\n",
      "\n",
      "--------Epoch 99--------\n",
      "Val loss = -0.7676\n",
      "\n",
      "--------Epoch 100--------\n",
      "Val loss = -0.7683\n",
      "\n",
      "--------Epoch 101--------\n",
      "Val loss = -0.7666\n",
      "\n",
      "--------Epoch 102--------\n",
      "Val loss = -0.7773\n",
      "\n",
      "--------Epoch 103--------\n",
      "Val loss = -0.7870\n",
      "\n",
      "--------Epoch 104--------\n",
      "Val loss = -0.7790\n",
      "\n",
      "--------Epoch 105--------\n",
      "Val loss = -0.7849\n",
      "\n",
      "--------Epoch 106--------\n",
      "Val loss = -0.7841\n",
      "\n",
      "--------Epoch 107--------\n",
      "Val loss = -0.7885\n",
      "\n",
      "--------Epoch 108--------\n",
      "Val loss = -0.7868\n",
      "\n",
      "--------Epoch 109--------\n",
      "Val loss = -0.7941\n",
      "\n",
      "--------Epoch 110--------\n",
      "Val loss = -0.7949\n",
      "\n",
      "--------Epoch 111--------\n",
      "Val loss = -0.7984\n",
      "\n",
      "--------Epoch 112--------\n",
      "Val loss = -0.7969\n",
      "\n",
      "--------Epoch 113--------\n",
      "Val loss = -0.8042\n",
      "\n",
      "--------Epoch 114--------\n",
      "Val loss = -0.8075\n",
      "\n",
      "--------Epoch 115--------\n",
      "Val loss = -0.8058\n",
      "\n",
      "--------Epoch 116--------\n",
      "Val loss = -0.8122\n",
      "\n",
      "--------Epoch 117--------\n",
      "Val loss = -0.8106\n",
      "\n",
      "--------Epoch 118--------\n",
      "Val loss = -0.8159\n",
      "\n",
      "--------Epoch 119--------\n",
      "Val loss = -0.8156\n",
      "\n",
      "--------Epoch 120--------\n",
      "Val loss = -0.8117\n",
      "\n",
      "--------Epoch 121--------\n",
      "Val loss = -0.8137\n",
      "\n",
      "--------Epoch 122--------\n",
      "Val loss = -0.8139\n",
      "\n",
      "--------Epoch 123--------\n",
      "Val loss = -0.8143\n",
      "\n",
      "--------Epoch 124--------\n",
      "Val loss = -0.8116\n",
      "\n",
      "--------Epoch 125--------\n",
      "Val loss = -0.8213\n",
      "\n",
      "--------Epoch 126--------\n",
      "Val loss = -0.8228\n",
      "\n",
      "--------Epoch 127--------\n",
      "Val loss = -0.8201\n",
      "\n",
      "--------Epoch 128--------\n",
      "Val loss = -0.8273\n",
      "\n",
      "--------Epoch 129--------\n",
      "Val loss = -0.8223\n",
      "\n",
      "--------Epoch 130--------\n",
      "Val loss = -0.8292\n",
      "\n",
      "--------Epoch 131--------\n",
      "Val loss = -0.8295\n",
      "\n",
      "--------Epoch 132--------\n",
      "Val loss = -0.8256\n",
      "\n",
      "--------Epoch 133--------\n",
      "Val loss = -0.8262\n",
      "\n",
      "--------Epoch 134--------\n",
      "Val loss = -0.8280\n",
      "\n",
      "--------Epoch 135--------\n",
      "Val loss = -0.8303\n",
      "\n",
      "--------Epoch 136--------\n",
      "Val loss = -0.8260\n",
      "\n",
      "--------Epoch 137--------\n",
      "Val loss = -0.8323\n",
      "\n",
      "--------Epoch 138--------\n",
      "Val loss = -0.8272\n",
      "\n",
      "--------Epoch 139--------\n",
      "Val loss = -0.8310\n",
      "\n",
      "--------Epoch 140--------\n",
      "Val loss = -0.8330\n",
      "\n",
      "--------Epoch 141--------\n",
      "Val loss = -0.8305\n",
      "\n",
      "--------Epoch 142--------\n",
      "Val loss = -0.8301\n",
      "\n",
      "--------Epoch 143--------\n",
      "Val loss = -0.8361\n",
      "\n",
      "--------Epoch 144--------\n",
      "Val loss = -0.8335\n",
      "\n",
      "--------Epoch 145--------\n",
      "Val loss = -0.8341\n",
      "\n",
      "--------Epoch 146--------\n",
      "Val loss = -0.8365\n",
      "\n",
      "--------Epoch 147--------\n",
      "Val loss = -0.8353\n",
      "\n",
      "--------Epoch 148--------\n",
      "Val loss = -0.8360\n",
      "\n",
      "--------Epoch 149--------\n",
      "Val loss = -0.8374\n",
      "\n",
      "--------Epoch 150--------\n",
      "Val loss = -0.8377\n",
      "\n",
      "--------Epoch 151--------\n",
      "Val loss = -0.8347\n",
      "\n",
      "--------Epoch 152--------\n",
      "Val loss = -0.8405\n",
      "\n",
      "--------Epoch 153--------\n",
      "Val loss = -0.8373\n",
      "\n",
      "--------Epoch 154--------\n",
      "Val loss = -0.8405\n",
      "\n",
      "--------Epoch 155--------\n",
      "Val loss = -0.8358\n",
      "\n",
      "--------Epoch 156--------\n",
      "Val loss = -0.8380\n",
      "\n",
      "--------Epoch 157--------\n",
      "Val loss = -0.8433\n",
      "\n",
      "--------Epoch 158--------\n",
      "Val loss = -0.8412\n",
      "\n",
      "--------Epoch 159--------\n",
      "Val loss = -0.8439\n",
      "\n",
      "--------Epoch 160--------\n",
      "Val loss = -0.8378\n",
      "\n",
      "--------Epoch 161--------\n",
      "Val loss = -0.8369\n",
      "\n",
      "--------Epoch 162--------\n",
      "Val loss = -0.8422\n",
      "\n",
      "--------Epoch 163--------\n",
      "Val loss = -0.8382\n",
      "\n",
      "--------Epoch 164--------\n",
      "Val loss = -0.8441\n",
      "\n",
      "--------Epoch 165--------\n",
      "Val loss = -0.8384\n",
      "\n",
      "--------Epoch 166--------\n",
      "Val loss = -0.8440\n",
      "\n",
      "--------Epoch 167--------\n",
      "Val loss = -0.8404\n",
      "\n",
      "--------Epoch 168--------\n",
      "Val loss = -0.8465\n",
      "\n",
      "--------Epoch 169--------\n",
      "Val loss = -0.8414\n",
      "\n",
      "--------Epoch 170--------\n",
      "Val loss = -0.8418\n",
      "\n",
      "--------Epoch 171--------\n",
      "Val loss = -0.8406\n",
      "\n",
      "--------Epoch 172--------\n",
      "Val loss = -0.8388\n",
      "\n",
      "--------Epoch 173--------\n",
      "Val loss = -0.8390\n",
      "\n",
      "--------Epoch 174--------\n",
      "Val loss = -0.8359\n",
      "\n",
      "--------Epoch 175--------\n",
      "Val loss = -0.8438\n",
      "\n",
      "--------Epoch 176--------\n",
      "Val loss = -0.8424\n",
      "\n",
      "--------Epoch 177--------\n",
      "Val loss = -0.8419\n",
      "\n",
      "--------Epoch 178--------\n",
      "Val loss = -0.8434\n",
      "\n",
      "--------Epoch 179--------\n",
      "Val loss = -0.8425\n",
      "\n",
      "--------Epoch 180--------\n",
      "Val loss = -0.8449\n",
      "\n",
      "--------Epoch 181--------\n",
      "Val loss = -0.8424\n",
      "\n",
      "--------Epoch 182--------\n",
      "Val loss = -0.8506\n",
      "\n",
      "--------Epoch 183--------\n",
      "Val loss = -0.8458\n",
      "\n",
      "--------Epoch 184--------\n",
      "Val loss = -0.8451\n",
      "\n",
      "--------Epoch 185--------\n",
      "Val loss = -0.8438\n",
      "\n",
      "--------Epoch 186--------\n",
      "Val loss = -0.8497\n",
      "\n",
      "--------Epoch 187--------\n",
      "Val loss = -0.8422\n",
      "\n",
      "--------Epoch 188--------\n",
      "Val loss = -0.8424\n",
      "\n",
      "--------Epoch 189--------\n",
      "Val loss = -0.8405\n",
      "\n",
      "--------Epoch 190--------\n",
      "Val loss = -0.8476\n",
      "\n",
      "--------Epoch 191--------\n",
      "Val loss = -0.8438\n",
      "\n",
      "--------Epoch 192--------\n",
      "Val loss = -0.8416\n",
      "\n",
      "--------Epoch 193--------\n",
      "Val loss = -0.8425\n",
      "\n",
      "--------Epoch 194--------\n",
      "Val loss = -0.8394\n",
      "\n",
      "--------Epoch 195--------\n",
      "Val loss = -0.8477\n",
      "\n",
      "--------Epoch 196--------\n",
      "Val loss = -0.8467\n",
      "\n",
      "--------Epoch 197--------\n",
      "Val loss = -0.8394\n",
      "\n",
      "--------Epoch 198--------\n",
      "Val loss = -0.8498\n",
      "\n",
      "--------Epoch 199--------\n",
      "Val loss = -0.8463\n",
      "\n",
      "--------Epoch 200--------\n",
      "Val loss = -0.8418\n",
      "\n",
      "--------Epoch 201--------\n",
      "Val loss = -0.8456\n",
      "\n",
      "--------Epoch 202--------\n",
      "Val loss = -0.8487\n",
      "\n",
      "--------Epoch 203--------\n",
      "Val loss = -0.8460\n",
      "\n",
      "--------Epoch 204--------\n",
      "Val loss = -0.8418\n",
      "\n",
      "--------Epoch 205--------\n",
      "Val loss = -0.8466\n",
      "\n",
      "--------Epoch 206--------\n",
      "Val loss = -0.8483\n",
      "\n",
      "--------Epoch 207--------\n",
      "Val loss = -0.8275\n",
      "\n",
      "--------Epoch 208--------\n",
      "Val loss = -0.8442\n",
      "\n",
      "--------Epoch 209--------\n",
      "Val loss = -0.8410\n",
      "\n",
      "--------Epoch 210--------\n",
      "Val loss = -0.8419\n",
      "\n",
      "--------Epoch 211--------\n",
      "Val loss = -0.8431\n",
      "\n",
      "--------Epoch 212--------\n",
      "Val loss = -0.8455\n",
      "\n",
      "--------Epoch 213--------\n",
      "Val loss = -0.8409\n",
      "\n",
      "--------Epoch 214--------\n",
      "Val loss = -0.8425\n",
      "\n",
      "--------Epoch 215--------\n",
      "Val loss = -0.8487\n",
      "\n",
      "--------Epoch 216--------\n",
      "Val loss = -0.8446\n",
      "\n",
      "--------Epoch 217--------\n",
      "Val loss = -0.8428\n",
      "\n",
      "--------Epoch 218--------\n",
      "Val loss = -0.8384\n",
      "\n",
      "--------Epoch 219--------\n",
      "Val loss = -0.8306\n",
      "\n",
      "--------Epoch 220--------\n",
      "Val loss = -0.8443\n",
      "\n",
      "--------Epoch 221--------\n",
      "Val loss = -0.8356\n",
      "\n",
      "--------Epoch 222--------\n",
      "Val loss = -0.8430\n",
      "\n",
      "--------Epoch 223--------\n",
      "Val loss = -0.8410\n",
      "\n",
      "--------Epoch 224--------\n",
      "Val loss = -0.8474\n",
      "\n",
      "--------Epoch 225--------\n",
      "Val loss = -0.8475\n",
      "\n",
      "--------Epoch 226--------\n",
      "Val loss = -0.8356\n",
      "\n",
      "--------Epoch 227--------\n",
      "Val loss = -0.8426\n",
      "\n",
      "--------Epoch 228--------\n",
      "Val loss = -0.8397\n",
      "\n",
      "--------Epoch 229--------\n",
      "Val loss = -0.8402\n",
      "\n",
      "--------Epoch 230--------\n",
      "Val loss = -0.8420\n",
      "\n",
      "--------Epoch 231--------\n",
      "Val loss = -0.8403\n",
      "\n",
      "--------Epoch 232--------\n",
      "Val loss = -0.8426\n",
      "\n",
      "--------Epoch 233--------\n",
      "Val loss = -0.8435\n",
      "\n",
      "--------Epoch 234--------\n",
      "Val loss = -0.8431\n",
      "\n",
      "--------Epoch 235--------\n",
      "Val loss = -0.8400\n",
      "\n",
      "--------Epoch 236--------\n",
      "Val loss = -0.8412\n",
      "\n",
      "--------Epoch 237--------\n",
      "Val loss = -0.8367\n",
      "\n",
      "--------Epoch 238--------\n",
      "Val loss = -0.8384\n",
      "\n",
      "--------Epoch 239--------\n",
      "Val loss = -0.8351\n",
      "\n",
      "--------Epoch 240--------\n",
      "Val loss = -0.8333\n",
      "\n",
      "--------Epoch 241--------\n",
      "Val loss = -0.8298\n",
      "\n",
      "--------Epoch 242--------\n",
      "Val loss = -0.8371\n",
      "\n",
      "--------Epoch 243--------\n",
      "Val loss = -0.8292\n",
      "\n",
      "--------Epoch 244--------\n",
      "Val loss = -0.8344\n",
      "\n",
      "--------Epoch 245--------\n",
      "Val loss = -0.8356\n",
      "\n",
      "--------Epoch 246--------\n",
      "Val loss = -0.8324\n",
      "\n",
      "--------Epoch 247--------\n",
      "Val loss = -0.8197\n",
      "\n",
      "--------Epoch 248--------\n",
      "Val loss = -0.8307\n",
      "\n",
      "--------Epoch 249--------\n",
      "Val loss = -0.8255\n",
      "\n",
      "--------Epoch 250--------\n",
      "Val loss = -0.8331\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up selector\n",
    "selector = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_in))\n",
    "selector_layer = ConcreteSelector()\n",
    "gafs = GreedyAdaptiveFS(selector, deepcopy(model), mask_layer, selector_layer).to(device)\n",
    "\n",
    "# Train\n",
    "gafs.fit(train_dataset,\n",
    "         val_dataset,\n",
    "         mbsize=128,\n",
    "         lr=2e-4,\n",
    "         nepochs=250,\n",
    "         max_features=max_features,\n",
    "         loss_fn=nn.CrossEntropyLoss(),\n",
    "         val_loss_fn=NegAccuracy(),\n",
    "         argmax=False,\n",
    "         no_repeats=False,\n",
    "         start_temp=10.0,\n",
    "         end_temp=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "twelve-latino",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test acc = 86.26\n"
     ]
    }
   ],
   "source": [
    "test_acc = gafs.evaluate(test_dataset, max_features, Accuracy(), 1024)\n",
    "print(f'Test acc = {100*test_acc:.2f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aboriginal-corps",
   "metadata": {},
   "source": [
    "# Gaussian noise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "comfortable-walker",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ConcreteGaussian(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "    \n",
    "    def forward(self, logits, temp):\n",
    "        rand = 1e-2 * torch.randn(logits.shape, device=logits.device)\n",
    "        return torch.softmax((logits + rand) / temp, dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "processed-margin",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.1813\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.2063\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.3878\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.3292\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.4298\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.2911\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.3735\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.4934\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.4125\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.4784\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.5006\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.5711\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.5668\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.5692\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.6040\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.5582\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.5788\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.6248\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.5990\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.5897\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = -0.5987\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = -0.6270\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = -0.5864\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = -0.6451\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = -0.6278\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = -0.6255\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = -0.6060\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = -0.6204\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = -0.6840\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = -0.5775\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = -0.6833\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = -0.6403\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = -0.6405\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = -0.6605\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = -0.6625\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = -0.6082\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = -0.6116\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = -0.6578\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = -0.6445\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = -0.6515\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = -0.6373\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = -0.6581\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = -0.6442\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = -0.6515\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = -0.6061\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = -0.6421\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = -0.6388\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = -0.6350\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = -0.6533\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = -0.6343\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = -0.6422\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = -0.6572\n",
      "\n",
      "--------Epoch 53--------\n",
      "Val loss = -0.6255\n",
      "\n",
      "--------Epoch 54--------\n",
      "Val loss = -0.6574\n",
      "\n",
      "--------Epoch 55--------\n",
      "Val loss = -0.6607\n",
      "\n",
      "--------Epoch 56--------\n",
      "Val loss = -0.6351\n",
      "\n",
      "--------Epoch 57--------\n",
      "Val loss = -0.6495\n",
      "\n",
      "--------Epoch 58--------\n",
      "Val loss = -0.6422\n",
      "\n",
      "--------Epoch 59--------\n",
      "Val loss = -0.6237\n",
      "\n",
      "--------Epoch 60--------\n",
      "Val loss = -0.6278\n",
      "\n",
      "--------Epoch 61--------\n",
      "Val loss = -0.6019\n",
      "\n",
      "--------Epoch 62--------\n",
      "Val loss = -0.6065\n",
      "\n",
      "--------Epoch 63--------\n",
      "Val loss = -0.6244\n",
      "\n",
      "--------Epoch 64--------\n",
      "Val loss = -0.6139\n",
      "\n",
      "--------Epoch 65--------\n",
      "Val loss = -0.6457\n",
      "\n",
      "--------Epoch 66--------\n",
      "Val loss = -0.6177\n",
      "\n",
      "--------Epoch 67--------\n",
      "Val loss = -0.5896\n",
      "\n",
      "--------Epoch 68--------\n",
      "Val loss = -0.6033\n",
      "\n",
      "--------Epoch 69--------\n",
      "Val loss = -0.5913\n",
      "\n",
      "--------Epoch 70--------\n",
      "Val loss = -0.6172\n",
      "\n",
      "--------Epoch 71--------\n",
      "Val loss = -0.6060\n",
      "\n",
      "--------Epoch 72--------\n",
      "Val loss = -0.6175\n",
      "\n",
      "--------Epoch 73--------\n",
      "Val loss = -0.6136\n",
      "\n",
      "--------Epoch 74--------\n",
      "Val loss = -0.6112\n",
      "\n",
      "--------Epoch 75--------\n",
      "Val loss = -0.6104\n",
      "\n",
      "--------Epoch 76--------\n",
      "Val loss = -0.6260\n",
      "\n",
      "--------Epoch 77--------\n",
      "Val loss = -0.5998\n",
      "\n",
      "--------Epoch 78--------\n",
      "Val loss = -0.6062\n",
      "\n",
      "--------Epoch 79--------\n",
      "Val loss = -0.5975\n",
      "\n",
      "--------Epoch 80--------\n",
      "Val loss = -0.6013\n",
      "\n",
      "--------Epoch 81--------\n",
      "Val loss = -0.5601\n",
      "\n",
      "--------Epoch 82--------\n",
      "Val loss = -0.6141\n",
      "\n",
      "--------Epoch 83--------\n",
      "Val loss = -0.6166\n",
      "\n",
      "--------Epoch 84--------\n",
      "Val loss = -0.6010\n",
      "\n",
      "--------Epoch 85--------\n",
      "Val loss = -0.5848\n",
      "\n",
      "--------Epoch 86--------\n",
      "Val loss = -0.5702\n",
      "\n",
      "--------Epoch 87--------\n",
      "Val loss = -0.5429\n",
      "\n",
      "--------Epoch 88--------\n",
      "Val loss = -0.5610\n",
      "\n",
      "--------Epoch 89--------\n",
      "Val loss = -0.5691\n",
      "\n",
      "--------Epoch 90--------\n",
      "Val loss = -0.5353\n",
      "\n",
      "--------Epoch 91--------\n",
      "Val loss = -0.5429\n",
      "\n",
      "--------Epoch 92--------\n",
      "Val loss = -0.5573\n",
      "\n",
      "--------Epoch 93--------\n",
      "Val loss = -0.5419\n",
      "\n",
      "--------Epoch 94--------\n",
      "Val loss = -0.5491\n",
      "\n",
      "--------Epoch 95--------\n",
      "Val loss = -0.5684\n",
      "\n",
      "--------Epoch 96--------\n",
      "Val loss = -0.5523\n",
      "\n",
      "--------Epoch 97--------\n",
      "Val loss = -0.5239\n",
      "\n",
      "--------Epoch 98--------\n",
      "Val loss = -0.5400\n",
      "\n",
      "--------Epoch 99--------\n",
      "Val loss = -0.5420\n",
      "\n",
      "--------Epoch 100--------\n",
      "Val loss = -0.5346\n",
      "\n",
      "--------Epoch 101--------\n",
      "Val loss = -0.5722\n",
      "\n",
      "--------Epoch 102--------\n",
      "Val loss = -0.5394\n",
      "\n",
      "--------Epoch 103--------\n",
      "Val loss = -0.5570\n",
      "\n",
      "--------Epoch 104--------\n",
      "Val loss = -0.5550\n",
      "\n",
      "--------Epoch 105--------\n",
      "Val loss = -0.5727\n",
      "\n",
      "--------Epoch 106--------\n",
      "Val loss = -0.5679\n",
      "\n",
      "--------Epoch 107--------\n",
      "Val loss = -0.5430\n",
      "\n",
      "--------Epoch 108--------\n",
      "Val loss = -0.5540\n",
      "\n",
      "--------Epoch 109--------\n",
      "Val loss = -0.5566\n",
      "\n",
      "--------Epoch 110--------\n",
      "Val loss = -0.5746\n",
      "\n",
      "--------Epoch 111--------\n",
      "Val loss = -0.5557\n",
      "\n",
      "--------Epoch 112--------\n",
      "Val loss = -0.5485\n",
      "\n",
      "--------Epoch 113--------\n",
      "Val loss = -0.5293\n",
      "\n",
      "--------Epoch 114--------\n",
      "Val loss = -0.5269\n",
      "\n",
      "--------Epoch 115--------\n",
      "Val loss = -0.5336\n",
      "\n",
      "--------Epoch 116--------\n",
      "Val loss = -0.5348\n",
      "\n",
      "--------Epoch 117--------\n",
      "Val loss = -0.5440\n",
      "\n",
      "--------Epoch 118--------\n",
      "Val loss = -0.5094\n",
      "\n",
      "--------Epoch 119--------\n",
      "Val loss = -0.5447\n",
      "\n",
      "--------Epoch 120--------\n",
      "Val loss = -0.5412\n",
      "\n",
      "--------Epoch 121--------\n",
      "Val loss = -0.5472\n",
      "\n",
      "--------Epoch 122--------\n",
      "Val loss = -0.5310\n",
      "\n",
      "--------Epoch 123--------\n",
      "Val loss = -0.5326\n",
      "\n",
      "--------Epoch 124--------\n",
      "Val loss = -0.5388\n",
      "\n",
      "--------Epoch 125--------\n",
      "Val loss = -0.5321\n",
      "\n",
      "--------Epoch 126--------\n",
      "Val loss = -0.5517\n",
      "\n",
      "--------Epoch 127--------\n",
      "Val loss = -0.5337\n",
      "\n",
      "--------Epoch 128--------\n",
      "Val loss = -0.5100\n",
      "\n",
      "--------Epoch 129--------\n",
      "Val loss = -0.5347\n",
      "\n",
      "--------Epoch 130--------\n",
      "Val loss = -0.5290\n",
      "\n",
      "--------Epoch 131--------\n",
      "Val loss = -0.5152\n",
      "\n",
      "--------Epoch 132--------\n",
      "Val loss = -0.5340\n",
      "\n",
      "--------Epoch 133--------\n",
      "Val loss = -0.5290\n",
      "\n",
      "--------Epoch 134--------\n",
      "Val loss = -0.5196\n",
      "\n",
      "--------Epoch 135--------\n",
      "Val loss = -0.5372\n",
      "\n",
      "--------Epoch 136--------\n",
      "Val loss = -0.5007\n",
      "\n",
      "--------Epoch 137--------\n",
      "Val loss = -0.4918\n",
      "\n",
      "--------Epoch 138--------\n",
      "Val loss = -0.4961\n",
      "\n",
      "--------Epoch 139--------\n",
      "Val loss = -0.5069\n",
      "\n",
      "--------Epoch 140--------\n",
      "Val loss = -0.5263\n",
      "\n",
      "--------Epoch 141--------\n",
      "Val loss = -0.5019\n",
      "\n",
      "--------Epoch 142--------\n",
      "Val loss = -0.5326\n",
      "\n",
      "--------Epoch 143--------\n",
      "Val loss = -0.5042\n",
      "\n",
      "--------Epoch 144--------\n",
      "Val loss = -0.5084\n",
      "\n",
      "--------Epoch 145--------\n",
      "Val loss = -0.5162\n",
      "\n",
      "--------Epoch 146--------\n",
      "Val loss = -0.5198\n",
      "\n",
      "--------Epoch 147--------\n",
      "Val loss = -0.5102\n",
      "\n",
      "--------Epoch 148--------\n",
      "Val loss = -0.5138\n",
      "\n",
      "--------Epoch 149--------\n",
      "Val loss = -0.5419\n",
      "\n",
      "--------Epoch 150--------\n",
      "Val loss = -0.5428\n",
      "\n",
      "--------Epoch 151--------\n",
      "Val loss = -0.5321\n",
      "\n",
      "--------Epoch 152--------\n",
      "Val loss = -0.5336\n",
      "\n",
      "--------Epoch 153--------\n",
      "Val loss = -0.5318\n",
      "\n",
      "--------Epoch 154--------\n",
      "Val loss = -0.5410\n",
      "\n",
      "--------Epoch 155--------\n",
      "Val loss = -0.4994\n",
      "\n",
      "--------Epoch 156--------\n",
      "Val loss = -0.5224\n",
      "\n",
      "--------Epoch 157--------\n",
      "Val loss = -0.5041\n",
      "\n",
      "--------Epoch 158--------\n",
      "Val loss = -0.5070\n",
      "\n",
      "--------Epoch 159--------\n",
      "Val loss = -0.5030\n",
      "\n",
      "--------Epoch 160--------\n",
      "Val loss = -0.5142\n",
      "\n",
      "--------Epoch 161--------\n",
      "Val loss = -0.4897\n",
      "\n",
      "--------Epoch 162--------\n",
      "Val loss = -0.4889\n",
      "\n",
      "--------Epoch 163--------\n",
      "Val loss = -0.4945\n",
      "\n",
      "--------Epoch 164--------\n",
      "Val loss = -0.4750\n",
      "\n",
      "--------Epoch 165--------\n",
      "Val loss = -0.4924\n",
      "\n",
      "--------Epoch 166--------\n",
      "Val loss = -0.5029\n",
      "\n",
      "--------Epoch 167--------\n",
      "Val loss = -0.5034\n",
      "\n",
      "--------Epoch 168--------\n",
      "Val loss = -0.4931\n",
      "\n",
      "--------Epoch 169--------\n",
      "Val loss = -0.4829\n",
      "\n",
      "--------Epoch 170--------\n",
      "Val loss = -0.4820\n",
      "\n",
      "--------Epoch 171--------\n",
      "Val loss = -0.5010\n",
      "\n",
      "--------Epoch 172--------\n",
      "Val loss = -0.5098\n",
      "\n",
      "--------Epoch 173--------\n",
      "Val loss = -0.4935\n",
      "\n",
      "--------Epoch 174--------\n",
      "Val loss = -0.4817\n",
      "\n",
      "--------Epoch 175--------\n",
      "Val loss = -0.4836\n",
      "\n",
      "--------Epoch 176--------\n",
      "Val loss = -0.4671\n",
      "\n",
      "--------Epoch 177--------\n",
      "Val loss = -0.4769\n",
      "\n",
      "--------Epoch 178--------\n",
      "Val loss = -0.4845\n",
      "\n",
      "--------Epoch 179--------\n",
      "Val loss = -0.4817\n",
      "\n",
      "--------Epoch 180--------\n",
      "Val loss = -0.4745\n",
      "\n",
      "--------Epoch 181--------\n",
      "Val loss = -0.4795\n",
      "\n",
      "--------Epoch 182--------\n",
      "Val loss = -0.4796\n",
      "\n",
      "--------Epoch 183--------\n",
      "Val loss = -0.4846\n",
      "\n",
      "--------Epoch 184--------\n",
      "Val loss = -0.4721\n",
      "\n",
      "--------Epoch 185--------\n",
      "Val loss = -0.4723\n",
      "\n",
      "--------Epoch 186--------\n",
      "Val loss = -0.4728\n",
      "\n",
      "--------Epoch 187--------\n",
      "Val loss = -0.4694\n",
      "\n",
      "--------Epoch 188--------\n",
      "Val loss = -0.4670\n",
      "\n",
      "--------Epoch 189--------\n",
      "Val loss = -0.4573\n",
      "\n",
      "--------Epoch 190--------\n",
      "Val loss = -0.4585\n",
      "\n",
      "--------Epoch 191--------\n",
      "Val loss = -0.4696\n",
      "\n",
      "--------Epoch 192--------\n",
      "Val loss = -0.4754\n",
      "\n",
      "--------Epoch 193--------\n",
      "Val loss = -0.4856\n",
      "\n",
      "--------Epoch 194--------\n",
      "Val loss = -0.4644\n",
      "\n",
      "--------Epoch 195--------\n",
      "Val loss = -0.4849\n",
      "\n",
      "--------Epoch 196--------\n",
      "Val loss = -0.4759\n",
      "\n",
      "--------Epoch 197--------\n",
      "Val loss = -0.4581\n",
      "\n",
      "--------Epoch 198--------\n",
      "Val loss = -0.4657\n",
      "\n",
      "--------Epoch 199--------\n",
      "Val loss = -0.4728\n",
      "\n",
      "--------Epoch 200--------\n",
      "Val loss = -0.4807\n",
      "\n",
      "--------Epoch 201--------\n",
      "Val loss = -0.4642\n",
      "\n",
      "--------Epoch 202--------\n",
      "Val loss = -0.4637\n",
      "\n",
      "--------Epoch 203--------\n",
      "Val loss = -0.4715\n",
      "\n",
      "--------Epoch 204--------\n",
      "Val loss = -0.4914\n",
      "\n",
      "--------Epoch 205--------\n",
      "Val loss = -0.4835\n",
      "\n",
      "--------Epoch 206--------\n",
      "Val loss = -0.4655\n",
      "\n",
      "--------Epoch 207--------\n",
      "Val loss = -0.4801\n",
      "\n",
      "--------Epoch 208--------\n",
      "Val loss = -0.4611\n",
      "\n",
      "--------Epoch 209--------\n",
      "Val loss = -0.4758\n",
      "\n",
      "--------Epoch 210--------\n",
      "Val loss = -0.4647\n",
      "\n",
      "--------Epoch 211--------\n",
      "Val loss = -0.4734\n",
      "\n",
      "--------Epoch 212--------\n",
      "Val loss = -0.4808\n",
      "\n",
      "--------Epoch 213--------\n",
      "Val loss = -0.4806\n",
      "\n",
      "--------Epoch 214--------\n",
      "Val loss = -0.4739\n",
      "\n",
      "--------Epoch 215--------\n",
      "Val loss = -0.4665\n",
      "\n",
      "--------Epoch 216--------\n",
      "Val loss = -0.4809\n",
      "\n",
      "--------Epoch 217--------\n",
      "Val loss = -0.4667\n",
      "\n",
      "--------Epoch 218--------\n",
      "Val loss = -0.4870\n",
      "\n",
      "--------Epoch 219--------\n",
      "Val loss = -0.4800\n",
      "\n",
      "--------Epoch 220--------\n",
      "Val loss = -0.4744\n",
      "\n",
      "--------Epoch 221--------\n",
      "Val loss = -0.4775\n",
      "\n",
      "--------Epoch 222--------\n",
      "Val loss = -0.4745\n",
      "\n",
      "--------Epoch 223--------\n",
      "Val loss = -0.4895\n",
      "\n",
      "--------Epoch 224--------\n",
      "Val loss = -0.4736\n",
      "\n",
      "--------Epoch 225--------\n",
      "Val loss = -0.4686\n",
      "\n",
      "--------Epoch 226--------\n",
      "Val loss = -0.4858\n",
      "\n",
      "--------Epoch 227--------\n",
      "Val loss = -0.4945\n",
      "\n",
      "--------Epoch 228--------\n",
      "Val loss = -0.4888\n",
      "\n",
      "--------Epoch 229--------\n",
      "Val loss = -0.4916\n",
      "\n",
      "--------Epoch 230--------\n",
      "Val loss = -0.4858\n",
      "\n",
      "--------Epoch 231--------\n",
      "Val loss = -0.4805\n",
      "\n",
      "--------Epoch 232--------\n",
      "Val loss = -0.4902\n",
      "\n",
      "--------Epoch 233--------\n",
      "Val loss = -0.4985\n",
      "\n",
      "--------Epoch 234--------\n",
      "Val loss = -0.5020\n",
      "\n",
      "--------Epoch 235--------\n",
      "Val loss = -0.4966\n",
      "\n",
      "--------Epoch 236--------\n",
      "Val loss = -0.4866\n",
      "\n",
      "--------Epoch 237--------\n",
      "Val loss = -0.4905\n",
      "\n",
      "--------Epoch 238--------\n",
      "Val loss = -0.5098\n",
      "\n",
      "--------Epoch 239--------\n",
      "Val loss = -0.5259\n",
      "\n",
      "--------Epoch 240--------\n",
      "Val loss = -0.5018\n",
      "\n",
      "--------Epoch 241--------\n",
      "Val loss = -0.5061\n",
      "\n",
      "--------Epoch 242--------\n",
      "Val loss = -0.5240\n",
      "\n",
      "--------Epoch 243--------\n",
      "Val loss = -0.4927\n",
      "\n",
      "--------Epoch 244--------\n",
      "Val loss = -0.5217\n",
      "\n",
      "--------Epoch 245--------\n",
      "Val loss = -0.5316\n",
      "\n",
      "--------Epoch 246--------\n",
      "Val loss = -0.5361\n",
      "\n",
      "--------Epoch 247--------\n",
      "Val loss = -0.5430\n",
      "\n",
      "--------Epoch 248--------\n",
      "Val loss = -0.5466\n",
      "\n",
      "--------Epoch 249--------\n",
      "Val loss = -0.5445\n",
      "\n",
      "--------Epoch 250--------\n",
      "Val loss = -0.5421\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up selector\n",
    "selector = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_in))\n",
    "selector_layer = ConcreteGaussian()\n",
    "gafs = GreedyAdaptiveFS(selector, deepcopy(model), mask_layer, selector_layer).to(device)\n",
    "\n",
    "# Train\n",
    "gafs.fit(train_dataset,\n",
    "         val_dataset,\n",
    "         mbsize=128,\n",
    "         lr=2e-4,\n",
    "         nepochs=250,\n",
    "         max_features=max_features,\n",
    "         loss_fn=nn.CrossEntropyLoss(),\n",
    "         val_loss_fn=NegAccuracy(),\n",
    "         argmax=False,\n",
    "         no_repeats=False,\n",
    "         start_temp=10.0,\n",
    "         end_temp=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "accredited-hanging",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test acc = 72.74\n"
     ]
    }
   ],
   "source": [
    "test_acc = gafs.evaluate(test_dataset, max_features, Accuracy(), 1024)\n",
    "print(f'Test acc = {100*test_acc:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "daily-equivalent",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
