{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "committed-malawi",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n",
    "from torchvision import transforms\n",
    "from torchvision.datasets import MNIST\n",
    "from copy import deepcopy\n",
    "from utils import *\n",
    "from models import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "pressing-publicity",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda', 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "valued-variation",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load data\n",
    "class Flatten(object):\n",
    "    def __call__(self, pic):\n",
    "        return torch.flatten(pic)\n",
    "    \n",
    "mnist_dataset = MNIST('/tmp/mnist/', download=True, train=True,\n",
    "                      transform=transforms.Compose([transforms.ToTensor(), Flatten()]))\n",
    "images = mnist_dataset.data\n",
    "targets = mnist_dataset.targets\n",
    "np.random.seed(0)\n",
    "val_inds = np.sort(np.random.choice(len(images), size=10000, replace=False))\n",
    "train_inds = np.setdiff1d(np.arange(len(images)), val_inds)\n",
    "\n",
    "# Training dataset\n",
    "train_dataset = torch.utils.data.Subset(mnist_dataset, train_inds)\n",
    "\n",
    "# Validation dataset\n",
    "val_dataset = torch.utils.data.Subset(mnist_dataset, val_inds)\n",
    "\n",
    "# Test dataset\n",
    "test_dataset = MNIST('/tmp/mnist/', download=True, train=False,\n",
    "                     transform=transforms.Compose([transforms.ToTensor(), Flatten()]))\n",
    "\n",
    "# Set input/output dimensions\n",
    "d_in = 784\n",
    "d_out = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "martial-necessity",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Number of features to select\n",
    "max_features = 10"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "simplified-integration",
   "metadata": {},
   "source": [
    "# Global FS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "surrounded-stroke",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.2925\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.3034\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.3138\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.3276\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.3303\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.3334\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.3281\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.3442\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.3570\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.3551\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.3610\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.3640\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.3657\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.3788\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.3717\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.3803\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.3904\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.3777\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.3961\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.3948\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = -0.3969\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = -0.4066\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = -0.4175\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = -0.4144\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = -0.4181\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = -0.4221\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = -0.4208\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = -0.4152\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = -0.4332\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = -0.4373\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = -0.4391\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = -0.4405\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = -0.4451\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = -0.4466\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = -0.4532\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = -0.4567\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = -0.4548\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = -0.4632\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = -0.4641\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = -0.4744\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = -0.4902\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = -0.4835\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = -0.4781\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = -0.4947\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = -0.4995\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = -0.4895\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = -0.4914\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = -0.5053\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = -0.5061\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = -0.4981\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = -0.5105\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = -0.5089\n",
      "\n",
      "--------Epoch 53--------\n",
      "Val loss = -0.5167\n",
      "\n",
      "--------Epoch 54--------\n",
      "Val loss = -0.5273\n",
      "\n",
      "--------Epoch 55--------\n",
      "Val loss = -0.5285\n",
      "\n",
      "--------Epoch 56--------\n",
      "Val loss = -0.5265\n",
      "\n",
      "--------Epoch 57--------\n",
      "Val loss = -0.5284\n",
      "\n",
      "--------Epoch 58--------\n",
      "Val loss = -0.5348\n",
      "\n",
      "--------Epoch 59--------\n",
      "Val loss = -0.5384\n",
      "\n",
      "--------Epoch 60--------\n",
      "Val loss = -0.5379\n",
      "\n",
      "--------Epoch 61--------\n",
      "Val loss = -0.5298\n",
      "\n",
      "--------Epoch 62--------\n",
      "Val loss = -0.5471\n",
      "\n",
      "--------Epoch 63--------\n",
      "Val loss = -0.5498\n",
      "\n",
      "--------Epoch 64--------\n",
      "Val loss = -0.5534\n",
      "\n",
      "--------Epoch 65--------\n",
      "Val loss = -0.5640\n",
      "\n",
      "--------Epoch 66--------\n",
      "Val loss = -0.5568\n",
      "\n",
      "--------Epoch 67--------\n",
      "Val loss = -0.5509\n",
      "\n",
      "--------Epoch 68--------\n",
      "Val loss = -0.5666\n",
      "\n",
      "--------Epoch 69--------\n",
      "Val loss = -0.5596\n",
      "\n",
      "--------Epoch 70--------\n",
      "Val loss = -0.5729\n",
      "\n",
      "--------Epoch 71--------\n",
      "Val loss = -0.5689\n",
      "\n",
      "--------Epoch 72--------\n",
      "Val loss = -0.5733\n",
      "\n",
      "--------Epoch 73--------\n",
      "Val loss = -0.5767\n",
      "\n",
      "--------Epoch 74--------\n",
      "Val loss = -0.5765\n",
      "\n",
      "--------Epoch 75--------\n",
      "Val loss = -0.5747\n",
      "\n",
      "--------Epoch 76--------\n",
      "Val loss = -0.5788\n",
      "\n",
      "--------Epoch 77--------\n",
      "Val loss = -0.5883\n",
      "\n",
      "--------Epoch 78--------\n",
      "Val loss = -0.5926\n",
      "\n",
      "--------Epoch 79--------\n",
      "Val loss = -0.5954\n",
      "\n",
      "--------Epoch 80--------\n",
      "Val loss = -0.5939\n",
      "\n",
      "--------Epoch 81--------\n",
      "Val loss = -0.5971\n",
      "\n",
      "--------Epoch 82--------\n",
      "Val loss = -0.6038\n",
      "\n",
      "--------Epoch 83--------\n",
      "Val loss = -0.6007\n",
      "\n",
      "--------Epoch 84--------\n",
      "Val loss = -0.6036\n",
      "\n",
      "--------Epoch 85--------\n",
      "Val loss = -0.6107\n",
      "\n",
      "--------Epoch 86--------\n",
      "Val loss = -0.6088\n",
      "\n",
      "--------Epoch 87--------\n",
      "Val loss = -0.6088\n",
      "\n",
      "--------Epoch 88--------\n",
      "Val loss = -0.6157\n",
      "\n",
      "--------Epoch 89--------\n",
      "Val loss = -0.6173\n",
      "\n",
      "--------Epoch 90--------\n",
      "Val loss = -0.6207\n",
      "\n",
      "--------Epoch 91--------\n",
      "Val loss = -0.6186\n",
      "\n",
      "--------Epoch 92--------\n",
      "Val loss = -0.6317\n",
      "\n",
      "--------Epoch 93--------\n",
      "Val loss = -0.6382\n",
      "\n",
      "--------Epoch 94--------\n",
      "Val loss = -0.6378\n",
      "\n",
      "--------Epoch 95--------\n",
      "Val loss = -0.6352\n",
      "\n",
      "--------Epoch 96--------\n",
      "Val loss = -0.6308\n",
      "\n",
      "--------Epoch 97--------\n",
      "Val loss = -0.6394\n",
      "\n",
      "--------Epoch 98--------\n",
      "Val loss = -0.6429\n",
      "\n",
      "--------Epoch 99--------\n",
      "Val loss = -0.6377\n",
      "\n",
      "--------Epoch 100--------\n",
      "Val loss = -0.6441\n",
      "\n",
      "--------Epoch 101--------\n",
      "Val loss = -0.6487\n",
      "\n",
      "--------Epoch 102--------\n",
      "Val loss = -0.6577\n",
      "\n",
      "--------Epoch 103--------\n",
      "Val loss = -0.6572\n",
      "\n",
      "--------Epoch 104--------\n",
      "Val loss = -0.6533\n",
      "\n",
      "--------Epoch 105--------\n",
      "Val loss = -0.6578\n",
      "\n",
      "--------Epoch 106--------\n",
      "Val loss = -0.6666\n",
      "\n",
      "--------Epoch 107--------\n",
      "Val loss = -0.6651\n",
      "\n",
      "--------Epoch 108--------\n",
      "Val loss = -0.6705\n",
      "\n",
      "--------Epoch 109--------\n",
      "Val loss = -0.6701\n",
      "\n",
      "--------Epoch 110--------\n",
      "Val loss = -0.6797\n",
      "\n",
      "--------Epoch 111--------\n",
      "Val loss = -0.6715\n",
      "\n",
      "--------Epoch 112--------\n",
      "Val loss = -0.6785\n",
      "\n",
      "--------Epoch 113--------\n",
      "Val loss = -0.6814\n",
      "\n",
      "--------Epoch 114--------\n",
      "Val loss = -0.6797\n",
      "\n",
      "--------Epoch 115--------\n",
      "Val loss = -0.6860\n",
      "\n",
      "--------Epoch 116--------\n",
      "Val loss = -0.6820\n",
      "\n",
      "--------Epoch 117--------\n",
      "Val loss = -0.6942\n",
      "\n",
      "--------Epoch 118--------\n",
      "Val loss = -0.6946\n",
      "\n",
      "--------Epoch 119--------\n",
      "Val loss = -0.6990\n",
      "\n",
      "--------Epoch 120--------\n",
      "Val loss = -0.6966\n",
      "\n",
      "--------Epoch 121--------\n",
      "Val loss = -0.6957\n",
      "\n",
      "--------Epoch 122--------\n",
      "Val loss = -0.7061\n",
      "\n",
      "--------Epoch 123--------\n",
      "Val loss = -0.7083\n",
      "\n",
      "--------Epoch 124--------\n",
      "Val loss = -0.7062\n",
      "\n",
      "--------Epoch 125--------\n",
      "Val loss = -0.7125\n",
      "\n",
      "--------Epoch 126--------\n",
      "Val loss = -0.7109\n",
      "\n",
      "--------Epoch 127--------\n",
      "Val loss = -0.7108\n",
      "\n",
      "--------Epoch 128--------\n",
      "Val loss = -0.7136\n",
      "\n",
      "--------Epoch 129--------\n",
      "Val loss = -0.7192\n",
      "\n",
      "--------Epoch 130--------\n",
      "Val loss = -0.7185\n",
      "\n",
      "--------Epoch 131--------\n",
      "Val loss = -0.7205\n",
      "\n",
      "--------Epoch 132--------\n",
      "Val loss = -0.7277\n",
      "\n",
      "--------Epoch 133--------\n",
      "Val loss = -0.7214\n",
      "\n",
      "--------Epoch 134--------\n",
      "Val loss = -0.7294\n",
      "\n",
      "--------Epoch 135--------\n",
      "Val loss = -0.7248\n",
      "\n",
      "--------Epoch 136--------\n",
      "Val loss = -0.7307\n",
      "\n",
      "--------Epoch 137--------\n",
      "Val loss = -0.7229\n",
      "\n",
      "--------Epoch 138--------\n",
      "Val loss = -0.7305\n",
      "\n",
      "--------Epoch 139--------\n",
      "Val loss = -0.7368\n",
      "\n",
      "--------Epoch 140--------\n",
      "Val loss = -0.7384\n",
      "\n",
      "--------Epoch 141--------\n",
      "Val loss = -0.7442\n",
      "\n",
      "--------Epoch 142--------\n",
      "Val loss = -0.7416\n",
      "\n",
      "--------Epoch 143--------\n",
      "Val loss = -0.7381\n",
      "\n",
      "--------Epoch 144--------\n",
      "Val loss = -0.7407\n",
      "\n",
      "--------Epoch 145--------\n",
      "Val loss = -0.7398\n",
      "\n",
      "--------Epoch 146--------\n",
      "Val loss = -0.7425\n",
      "\n",
      "--------Epoch 147--------\n",
      "Val loss = -0.7457\n",
      "\n",
      "--------Epoch 148--------\n",
      "Val loss = -0.7497\n",
      "\n",
      "--------Epoch 149--------\n",
      "Val loss = -0.7539\n",
      "\n",
      "--------Epoch 150--------\n",
      "Val loss = -0.7507\n",
      "\n",
      "--------Epoch 151--------\n",
      "Val loss = -0.7553\n",
      "\n",
      "--------Epoch 152--------\n",
      "Val loss = -0.7504\n",
      "\n",
      "--------Epoch 153--------\n",
      "Val loss = -0.7534\n",
      "\n",
      "--------Epoch 154--------\n",
      "Val loss = -0.7545\n",
      "\n",
      "--------Epoch 155--------\n",
      "Val loss = -0.7629\n",
      "\n",
      "--------Epoch 156--------\n",
      "Val loss = -0.7592\n",
      "\n",
      "--------Epoch 157--------\n",
      "Val loss = -0.7650\n",
      "\n",
      "--------Epoch 158--------\n",
      "Val loss = -0.7596\n",
      "\n",
      "--------Epoch 159--------\n",
      "Val loss = -0.7574\n",
      "\n",
      "--------Epoch 160--------\n",
      "Val loss = -0.7588\n",
      "\n",
      "--------Epoch 161--------\n",
      "Val loss = -0.7616\n",
      "\n",
      "--------Epoch 162--------\n",
      "Val loss = -0.7604\n",
      "\n",
      "--------Epoch 163--------\n",
      "Val loss = -0.7679\n",
      "\n",
      "--------Epoch 164--------\n",
      "Val loss = -0.7695\n",
      "\n",
      "--------Epoch 165--------\n",
      "Val loss = -0.7630\n",
      "\n",
      "--------Epoch 166--------\n",
      "Val loss = -0.7661\n",
      "\n",
      "--------Epoch 167--------\n",
      "Val loss = -0.7613\n",
      "\n",
      "--------Epoch 168--------\n",
      "Val loss = -0.7689\n",
      "\n",
      "--------Epoch 169--------\n",
      "Val loss = -0.7673\n",
      "\n",
      "--------Epoch 170--------\n",
      "Val loss = -0.7645\n",
      "\n",
      "--------Epoch 171--------\n",
      "Val loss = -0.7685\n",
      "\n",
      "--------Epoch 172--------\n",
      "Val loss = -0.7634\n",
      "\n",
      "--------Epoch 173--------\n",
      "Val loss = -0.7711\n",
      "\n",
      "--------Epoch 174--------\n",
      "Val loss = -0.7688\n",
      "\n",
      "--------Epoch 175--------\n",
      "Val loss = -0.7658\n",
      "\n",
      "--------Epoch 176--------\n",
      "Val loss = -0.7630\n",
      "\n",
      "--------Epoch 177--------\n",
      "Val loss = -0.7724\n",
      "\n",
      "--------Epoch 178--------\n",
      "Val loss = -0.7715\n",
      "\n",
      "--------Epoch 179--------\n",
      "Val loss = -0.7752\n",
      "\n",
      "--------Epoch 180--------\n",
      "Val loss = -0.7696\n",
      "\n",
      "--------Epoch 181--------\n",
      "Val loss = -0.7739\n",
      "\n",
      "--------Epoch 182--------\n",
      "Val loss = -0.7750\n",
      "\n",
      "--------Epoch 183--------\n",
      "Val loss = -0.7689\n",
      "\n",
      "--------Epoch 184--------\n",
      "Val loss = -0.7765\n",
      "\n",
      "--------Epoch 185--------\n",
      "Val loss = -0.7717\n",
      "\n",
      "--------Epoch 186--------\n",
      "Val loss = -0.7716\n",
      "\n",
      "--------Epoch 187--------\n",
      "Val loss = -0.7733\n",
      "\n",
      "--------Epoch 188--------\n",
      "Val loss = -0.7740\n",
      "\n",
      "--------Epoch 189--------\n",
      "Val loss = -0.7738\n",
      "\n",
      "--------Epoch 190--------\n",
      "Val loss = -0.7743\n",
      "\n",
      "--------Epoch 191--------\n",
      "Val loss = -0.7748\n",
      "\n",
      "--------Epoch 192--------\n",
      "Val loss = -0.7769\n",
      "\n",
      "--------Epoch 193--------\n",
      "Val loss = -0.7770\n",
      "\n",
      "--------Epoch 194--------\n",
      "Val loss = -0.7745\n",
      "\n",
      "--------Epoch 195--------\n",
      "Val loss = -0.7766\n",
      "\n",
      "--------Epoch 196--------\n",
      "Val loss = -0.7762\n",
      "\n",
      "--------Epoch 197--------\n",
      "Val loss = -0.7811\n",
      "\n",
      "--------Epoch 198--------\n",
      "Val loss = -0.7769\n",
      "\n",
      "--------Epoch 199--------\n",
      "Val loss = -0.7769\n",
      "\n",
      "--------Epoch 200--------\n",
      "Val loss = -0.7751\n",
      "\n",
      "--------Epoch 201--------\n",
      "Val loss = -0.7737\n",
      "\n",
      "--------Epoch 202--------\n",
      "Val loss = -0.7777\n",
      "\n",
      "--------Epoch 203--------\n",
      "Val loss = -0.7757\n",
      "\n",
      "--------Epoch 204--------\n",
      "Val loss = -0.7757\n",
      "\n",
      "--------Epoch 205--------\n",
      "Val loss = -0.7740\n",
      "\n",
      "--------Epoch 206--------\n",
      "Val loss = -0.7769\n",
      "\n",
      "--------Epoch 207--------\n",
      "Val loss = -0.7759\n",
      "\n",
      "--------Epoch 208--------\n",
      "Val loss = -0.7702\n",
      "\n",
      "--------Epoch 209--------\n",
      "Val loss = -0.7724\n",
      "\n",
      "--------Epoch 210--------\n",
      "Val loss = -0.7762\n",
      "\n",
      "--------Epoch 211--------\n",
      "Val loss = -0.7744\n",
      "\n",
      "--------Epoch 212--------\n",
      "Val loss = -0.7708\n",
      "\n",
      "--------Epoch 213--------\n",
      "Val loss = -0.7711\n",
      "\n",
      "--------Epoch 214--------\n",
      "Val loss = -0.7741\n",
      "\n",
      "--------Epoch 215--------\n",
      "Val loss = -0.7725\n",
      "\n",
      "--------Epoch 216--------\n",
      "Val loss = -0.7755\n",
      "\n",
      "--------Epoch 217--------\n",
      "Val loss = -0.7734\n",
      "\n",
      "--------Epoch 218--------\n",
      "Val loss = -0.7740\n",
      "\n",
      "--------Epoch 219--------\n",
      "Val loss = -0.7699\n",
      "\n",
      "--------Epoch 220--------\n",
      "Val loss = -0.7725\n",
      "\n",
      "--------Epoch 221--------\n",
      "Val loss = -0.7678\n",
      "\n",
      "--------Epoch 222--------\n",
      "Val loss = -0.7742\n",
      "\n",
      "--------Epoch 223--------\n",
      "Val loss = -0.7703\n",
      "\n",
      "--------Epoch 224--------\n",
      "Val loss = -0.7703\n",
      "\n",
      "--------Epoch 225--------\n",
      "Val loss = -0.7680\n",
      "\n",
      "--------Epoch 226--------\n",
      "Val loss = -0.7669\n",
      "\n",
      "--------Epoch 227--------\n",
      "Val loss = -0.7648\n",
      "\n",
      "--------Epoch 228--------\n",
      "Val loss = -0.7686\n",
      "\n",
      "--------Epoch 229--------\n",
      "Val loss = -0.7667\n",
      "\n",
      "--------Epoch 230--------\n",
      "Val loss = -0.7692\n",
      "\n",
      "--------Epoch 231--------\n",
      "Val loss = -0.7664\n",
      "\n",
      "--------Epoch 232--------\n",
      "Val loss = -0.7651\n",
      "\n",
      "--------Epoch 233--------\n",
      "Val loss = -0.7652\n",
      "\n",
      "--------Epoch 234--------\n",
      "Val loss = -0.7645\n",
      "\n",
      "--------Epoch 235--------\n",
      "Val loss = -0.7607\n",
      "\n",
      "--------Epoch 236--------\n",
      "Val loss = -0.7652\n",
      "\n",
      "--------Epoch 237--------\n",
      "Val loss = -0.7648\n",
      "\n",
      "--------Epoch 238--------\n",
      "Val loss = -0.7644\n",
      "\n",
      "--------Epoch 239--------\n",
      "Val loss = -0.7625\n",
      "\n",
      "--------Epoch 240--------\n",
      "Val loss = -0.7636\n",
      "\n",
      "--------Epoch 241--------\n",
      "Val loss = -0.7635\n",
      "\n",
      "--------Epoch 242--------\n",
      "Val loss = -0.7654\n",
      "\n",
      "--------Epoch 243--------\n",
      "Val loss = -0.7656\n",
      "\n",
      "--------Epoch 244--------\n",
      "Val loss = -0.7643\n",
      "\n",
      "--------Epoch 245--------\n",
      "Val loss = -0.7602\n",
      "\n",
      "--------Epoch 246--------\n",
      "Val loss = -0.7627\n",
      "\n",
      "--------Epoch 247--------\n",
      "Val loss = -0.7580\n",
      "\n",
      "--------Epoch 248--------\n",
      "Val loss = -0.7616\n",
      "\n",
      "--------Epoch 249--------\n",
      "Val loss = -0.7587\n",
      "\n",
      "--------Epoch 250--------\n",
      "Val loss = -0.7617\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up model\n",
    "global_model = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_out))\n",
    "selector = ConcreteMask(d_in, max_features, append=True)\n",
    "globalfs = GlobalSelector(global_model, selector).to(device)\n",
    "\n",
    "# Train\n",
    "globalfs.fit(train_dataset,\n",
    "             val_dataset,\n",
    "             mbsize=128,\n",
    "             lr=1e-3,\n",
    "             nepochs=250,\n",
    "             loss_fn=nn.CrossEntropyLoss(),\n",
    "             val_loss_fn=NegAccuracy(),\n",
    "             start_temp=1.0,\n",
    "             end_temp=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "crazy-ensemble",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test acc = 78.47\n"
     ]
    }
   ],
   "source": [
    "test_acc = globalfs.evaluate(test_dataset, Accuracy(), 1024)\n",
    "print(f'Test acc = {100*test_acc:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "annoying-decimal",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAKz0lEQVR4nO3dT4ic933H8fenrqyAkoKc1EZ1TJMGU2oKVcriFlxKinHq+CLn0BIdggoG5RBDAjnUpIf6aEqT0EMJKLWIWlKHQmLsg2kiRMAEivHaqLYctZVj1EaRkBp8iFOoLDvfHvZxWcu7q/HMM3+i7/sFy8w8M6vny6C3ntl5ZvVLVSHp+vdLyx5A0mIYu9SEsUtNGLvUhLFLTfzyInd2Y3bXe9izyF1Krfwv/8PrdTlb3TdT7EnuBf4GuAH4u6p6ZKfHv4c9/F7unmWXknbwTJ3Y9r6pX8YnuQH4W+ATwB3AwSR3TPvnSZqvWX5mvxN4uapeqarXgW8CB8YZS9LYZon9VuBHm26fG7a9TZLDSdaTrF/h8gy7kzSLWWLf6k2Ad3z2tqqOVNVaVa3tYvcMu5M0i1liPwfctun2B4Hzs40jaV5mif1Z4PYkH05yI/Ap4MlxxpI0tqlPvVXVG0keBL7Dxqm3o1X10miTSRrVTOfZq+op4KmRZpE0R35cVmrC2KUmjF1qwtilJoxdasLYpSaMXWrC2KUmjF1qwtilJoxdasLYpSaMXWrC2KUmjF1qwtilJoxdasLYpSaMXWrC2KUmjF1qwtilJoxdasLYpSaMXWrC2KUmjF1qwtilJoxdamKmVVy1Gr5z/uS29/3xr+1f2BxbWeXZupkp9iRngdeAN4E3qmptjKEkjW+MI/sfVdVPRvhzJM2RP7NLTcwaewHfTfJcksNbPSDJ4STrSdavcHnG3Uma1qwv4++qqvNJbgaOJ/m3qnp68wOq6ghwBOBXclPNuD9JU5rpyF5V54fLS8DjwJ1jDCVpfFPHnmRPkve9dR34OHBqrMEkjWuWl/G3AI8neevP+ceq+udRptK7ssrnq1d5tm6mjr2qXgF+Z8RZJM2Rp96kJoxdasLYpSaMXWrC2KUmjF1qwtilJoxdasLYpSaMXWrC2KUmjF1qwtilJvyvpDWTnf6raPBXXFeJR3apCWOXmjB2qQljl5owdqkJY5eaMHapCc+zayaznEef9zl6l4t+O4/sUhPGLjVh7FITxi41YexSE8YuNWHsUhOeZ9fSzPtcd8dz6Tu55pE9ydEkl5Kc2rTtpiTHk5wZLvfOd0xJs5rkZfzXgXuv2vYQcKKqbgdODLclrbBrxl5VTwOvXrX5AHBsuH4MuH/csSSNbdo36G6pqgsAw+XN2z0wyeEk60nWr3B5yt1JmtXc342vqiNVtVZVa7vYPe/dSdrGtLFfTLIPYLi8NN5IkuZh2tifBA4N1w8BT4wzjqR5meTU22PAvwC/meRckgeAR4B7kpwB7hluS1ph1/xQTVUd3Oauu0eeRdIc+XFZqQljl5owdqkJY5eaMHapCX/F9ReAyyJrDB7ZpSaMXWrC2KUmjF1qwtilJoxdasLYpSY8z/4LYJnn0T3Hf/3wyC41YexSE8YuNWHsUhPGLjVh7FITxi414Xl27cjz6NcPj+xSE8YuNWHsUhPGLjVh7FITxi41YexSE55nXwH+zrgWYZL12Y8muZTk1KZtDyf5cZKTw9d98x1T0qwmeRn/deDeLbZ/par2D19PjTuWpLFdM/aqehp4dQGzSJqjWd6gezDJC8PL/L3bPSjJ4STrSdavcHmG3UmaxbSxfxX4CLAfuAB8absHVtWRqlqrqrVd7J5yd5JmNVXsVXWxqt6sqp8DXwPuHHcsSWObKvYk+zbd/CRwarvHSloN1zzPnuQx4GPAB5KcA/4S+FiS/UABZ4HPzG/E698qn0f3MwDXj2vGXlUHt9j86BxmkTRHflxWasLYpSaMXWrC2KUmjF1qwl9xvQ7sdHps1lNjnlq7fnhkl5owdqkJY5eaMHapCWOXmjB2qQljl5rwPPt1wHPhmoRHdqkJY5eaMHapCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmrhm7EluS/K9JKeTvJTkc8P2m5IcT3JmuNw7/3ElTWuSI/sbwBeq6reA3wc+m+QO4CHgRFXdDpwYbktaUdeMvaouVNXzw/XXgNPArcAB4NjwsGPA/XOaUdII3tXP7Ek+BHwUeAa4paouwMY/CMDN23zP4STrSdavcHnGcSVNa+LYk7wX+Bbw+ar66aTfV1VHqmqtqtZ2sXuaGSWNYKLYk+xiI/RvVNW3h80Xk+wb7t8HXJrPiJLGMMm78QEeBU5X1Zc33fUkcGi4fgh4YvzxJI1lkv83/i7g08CLSU4O274IPAL8U5IHgP8C/mQuE0oaxTVjr6rvA9nm7rvHHUfSvPgJOqkJY5eaMHapCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmjB2qYlJ1me/Lcn3kpxO8lKSzw3bH07y4yQnh6/75j+upGlNsj77G8AXqur5JO8DnktyfLjvK1X11/MbT9JYJlmf/QJwYbj+WpLTwK3zHkzSuN7Vz+xJPgR8FHhm2PRgkheSHE2yd5vvOZxkPcn6FS7PNq2kqU0ce5L3At8CPl9VPwW+CnwE2M/Gkf9LW31fVR2pqrWqWtvF7tknljSViWJPsouN0L9RVd8GqKqLVfVmVf0c+Bpw5/zGlDSrSd6ND/AocLqqvrxp+75ND/skcGr88SSNZZJ34+8CPg28mOTksO2LwMEk+4ECzgKfmcN8kkYyybvx3weyxV1PjT+OpHnxE3RSE8YuNWHsUhPGLjVh7FITxi41YexSE8YuNWHsUhPGLjVh7FITxi41YexSE8YuNZGqWtzOkv8G/nPTpg8AP1nYAO/Oqs62qnOBs01rzNl+vap+das7Fhr7O3aerFfV2tIG2MGqzraqc4GzTWtRs/kyXmrC2KUmlh37kSXvfyerOtuqzgXONq2FzLbUn9klLc6yj+ySFsTYpSaWEnuSe5P8e5KXkzy0jBm2k+RskheHZajXlzzL0SSXkpzatO2mJMeTnBkut1xjb0mzrcQy3jssM77U527Zy58v/Gf2JDcA/wHcA5wDngUOVtUPFjrINpKcBdaqaukfwEjyh8DPgL+vqt8etv0V8GpVPTL8Q7m3qv58RWZ7GPjZspfxHlYr2rd5mXHgfuDPWOJzt8Ncf8oCnrdlHNnvBF6uqleq6nXgm8CBJcyx8qrqaeDVqzYfAI4N14+x8Zdl4baZbSVU1YWqen64/hrw1jLjS33udphrIZYR+63AjzbdPsdqrfdewHeTPJfk8LKH2cItVXUBNv7yADcveZ6rXXMZ70W6apnxlXnupln+fFbLiH2rpaRW6fzfXVX1u8AngM8OL1c1mYmW8V6ULZYZXwnTLn8+q2XEfg64bdPtDwLnlzDHlqrq/HB5CXic1VuK+uJbK+gOl5eWPM//W6VlvLdaZpwVeO6Wufz5MmJ/Frg9yYeT3Ah8CnhyCXO8Q5I9wxsnJNkDfJzVW4r6SeDQcP0Q8MQSZ3mbVVnGe7tlxlnyc7f05c+rauFfwH1svCP/Q+AvljHDNnP9BvCvw9dLy54NeIyNl3VX2HhF9ADwfuAEcGa4vGmFZvsH4EXgBTbC2rek2f6AjR8NXwBODl/3Lfu522GuhTxvflxWasJP0ElNGLvUhLFLTRi71ISxS00Yu9SEsUtN/B9D72AikE1negAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "inds = selector.logits.argmax(dim=1).cpu().data.numpy()\n",
    "zeros = np.zeros(784)\n",
    "zeros[inds] = 1\n",
    "plt.figure()\n",
    "plt.imshow(zeros.reshape(28, 28))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "monthly-tender",
   "metadata": {},
   "source": [
    "# Greedy adaptive FS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "absolute-gauge",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.2346\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.2760\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.2819\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.2871\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.3011\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.2977\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.3007\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.2988\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.3031\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.3041\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.3091\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.3013\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.3094\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.3069\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.3156\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.3192\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.3194\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.3221\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.3146\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.3187\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = -0.3180\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = -0.3090\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = -0.3253\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = -0.3184\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = -0.3212\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = -0.3194\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = -0.3204\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = -0.3270\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = -0.3225\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = -0.3177\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = -0.3209\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = -0.3185\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = -0.3314\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = -0.3283\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = -0.3298\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = -0.3319\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = -0.3233\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = -0.3157\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = -0.3218\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = -0.3262\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = -0.3264\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = -0.3320\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = -0.3311\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = -0.3303\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = -0.3301\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = -0.3235\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = -0.3275\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = -0.3319\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = -0.3353\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = -0.3406\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = -0.3269\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = -0.3312\n",
      "\n",
      "--------Epoch 53--------\n",
      "Val loss = -0.3332\n",
      "\n",
      "--------Epoch 54--------\n",
      "Val loss = -0.3275\n",
      "\n",
      "--------Epoch 55--------\n",
      "Val loss = -0.3354\n",
      "\n",
      "--------Epoch 56--------\n",
      "Val loss = -0.3345\n",
      "\n",
      "--------Epoch 57--------\n",
      "Val loss = -0.3316\n",
      "\n",
      "--------Epoch 58--------\n",
      "Val loss = -0.3299\n",
      "\n",
      "--------Epoch 59--------\n",
      "Val loss = -0.3339\n",
      "\n",
      "--------Epoch 60--------\n",
      "Val loss = -0.3332\n",
      "\n",
      "--------Epoch 61--------\n",
      "Val loss = -0.3338\n",
      "\n",
      "--------Epoch 62--------\n",
      "Val loss = -0.3314\n",
      "\n",
      "--------Epoch 63--------\n",
      "Val loss = -0.3362\n",
      "\n",
      "--------Epoch 64--------\n",
      "Val loss = -0.3349\n",
      "\n",
      "--------Epoch 65--------\n",
      "Val loss = -0.3344\n",
      "\n",
      "--------Epoch 66--------\n",
      "Val loss = -0.3350\n",
      "\n",
      "--------Epoch 67--------\n",
      "Val loss = -0.3369\n",
      "\n",
      "--------Epoch 68--------\n",
      "Val loss = -0.3284\n",
      "\n",
      "--------Epoch 69--------\n",
      "Val loss = -0.3328\n",
      "\n",
      "--------Epoch 70--------\n",
      "Val loss = -0.3349\n",
      "\n",
      "--------Epoch 71--------\n",
      "Val loss = -0.3349\n",
      "\n",
      "--------Epoch 72--------\n",
      "Val loss = -0.3334\n",
      "\n",
      "--------Epoch 73--------\n",
      "Val loss = -0.3306\n",
      "\n",
      "--------Epoch 74--------\n",
      "Val loss = -0.3344\n",
      "\n",
      "--------Epoch 75--------\n",
      "Val loss = -0.3387\n",
      "\n",
      "--------Epoch 76--------\n",
      "Val loss = -0.3409\n",
      "\n",
      "--------Epoch 77--------\n",
      "Val loss = -0.3339\n",
      "\n",
      "--------Epoch 78--------\n",
      "Val loss = -0.3278\n",
      "\n",
      "--------Epoch 79--------\n",
      "Val loss = -0.3376\n",
      "\n",
      "--------Epoch 80--------\n",
      "Val loss = -0.3463\n",
      "\n",
      "--------Epoch 81--------\n",
      "Val loss = -0.3367\n",
      "\n",
      "--------Epoch 82--------\n",
      "Val loss = -0.3412\n",
      "\n",
      "--------Epoch 83--------\n",
      "Val loss = -0.3374\n",
      "\n",
      "--------Epoch 84--------\n",
      "Val loss = -0.3425\n",
      "\n",
      "--------Epoch 85--------\n",
      "Val loss = -0.3395\n",
      "\n",
      "--------Epoch 86--------\n",
      "Val loss = -0.3314\n",
      "\n",
      "--------Epoch 87--------\n",
      "Val loss = -0.3348\n",
      "\n",
      "--------Epoch 88--------\n",
      "Val loss = -0.3369\n",
      "\n",
      "--------Epoch 89--------\n",
      "Val loss = -0.3440\n",
      "\n",
      "--------Epoch 90--------\n",
      "Val loss = -0.3408\n",
      "\n",
      "--------Epoch 91--------\n",
      "Val loss = -0.3426\n",
      "\n",
      "--------Epoch 92--------\n",
      "Val loss = -0.3429\n",
      "\n",
      "--------Epoch 93--------\n",
      "Val loss = -0.3361\n",
      "\n",
      "--------Epoch 94--------\n",
      "Val loss = -0.3344\n",
      "\n",
      "--------Epoch 95--------\n",
      "Val loss = -0.3408\n",
      "\n",
      "--------Epoch 96--------\n",
      "Val loss = -0.3384\n",
      "\n",
      "--------Epoch 97--------\n",
      "Val loss = -0.3427\n",
      "\n",
      "--------Epoch 98--------\n",
      "Val loss = -0.3336\n",
      "\n",
      "--------Epoch 99--------\n",
      "Val loss = -0.3401\n",
      "\n",
      "--------Epoch 100--------\n",
      "Val loss = -0.3461\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up model\n",
    "model = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_out))\n",
    "mask_layer = MaskLayer(append=True)\n",
    "pretrain = Pretrainer(model, mask_layer).to(device)\n",
    "\n",
    "# Pretrain\n",
    "pretrain.fit(train_dataset,\n",
    "             val_dataset,\n",
    "             mbsize=128,\n",
    "             lr=1e-3,\n",
    "             nepochs=100,\n",
    "             max_features=max_features,\n",
    "             loss_fn=nn.CrossEntropyLoss(),\n",
    "             val_loss_fn=NegAccuracy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "equal-architect",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.6582\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.7033\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.7166\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.7210\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.7236\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.7134\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.7144\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.6723\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.7068\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.7195\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.7207\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.7239\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.7009\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.7123\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.7055\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.7171\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.6970\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.6949\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.6875\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.6821\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = -0.6978\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = -0.6899\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = -0.6710\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = -0.6670\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = -0.6876\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = -0.6658\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = -0.6846\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = -0.6623\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = -0.6865\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = -0.6812\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = -0.6735\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = -0.6765\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = -0.6732\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = -0.6690\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = -0.6627\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = -0.6687\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = -0.6734\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = -0.6710\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = -0.6650\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = -0.6702\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = -0.6652\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = -0.6611\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = -0.6808\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = -0.6653\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = -0.6586\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = -0.6541\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = -0.6708\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = -0.6649\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = -0.6672\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = -0.6697\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = -0.6713\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = -0.6612\n",
      "\n",
      "--------Epoch 53--------\n",
      "Val loss = -0.6611\n",
      "\n",
      "--------Epoch 54--------\n",
      "Val loss = -0.6614\n",
      "\n",
      "--------Epoch 55--------\n",
      "Val loss = -0.6614\n",
      "\n",
      "--------Epoch 56--------\n",
      "Val loss = -0.6506\n",
      "\n",
      "--------Epoch 57--------\n",
      "Val loss = -0.6711\n",
      "\n",
      "--------Epoch 58--------\n",
      "Val loss = -0.6614\n",
      "\n",
      "--------Epoch 59--------\n",
      "Val loss = -0.6700\n",
      "\n",
      "--------Epoch 60--------\n",
      "Val loss = -0.6570\n",
      "\n",
      "--------Epoch 61--------\n",
      "Val loss = -0.6661\n",
      "\n",
      "--------Epoch 62--------\n",
      "Val loss = -0.6748\n",
      "\n",
      "--------Epoch 63--------\n",
      "Val loss = -0.6582\n",
      "\n",
      "--------Epoch 64--------\n",
      "Val loss = -0.6707\n",
      "\n",
      "--------Epoch 65--------\n",
      "Val loss = -0.6641\n",
      "\n",
      "--------Epoch 66--------\n",
      "Val loss = -0.6703\n",
      "\n",
      "--------Epoch 67--------\n",
      "Val loss = -0.6739\n",
      "\n",
      "--------Epoch 68--------\n",
      "Val loss = -0.6694\n",
      "\n",
      "--------Epoch 69--------\n",
      "Val loss = -0.6790\n",
      "\n",
      "--------Epoch 70--------\n",
      "Val loss = -0.6705\n",
      "\n",
      "--------Epoch 71--------\n",
      "Val loss = -0.6729\n",
      "\n",
      "--------Epoch 72--------\n",
      "Val loss = -0.6789\n",
      "\n",
      "--------Epoch 73--------\n",
      "Val loss = -0.6762\n",
      "\n",
      "--------Epoch 74--------\n",
      "Val loss = -0.6796\n",
      "\n",
      "--------Epoch 75--------\n",
      "Val loss = -0.6766\n",
      "\n",
      "--------Epoch 76--------\n",
      "Val loss = -0.6828\n",
      "\n",
      "--------Epoch 77--------\n",
      "Val loss = -0.6933\n",
      "\n",
      "--------Epoch 78--------\n",
      "Val loss = -0.6904\n",
      "\n",
      "--------Epoch 79--------\n",
      "Val loss = -0.6919\n",
      "\n",
      "--------Epoch 80--------\n",
      "Val loss = -0.6908\n",
      "\n",
      "--------Epoch 81--------\n",
      "Val loss = -0.7005\n",
      "\n",
      "--------Epoch 82--------\n",
      "Val loss = -0.7069\n",
      "\n",
      "--------Epoch 83--------\n",
      "Val loss = -0.7115\n",
      "\n",
      "--------Epoch 84--------\n",
      "Val loss = -0.7119\n",
      "\n",
      "--------Epoch 85--------\n",
      "Val loss = -0.7177\n",
      "\n",
      "--------Epoch 86--------\n",
      "Val loss = -0.7190\n",
      "\n",
      "--------Epoch 87--------\n",
      "Val loss = -0.7204\n",
      "\n",
      "--------Epoch 88--------\n",
      "Val loss = -0.7272\n",
      "\n",
      "--------Epoch 89--------\n",
      "Val loss = -0.7291\n",
      "\n",
      "--------Epoch 90--------\n",
      "Val loss = -0.7345\n",
      "\n",
      "--------Epoch 91--------\n",
      "Val loss = -0.7327\n",
      "\n",
      "--------Epoch 92--------\n",
      "Val loss = -0.7418\n",
      "\n",
      "--------Epoch 93--------\n",
      "Val loss = -0.7406\n",
      "\n",
      "--------Epoch 94--------\n",
      "Val loss = -0.7475\n",
      "\n",
      "--------Epoch 95--------\n",
      "Val loss = -0.7578\n",
      "\n",
      "--------Epoch 96--------\n",
      "Val loss = -0.7554\n",
      "\n",
      "--------Epoch 97--------\n",
      "Val loss = -0.7647\n",
      "\n",
      "--------Epoch 98--------\n",
      "Val loss = -0.7626\n",
      "\n",
      "--------Epoch 99--------\n",
      "Val loss = -0.7714\n",
      "\n",
      "--------Epoch 100--------\n",
      "Val loss = -0.7735\n",
      "\n",
      "--------Epoch 101--------\n",
      "Val loss = -0.7711\n",
      "\n",
      "--------Epoch 102--------\n",
      "Val loss = -0.7766\n",
      "\n",
      "--------Epoch 103--------\n",
      "Val loss = -0.7710\n",
      "\n",
      "--------Epoch 104--------\n",
      "Val loss = -0.7832\n",
      "\n",
      "--------Epoch 105--------\n",
      "Val loss = -0.7862\n",
      "\n",
      "--------Epoch 106--------\n",
      "Val loss = -0.7843\n",
      "\n",
      "--------Epoch 107--------\n",
      "Val loss = -0.7901\n",
      "\n",
      "--------Epoch 108--------\n",
      "Val loss = -0.7981\n",
      "\n",
      "--------Epoch 109--------\n",
      "Val loss = -0.7964\n",
      "\n",
      "--------Epoch 110--------\n",
      "Val loss = -0.8005\n",
      "\n",
      "--------Epoch 111--------\n",
      "Val loss = -0.7954\n",
      "\n",
      "--------Epoch 112--------\n",
      "Val loss = -0.8017\n",
      "\n",
      "--------Epoch 113--------\n",
      "Val loss = -0.8001\n",
      "\n",
      "--------Epoch 114--------\n",
      "Val loss = -0.7995\n",
      "\n",
      "--------Epoch 115--------\n",
      "Val loss = -0.8076\n",
      "\n",
      "--------Epoch 116--------\n",
      "Val loss = -0.8033\n",
      "\n",
      "--------Epoch 117--------\n",
      "Val loss = -0.8129\n",
      "\n",
      "--------Epoch 118--------\n",
      "Val loss = -0.8127\n",
      "\n",
      "--------Epoch 119--------\n",
      "Val loss = -0.8173\n",
      "\n",
      "--------Epoch 120--------\n",
      "Val loss = -0.8076\n",
      "\n",
      "--------Epoch 121--------\n",
      "Val loss = -0.8173\n",
      "\n",
      "--------Epoch 122--------\n",
      "Val loss = -0.8176\n",
      "\n",
      "--------Epoch 123--------\n",
      "Val loss = -0.8168\n",
      "\n",
      "--------Epoch 124--------\n",
      "Val loss = -0.8250\n",
      "\n",
      "--------Epoch 125--------\n",
      "Val loss = -0.8197\n",
      "\n",
      "--------Epoch 126--------\n",
      "Val loss = -0.8230\n",
      "\n",
      "--------Epoch 127--------\n",
      "Val loss = -0.8257\n",
      "\n",
      "--------Epoch 128--------\n",
      "Val loss = -0.8224\n",
      "\n",
      "--------Epoch 129--------\n",
      "Val loss = -0.8256\n",
      "\n",
      "--------Epoch 130--------\n",
      "Val loss = -0.8240\n",
      "\n",
      "--------Epoch 131--------\n",
      "Val loss = -0.8323\n",
      "\n",
      "--------Epoch 132--------\n",
      "Val loss = -0.8330\n",
      "\n",
      "--------Epoch 133--------\n",
      "Val loss = -0.8310\n",
      "\n",
      "--------Epoch 134--------\n",
      "Val loss = -0.8300\n",
      "\n",
      "--------Epoch 135--------\n",
      "Val loss = -0.8264\n",
      "\n",
      "--------Epoch 136--------\n",
      "Val loss = -0.8343\n",
      "\n",
      "--------Epoch 137--------\n",
      "Val loss = -0.8335\n",
      "\n",
      "--------Epoch 138--------\n",
      "Val loss = -0.8325\n",
      "\n",
      "--------Epoch 139--------\n",
      "Val loss = -0.8346\n",
      "\n",
      "--------Epoch 140--------\n",
      "Val loss = -0.8309\n",
      "\n",
      "--------Epoch 141--------\n",
      "Val loss = -0.8281\n",
      "\n",
      "--------Epoch 142--------\n",
      "Val loss = -0.8309\n",
      "\n",
      "--------Epoch 143--------\n",
      "Val loss = -0.8348\n",
      "\n",
      "--------Epoch 144--------\n",
      "Val loss = -0.8337\n",
      "\n",
      "--------Epoch 145--------\n",
      "Val loss = -0.8338\n",
      "\n",
      "--------Epoch 146--------\n",
      "Val loss = -0.8344\n",
      "\n",
      "--------Epoch 147--------\n",
      "Val loss = -0.8420\n",
      "\n",
      "--------Epoch 148--------\n",
      "Val loss = -0.8349\n",
      "\n",
      "--------Epoch 149--------\n",
      "Val loss = -0.8375\n",
      "\n",
      "--------Epoch 150--------\n",
      "Val loss = -0.8361\n",
      "\n",
      "--------Epoch 151--------\n",
      "Val loss = -0.8388\n",
      "\n",
      "--------Epoch 152--------\n",
      "Val loss = -0.8399\n",
      "\n",
      "--------Epoch 153--------\n",
      "Val loss = -0.8398\n",
      "\n",
      "--------Epoch 154--------\n",
      "Val loss = -0.8479\n",
      "\n",
      "--------Epoch 155--------\n",
      "Val loss = -0.8452\n",
      "\n",
      "--------Epoch 156--------\n",
      "Val loss = -0.8430\n",
      "\n",
      "--------Epoch 157--------\n",
      "Val loss = -0.8402\n",
      "\n",
      "--------Epoch 158--------\n",
      "Val loss = -0.8420\n",
      "\n",
      "--------Epoch 159--------\n",
      "Val loss = -0.8427\n",
      "\n",
      "--------Epoch 160--------\n",
      "Val loss = -0.8349\n",
      "\n",
      "--------Epoch 161--------\n",
      "Val loss = -0.8456\n",
      "\n",
      "--------Epoch 162--------\n",
      "Val loss = -0.8428\n",
      "\n",
      "--------Epoch 163--------\n",
      "Val loss = -0.8402\n",
      "\n",
      "--------Epoch 164--------\n",
      "Val loss = -0.8434\n",
      "\n",
      "--------Epoch 165--------\n",
      "Val loss = -0.8465\n",
      "\n",
      "--------Epoch 166--------\n",
      "Val loss = -0.8435\n",
      "\n",
      "--------Epoch 167--------\n",
      "Val loss = -0.8494\n",
      "\n",
      "--------Epoch 168--------\n",
      "Val loss = -0.8424\n",
      "\n",
      "--------Epoch 169--------\n",
      "Val loss = -0.8395\n",
      "\n",
      "--------Epoch 170--------\n",
      "Val loss = -0.8452\n",
      "\n",
      "--------Epoch 171--------\n",
      "Val loss = -0.8413\n",
      "\n",
      "--------Epoch 172--------\n",
      "Val loss = -0.8490\n",
      "\n",
      "--------Epoch 173--------\n",
      "Val loss = -0.8442\n",
      "\n",
      "--------Epoch 174--------\n",
      "Val loss = -0.8400\n",
      "\n",
      "--------Epoch 175--------\n",
      "Val loss = -0.8421\n",
      "\n",
      "--------Epoch 176--------\n",
      "Val loss = -0.8415\n",
      "\n",
      "--------Epoch 177--------\n",
      "Val loss = -0.8437\n",
      "\n",
      "--------Epoch 178--------\n",
      "Val loss = -0.8432\n",
      "\n",
      "--------Epoch 179--------\n",
      "Val loss = -0.8477\n",
      "\n",
      "--------Epoch 180--------\n",
      "Val loss = -0.8445\n",
      "\n",
      "--------Epoch 181--------\n",
      "Val loss = -0.8473\n",
      "\n",
      "--------Epoch 182--------\n",
      "Val loss = -0.8453\n",
      "\n",
      "--------Epoch 183--------\n",
      "Val loss = -0.8478\n",
      "\n",
      "--------Epoch 184--------\n",
      "Val loss = -0.8519\n",
      "\n",
      "--------Epoch 185--------\n",
      "Val loss = -0.8429\n",
      "\n",
      "--------Epoch 186--------\n",
      "Val loss = -0.8441\n",
      "\n",
      "--------Epoch 187--------\n",
      "Val loss = -0.8492\n",
      "\n",
      "--------Epoch 188--------\n",
      "Val loss = -0.8445\n",
      "\n",
      "--------Epoch 189--------\n",
      "Val loss = -0.8499\n",
      "\n",
      "--------Epoch 190--------\n",
      "Val loss = -0.8395\n",
      "\n",
      "--------Epoch 191--------\n",
      "Val loss = -0.8443\n",
      "\n",
      "--------Epoch 192--------\n",
      "Val loss = -0.8507\n",
      "\n",
      "--------Epoch 193--------\n",
      "Val loss = -0.8482\n",
      "\n",
      "--------Epoch 194--------\n",
      "Val loss = -0.8495\n",
      "\n",
      "--------Epoch 195--------\n",
      "Val loss = -0.8501\n",
      "\n",
      "--------Epoch 196--------\n",
      "Val loss = -0.8458\n",
      "\n",
      "--------Epoch 197--------\n",
      "Val loss = -0.8435\n",
      "\n",
      "--------Epoch 198--------\n",
      "Val loss = -0.8463\n",
      "\n",
      "--------Epoch 199--------\n",
      "Val loss = -0.8415\n",
      "\n",
      "--------Epoch 200--------\n",
      "Val loss = -0.8497\n",
      "\n",
      "--------Epoch 201--------\n",
      "Val loss = -0.8472\n",
      "\n",
      "--------Epoch 202--------\n",
      "Val loss = -0.8430\n",
      "\n",
      "--------Epoch 203--------\n",
      "Val loss = -0.8455\n",
      "\n",
      "--------Epoch 204--------\n",
      "Val loss = -0.8450\n",
      "\n",
      "--------Epoch 205--------\n",
      "Val loss = -0.8446\n",
      "\n",
      "--------Epoch 206--------\n",
      "Val loss = -0.8463\n",
      "\n",
      "--------Epoch 207--------\n",
      "Val loss = -0.8506\n",
      "\n",
      "--------Epoch 208--------\n",
      "Val loss = -0.8452\n",
      "\n",
      "--------Epoch 209--------\n",
      "Val loss = -0.8494\n",
      "\n",
      "--------Epoch 210--------\n",
      "Val loss = -0.8508\n",
      "\n",
      "--------Epoch 211--------\n",
      "Val loss = -0.8429\n",
      "\n",
      "--------Epoch 212--------\n",
      "Val loss = -0.8427\n",
      "\n",
      "--------Epoch 213--------\n",
      "Val loss = -0.8461\n",
      "\n",
      "--------Epoch 214--------\n",
      "Val loss = -0.8492\n",
      "\n",
      "--------Epoch 215--------\n",
      "Val loss = -0.8486\n",
      "\n",
      "--------Epoch 216--------\n",
      "Val loss = -0.8458\n",
      "\n",
      "--------Epoch 217--------\n",
      "Val loss = -0.8452\n",
      "\n",
      "--------Epoch 218--------\n",
      "Val loss = -0.8526\n",
      "\n",
      "--------Epoch 219--------\n",
      "Val loss = -0.8461\n",
      "\n",
      "--------Epoch 220--------\n",
      "Val loss = -0.8479\n",
      "\n",
      "--------Epoch 221--------\n",
      "Val loss = -0.8414\n",
      "\n",
      "--------Epoch 222--------\n",
      "Val loss = -0.8450\n",
      "\n",
      "--------Epoch 223--------\n",
      "Val loss = -0.8393\n",
      "\n",
      "--------Epoch 224--------\n",
      "Val loss = -0.8448\n",
      "\n",
      "--------Epoch 225--------\n",
      "Val loss = -0.8435\n",
      "\n",
      "--------Epoch 226--------\n",
      "Val loss = -0.8398\n",
      "\n",
      "--------Epoch 227--------\n",
      "Val loss = -0.8329\n",
      "\n",
      "--------Epoch 228--------\n",
      "Val loss = -0.8438\n",
      "\n",
      "--------Epoch 229--------\n",
      "Val loss = -0.8429\n",
      "\n",
      "--------Epoch 230--------\n",
      "Val loss = -0.8366\n",
      "\n",
      "--------Epoch 231--------\n",
      "Val loss = -0.8369\n",
      "\n",
      "--------Epoch 232--------\n",
      "Val loss = -0.8276\n",
      "\n",
      "--------Epoch 233--------\n",
      "Val loss = -0.8345\n",
      "\n",
      "--------Epoch 234--------\n",
      "Val loss = -0.8350\n",
      "\n",
      "--------Epoch 235--------\n",
      "Val loss = -0.8380\n",
      "\n",
      "--------Epoch 236--------\n",
      "Val loss = -0.8319\n",
      "\n",
      "--------Epoch 237--------\n",
      "Val loss = -0.8376\n",
      "\n",
      "--------Epoch 238--------\n",
      "Val loss = -0.8308\n",
      "\n",
      "--------Epoch 239--------\n",
      "Val loss = -0.8196\n",
      "\n",
      "--------Epoch 240--------\n",
      "Val loss = -0.8267\n",
      "\n",
      "--------Epoch 241--------\n",
      "Val loss = -0.8297\n",
      "\n",
      "--------Epoch 242--------\n",
      "Val loss = -0.8326\n",
      "\n",
      "--------Epoch 243--------\n",
      "Val loss = -0.8227\n",
      "\n",
      "--------Epoch 244--------\n",
      "Val loss = -0.8242\n",
      "\n",
      "--------Epoch 245--------\n",
      "Val loss = -0.8276\n",
      "\n",
      "--------Epoch 246--------\n",
      "Val loss = -0.8151\n",
      "\n",
      "--------Epoch 247--------\n",
      "Val loss = -0.8274\n",
      "\n",
      "--------Epoch 248--------\n",
      "Val loss = -0.8331\n",
      "\n",
      "--------Epoch 249--------\n",
      "Val loss = -0.8278\n",
      "\n",
      "--------Epoch 250--------\n",
      "Val loss = -0.8248\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up selector\n",
    "selector = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_in))\n",
    "selector_layer = ConcreteSelector()\n",
    "gafs = GreedyAdaptiveFS(selector, deepcopy(model), mask_layer, selector_layer).to(device)\n",
    "\n",
    "# Train\n",
    "gafs.fit(train_dataset,\n",
    "         val_dataset,\n",
    "         mbsize=128,\n",
    "         lr=2e-4,\n",
    "         nepochs=250,\n",
    "         max_features=max_features,\n",
    "         loss_fn=nn.CrossEntropyLoss(),\n",
    "         val_loss_fn=NegAccuracy(),\n",
    "         start_temp=10.0,\n",
    "         end_temp=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "willing-compatibility",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test acc = 84.90\n"
     ]
    }
   ],
   "source": [
    "test_acc = gafs.evaluate(test_dataset, max_features, Accuracy(), 1024)\n",
    "print(f'Test acc = {100*test_acc:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "appointed-motorcycle",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
