{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fcd9578a-ccd9-48c0-a024-c2152a18cb9d",
   "metadata": {},
   "source": [
    "# No repeats\n",
    "\n",
    "- When using the selector model, during training and/or inference, it is possible to prevent it from selecting features it has already seen. This version tests to what extent this can help performance.\n",
    "- The results appear to show that it should be used during inference (it obviously can only help), but it's better not to use during training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "generic-average",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n",
    "from torchvision import transforms\n",
    "from torchvision.datasets import MNIST\n",
    "from copy import deepcopy\n",
    "from utils import *\n",
    "from models import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "controversial-efficiency",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda', 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "featured-lucas",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load data\n",
    "class Flatten(object):\n",
    "    def __call__(self, pic):\n",
    "        return torch.flatten(pic)\n",
    "    \n",
    "mnist_dataset = MNIST('/tmp/mnist/', download=True, train=True,\n",
    "                      transform=transforms.Compose([transforms.ToTensor(), Flatten()]))\n",
    "images = mnist_dataset.data\n",
    "targets = mnist_dataset.targets\n",
    "np.random.seed(0)\n",
    "val_inds = np.sort(np.random.choice(len(images), size=10000, replace=False))\n",
    "train_inds = np.setdiff1d(np.arange(len(images)), val_inds)\n",
    "\n",
    "# Training dataset\n",
    "train_dataset = torch.utils.data.Subset(mnist_dataset, train_inds)\n",
    "\n",
    "# Validation dataset\n",
    "val_dataset = torch.utils.data.Subset(mnist_dataset, val_inds)\n",
    "\n",
    "# Test dataset\n",
    "test_dataset = MNIST('/tmp/mnist/', download=True, train=False,\n",
    "                     transform=transforms.Compose([transforms.ToTensor(), Flatten()]))\n",
    "\n",
    "# Set input/output dimensions\n",
    "d_in = 784\n",
    "d_out = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "persistent-patch",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Number of features to select\n",
    "max_features = 10"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "digital-monitor",
   "metadata": {},
   "source": [
    "# Global FS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "missing-sunrise",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.2958\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.3116\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.3207\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.3371\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.3239\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.3350\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.3415\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.3530\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.3490\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.3513\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.3561\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.3638\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.3713\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.3608\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.3685\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.3922\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.3784\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.3839\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.3929\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.4015\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = -0.4015\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = -0.4055\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = -0.4035\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = -0.4150\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = -0.4209\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = -0.4209\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = -0.4248\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = -0.4315\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = -0.4369\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = -0.4339\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = -0.4423\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = -0.4431\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = -0.4466\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = -0.4478\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = -0.4630\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = -0.4591\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = -0.4603\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = -0.4641\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = -0.4726\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = -0.4650\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = -0.4803\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = -0.4705\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = -0.4811\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = -0.4945\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = -0.5002\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = -0.5040\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = -0.5093\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = -0.5123\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = -0.5064\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = -0.5171\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = -0.5138\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = -0.5210\n",
      "\n",
      "--------Epoch 53--------\n",
      "Val loss = -0.5239\n",
      "\n",
      "--------Epoch 54--------\n",
      "Val loss = -0.5312\n",
      "\n",
      "--------Epoch 55--------\n",
      "Val loss = -0.5314\n",
      "\n",
      "--------Epoch 56--------\n",
      "Val loss = -0.5300\n",
      "\n",
      "--------Epoch 57--------\n",
      "Val loss = -0.5457\n",
      "\n",
      "--------Epoch 58--------\n",
      "Val loss = -0.5422\n",
      "\n",
      "--------Epoch 59--------\n",
      "Val loss = -0.5465\n",
      "\n",
      "--------Epoch 60--------\n",
      "Val loss = -0.5447\n",
      "\n",
      "--------Epoch 61--------\n",
      "Val loss = -0.5488\n",
      "\n",
      "--------Epoch 62--------\n",
      "Val loss = -0.5490\n",
      "\n",
      "--------Epoch 63--------\n",
      "Val loss = -0.5457\n",
      "\n",
      "--------Epoch 64--------\n",
      "Val loss = -0.5524\n",
      "\n",
      "--------Epoch 65--------\n",
      "Val loss = -0.5549\n",
      "\n",
      "--------Epoch 66--------\n",
      "Val loss = -0.5567\n",
      "\n",
      "--------Epoch 67--------\n",
      "Val loss = -0.5617\n",
      "\n",
      "--------Epoch 68--------\n",
      "Val loss = -0.5671\n",
      "\n",
      "--------Epoch 69--------\n",
      "Val loss = -0.5732\n",
      "\n",
      "--------Epoch 70--------\n",
      "Val loss = -0.5748\n",
      "\n",
      "--------Epoch 71--------\n",
      "Val loss = -0.5771\n",
      "\n",
      "--------Epoch 72--------\n",
      "Val loss = -0.5873\n",
      "\n",
      "--------Epoch 73--------\n",
      "Val loss = -0.5879\n",
      "\n",
      "--------Epoch 74--------\n",
      "Val loss = -0.5938\n",
      "\n",
      "--------Epoch 75--------\n",
      "Val loss = -0.5966\n",
      "\n",
      "--------Epoch 76--------\n",
      "Val loss = -0.5951\n",
      "\n",
      "--------Epoch 77--------\n",
      "Val loss = -0.5931\n",
      "\n",
      "--------Epoch 78--------\n",
      "Val loss = -0.6006\n",
      "\n",
      "--------Epoch 79--------\n",
      "Val loss = -0.5982\n",
      "\n",
      "--------Epoch 80--------\n",
      "Val loss = -0.5872\n",
      "\n",
      "--------Epoch 81--------\n",
      "Val loss = -0.5983\n",
      "\n",
      "--------Epoch 82--------\n",
      "Val loss = -0.6056\n",
      "\n",
      "--------Epoch 83--------\n",
      "Val loss = -0.6064\n",
      "\n",
      "--------Epoch 84--------\n",
      "Val loss = -0.6175\n",
      "\n",
      "--------Epoch 85--------\n",
      "Val loss = -0.6138\n",
      "\n",
      "--------Epoch 86--------\n",
      "Val loss = -0.6299\n",
      "\n",
      "--------Epoch 87--------\n",
      "Val loss = -0.6163\n",
      "\n",
      "--------Epoch 88--------\n",
      "Val loss = -0.6201\n",
      "\n",
      "--------Epoch 89--------\n",
      "Val loss = -0.6167\n",
      "\n",
      "--------Epoch 90--------\n",
      "Val loss = -0.6295\n",
      "\n",
      "--------Epoch 91--------\n",
      "Val loss = -0.6321\n",
      "\n",
      "--------Epoch 92--------\n",
      "Val loss = -0.6396\n",
      "\n",
      "--------Epoch 93--------\n",
      "Val loss = -0.6311\n",
      "\n",
      "--------Epoch 94--------\n",
      "Val loss = -0.6296\n",
      "\n",
      "--------Epoch 95--------\n",
      "Val loss = -0.6327\n",
      "\n",
      "--------Epoch 96--------\n",
      "Val loss = -0.6394\n",
      "\n",
      "--------Epoch 97--------\n",
      "Val loss = -0.6442\n",
      "\n",
      "--------Epoch 98--------\n",
      "Val loss = -0.6494\n",
      "\n",
      "--------Epoch 99--------\n",
      "Val loss = -0.6471\n",
      "\n",
      "--------Epoch 100--------\n",
      "Val loss = -0.6471\n",
      "\n",
      "--------Epoch 101--------\n",
      "Val loss = -0.6521\n",
      "\n",
      "--------Epoch 102--------\n",
      "Val loss = -0.6528\n",
      "\n",
      "--------Epoch 103--------\n",
      "Val loss = -0.6536\n",
      "\n",
      "--------Epoch 104--------\n",
      "Val loss = -0.6550\n",
      "\n",
      "--------Epoch 105--------\n",
      "Val loss = -0.6594\n",
      "\n",
      "--------Epoch 106--------\n",
      "Val loss = -0.6523\n",
      "\n",
      "--------Epoch 107--------\n",
      "Val loss = -0.6599\n",
      "\n",
      "--------Epoch 108--------\n",
      "Val loss = -0.6655\n",
      "\n",
      "--------Epoch 109--------\n",
      "Val loss = -0.6665\n",
      "\n",
      "--------Epoch 110--------\n",
      "Val loss = -0.6740\n",
      "\n",
      "--------Epoch 111--------\n",
      "Val loss = -0.6663\n",
      "\n",
      "--------Epoch 112--------\n",
      "Val loss = -0.6762\n",
      "\n",
      "--------Epoch 113--------\n",
      "Val loss = -0.6775\n",
      "\n",
      "--------Epoch 114--------\n",
      "Val loss = -0.6762\n",
      "\n",
      "--------Epoch 115--------\n",
      "Val loss = -0.6808\n",
      "\n",
      "--------Epoch 116--------\n",
      "Val loss = -0.6752\n",
      "\n",
      "--------Epoch 117--------\n",
      "Val loss = -0.6844\n",
      "\n",
      "--------Epoch 118--------\n",
      "Val loss = -0.6829\n",
      "\n",
      "--------Epoch 119--------\n",
      "Val loss = -0.6878\n",
      "\n",
      "--------Epoch 120--------\n",
      "Val loss = -0.6880\n",
      "\n",
      "--------Epoch 121--------\n",
      "Val loss = -0.6883\n",
      "\n",
      "--------Epoch 122--------\n",
      "Val loss = -0.6872\n",
      "\n",
      "--------Epoch 123--------\n",
      "Val loss = -0.6922\n",
      "\n",
      "--------Epoch 124--------\n",
      "Val loss = -0.6971\n",
      "\n",
      "--------Epoch 125--------\n",
      "Val loss = -0.7104\n",
      "\n",
      "--------Epoch 126--------\n",
      "Val loss = -0.6997\n",
      "\n",
      "--------Epoch 127--------\n",
      "Val loss = -0.6957\n",
      "\n",
      "--------Epoch 128--------\n",
      "Val loss = -0.7067\n",
      "\n",
      "--------Epoch 129--------\n",
      "Val loss = -0.6999\n",
      "\n",
      "--------Epoch 130--------\n",
      "Val loss = -0.7108\n",
      "\n",
      "--------Epoch 131--------\n",
      "Val loss = -0.7120\n",
      "\n",
      "--------Epoch 132--------\n",
      "Val loss = -0.7090\n",
      "\n",
      "--------Epoch 133--------\n",
      "Val loss = -0.7163\n",
      "\n",
      "--------Epoch 134--------\n",
      "Val loss = -0.7119\n",
      "\n",
      "--------Epoch 135--------\n",
      "Val loss = -0.7146\n",
      "\n",
      "--------Epoch 136--------\n",
      "Val loss = -0.7104\n",
      "\n",
      "--------Epoch 137--------\n",
      "Val loss = -0.7229\n",
      "\n",
      "--------Epoch 138--------\n",
      "Val loss = -0.7188\n",
      "\n",
      "--------Epoch 139--------\n",
      "Val loss = -0.7263\n",
      "\n",
      "--------Epoch 140--------\n",
      "Val loss = -0.7257\n",
      "\n",
      "--------Epoch 141--------\n",
      "Val loss = -0.7199\n",
      "\n",
      "--------Epoch 142--------\n",
      "Val loss = -0.7258\n",
      "\n",
      "--------Epoch 143--------\n",
      "Val loss = -0.7269\n",
      "\n",
      "--------Epoch 144--------\n",
      "Val loss = -0.7328\n",
      "\n",
      "--------Epoch 145--------\n",
      "Val loss = -0.7292\n",
      "\n",
      "--------Epoch 146--------\n",
      "Val loss = -0.7281\n",
      "\n",
      "--------Epoch 147--------\n",
      "Val loss = -0.7396\n",
      "\n",
      "--------Epoch 148--------\n",
      "Val loss = -0.7431\n",
      "\n",
      "--------Epoch 149--------\n",
      "Val loss = -0.7365\n",
      "\n",
      "--------Epoch 150--------\n",
      "Val loss = -0.7387\n",
      "\n",
      "--------Epoch 151--------\n",
      "Val loss = -0.7337\n",
      "\n",
      "--------Epoch 152--------\n",
      "Val loss = -0.7364\n",
      "\n",
      "--------Epoch 153--------\n",
      "Val loss = -0.7402\n",
      "\n",
      "--------Epoch 154--------\n",
      "Val loss = -0.7462\n",
      "\n",
      "--------Epoch 155--------\n",
      "Val loss = -0.7425\n",
      "\n",
      "--------Epoch 156--------\n",
      "Val loss = -0.7455\n",
      "\n",
      "--------Epoch 157--------\n",
      "Val loss = -0.7425\n",
      "\n",
      "--------Epoch 158--------\n",
      "Val loss = -0.7438\n",
      "\n",
      "--------Epoch 159--------\n",
      "Val loss = -0.7504\n",
      "\n",
      "--------Epoch 160--------\n",
      "Val loss = -0.7484\n",
      "\n",
      "--------Epoch 161--------\n",
      "Val loss = -0.7517\n",
      "\n",
      "--------Epoch 162--------\n",
      "Val loss = -0.7493\n",
      "\n",
      "--------Epoch 163--------\n",
      "Val loss = -0.7479\n",
      "\n",
      "--------Epoch 164--------\n",
      "Val loss = -0.7481\n",
      "\n",
      "--------Epoch 165--------\n",
      "Val loss = -0.7513\n",
      "\n",
      "--------Epoch 166--------\n",
      "Val loss = -0.7511\n",
      "\n",
      "--------Epoch 167--------\n",
      "Val loss = -0.7544\n",
      "\n",
      "--------Epoch 168--------\n",
      "Val loss = -0.7562\n",
      "\n",
      "--------Epoch 169--------\n",
      "Val loss = -0.7571\n",
      "\n",
      "--------Epoch 170--------\n",
      "Val loss = -0.7524\n",
      "\n",
      "--------Epoch 171--------\n",
      "Val loss = -0.7605\n",
      "\n",
      "--------Epoch 172--------\n",
      "Val loss = -0.7576\n",
      "\n",
      "--------Epoch 173--------\n",
      "Val loss = -0.7576\n",
      "\n",
      "--------Epoch 174--------\n",
      "Val loss = -0.7616\n",
      "\n",
      "--------Epoch 175--------\n",
      "Val loss = -0.7618\n",
      "\n",
      "--------Epoch 176--------\n",
      "Val loss = -0.7576\n",
      "\n",
      "--------Epoch 177--------\n",
      "Val loss = -0.7598\n",
      "\n",
      "--------Epoch 178--------\n",
      "Val loss = -0.7600\n",
      "\n",
      "--------Epoch 179--------\n",
      "Val loss = -0.7609\n",
      "\n",
      "--------Epoch 180--------\n",
      "Val loss = -0.7623\n",
      "\n",
      "--------Epoch 181--------\n",
      "Val loss = -0.7587\n",
      "\n",
      "--------Epoch 182--------\n",
      "Val loss = -0.7584\n",
      "\n",
      "--------Epoch 183--------\n",
      "Val loss = -0.7616\n",
      "\n",
      "--------Epoch 184--------\n",
      "Val loss = -0.7547\n",
      "\n",
      "--------Epoch 185--------\n",
      "Val loss = -0.7604\n",
      "\n",
      "--------Epoch 186--------\n",
      "Val loss = -0.7617\n",
      "\n",
      "--------Epoch 187--------\n",
      "Val loss = -0.7576\n",
      "\n",
      "--------Epoch 188--------\n",
      "Val loss = -0.7591\n",
      "\n",
      "--------Epoch 189--------\n",
      "Val loss = -0.7584\n",
      "\n",
      "--------Epoch 190--------\n",
      "Val loss = -0.7582\n",
      "\n",
      "--------Epoch 191--------\n",
      "Val loss = -0.7572\n",
      "\n",
      "--------Epoch 192--------\n",
      "Val loss = -0.7601\n",
      "\n",
      "--------Epoch 193--------\n",
      "Val loss = -0.7574\n",
      "\n",
      "--------Epoch 194--------\n",
      "Val loss = -0.7617\n",
      "\n",
      "--------Epoch 195--------\n",
      "Val loss = -0.7622\n",
      "\n",
      "--------Epoch 196--------\n",
      "Val loss = -0.7602\n",
      "\n",
      "--------Epoch 197--------\n",
      "Val loss = -0.7616\n",
      "\n",
      "--------Epoch 198--------\n",
      "Val loss = -0.7576\n",
      "\n",
      "--------Epoch 199--------\n",
      "Val loss = -0.7572\n",
      "\n",
      "--------Epoch 200--------\n",
      "Val loss = -0.7622\n",
      "\n",
      "--------Epoch 201--------\n",
      "Val loss = -0.7592\n",
      "\n",
      "--------Epoch 202--------\n",
      "Val loss = -0.7613\n",
      "\n",
      "--------Epoch 203--------\n",
      "Val loss = -0.7606\n",
      "\n",
      "--------Epoch 204--------\n",
      "Val loss = -0.7586\n",
      "\n",
      "--------Epoch 205--------\n",
      "Val loss = -0.7590\n",
      "\n",
      "--------Epoch 206--------\n",
      "Val loss = -0.7616\n",
      "\n",
      "--------Epoch 207--------\n",
      "Val loss = -0.7559\n",
      "\n",
      "--------Epoch 208--------\n",
      "Val loss = -0.7571\n",
      "\n",
      "--------Epoch 209--------\n",
      "Val loss = -0.7537\n",
      "\n",
      "--------Epoch 210--------\n",
      "Val loss = -0.7494\n",
      "\n",
      "--------Epoch 211--------\n",
      "Val loss = -0.7629\n",
      "\n",
      "--------Epoch 212--------\n",
      "Val loss = -0.7550\n",
      "\n",
      "--------Epoch 213--------\n",
      "Val loss = -0.7588\n",
      "\n",
      "--------Epoch 214--------\n",
      "Val loss = -0.7546\n",
      "\n",
      "--------Epoch 215--------\n",
      "Val loss = -0.7565\n",
      "\n",
      "--------Epoch 216--------\n",
      "Val loss = -0.7526\n",
      "\n",
      "--------Epoch 217--------\n",
      "Val loss = -0.7578\n",
      "\n",
      "--------Epoch 218--------\n",
      "Val loss = -0.7513\n",
      "\n",
      "--------Epoch 219--------\n",
      "Val loss = -0.7568\n",
      "\n",
      "--------Epoch 220--------\n",
      "Val loss = -0.7524\n",
      "\n",
      "--------Epoch 221--------\n",
      "Val loss = -0.7535\n",
      "\n",
      "--------Epoch 222--------\n",
      "Val loss = -0.7573\n",
      "\n",
      "--------Epoch 223--------\n",
      "Val loss = -0.7550\n",
      "\n",
      "--------Epoch 224--------\n",
      "Val loss = -0.7562\n",
      "\n",
      "--------Epoch 225--------\n",
      "Val loss = -0.7507\n",
      "\n",
      "--------Epoch 226--------\n",
      "Val loss = -0.7530\n",
      "\n",
      "--------Epoch 227--------\n",
      "Val loss = -0.7498\n",
      "\n",
      "--------Epoch 228--------\n",
      "Val loss = -0.7519\n",
      "\n",
      "--------Epoch 229--------\n",
      "Val loss = -0.7513\n",
      "\n",
      "--------Epoch 230--------\n",
      "Val loss = -0.7466\n",
      "\n",
      "--------Epoch 231--------\n",
      "Val loss = -0.7500\n",
      "\n",
      "--------Epoch 232--------\n",
      "Val loss = -0.7516\n",
      "\n",
      "--------Epoch 233--------\n",
      "Val loss = -0.7457\n",
      "\n",
      "--------Epoch 234--------\n",
      "Val loss = -0.7492\n",
      "\n",
      "--------Epoch 235--------\n",
      "Val loss = -0.7482\n",
      "\n",
      "--------Epoch 236--------\n",
      "Val loss = -0.7488\n",
      "\n",
      "--------Epoch 237--------\n",
      "Val loss = -0.7499\n",
      "\n",
      "--------Epoch 238--------\n",
      "Val loss = -0.7479\n",
      "\n",
      "--------Epoch 239--------\n",
      "Val loss = -0.7453\n",
      "\n",
      "--------Epoch 240--------\n",
      "Val loss = -0.7455\n",
      "\n",
      "--------Epoch 241--------\n",
      "Val loss = -0.7447\n",
      "\n",
      "--------Epoch 242--------\n",
      "Val loss = -0.7468\n",
      "\n",
      "--------Epoch 243--------\n",
      "Val loss = -0.7449\n",
      "\n",
      "--------Epoch 244--------\n",
      "Val loss = -0.7459\n",
      "\n",
      "--------Epoch 245--------\n",
      "Val loss = -0.7464\n",
      "\n",
      "--------Epoch 246--------\n",
      "Val loss = -0.7438\n",
      "\n",
      "--------Epoch 247--------\n",
      "Val loss = -0.7402\n",
      "\n",
      "--------Epoch 248--------\n",
      "Val loss = -0.7425\n",
      "\n",
      "--------Epoch 249--------\n",
      "Val loss = -0.7435\n",
      "\n",
      "--------Epoch 250--------\n",
      "Val loss = -0.7416\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up model\n",
    "global_model = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_out))\n",
    "selector = ConcreteMask(d_in, max_features, append=True)\n",
    "globalfs = GlobalSelector(global_model, selector).to(device)\n",
    "\n",
    "# Train\n",
    "globalfs.fit(train_dataset,\n",
    "             val_dataset,\n",
    "             mbsize=128,\n",
    "             lr=1e-3,\n",
    "             nepochs=250,\n",
    "             loss_fn=nn.CrossEntropyLoss(),\n",
    "             val_loss_fn=NegAccuracy(),\n",
    "             start_temp=1.0,\n",
    "             end_temp=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "focal-bathroom",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test acc = 76.30\n"
     ]
    }
   ],
   "source": [
    "test_acc = globalfs.evaluate(test_dataset, Accuracy(), 1024)\n",
    "print(f'Test acc = {100*test_acc:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "blocked-flashing",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAKxElEQVR4nO3dT6hc93mH8edbV1ZASUGqa6M6pkmDFzWFKuWiFlyKi2nqeCNnkRItggoGZRFDAlnUpIt4aUqT0EUJKLWIWlKHQGKshWkiRMBkY3xtVFuu2to1aqJISA1exClUlp23i3tcbuz7zzNn/kjv84HLzJyZe+dl8KMzM2fGv1QVkm58v7LoASTNh7FLTRi71ISxS00Yu9TEr87zzm7O7nofe+Z5l1Ir/8v/8EZdzUbXTRV7kvuAvwVuAv6+qh7d6vbvYw9/kHunuUtJW3imTm963cRP45PcBPwd8HHgLuBwkrsm/XuSZmua1+wHgVeq6tWqegP4FnBonLEkjW2a2G8Hfrzu8oVh2y9JcjTJapLVa1yd4u4kTWOa2Dd6E+Bdn72tqmNVtVJVK7vYPcXdSZrGNLFfAO5Yd/mDwMXpxpE0K9PE/ixwZ5IPJ7kZ+BRwcpyxJI1t4kNvVfVmkoeA77F26O14Vb002mSSRjXVcfaqegp4aqRZJM2QH5eVmjB2qQljl5owdqkJY5eaMHapCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmjB2qQljl5owdqkJY5eamGoVV+l7F89sef2f/eaBucyh7U0Ve5LzwOvAW8CbVbUyxlCSxjfGnv1PquqnI/wdSTPka3apiWljL+D7SZ5LcnSjGyQ5mmQ1yeo1rk55d5ImNe3T+Lur6mKSW4FTSf6tqp5ef4OqOgYcA/i17Ksp70/ShKbas1fVxeH0CvAEcHCMoSSNb+LYk+xJ8oG3zwMfA86ONZikcU3zNP424Ikkb/+df6qqfx5lKl03PI5+/Zg49qp6Ffi9EWeRNEMeepOaMHapCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmjB2qQljl5owdqkJY5eaMHapCWOXmnDJZm3JJZlvHO7ZpSaMXWrC2KUmjF1qwtilJoxdasLYpSY8zq4tLfI4usf4x7Xtnj3J8SRXkpxdt21fklNJXh5O9852TEnT2snT+G8A971j28PA6aq6Ezg9XJa0xLaNvaqeBl57x+ZDwInh/AnggXHHkjS2Sd+gu62qLgEMp7dudsMkR5OsJlm9xtUJ707StGb+bnxVHauqlapa2cXuWd+dpE1MGvvlJPsBhtMr440kaRYmjf0kcGQ4fwR4cpxxJM3KtsfZkzwO3APckuQC8CXgUeDbSR4EfgR8cpZDqiePo49r29ir6vAmV9078iySZsiPy0pNGLvUhLFLTRi71ISxS034Fdcb3LRfE/VrpjcO9+xSE8YuNWHsUhPGLjVh7FITxi41YexSEx5nn4NFHque9m97HP3G4Z5dasLYpSaMXWrC2KUmjF1qwtilJoxdasLj7HNwPR+rnuVnBPyu/Hy5Z5eaMHapCWOXmjB2qQljl5owdqkJY5ea8Di7trTM37XXe7Ptnj3J8SRXkpxdt+2RJD9Jcmb4uX+2Y0qa1k6exn8DuG+D7V+tqgPDz1PjjiVpbNvGXlVPA6/NYRZJMzTNG3QPJXlheJq/d7MbJTmaZDXJ6jWuTnF3kqYxaexfAz4CHAAuAV/e7IZVdayqVqpqZRe7J7w7SdOaKPaqulxVb1XVL4CvAwfHHUvS2CaKPcn+dRc/AZzd7LaSlsO2x9mTPA7cA9yS5ALwJeCeJAeAAs4Dn5ndiJLGsG3sVXV4g82PzWAWSTPkx2WlJoxdasLYpSaMXWrC2KUm/IqrtuT/7vnG4Z5dasLYpSaMXWrC2KUmjF1qwtilJoxdasLj7NeBRR7r9jj6jcM9u9SEsUtNGLvUhLFLTRi71ISxS00Yu9SEx9mvAx7r1hjcs0tNGLvUhLFLTRi71ISxS00Yu9SEsUtNGLvUxLaxJ7kjyQ+SnEvyUpLPDdv3JTmV5OXhdO/sx5U0qZ3s2d8EvlBVvwP8IfDZJHcBDwOnq+pO4PRwWdKS2jb2qrpUVc8P518HzgG3A4eAE8PNTgAPzGhGSSN4T6/Zk3wI+CjwDHBbVV2CtX8QgFs3+Z2jSVaTrF7j6pTjSprUjmNP8n7gO8Dnq+pnO/29qjpWVStVtbKL3ZPMKGkEO4o9yS7WQv9mVX132Hw5yf7h+v3AldmMKGkMO3k3PsBjwLmq+sq6q04CR4bzR4Anxx9P0lh28n32u4FPAy8mOTNs+yLwKPDtJA8CPwI+OZMJJY1i29ir6odANrn63nHHkTQrfoJOasLYpSaMXWrC2KUmjF1qwtilJoxdasLYpSaMXWrC2KUmjF1qwtilJoxdasLYpSaMXWrC2KUmjF1qwtilJoxdasLYpSaMXWrC2KUmjF1qwtilJoxdasLYpSaMXWrC2KUmjF1qYifrs9+R5AdJziV5Kcnnhu2PJPlJkjPDz/2zH1fSpHayPvubwBeq6vkkHwCeS3JquO6rVfU3sxtP0lh2sj77JeDScP71JOeA22c9mKRxvafX7Ek+BHwUeGbY9FCSF5IcT7J3k985mmQ1yeo1rk43raSJ7Tj2JO8HvgN8vqp+BnwN+AhwgLU9/5c3+r2qOlZVK1W1sovd008saSI7ij3JLtZC/2ZVfRegqi5X1VtV9Qvg68DB2Y0paVo7eTc+wGPAuar6yrrt+9fd7BPA2fHHkzSWnbwbfzfwaeDFJGeGbV8EDic5ABRwHvjMDOaTNJKdvBv/QyAbXPXU+ONImhU/QSc1YexSE8YuNWHsUhPGLjVh7FITxi41YexSE8YuNWHsUhPGLjVh7FITxi41YexSE6mq+d1Z8t/Af63bdAvw07kN8N4s62zLOhc426TGnO23quo3NrpirrG/686T1apaWdgAW1jW2ZZ1LnC2Sc1rNp/GS00Yu9TEomM/tuD738qyzrasc4GzTWousy30Nbuk+Vn0nl3SnBi71MRCYk9yX5J/T/JKkocXMcNmkpxP8uKwDPXqgmc5nuRKkrPrtu1LcirJy8PphmvsLWi2pVjGe4tlxhf62C16+fO5v2ZPchPwH8CfAheAZ4HDVfWvcx1kE0nOAytVtfAPYCT5Y+DnwD9U1e8O2/4aeK2qHh3+odxbVX+5JLM9Avx80ct4D6sV7V+/zDjwAPAXLPCx22KuP2cOj9si9uwHgVeq6tWqegP4FnBoAXMsvap6GnjtHZsPASeG8ydY+49l7jaZbSlU1aWqen44/zrw9jLjC33stphrLhYR++3Aj9ddvsByrfdewPeTPJfk6KKH2cBtVXUJ1v7jAW5d8DzvtO0y3vP0jmXGl+axm2T582ktIvaNlpJapuN/d1fV7wMfBz47PF3VzuxoGe952WCZ8aUw6fLn01pE7BeAO9Zd/iBwcQFzbKiqLg6nV4AnWL6lqC+/vYLucHplwfP8v2VaxnujZcZZgsdukcufLyL2Z4E7k3w4yc3Ap4CTC5jjXZLsGd44Icke4GMs31LUJ4Ejw/kjwJMLnOWXLMsy3pstM86CH7uFL39eVXP/Ae5n7R35/wT+ahEzbDLXbwP/Mvy8tOjZgMdZe1p3jbVnRA8Cvw6cBl4eTvct0Wz/CLwIvMBaWPsXNNsfsfbS8AXgzPBz/6Ifuy3mmsvj5sdlpSb8BJ3UhLFLTRi71ISxS00Yu9SEsUtNGLvUxP8BubFf+o/I5PIAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "inds = selector.logits.argmax(dim=1).cpu().data.numpy()\n",
    "zeros = np.zeros(784)\n",
    "zeros[inds] = 1\n",
    "plt.figure()\n",
    "plt.imshow(zeros.reshape(28, 28))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "elder-windows",
   "metadata": {},
   "source": [
    "# Greedy adaptive FS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "portuguese-shark",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.2389\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.2710\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.2806\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.2797\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.2881\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.2928\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.2967\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.2992\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.3034\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.2992\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.3032\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.3093\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.3131\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.3082\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.3103\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.3042\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.3058\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.3083\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.3168\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.3193\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = -0.3145\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = -0.3202\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = -0.3199\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = -0.3174\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = -0.3210\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = -0.3189\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = -0.3238\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = -0.3246\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = -0.3190\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = -0.3274\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = -0.3215\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = -0.3162\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = -0.3315\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = -0.3176\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = -0.3207\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = -0.3188\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = -0.3280\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = -0.3223\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = -0.3239\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = -0.3290\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = -0.3342\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = -0.3239\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = -0.3264\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = -0.3269\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = -0.3339\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = -0.3209\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = -0.3263\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = -0.3284\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = -0.3358\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = -0.3287\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = -0.3348\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = -0.3264\n",
      "\n",
      "--------Epoch 53--------\n",
      "Val loss = -0.3282\n",
      "\n",
      "--------Epoch 54--------\n",
      "Val loss = -0.3291\n",
      "\n",
      "--------Epoch 55--------\n",
      "Val loss = -0.3266\n",
      "\n",
      "--------Epoch 56--------\n",
      "Val loss = -0.3329\n",
      "\n",
      "--------Epoch 57--------\n",
      "Val loss = -0.3317\n",
      "\n",
      "--------Epoch 58--------\n",
      "Val loss = -0.3247\n",
      "\n",
      "--------Epoch 59--------\n",
      "Val loss = -0.3324\n",
      "\n",
      "--------Epoch 60--------\n",
      "Val loss = -0.3306\n",
      "\n",
      "--------Epoch 61--------\n",
      "Val loss = -0.3346\n",
      "\n",
      "--------Epoch 62--------\n",
      "Val loss = -0.3327\n",
      "\n",
      "--------Epoch 63--------\n",
      "Val loss = -0.3321\n",
      "\n",
      "--------Epoch 64--------\n",
      "Val loss = -0.3395\n",
      "\n",
      "--------Epoch 65--------\n",
      "Val loss = -0.3301\n",
      "\n",
      "--------Epoch 66--------\n",
      "Val loss = -0.3402\n",
      "\n",
      "--------Epoch 67--------\n",
      "Val loss = -0.3315\n",
      "\n",
      "--------Epoch 68--------\n",
      "Val loss = -0.3406\n",
      "\n",
      "--------Epoch 69--------\n",
      "Val loss = -0.3281\n",
      "\n",
      "--------Epoch 70--------\n",
      "Val loss = -0.3310\n",
      "\n",
      "--------Epoch 71--------\n",
      "Val loss = -0.3329\n",
      "\n",
      "--------Epoch 72--------\n",
      "Val loss = -0.3348\n",
      "\n",
      "--------Epoch 73--------\n",
      "Val loss = -0.3390\n",
      "\n",
      "--------Epoch 74--------\n",
      "Val loss = -0.3408\n",
      "\n",
      "--------Epoch 75--------\n",
      "Val loss = -0.3300\n",
      "\n",
      "--------Epoch 76--------\n",
      "Val loss = -0.3310\n",
      "\n",
      "--------Epoch 77--------\n",
      "Val loss = -0.3373\n",
      "\n",
      "--------Epoch 78--------\n",
      "Val loss = -0.3377\n",
      "\n",
      "--------Epoch 79--------\n",
      "Val loss = -0.3368\n",
      "\n",
      "--------Epoch 80--------\n",
      "Val loss = -0.3423\n",
      "\n",
      "--------Epoch 81--------\n",
      "Val loss = -0.3278\n",
      "\n",
      "--------Epoch 82--------\n",
      "Val loss = -0.3328\n",
      "\n",
      "--------Epoch 83--------\n",
      "Val loss = -0.3371\n",
      "\n",
      "--------Epoch 84--------\n",
      "Val loss = -0.3355\n",
      "\n",
      "--------Epoch 85--------\n",
      "Val loss = -0.3307\n",
      "\n",
      "--------Epoch 86--------\n",
      "Val loss = -0.3399\n",
      "\n",
      "--------Epoch 87--------\n",
      "Val loss = -0.3392\n",
      "\n",
      "--------Epoch 88--------\n",
      "Val loss = -0.3338\n",
      "\n",
      "--------Epoch 89--------\n",
      "Val loss = -0.3469\n",
      "\n",
      "--------Epoch 90--------\n",
      "Val loss = -0.3428\n",
      "\n",
      "--------Epoch 91--------\n",
      "Val loss = -0.3438\n",
      "\n",
      "--------Epoch 92--------\n",
      "Val loss = -0.3385\n",
      "\n",
      "--------Epoch 93--------\n",
      "Val loss = -0.3401\n",
      "\n",
      "--------Epoch 94--------\n",
      "Val loss = -0.3291\n",
      "\n",
      "--------Epoch 95--------\n",
      "Val loss = -0.3353\n",
      "\n",
      "--------Epoch 96--------\n",
      "Val loss = -0.3348\n",
      "\n",
      "--------Epoch 97--------\n",
      "Val loss = -0.3387\n",
      "\n",
      "--------Epoch 98--------\n",
      "Val loss = -0.3329\n",
      "\n",
      "--------Epoch 99--------\n",
      "Val loss = -0.3407\n",
      "\n",
      "--------Epoch 100--------\n",
      "Val loss = -0.3455\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up model\n",
    "model = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_out))\n",
    "mask_layer = MaskLayer(append=True)\n",
    "pretrain = Pretrainer(model, mask_layer).to(device)\n",
    "\n",
    "# Pretrain\n",
    "pretrain.fit(train_dataset,\n",
    "             val_dataset,\n",
    "             mbsize=128,\n",
    "             lr=1e-3,\n",
    "             nepochs=100,\n",
    "             max_features=max_features,\n",
    "             loss_fn=nn.CrossEntropyLoss(),\n",
    "             val_loss_fn=NegAccuracy())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "clinical-recording",
   "metadata": {},
   "source": [
    "# Normal training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "standard-float",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.6577\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.6859\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.7033\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.7057\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.7252\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.7240\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.7100\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.7068\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.7192\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.7330\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.7232\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.7165\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.7161\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.7235\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.6824\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.7044\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.6921\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.6895\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.6941\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.6905\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = -0.6812\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = -0.6792\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = -0.6784\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = -0.6632\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = -0.6619\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = -0.6662\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = -0.6768\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = -0.6682\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = -0.6737\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = -0.6921\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = -0.6742\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = -0.6624\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = -0.6688\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = -0.6792\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = -0.6785\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = -0.6706\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = -0.6654\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = -0.6763\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = -0.6646\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = -0.6650\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = -0.6678\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = -0.6715\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = -0.6627\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = -0.6624\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = -0.6679\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = -0.6571\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = -0.6508\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = -0.6541\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = -0.6689\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = -0.6553\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = -0.6550\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = -0.6582\n",
      "\n",
      "--------Epoch 53--------\n",
      "Val loss = -0.6553\n",
      "\n",
      "--------Epoch 54--------\n",
      "Val loss = -0.6614\n",
      "\n",
      "--------Epoch 55--------\n",
      "Val loss = -0.6696\n",
      "\n",
      "--------Epoch 56--------\n",
      "Val loss = -0.6605\n",
      "\n",
      "--------Epoch 57--------\n",
      "Val loss = -0.6628\n",
      "\n",
      "--------Epoch 58--------\n",
      "Val loss = -0.6493\n",
      "\n",
      "--------Epoch 59--------\n",
      "Val loss = -0.6642\n",
      "\n",
      "--------Epoch 60--------\n",
      "Val loss = -0.6582\n",
      "\n",
      "--------Epoch 61--------\n",
      "Val loss = -0.6635\n",
      "\n",
      "--------Epoch 62--------\n",
      "Val loss = -0.6604\n",
      "\n",
      "--------Epoch 63--------\n",
      "Val loss = -0.6640\n",
      "\n",
      "--------Epoch 64--------\n",
      "Val loss = -0.6661\n",
      "\n",
      "--------Epoch 65--------\n",
      "Val loss = -0.6676\n",
      "\n",
      "--------Epoch 66--------\n",
      "Val loss = -0.6697\n",
      "\n",
      "--------Epoch 67--------\n",
      "Val loss = -0.6729\n",
      "\n",
      "--------Epoch 68--------\n",
      "Val loss = -0.6717\n",
      "\n",
      "--------Epoch 69--------\n",
      "Val loss = -0.6708\n",
      "\n",
      "--------Epoch 70--------\n",
      "Val loss = -0.6706\n",
      "\n",
      "--------Epoch 71--------\n",
      "Val loss = -0.6785\n",
      "\n",
      "--------Epoch 72--------\n",
      "Val loss = -0.6706\n",
      "\n",
      "--------Epoch 73--------\n",
      "Val loss = -0.6862\n",
      "\n",
      "--------Epoch 74--------\n",
      "Val loss = -0.6822\n",
      "\n",
      "--------Epoch 75--------\n",
      "Val loss = -0.6808\n",
      "\n",
      "--------Epoch 76--------\n",
      "Val loss = -0.6789\n",
      "\n",
      "--------Epoch 77--------\n",
      "Val loss = -0.6858\n",
      "\n",
      "--------Epoch 78--------\n",
      "Val loss = -0.6972\n",
      "\n",
      "--------Epoch 79--------\n",
      "Val loss = -0.7023\n",
      "\n",
      "--------Epoch 80--------\n",
      "Val loss = -0.7035\n",
      "\n",
      "--------Epoch 81--------\n",
      "Val loss = -0.6993\n",
      "\n",
      "--------Epoch 82--------\n",
      "Val loss = -0.7124\n",
      "\n",
      "--------Epoch 83--------\n",
      "Val loss = -0.7086\n",
      "\n",
      "--------Epoch 84--------\n",
      "Val loss = -0.7119\n",
      "\n",
      "--------Epoch 85--------\n",
      "Val loss = -0.7104\n",
      "\n",
      "--------Epoch 86--------\n",
      "Val loss = -0.7168\n",
      "\n",
      "--------Epoch 87--------\n",
      "Val loss = -0.7262\n",
      "\n",
      "--------Epoch 88--------\n",
      "Val loss = -0.7274\n",
      "\n",
      "--------Epoch 89--------\n",
      "Val loss = -0.7318\n",
      "\n",
      "--------Epoch 90--------\n",
      "Val loss = -0.7337\n",
      "\n",
      "--------Epoch 91--------\n",
      "Val loss = -0.7402\n",
      "\n",
      "--------Epoch 92--------\n",
      "Val loss = -0.7400\n",
      "\n",
      "--------Epoch 93--------\n",
      "Val loss = -0.7406\n",
      "\n",
      "--------Epoch 94--------\n",
      "Val loss = -0.7457\n",
      "\n",
      "--------Epoch 95--------\n",
      "Val loss = -0.7463\n",
      "\n",
      "--------Epoch 96--------\n",
      "Val loss = -0.7572\n",
      "\n",
      "--------Epoch 97--------\n",
      "Val loss = -0.7579\n",
      "\n",
      "--------Epoch 98--------\n",
      "Val loss = -0.7621\n",
      "\n",
      "--------Epoch 99--------\n",
      "Val loss = -0.7733\n",
      "\n",
      "--------Epoch 100--------\n",
      "Val loss = -0.7673\n",
      "\n",
      "--------Epoch 101--------\n",
      "Val loss = -0.7721\n",
      "\n",
      "--------Epoch 102--------\n",
      "Val loss = -0.7798\n",
      "\n",
      "--------Epoch 103--------\n",
      "Val loss = -0.7848\n",
      "\n",
      "--------Epoch 104--------\n",
      "Val loss = -0.7761\n",
      "\n",
      "--------Epoch 105--------\n",
      "Val loss = -0.7857\n",
      "\n",
      "--------Epoch 106--------\n",
      "Val loss = -0.7879\n",
      "\n",
      "--------Epoch 107--------\n",
      "Val loss = -0.7887\n",
      "\n",
      "--------Epoch 108--------\n",
      "Val loss = -0.7962\n",
      "\n",
      "--------Epoch 109--------\n",
      "Val loss = -0.7902\n",
      "\n",
      "--------Epoch 110--------\n",
      "Val loss = -0.7967\n",
      "\n",
      "--------Epoch 111--------\n",
      "Val loss = -0.7953\n",
      "\n",
      "--------Epoch 112--------\n",
      "Val loss = -0.8062\n",
      "\n",
      "--------Epoch 113--------\n",
      "Val loss = -0.8091\n",
      "\n",
      "--------Epoch 114--------\n",
      "Val loss = -0.8039\n",
      "\n",
      "--------Epoch 115--------\n",
      "Val loss = -0.8112\n",
      "\n",
      "--------Epoch 116--------\n",
      "Val loss = -0.8143\n",
      "\n",
      "--------Epoch 117--------\n",
      "Val loss = -0.8081\n",
      "\n",
      "--------Epoch 118--------\n",
      "Val loss = -0.8092\n",
      "\n",
      "--------Epoch 119--------\n",
      "Val loss = -0.8128\n",
      "\n",
      "--------Epoch 120--------\n",
      "Val loss = -0.8198\n",
      "\n",
      "--------Epoch 121--------\n",
      "Val loss = -0.8163\n",
      "\n",
      "--------Epoch 122--------\n",
      "Val loss = -0.8164\n",
      "\n",
      "--------Epoch 123--------\n",
      "Val loss = -0.8177\n",
      "\n",
      "--------Epoch 124--------\n",
      "Val loss = -0.8198\n",
      "\n",
      "--------Epoch 125--------\n",
      "Val loss = -0.8180\n",
      "\n",
      "--------Epoch 126--------\n",
      "Val loss = -0.8224\n",
      "\n",
      "--------Epoch 127--------\n",
      "Val loss = -0.8241\n",
      "\n",
      "--------Epoch 128--------\n",
      "Val loss = -0.8228\n",
      "\n",
      "--------Epoch 129--------\n",
      "Val loss = -0.8255\n",
      "\n",
      "--------Epoch 130--------\n",
      "Val loss = -0.8302\n",
      "\n",
      "--------Epoch 131--------\n",
      "Val loss = -0.8206\n",
      "\n",
      "--------Epoch 132--------\n",
      "Val loss = -0.8254\n",
      "\n",
      "--------Epoch 133--------\n",
      "Val loss = -0.8325\n",
      "\n",
      "--------Epoch 134--------\n",
      "Val loss = -0.8315\n",
      "\n",
      "--------Epoch 135--------\n",
      "Val loss = -0.8289\n",
      "\n",
      "--------Epoch 136--------\n",
      "Val loss = -0.8327\n",
      "\n",
      "--------Epoch 137--------\n",
      "Val loss = -0.8341\n",
      "\n",
      "--------Epoch 138--------\n",
      "Val loss = -0.8354\n",
      "\n",
      "--------Epoch 139--------\n",
      "Val loss = -0.8337\n",
      "\n",
      "--------Epoch 140--------\n",
      "Val loss = -0.8342\n",
      "\n",
      "--------Epoch 141--------\n",
      "Val loss = -0.8372\n",
      "\n",
      "--------Epoch 142--------\n",
      "Val loss = -0.8343\n",
      "\n",
      "--------Epoch 143--------\n",
      "Val loss = -0.8317\n",
      "\n",
      "--------Epoch 144--------\n",
      "Val loss = -0.8377\n",
      "\n",
      "--------Epoch 145--------\n",
      "Val loss = -0.8432\n",
      "\n",
      "--------Epoch 146--------\n",
      "Val loss = -0.8429\n",
      "\n",
      "--------Epoch 147--------\n",
      "Val loss = -0.8391\n",
      "\n",
      "--------Epoch 148--------\n",
      "Val loss = -0.8374\n",
      "\n",
      "--------Epoch 149--------\n",
      "Val loss = -0.8393\n",
      "\n",
      "--------Epoch 150--------\n",
      "Val loss = -0.8421\n",
      "\n",
      "--------Epoch 151--------\n",
      "Val loss = -0.8365\n",
      "\n",
      "--------Epoch 152--------\n",
      "Val loss = -0.8416\n",
      "\n",
      "--------Epoch 153--------\n",
      "Val loss = -0.8383\n",
      "\n",
      "--------Epoch 154--------\n",
      "Val loss = -0.8432\n",
      "\n",
      "--------Epoch 155--------\n",
      "Val loss = -0.8380\n",
      "\n",
      "--------Epoch 156--------\n",
      "Val loss = -0.8395\n",
      "\n",
      "--------Epoch 157--------\n",
      "Val loss = -0.8413\n",
      "\n",
      "--------Epoch 158--------\n",
      "Val loss = -0.8357\n",
      "\n",
      "--------Epoch 159--------\n",
      "Val loss = -0.8396\n",
      "\n",
      "--------Epoch 160--------\n",
      "Val loss = -0.8427\n",
      "\n",
      "--------Epoch 161--------\n",
      "Val loss = -0.8400\n",
      "\n",
      "--------Epoch 162--------\n",
      "Val loss = -0.8489\n",
      "\n",
      "--------Epoch 163--------\n",
      "Val loss = -0.8394\n",
      "\n",
      "--------Epoch 164--------\n",
      "Val loss = -0.8445\n",
      "\n",
      "--------Epoch 165--------\n",
      "Val loss = -0.8413\n",
      "\n",
      "--------Epoch 166--------\n",
      "Val loss = -0.8392\n",
      "\n",
      "--------Epoch 167--------\n",
      "Val loss = -0.8419\n",
      "\n",
      "--------Epoch 168--------\n",
      "Val loss = -0.8434\n",
      "\n",
      "--------Epoch 169--------\n",
      "Val loss = -0.8393\n",
      "\n",
      "--------Epoch 170--------\n",
      "Val loss = -0.8401\n",
      "\n",
      "--------Epoch 171--------\n",
      "Val loss = -0.8397\n",
      "\n",
      "--------Epoch 172--------\n",
      "Val loss = -0.8466\n",
      "\n",
      "--------Epoch 173--------\n",
      "Val loss = -0.8413\n",
      "\n",
      "--------Epoch 174--------\n",
      "Val loss = -0.8441\n",
      "\n",
      "--------Epoch 175--------\n",
      "Val loss = -0.8496\n",
      "\n",
      "--------Epoch 176--------\n",
      "Val loss = -0.8421\n",
      "\n",
      "--------Epoch 177--------\n",
      "Val loss = -0.8406\n",
      "\n",
      "--------Epoch 178--------\n",
      "Val loss = -0.8489\n",
      "\n",
      "--------Epoch 179--------\n",
      "Val loss = -0.8491\n",
      "\n",
      "--------Epoch 180--------\n",
      "Val loss = -0.8420\n",
      "\n",
      "--------Epoch 181--------\n",
      "Val loss = -0.8437\n",
      "\n",
      "--------Epoch 182--------\n",
      "Val loss = -0.8495\n",
      "\n",
      "--------Epoch 183--------\n",
      "Val loss = -0.8408\n",
      "\n",
      "--------Epoch 184--------\n",
      "Val loss = -0.8455\n",
      "\n",
      "--------Epoch 185--------\n",
      "Val loss = -0.8426\n",
      "\n",
      "--------Epoch 186--------\n",
      "Val loss = -0.8451\n",
      "\n",
      "--------Epoch 187--------\n",
      "Val loss = -0.8424\n",
      "\n",
      "--------Epoch 188--------\n",
      "Val loss = -0.8462\n",
      "\n",
      "--------Epoch 189--------\n",
      "Val loss = -0.8477\n",
      "\n",
      "--------Epoch 190--------\n",
      "Val loss = -0.8429\n",
      "\n",
      "--------Epoch 191--------\n",
      "Val loss = -0.8445\n",
      "\n",
      "--------Epoch 192--------\n",
      "Val loss = -0.8389\n",
      "\n",
      "--------Epoch 193--------\n",
      "Val loss = -0.8478\n",
      "\n",
      "--------Epoch 194--------\n",
      "Val loss = -0.8476\n",
      "\n",
      "--------Epoch 195--------\n",
      "Val loss = -0.8424\n",
      "\n",
      "--------Epoch 196--------\n",
      "Val loss = -0.8423\n",
      "\n",
      "--------Epoch 197--------\n",
      "Val loss = -0.8424\n",
      "\n",
      "--------Epoch 198--------\n",
      "Val loss = -0.8453\n",
      "\n",
      "--------Epoch 199--------\n",
      "Val loss = -0.8424\n",
      "\n",
      "--------Epoch 200--------\n",
      "Val loss = -0.8458\n",
      "\n",
      "--------Epoch 201--------\n",
      "Val loss = -0.8490\n",
      "\n",
      "--------Epoch 202--------\n",
      "Val loss = -0.8486\n",
      "\n",
      "--------Epoch 203--------\n",
      "Val loss = -0.8451\n",
      "\n",
      "--------Epoch 204--------\n",
      "Val loss = -0.8471\n",
      "\n",
      "--------Epoch 205--------\n",
      "Val loss = -0.8485\n",
      "\n",
      "--------Epoch 206--------\n",
      "Val loss = -0.8463\n",
      "\n",
      "--------Epoch 207--------\n",
      "Val loss = -0.8523\n",
      "\n",
      "--------Epoch 208--------\n",
      "Val loss = -0.8461\n",
      "\n",
      "--------Epoch 209--------\n",
      "Val loss = -0.8550\n",
      "\n",
      "--------Epoch 210--------\n",
      "Val loss = -0.8526\n",
      "\n",
      "--------Epoch 211--------\n",
      "Val loss = -0.8479\n",
      "\n",
      "--------Epoch 212--------\n",
      "Val loss = -0.8497\n",
      "\n",
      "--------Epoch 213--------\n",
      "Val loss = -0.8463\n",
      "\n",
      "--------Epoch 214--------\n",
      "Val loss = -0.8384\n",
      "\n",
      "--------Epoch 215--------\n",
      "Val loss = -0.8467\n",
      "\n",
      "--------Epoch 216--------\n",
      "Val loss = -0.8470\n",
      "\n",
      "--------Epoch 217--------\n",
      "Val loss = -0.8448\n",
      "\n",
      "--------Epoch 218--------\n",
      "Val loss = -0.8352\n",
      "\n",
      "--------Epoch 219--------\n",
      "Val loss = -0.8374\n",
      "\n",
      "--------Epoch 220--------\n",
      "Val loss = -0.8417\n",
      "\n",
      "--------Epoch 221--------\n",
      "Val loss = -0.8454\n",
      "\n",
      "--------Epoch 222--------\n",
      "Val loss = -0.8433\n",
      "\n",
      "--------Epoch 223--------\n",
      "Val loss = -0.8373\n",
      "\n",
      "--------Epoch 224--------\n",
      "Val loss = -0.8420\n",
      "\n",
      "--------Epoch 225--------\n",
      "Val loss = -0.8504\n",
      "\n",
      "--------Epoch 226--------\n",
      "Val loss = -0.8429\n",
      "\n",
      "--------Epoch 227--------\n",
      "Val loss = -0.8501\n",
      "\n",
      "--------Epoch 228--------\n",
      "Val loss = -0.8525\n",
      "\n",
      "--------Epoch 229--------\n",
      "Val loss = -0.8510\n",
      "\n",
      "--------Epoch 230--------\n",
      "Val loss = -0.8486\n",
      "\n",
      "--------Epoch 231--------\n",
      "Val loss = -0.8463\n",
      "\n",
      "--------Epoch 232--------\n",
      "Val loss = -0.8511\n",
      "\n",
      "--------Epoch 233--------\n",
      "Val loss = -0.8378\n",
      "\n",
      "--------Epoch 234--------\n",
      "Val loss = -0.8408\n",
      "\n",
      "--------Epoch 235--------\n",
      "Val loss = -0.8406\n",
      "\n",
      "--------Epoch 236--------\n",
      "Val loss = -0.8368\n",
      "\n",
      "--------Epoch 237--------\n",
      "Val loss = -0.8418\n",
      "\n",
      "--------Epoch 238--------\n",
      "Val loss = -0.8330\n",
      "\n",
      "--------Epoch 239--------\n",
      "Val loss = -0.8422\n",
      "\n",
      "--------Epoch 240--------\n",
      "Val loss = -0.8412\n",
      "\n",
      "--------Epoch 241--------\n",
      "Val loss = -0.8416\n",
      "\n",
      "--------Epoch 242--------\n",
      "Val loss = -0.8409\n",
      "\n",
      "--------Epoch 243--------\n",
      "Val loss = -0.8422\n",
      "\n",
      "--------Epoch 244--------\n",
      "Val loss = -0.8420\n",
      "\n",
      "--------Epoch 245--------\n",
      "Val loss = -0.8419\n",
      "\n",
      "--------Epoch 246--------\n",
      "Val loss = -0.8312\n",
      "\n",
      "--------Epoch 247--------\n",
      "Val loss = -0.8407\n",
      "\n",
      "--------Epoch 248--------\n",
      "Val loss = -0.8324\n",
      "\n",
      "--------Epoch 249--------\n",
      "Val loss = -0.8377\n",
      "\n",
      "--------Epoch 250--------\n",
      "Val loss = -0.8438\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up selector\n",
    "selector = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_in))\n",
    "selector_layer = ConcreteSelector()\n",
    "gafs = GreedyAdaptiveFS(selector, deepcopy(model), mask_layer, selector_layer).to(device)\n",
    "\n",
    "# Train\n",
    "gafs.fit(train_dataset,\n",
    "         val_dataset,\n",
    "         mbsize=128,\n",
    "         lr=2e-4,\n",
    "         nepochs=250,\n",
    "         max_features=max_features,\n",
    "         loss_fn=nn.CrossEntropyLoss(),\n",
    "         val_loss_fn=NegAccuracy(),\n",
    "         argmax=False,\n",
    "         no_repeats=False,\n",
    "         start_temp=10.0,\n",
    "         end_temp=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "accepted-cholesterol",
   "metadata": {},
   "outputs": [],
   "source": [
    "self = gafs\n",
    "def evaluate(dataset, max_features, loss_fn, batch_size, argmax=False,\n",
    "             no_repeats=False):\n",
    "    '''\n",
    "    Evaluate mean performance across a dataset.\n",
    "    '''\n",
    "    # Setup.\n",
    "    device = next(self.model.parameters()).device\n",
    "    loader = DataLoader(\n",
    "        dataset, batch_size=batch_size, shuffle=False, pin_memory=True,\n",
    "        drop_last=False, num_workers=4)\n",
    "\n",
    "    # For calculating mean loss.\n",
    "    mean_loss = 0\n",
    "    n = 0\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for x, y in loader:\n",
    "            # Move to GPU.\n",
    "            x = x.to(device)\n",
    "            y = y.to(device)\n",
    "\n",
    "            # Calculate loss.\n",
    "            pred = self.forward(x, max_features, argmax, no_repeats)\n",
    "            loss = loss_fn(pred, y).item()\n",
    "\n",
    "            # Update average.\n",
    "            mean_loss = (mean_loss * n + loss * len(x)) / (n + len(x))\n",
    "            n += len(x)\n",
    "\n",
    "    return mean_loss\n",
    "gafs.evaluate = evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "electoral-rhythm",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sampling acc = 85.53 +/- 0.33\n",
      "Argmax acc = 85.67\n"
     ]
    }
   ],
   "source": [
    "# Accuracy distribution\n",
    "acc_list = []\n",
    "for _ in range(50):\n",
    "    test_acc = gafs.evaluate(test_dataset, max_features, Accuracy(), 1024)\n",
    "    acc_list.append(test_acc)\n",
    "    \n",
    "# Argmax accuracy\n",
    "test_acc = gafs.evaluate(test_dataset, max_features, Accuracy(), 1024, argmax=True)\n",
    "\n",
    "print(f'Sampling acc = {100*np.mean(acc_list):.2f} +/- {100*1.96*np.std(acc_list):.2f}')\n",
    "print(f'Argmax acc = {100*test_acc:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "anonymous-animation",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sampling acc = 85.76 +/- 0.38\n",
      "Argmax acc = 86.01\n"
     ]
    }
   ],
   "source": [
    "# Accuracy with no_repeats\n",
    "acc_list = []\n",
    "for _ in range(50):\n",
    "    test_acc = gafs.evaluate(test_dataset, max_features, Accuracy(), 1024, no_repeats=True)\n",
    "    acc_list.append(test_acc)\n",
    "    \n",
    "# Argmax accuracy\n",
    "test_acc = gafs.evaluate(test_dataset, max_features, Accuracy(), 1024, no_repeats=True, argmax=True)\n",
    "\n",
    "print(f'Sampling acc = {100*np.mean(acc_list):.2f} +/- {100*1.96*np.std(acc_list):.2f}')\n",
    "print(f'Argmax acc = {100*test_acc:.2f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eligible-witness",
   "metadata": {},
   "source": [
    "# Training with `no_repeats`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "anonymous-truck",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.6955\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.7152\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.7254\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.7293\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.7528\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.7731\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.7607\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.7447\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.7557\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.7461\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.7265\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.7287\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.7081\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.7132\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.6935\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.7049\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.7074\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.6823\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.6834\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.6707\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = -0.6834\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = -0.6656\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = -0.6675\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = -0.6755\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = -0.6661\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = -0.6717\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = -0.6474\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = -0.6638\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = -0.6528\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = -0.6514\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = -0.6519\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = -0.6564\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = -0.6682\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = -0.6604\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = -0.6517\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = -0.6497\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = -0.6490\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = -0.6348\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = -0.6516\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = -0.6596\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = -0.6373\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = -0.6418\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = -0.6450\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = -0.6520\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = -0.6535\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = -0.6556\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = -0.6417\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = -0.6415\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = -0.6520\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = -0.6511\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = -0.6614\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = -0.6642\n",
      "\n",
      "--------Epoch 53--------\n",
      "Val loss = -0.6586\n",
      "\n",
      "--------Epoch 54--------\n",
      "Val loss = -0.6673\n",
      "\n",
      "--------Epoch 55--------\n",
      "Val loss = -0.6423\n",
      "\n",
      "--------Epoch 56--------\n",
      "Val loss = -0.6684\n",
      "\n",
      "--------Epoch 57--------\n",
      "Val loss = -0.6609\n",
      "\n",
      "--------Epoch 58--------\n",
      "Val loss = -0.6702\n",
      "\n",
      "--------Epoch 59--------\n",
      "Val loss = -0.6636\n",
      "\n",
      "--------Epoch 60--------\n",
      "Val loss = -0.6724\n",
      "\n",
      "--------Epoch 61--------\n",
      "Val loss = -0.6714\n",
      "\n",
      "--------Epoch 62--------\n",
      "Val loss = -0.6689\n",
      "\n",
      "--------Epoch 63--------\n",
      "Val loss = -0.6744\n",
      "\n",
      "--------Epoch 64--------\n",
      "Val loss = -0.6781\n",
      "\n",
      "--------Epoch 65--------\n",
      "Val loss = -0.6739\n",
      "\n",
      "--------Epoch 66--------\n",
      "Val loss = -0.6740\n",
      "\n",
      "--------Epoch 67--------\n",
      "Val loss = -0.6692\n",
      "\n",
      "--------Epoch 68--------\n",
      "Val loss = -0.6880\n",
      "\n",
      "--------Epoch 69--------\n",
      "Val loss = -0.6754\n",
      "\n",
      "--------Epoch 70--------\n",
      "Val loss = -0.6730\n",
      "\n",
      "--------Epoch 71--------\n",
      "Val loss = -0.6806\n",
      "\n",
      "--------Epoch 72--------\n",
      "Val loss = -0.6901\n",
      "\n",
      "--------Epoch 73--------\n",
      "Val loss = -0.6925\n",
      "\n",
      "--------Epoch 74--------\n",
      "Val loss = -0.6899\n",
      "\n",
      "--------Epoch 75--------\n",
      "Val loss = -0.6978\n",
      "\n",
      "--------Epoch 76--------\n",
      "Val loss = -0.6870\n",
      "\n",
      "--------Epoch 77--------\n",
      "Val loss = -0.6960\n",
      "\n",
      "--------Epoch 78--------\n",
      "Val loss = -0.7016\n",
      "\n",
      "--------Epoch 79--------\n",
      "Val loss = -0.7045\n",
      "\n",
      "--------Epoch 80--------\n",
      "Val loss = -0.7089\n",
      "\n",
      "--------Epoch 81--------\n",
      "Val loss = -0.7220\n",
      "\n",
      "--------Epoch 82--------\n",
      "Val loss = -0.7030\n",
      "\n",
      "--------Epoch 83--------\n",
      "Val loss = -0.7248\n",
      "\n",
      "--------Epoch 84--------\n",
      "Val loss = -0.7269\n",
      "\n",
      "--------Epoch 85--------\n",
      "Val loss = -0.7272\n",
      "\n",
      "--------Epoch 86--------\n",
      "Val loss = -0.7244\n",
      "\n",
      "--------Epoch 87--------\n",
      "Val loss = -0.7272\n",
      "\n",
      "--------Epoch 88--------\n",
      "Val loss = -0.7351\n",
      "\n",
      "--------Epoch 89--------\n",
      "Val loss = -0.7444\n",
      "\n",
      "--------Epoch 90--------\n",
      "Val loss = -0.7385\n",
      "\n",
      "--------Epoch 91--------\n",
      "Val loss = -0.7487\n",
      "\n",
      "--------Epoch 92--------\n",
      "Val loss = -0.7520\n",
      "\n",
      "--------Epoch 93--------\n",
      "Val loss = -0.7540\n",
      "\n",
      "--------Epoch 94--------\n",
      "Val loss = -0.7561\n",
      "\n",
      "--------Epoch 95--------\n",
      "Val loss = -0.7593\n",
      "\n",
      "--------Epoch 96--------\n",
      "Val loss = -0.7693\n",
      "\n",
      "--------Epoch 97--------\n",
      "Val loss = -0.7614\n",
      "\n",
      "--------Epoch 98--------\n",
      "Val loss = -0.7687\n",
      "\n",
      "--------Epoch 99--------\n",
      "Val loss = -0.7690\n",
      "\n",
      "--------Epoch 100--------\n",
      "Val loss = -0.7796\n",
      "\n",
      "--------Epoch 101--------\n",
      "Val loss = -0.7787\n",
      "\n",
      "--------Epoch 102--------\n",
      "Val loss = -0.7890\n",
      "\n",
      "--------Epoch 103--------\n",
      "Val loss = -0.7886\n",
      "\n",
      "--------Epoch 104--------\n",
      "Val loss = -0.7932\n",
      "\n",
      "--------Epoch 105--------\n",
      "Val loss = -0.7948\n",
      "\n",
      "--------Epoch 106--------\n",
      "Val loss = -0.7953\n",
      "\n",
      "--------Epoch 107--------\n",
      "Val loss = -0.7946\n",
      "\n",
      "--------Epoch 108--------\n",
      "Val loss = -0.7928\n",
      "\n",
      "--------Epoch 109--------\n",
      "Val loss = -0.8056\n",
      "\n",
      "--------Epoch 110--------\n",
      "Val loss = -0.8012\n",
      "\n",
      "--------Epoch 111--------\n",
      "Val loss = -0.8079\n",
      "\n",
      "--------Epoch 112--------\n",
      "Val loss = -0.8059\n",
      "\n",
      "--------Epoch 113--------\n",
      "Val loss = -0.8086\n",
      "\n",
      "--------Epoch 114--------\n",
      "Val loss = -0.8104\n",
      "\n",
      "--------Epoch 115--------\n",
      "Val loss = -0.8144\n",
      "\n",
      "--------Epoch 116--------\n",
      "Val loss = -0.8155\n",
      "\n",
      "--------Epoch 117--------\n",
      "Val loss = -0.8157\n",
      "\n",
      "--------Epoch 118--------\n",
      "Val loss = -0.8124\n",
      "\n",
      "--------Epoch 119--------\n",
      "Val loss = -0.8148\n",
      "\n",
      "--------Epoch 120--------\n",
      "Val loss = -0.8188\n",
      "\n",
      "--------Epoch 121--------\n",
      "Val loss = -0.8179\n",
      "\n",
      "--------Epoch 122--------\n",
      "Val loss = -0.8279\n",
      "\n",
      "--------Epoch 123--------\n",
      "Val loss = -0.8213\n",
      "\n",
      "--------Epoch 124--------\n",
      "Val loss = -0.8209\n",
      "\n",
      "--------Epoch 125--------\n",
      "Val loss = -0.8209\n",
      "\n",
      "--------Epoch 126--------\n",
      "Val loss = -0.8299\n",
      "\n",
      "--------Epoch 127--------\n",
      "Val loss = -0.8258\n",
      "\n",
      "--------Epoch 128--------\n",
      "Val loss = -0.8326\n",
      "\n",
      "--------Epoch 129--------\n",
      "Val loss = -0.8293\n",
      "\n",
      "--------Epoch 130--------\n",
      "Val loss = -0.8310\n",
      "\n",
      "--------Epoch 131--------\n",
      "Val loss = -0.8303\n",
      "\n",
      "--------Epoch 132--------\n",
      "Val loss = -0.8356\n",
      "\n",
      "--------Epoch 133--------\n",
      "Val loss = -0.8360\n",
      "\n",
      "--------Epoch 134--------\n",
      "Val loss = -0.8331\n",
      "\n",
      "--------Epoch 135--------\n",
      "Val loss = -0.8361\n",
      "\n",
      "--------Epoch 136--------\n",
      "Val loss = -0.8354\n",
      "\n",
      "--------Epoch 137--------\n",
      "Val loss = -0.8326\n",
      "\n",
      "--------Epoch 138--------\n",
      "Val loss = -0.8358\n",
      "\n",
      "--------Epoch 139--------\n",
      "Val loss = -0.8364\n",
      "\n",
      "--------Epoch 140--------\n",
      "Val loss = -0.8369\n",
      "\n",
      "--------Epoch 141--------\n",
      "Val loss = -0.8435\n",
      "\n",
      "--------Epoch 142--------\n",
      "Val loss = -0.8400\n",
      "\n",
      "--------Epoch 143--------\n",
      "Val loss = -0.8389\n",
      "\n",
      "--------Epoch 144--------\n",
      "Val loss = -0.8449\n",
      "\n",
      "--------Epoch 145--------\n",
      "Val loss = -0.8456\n",
      "\n",
      "--------Epoch 146--------\n",
      "Val loss = -0.8424\n",
      "\n",
      "--------Epoch 147--------\n",
      "Val loss = -0.8389\n",
      "\n",
      "--------Epoch 148--------\n",
      "Val loss = -0.8409\n",
      "\n",
      "--------Epoch 149--------\n",
      "Val loss = -0.8451\n",
      "\n",
      "--------Epoch 150--------\n",
      "Val loss = -0.8410\n",
      "\n",
      "--------Epoch 151--------\n",
      "Val loss = -0.8500\n",
      "\n",
      "--------Epoch 152--------\n",
      "Val loss = -0.8397\n",
      "\n",
      "--------Epoch 153--------\n",
      "Val loss = -0.8413\n",
      "\n",
      "--------Epoch 154--------\n",
      "Val loss = -0.8409\n",
      "\n",
      "--------Epoch 155--------\n",
      "Val loss = -0.8448\n",
      "\n",
      "--------Epoch 156--------\n",
      "Val loss = -0.8437\n",
      "\n",
      "--------Epoch 157--------\n",
      "Val loss = -0.8444\n",
      "\n",
      "--------Epoch 158--------\n",
      "Val loss = -0.8444\n",
      "\n",
      "--------Epoch 159--------\n",
      "Val loss = -0.8467\n",
      "\n",
      "--------Epoch 160--------\n",
      "Val loss = -0.8422\n",
      "\n",
      "--------Epoch 161--------\n",
      "Val loss = -0.8431\n",
      "\n",
      "--------Epoch 162--------\n",
      "Val loss = -0.8441\n",
      "\n",
      "--------Epoch 163--------\n",
      "Val loss = -0.8488\n",
      "\n",
      "--------Epoch 164--------\n",
      "Val loss = -0.8475\n",
      "\n",
      "--------Epoch 165--------\n",
      "Val loss = -0.8490\n",
      "\n",
      "--------Epoch 166--------\n",
      "Val loss = -0.8552\n",
      "\n",
      "--------Epoch 167--------\n",
      "Val loss = -0.8390\n",
      "\n",
      "--------Epoch 168--------\n",
      "Val loss = -0.8438\n",
      "\n",
      "--------Epoch 169--------\n",
      "Val loss = -0.8431\n",
      "\n",
      "--------Epoch 170--------\n",
      "Val loss = -0.8479\n",
      "\n",
      "--------Epoch 171--------\n",
      "Val loss = -0.8490\n",
      "\n",
      "--------Epoch 172--------\n",
      "Val loss = -0.8518\n",
      "\n",
      "--------Epoch 173--------\n",
      "Val loss = -0.8521\n",
      "\n",
      "--------Epoch 174--------\n",
      "Val loss = -0.8465\n",
      "\n",
      "--------Epoch 175--------\n",
      "Val loss = -0.8518\n",
      "\n",
      "--------Epoch 176--------\n",
      "Val loss = -0.8535\n",
      "\n",
      "--------Epoch 177--------\n",
      "Val loss = -0.8582\n",
      "\n",
      "--------Epoch 178--------\n",
      "Val loss = -0.8535\n",
      "\n",
      "--------Epoch 179--------\n",
      "Val loss = -0.8426\n",
      "\n",
      "--------Epoch 180--------\n",
      "Val loss = -0.8460\n",
      "\n",
      "--------Epoch 181--------\n",
      "Val loss = -0.8473\n",
      "\n",
      "--------Epoch 182--------\n",
      "Val loss = -0.8481\n",
      "\n",
      "--------Epoch 183--------\n",
      "Val loss = -0.8517\n",
      "\n",
      "--------Epoch 184--------\n",
      "Val loss = -0.8513\n",
      "\n",
      "--------Epoch 185--------\n",
      "Val loss = -0.8487\n",
      "\n",
      "--------Epoch 186--------\n",
      "Val loss = -0.8511\n",
      "\n",
      "--------Epoch 187--------\n",
      "Val loss = -0.8517\n",
      "\n",
      "--------Epoch 188--------\n",
      "Val loss = -0.8509\n",
      "\n",
      "--------Epoch 189--------\n",
      "Val loss = -0.8467\n",
      "\n",
      "--------Epoch 190--------\n",
      "Val loss = -0.8484\n",
      "\n",
      "--------Epoch 191--------\n",
      "Val loss = -0.8548\n",
      "\n",
      "--------Epoch 192--------\n",
      "Val loss = -0.8529\n",
      "\n",
      "--------Epoch 193--------\n",
      "Val loss = -0.8500\n",
      "\n",
      "--------Epoch 194--------\n",
      "Val loss = -0.8548\n",
      "\n",
      "--------Epoch 195--------\n",
      "Val loss = -0.8541\n",
      "\n",
      "--------Epoch 196--------\n",
      "Val loss = -0.8491\n",
      "\n",
      "--------Epoch 197--------\n",
      "Val loss = -0.8531\n",
      "\n",
      "--------Epoch 198--------\n",
      "Val loss = -0.8585\n",
      "\n",
      "--------Epoch 199--------\n",
      "Val loss = -0.8537\n",
      "\n",
      "--------Epoch 200--------\n",
      "Val loss = -0.8482\n",
      "\n",
      "--------Epoch 201--------\n",
      "Val loss = -0.8485\n",
      "\n",
      "--------Epoch 202--------\n",
      "Val loss = -0.8537\n",
      "\n",
      "--------Epoch 203--------\n",
      "Val loss = -0.8541\n",
      "\n",
      "--------Epoch 204--------\n",
      "Val loss = -0.8517\n",
      "\n",
      "--------Epoch 205--------\n",
      "Val loss = -0.8540\n",
      "\n",
      "--------Epoch 206--------\n",
      "Val loss = -0.8544\n",
      "\n",
      "--------Epoch 207--------\n",
      "Val loss = -0.8594\n",
      "\n",
      "--------Epoch 208--------\n",
      "Val loss = -0.8581\n",
      "\n",
      "--------Epoch 209--------\n",
      "Val loss = -0.8560\n",
      "\n",
      "--------Epoch 210--------\n",
      "Val loss = -0.8533\n",
      "\n",
      "--------Epoch 211--------\n",
      "Val loss = -0.8546\n",
      "\n",
      "--------Epoch 212--------\n",
      "Val loss = -0.8552\n",
      "\n",
      "--------Epoch 213--------\n",
      "Val loss = -0.8536\n",
      "\n",
      "--------Epoch 214--------\n",
      "Val loss = -0.8558\n",
      "\n",
      "--------Epoch 215--------\n",
      "Val loss = -0.8522\n",
      "\n",
      "--------Epoch 216--------\n",
      "Val loss = -0.8564\n",
      "\n",
      "--------Epoch 217--------\n",
      "Val loss = -0.8541\n",
      "\n",
      "--------Epoch 218--------\n",
      "Val loss = -0.8545\n",
      "\n",
      "--------Epoch 219--------\n",
      "Val loss = -0.8556\n",
      "\n",
      "--------Epoch 220--------\n",
      "Val loss = -0.8542\n",
      "\n",
      "--------Epoch 221--------\n",
      "Val loss = -0.8483\n",
      "\n",
      "--------Epoch 222--------\n",
      "Val loss = -0.8436\n",
      "\n",
      "--------Epoch 223--------\n",
      "Val loss = -0.8446\n",
      "\n",
      "--------Epoch 224--------\n",
      "Val loss = -0.8475\n",
      "\n",
      "--------Epoch 225--------\n",
      "Val loss = -0.8471\n",
      "\n",
      "--------Epoch 226--------\n",
      "Val loss = -0.8472\n",
      "\n",
      "--------Epoch 227--------\n",
      "Val loss = -0.8505\n",
      "\n",
      "--------Epoch 228--------\n",
      "Val loss = -0.8481\n",
      "\n",
      "--------Epoch 229--------\n",
      "Val loss = -0.8468\n",
      "\n",
      "--------Epoch 230--------\n",
      "Val loss = -0.8488\n",
      "\n",
      "--------Epoch 231--------\n",
      "Val loss = -0.8507\n",
      "\n",
      "--------Epoch 232--------\n",
      "Val loss = -0.8515\n",
      "\n",
      "--------Epoch 233--------\n",
      "Val loss = -0.8472\n",
      "\n",
      "--------Epoch 234--------\n",
      "Val loss = -0.8466\n",
      "\n",
      "--------Epoch 235--------\n",
      "Val loss = -0.8479\n",
      "\n",
      "--------Epoch 236--------\n",
      "Val loss = -0.8463\n",
      "\n",
      "--------Epoch 237--------\n",
      "Val loss = -0.8407\n",
      "\n",
      "--------Epoch 238--------\n",
      "Val loss = -0.8442\n",
      "\n",
      "--------Epoch 239--------\n",
      "Val loss = -0.8443\n",
      "\n",
      "--------Epoch 240--------\n",
      "Val loss = -0.8381\n",
      "\n",
      "--------Epoch 241--------\n",
      "Val loss = -0.8357\n",
      "\n",
      "--------Epoch 242--------\n",
      "Val loss = -0.8474\n",
      "\n",
      "--------Epoch 243--------\n",
      "Val loss = -0.8560\n",
      "\n",
      "--------Epoch 244--------\n",
      "Val loss = -0.8562\n",
      "\n",
      "--------Epoch 245--------\n",
      "Val loss = -0.8491\n",
      "\n",
      "--------Epoch 246--------\n",
      "Val loss = -0.8546\n",
      "\n",
      "--------Epoch 247--------\n",
      "Val loss = -0.8468\n",
      "\n",
      "--------Epoch 248--------\n",
      "Val loss = -0.8476\n",
      "\n",
      "--------Epoch 249--------\n",
      "Val loss = -0.8506\n",
      "\n",
      "--------Epoch 250--------\n",
      "Val loss = -0.8509\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up selector\n",
    "selector = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_in))\n",
    "selector_layer = ConcreteSelector()\n",
    "gafs = GreedyAdaptiveFS(selector, deepcopy(model), mask_layer, selector_layer).to(device)\n",
    "\n",
    "# Train\n",
    "gafs.fit(train_dataset,\n",
    "         val_dataset,\n",
    "         mbsize=128,\n",
    "         lr=2e-4,\n",
    "         nepochs=250,\n",
    "         max_features=max_features,\n",
    "         loss_fn=nn.CrossEntropyLoss(),\n",
    "         val_loss_fn=NegAccuracy(),\n",
    "         argmax=False,\n",
    "         no_repeats=True,\n",
    "         start_temp=10.0,\n",
    "         end_temp=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "appropriate-finland",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sampling acc = 77.38 +/- 0.41\n",
      "Argmax acc = 73.22\n"
     ]
    }
   ],
   "source": [
    "# Accuracy distribution\n",
    "acc_list = []\n",
    "for _ in range(50):\n",
    "    test_acc = gafs.evaluate(test_dataset, max_features, Accuracy(), 1024)\n",
    "    acc_list.append(test_acc)\n",
    "    \n",
    "# Argmax accuracy\n",
    "test_acc = gafs.evaluate(test_dataset, max_features, Accuracy(), 1024, argmax=True)\n",
    "\n",
    "print(f'Sampling acc = {100*np.mean(acc_list):.2f} +/- {100*1.96*np.std(acc_list):.2f}')\n",
    "print(f'Argmax acc = {100*test_acc:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "brown-horror",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sampling acc = 85.73 +/- 0.41\n",
      "Argmax acc = 85.82\n"
     ]
    }
   ],
   "source": [
    "# Accuracy with no_repeats\n",
    "acc_list = []\n",
    "for _ in range(50):\n",
    "    test_acc = gafs.evaluate(test_dataset, max_features, Accuracy(), 1024, no_repeats=True)\n",
    "    acc_list.append(test_acc)\n",
    "    \n",
    "# Argmax accuracy\n",
    "test_acc = gafs.evaluate(test_dataset, max_features, Accuracy(), 1024, no_repeats=True, argmax=True)\n",
    "\n",
    "print(f'Sampling acc = {100*np.mean(acc_list):.2f} +/- {100*1.96*np.std(acc_list):.2f}')\n",
    "print(f'Argmax acc = {100*test_acc:.2f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "accessible-grade",
   "metadata": {},
   "source": [
    "# Pretraining selector"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "three-crack",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up selector\n",
    "selector = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, 512),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(512, d_in))\n",
    "selector_layer = ConcreteSelector()\n",
    "gafs = GreedyAdaptiveFS(selector, deepcopy(model), mask_layer, selector_layer).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "confident-wyoming",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.5686\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.6928\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.7070\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.7098\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.7208\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.7226\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.7258\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.7266\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.7295\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.7265\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.7242\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.7290\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.7362\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.7282\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.7260\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.7357\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.7329\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.7301\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.7390\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.7310\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = -0.7319\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = -0.7355\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = -0.7394\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = -0.7287\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = -0.7294\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = -0.7289\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = -0.7365\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = -0.7382\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = -0.7313\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = -0.7414\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = -0.7324\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = -0.7316\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = -0.7350\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = -0.7292\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = -0.7311\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = -0.7350\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = -0.7347\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = -0.7403\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = -0.7441\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = -0.7358\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = -0.7365\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = -0.7382\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = -0.7347\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = -0.7380\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = -0.7446\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = -0.7435\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = -0.7481\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = -0.7361\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = -0.7436\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = -0.7436\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = -0.7455\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = -0.7394\n",
      "\n",
      "--------Epoch 53--------\n",
      "Val loss = -0.7471\n",
      "\n",
      "--------Epoch 54--------\n",
      "Val loss = -0.7504\n",
      "\n",
      "--------Epoch 55--------\n",
      "Val loss = -0.7512\n",
      "\n",
      "--------Epoch 56--------\n",
      "Val loss = -0.7487\n",
      "\n",
      "--------Epoch 57--------\n",
      "Val loss = -0.7436\n",
      "\n",
      "--------Epoch 58--------\n",
      "Val loss = -0.7476\n",
      "\n",
      "--------Epoch 59--------\n",
      "Val loss = -0.7416\n",
      "\n",
      "--------Epoch 60--------\n",
      "Val loss = -0.7438\n",
      "\n",
      "--------Epoch 61--------\n",
      "Val loss = -0.7467\n",
      "\n",
      "--------Epoch 62--------\n",
      "Val loss = -0.7392\n",
      "\n",
      "--------Epoch 63--------\n",
      "Val loss = -0.7485\n",
      "\n",
      "--------Epoch 64--------\n",
      "Val loss = -0.7529\n",
      "\n",
      "--------Epoch 65--------\n",
      "Val loss = -0.7501\n",
      "\n",
      "--------Epoch 66--------\n",
      "Val loss = -0.7403\n",
      "\n",
      "--------Epoch 67--------\n",
      "Val loss = -0.7541\n",
      "\n",
      "--------Epoch 68--------\n",
      "Val loss = -0.7493\n",
      "\n",
      "--------Epoch 69--------\n",
      "Val loss = -0.7524\n",
      "\n",
      "--------Epoch 70--------\n",
      "Val loss = -0.7444\n",
      "\n",
      "--------Epoch 71--------\n",
      "Val loss = -0.7454\n",
      "\n",
      "--------Epoch 72--------\n",
      "Val loss = -0.7432\n",
      "\n",
      "--------Epoch 73--------\n",
      "Val loss = -0.7459\n",
      "\n",
      "--------Epoch 74--------\n",
      "Val loss = -0.7367\n",
      "\n",
      "--------Epoch 75--------\n",
      "Val loss = -0.7482\n",
      "\n",
      "--------Epoch 76--------\n",
      "Val loss = -0.7465\n",
      "\n",
      "--------Epoch 77--------\n",
      "Val loss = -0.7467\n",
      "\n",
      "--------Epoch 78--------\n",
      "Val loss = -0.7503\n",
      "\n",
      "--------Epoch 79--------\n",
      "Val loss = -0.7496\n",
      "\n",
      "--------Epoch 80--------\n",
      "Val loss = -0.7397\n",
      "\n",
      "--------Epoch 81--------\n",
      "Val loss = -0.7400\n",
      "\n",
      "--------Epoch 82--------\n",
      "Val loss = -0.7414\n",
      "\n",
      "--------Epoch 83--------\n",
      "Val loss = -0.7345\n",
      "\n",
      "--------Epoch 84--------\n",
      "Val loss = -0.7379\n",
      "\n",
      "--------Epoch 85--------\n",
      "Val loss = -0.7364\n",
      "\n",
      "--------Epoch 86--------\n",
      "Val loss = -0.7414\n",
      "\n",
      "--------Epoch 87--------\n",
      "Val loss = -0.7333\n",
      "\n",
      "--------Epoch 88--------\n",
      "Val loss = -0.7387\n",
      "\n",
      "--------Epoch 89--------\n",
      "Val loss = -0.7425\n",
      "\n",
      "--------Epoch 90--------\n",
      "Val loss = -0.7374\n",
      "\n",
      "--------Epoch 91--------\n",
      "Val loss = -0.7393\n",
      "\n",
      "--------Epoch 92--------\n",
      "Val loss = -0.7334\n",
      "\n",
      "--------Epoch 93--------\n",
      "Val loss = -0.7356\n",
      "\n",
      "--------Epoch 94--------\n",
      "Val loss = -0.7356\n",
      "\n",
      "--------Epoch 95--------\n",
      "Val loss = -0.7247\n",
      "\n",
      "--------Epoch 96--------\n",
      "Val loss = -0.7358\n",
      "\n",
      "--------Epoch 97--------\n",
      "Val loss = -0.7131\n",
      "\n",
      "--------Epoch 98--------\n",
      "Val loss = -0.7239\n",
      "\n",
      "--------Epoch 99--------\n",
      "Val loss = -0.7298\n",
      "\n",
      "--------Epoch 100--------\n",
      "Val loss = -0.6946\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Pretrain\n",
    "gafs.fit(train_dataset,\n",
    "         val_dataset,\n",
    "         mbsize=128,\n",
    "         lr=2e-4,\n",
    "         nepochs=100,\n",
    "         max_features=max_features,\n",
    "         loss_fn=nn.CrossEntropyLoss(),\n",
    "         val_loss_fn=NegAccuracy(),\n",
    "         train_model=False,\n",
    "         argmax=True,\n",
    "         no_repeats=False,\n",
    "         start_temp=10.0,\n",
    "         end_temp=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "tropical-security",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.7548\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.7460\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.7412\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.7371\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.7336\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.7155\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.7221\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.7114\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.7156\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.7216\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.6793\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.7317\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.7375\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.7447\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.7153\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.7103\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.6918\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.6816\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.7169\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.7002\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = -0.7165\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = -0.6977\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = -0.6655\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = -0.6795\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = -0.6835\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = -0.6709\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = -0.6703\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = -0.6638\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = -0.6768\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = -0.6690\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = -0.6548\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = -0.6992\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = -0.6706\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = -0.6968\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = -0.6658\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = -0.6944\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = -0.6729\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = -0.6911\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = -0.6670\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = -0.6970\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = -0.6706\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = -0.6811\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = -0.6575\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = -0.5639\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = -0.6531\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = -0.6404\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = -0.6705\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = -0.6845\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = -0.6876\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = -0.6986\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = -0.6854\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = -0.7070\n",
      "\n",
      "--------Epoch 53--------\n",
      "Val loss = -0.6697\n",
      "\n",
      "--------Epoch 54--------\n",
      "Val loss = -0.6611\n",
      "\n",
      "--------Epoch 55--------\n",
      "Val loss = -0.6877\n",
      "\n",
      "--------Epoch 56--------\n",
      "Val loss = -0.6682\n",
      "\n",
      "--------Epoch 57--------\n",
      "Val loss = -0.6801\n",
      "\n",
      "--------Epoch 58--------\n",
      "Val loss = -0.6478\n",
      "\n",
      "--------Epoch 59--------\n",
      "Val loss = -0.6780\n",
      "\n",
      "--------Epoch 60--------\n",
      "Val loss = -0.6734\n",
      "\n",
      "--------Epoch 61--------\n",
      "Val loss = -0.6993\n",
      "\n",
      "--------Epoch 62--------\n",
      "Val loss = -0.6851\n",
      "\n",
      "--------Epoch 63--------\n",
      "Val loss = -0.7004\n",
      "\n",
      "--------Epoch 64--------\n",
      "Val loss = -0.7037\n",
      "\n",
      "--------Epoch 65--------\n",
      "Val loss = -0.6542\n",
      "\n",
      "--------Epoch 66--------\n",
      "Val loss = -0.7089\n",
      "\n",
      "--------Epoch 67--------\n",
      "Val loss = -0.7471\n",
      "\n",
      "--------Epoch 68--------\n",
      "Val loss = -0.7486\n",
      "\n",
      "--------Epoch 69--------\n",
      "Val loss = -0.7458\n",
      "\n",
      "--------Epoch 70--------\n",
      "Val loss = -0.7647\n",
      "\n",
      "--------Epoch 71--------\n",
      "Val loss = -0.6727\n",
      "\n",
      "--------Epoch 72--------\n",
      "Val loss = -0.7672\n",
      "\n",
      "--------Epoch 73--------\n",
      "Val loss = -0.7171\n",
      "\n",
      "--------Epoch 74--------\n",
      "Val loss = -0.7558\n",
      "\n",
      "--------Epoch 75--------\n",
      "Val loss = -0.7526\n",
      "\n",
      "--------Epoch 76--------\n",
      "Val loss = -0.7670\n",
      "\n",
      "--------Epoch 77--------\n",
      "Val loss = -0.7831\n",
      "\n",
      "--------Epoch 78--------\n",
      "Val loss = -0.7738\n",
      "\n",
      "--------Epoch 79--------\n",
      "Val loss = -0.7935\n",
      "\n",
      "--------Epoch 80--------\n",
      "Val loss = -0.7612\n",
      "\n",
      "--------Epoch 81--------\n",
      "Val loss = -0.7877\n",
      "\n",
      "--------Epoch 82--------\n",
      "Val loss = -0.7945\n",
      "\n",
      "--------Epoch 83--------\n",
      "Val loss = -0.8015\n",
      "\n",
      "--------Epoch 84--------\n",
      "Val loss = -0.7830\n",
      "\n",
      "--------Epoch 85--------\n",
      "Val loss = -0.7968\n",
      "\n",
      "--------Epoch 86--------\n",
      "Val loss = -0.8116\n",
      "\n",
      "--------Epoch 87--------\n",
      "Val loss = -0.8102\n",
      "\n",
      "--------Epoch 88--------\n",
      "Val loss = -0.8150\n",
      "\n",
      "--------Epoch 89--------\n",
      "Val loss = -0.8118\n",
      "\n",
      "--------Epoch 90--------\n",
      "Val loss = -0.8158\n",
      "\n",
      "--------Epoch 91--------\n",
      "Val loss = -0.8162\n",
      "\n",
      "--------Epoch 92--------\n",
      "Val loss = -0.8173\n",
      "\n",
      "--------Epoch 93--------\n",
      "Val loss = -0.8271\n",
      "\n",
      "--------Epoch 94--------\n",
      "Val loss = -0.8246\n",
      "\n",
      "--------Epoch 95--------\n",
      "Val loss = -0.8198\n",
      "\n",
      "--------Epoch 96--------\n",
      "Val loss = -0.8301\n",
      "\n",
      "--------Epoch 97--------\n",
      "Val loss = -0.8215\n",
      "\n",
      "--------Epoch 98--------\n",
      "Val loss = -0.8273\n",
      "\n",
      "--------Epoch 99--------\n",
      "Val loss = -0.8284\n",
      "\n",
      "--------Epoch 100--------\n",
      "Val loss = -0.8378\n",
      "\n",
      "--------Epoch 101--------\n",
      "Val loss = -0.8302\n",
      "\n",
      "--------Epoch 102--------\n",
      "Val loss = -0.8413\n",
      "\n",
      "--------Epoch 103--------\n",
      "Val loss = -0.8332\n",
      "\n",
      "--------Epoch 104--------\n",
      "Val loss = -0.8355\n",
      "\n",
      "--------Epoch 105--------\n",
      "Val loss = -0.8383\n",
      "\n",
      "--------Epoch 106--------\n",
      "Val loss = -0.8427\n",
      "\n",
      "--------Epoch 107--------\n",
      "Val loss = -0.8387\n",
      "\n",
      "--------Epoch 108--------\n",
      "Val loss = -0.8365\n",
      "\n",
      "--------Epoch 109--------\n",
      "Val loss = -0.8473\n",
      "\n",
      "--------Epoch 110--------\n",
      "Val loss = -0.8408\n",
      "\n",
      "--------Epoch 111--------\n",
      "Val loss = -0.8424\n",
      "\n",
      "--------Epoch 112--------\n",
      "Val loss = -0.8433\n",
      "\n",
      "--------Epoch 113--------\n",
      "Val loss = -0.8455\n",
      "\n",
      "--------Epoch 114--------\n",
      "Val loss = -0.8433\n",
      "\n",
      "--------Epoch 115--------\n",
      "Val loss = -0.8461\n",
      "\n",
      "--------Epoch 116--------\n",
      "Val loss = -0.8464\n",
      "\n",
      "--------Epoch 117--------\n",
      "Val loss = -0.8472\n",
      "\n",
      "--------Epoch 118--------\n",
      "Val loss = -0.8437\n",
      "\n",
      "--------Epoch 119--------\n",
      "Val loss = -0.8412\n",
      "\n",
      "--------Epoch 120--------\n",
      "Val loss = -0.8445\n",
      "\n",
      "--------Epoch 121--------\n",
      "Val loss = -0.8489\n",
      "\n",
      "--------Epoch 122--------\n",
      "Val loss = -0.8446\n",
      "\n",
      "--------Epoch 123--------\n",
      "Val loss = -0.8433\n",
      "\n",
      "--------Epoch 124--------\n",
      "Val loss = -0.8473\n",
      "\n",
      "--------Epoch 125--------\n",
      "Val loss = -0.8452\n",
      "\n",
      "--------Epoch 126--------\n",
      "Val loss = -0.8488\n",
      "\n",
      "--------Epoch 127--------\n",
      "Val loss = -0.8429\n",
      "\n",
      "--------Epoch 128--------\n",
      "Val loss = -0.8445\n",
      "\n",
      "--------Epoch 129--------\n",
      "Val loss = -0.8486\n",
      "\n",
      "--------Epoch 130--------\n",
      "Val loss = -0.8392\n",
      "\n",
      "--------Epoch 131--------\n",
      "Val loss = -0.8484\n",
      "\n",
      "--------Epoch 132--------\n",
      "Val loss = -0.8461\n",
      "\n",
      "--------Epoch 133--------\n",
      "Val loss = -0.8443\n",
      "\n",
      "--------Epoch 134--------\n",
      "Val loss = -0.8407\n",
      "\n",
      "--------Epoch 135--------\n",
      "Val loss = -0.8559\n",
      "\n",
      "--------Epoch 136--------\n",
      "Val loss = -0.8447\n",
      "\n",
      "--------Epoch 137--------\n",
      "Val loss = -0.8533\n",
      "\n",
      "--------Epoch 138--------\n",
      "Val loss = -0.8502\n",
      "\n",
      "--------Epoch 139--------\n",
      "Val loss = -0.8460\n",
      "\n",
      "--------Epoch 140--------\n",
      "Val loss = -0.8502\n",
      "\n",
      "--------Epoch 141--------\n",
      "Val loss = -0.8553\n",
      "\n",
      "--------Epoch 142--------\n",
      "Val loss = -0.8506\n",
      "\n",
      "--------Epoch 143--------\n",
      "Val loss = -0.8471\n",
      "\n",
      "--------Epoch 144--------\n",
      "Val loss = -0.8561\n",
      "\n",
      "--------Epoch 145--------\n",
      "Val loss = -0.8516\n",
      "\n",
      "--------Epoch 146--------\n",
      "Val loss = -0.8491\n",
      "\n",
      "--------Epoch 147--------\n",
      "Val loss = -0.8491\n",
      "\n",
      "--------Epoch 148--------\n",
      "Val loss = -0.8488\n",
      "\n",
      "--------Epoch 149--------\n",
      "Val loss = -0.8519\n",
      "\n",
      "--------Epoch 150--------\n",
      "Val loss = -0.8545\n",
      "\n",
      "--------Epoch 151--------\n",
      "Val loss = -0.8493\n",
      "\n",
      "--------Epoch 152--------\n",
      "Val loss = -0.8513\n",
      "\n",
      "--------Epoch 153--------\n",
      "Val loss = -0.8521\n",
      "\n",
      "--------Epoch 154--------\n",
      "Val loss = -0.8533\n",
      "\n",
      "--------Epoch 155--------\n",
      "Val loss = -0.8529\n",
      "\n",
      "--------Epoch 156--------\n",
      "Val loss = -0.8504\n",
      "\n",
      "--------Epoch 157--------\n",
      "Val loss = -0.8474\n",
      "\n",
      "--------Epoch 158--------\n",
      "Val loss = -0.8539\n",
      "\n",
      "--------Epoch 159--------\n",
      "Val loss = -0.8527\n",
      "\n",
      "--------Epoch 160--------\n",
      "Val loss = -0.8497\n",
      "\n",
      "--------Epoch 161--------\n",
      "Val loss = -0.8523\n",
      "\n",
      "--------Epoch 162--------\n",
      "Val loss = -0.8477\n",
      "\n",
      "--------Epoch 163--------\n",
      "Val loss = -0.8461\n",
      "\n",
      "--------Epoch 164--------\n",
      "Val loss = -0.8554\n",
      "\n",
      "--------Epoch 165--------\n",
      "Val loss = -0.8508\n",
      "\n",
      "--------Epoch 166--------\n",
      "Val loss = -0.8523\n",
      "\n",
      "--------Epoch 167--------\n",
      "Val loss = -0.8507\n",
      "\n",
      "--------Epoch 168--------\n",
      "Val loss = -0.8493\n",
      "\n",
      "--------Epoch 169--------\n",
      "Val loss = -0.8519\n",
      "\n",
      "--------Epoch 170--------\n",
      "Val loss = -0.8476\n",
      "\n",
      "--------Epoch 171--------\n",
      "Val loss = -0.8512\n",
      "\n",
      "--------Epoch 172--------\n",
      "Val loss = -0.8537\n",
      "\n",
      "--------Epoch 173--------\n",
      "Val loss = -0.8515\n",
      "\n",
      "--------Epoch 174--------\n",
      "Val loss = -0.8536\n",
      "\n",
      "--------Epoch 175--------\n",
      "Val loss = -0.8521\n",
      "\n",
      "--------Epoch 176--------\n",
      "Val loss = -0.8579\n",
      "\n",
      "--------Epoch 177--------\n",
      "Val loss = -0.8547\n",
      "\n",
      "--------Epoch 178--------\n",
      "Val loss = -0.8526\n",
      "\n",
      "--------Epoch 179--------\n",
      "Val loss = -0.8518\n",
      "\n",
      "--------Epoch 180--------\n",
      "Val loss = -0.8615\n",
      "\n",
      "--------Epoch 181--------\n",
      "Val loss = -0.8525\n",
      "\n",
      "--------Epoch 182--------\n",
      "Val loss = -0.8635\n",
      "\n",
      "--------Epoch 183--------\n",
      "Val loss = -0.8522\n",
      "\n",
      "--------Epoch 184--------\n",
      "Val loss = -0.8545\n",
      "\n",
      "--------Epoch 185--------\n",
      "Val loss = -0.8532\n",
      "\n",
      "--------Epoch 186--------\n",
      "Val loss = -0.8475\n",
      "\n",
      "--------Epoch 187--------\n",
      "Val loss = -0.8510\n",
      "\n",
      "--------Epoch 188--------\n",
      "Val loss = -0.8477\n",
      "\n",
      "--------Epoch 189--------\n",
      "Val loss = -0.8486\n",
      "\n",
      "--------Epoch 190--------\n",
      "Val loss = -0.8548\n",
      "\n",
      "--------Epoch 191--------\n",
      "Val loss = -0.8569\n",
      "\n",
      "--------Epoch 192--------\n",
      "Val loss = -0.8558\n",
      "\n",
      "--------Epoch 193--------\n",
      "Val loss = -0.8497\n",
      "\n",
      "--------Epoch 194--------\n",
      "Val loss = -0.8500\n",
      "\n",
      "--------Epoch 195--------\n",
      "Val loss = -0.8562\n",
      "\n",
      "--------Epoch 196--------\n",
      "Val loss = -0.8487\n",
      "\n",
      "--------Epoch 197--------\n",
      "Val loss = -0.8531\n",
      "\n",
      "--------Epoch 198--------\n",
      "Val loss = -0.8509\n",
      "\n",
      "--------Epoch 199--------\n",
      "Val loss = -0.8497\n",
      "\n",
      "--------Epoch 200--------\n",
      "Val loss = -0.8525\n",
      "\n",
      "--------Epoch 201--------\n",
      "Val loss = -0.8445\n",
      "\n",
      "--------Epoch 202--------\n",
      "Val loss = -0.8376\n",
      "\n",
      "--------Epoch 203--------\n",
      "Val loss = -0.8454\n",
      "\n",
      "--------Epoch 204--------\n",
      "Val loss = -0.8480\n",
      "\n",
      "--------Epoch 205--------\n",
      "Val loss = -0.8540\n",
      "\n",
      "--------Epoch 206--------\n",
      "Val loss = -0.8379\n",
      "\n",
      "--------Epoch 207--------\n",
      "Val loss = -0.8520\n",
      "\n",
      "--------Epoch 208--------\n",
      "Val loss = -0.8411\n",
      "\n",
      "--------Epoch 209--------\n",
      "Val loss = -0.8479\n",
      "\n",
      "--------Epoch 210--------\n",
      "Val loss = -0.8424\n",
      "\n",
      "--------Epoch 211--------\n",
      "Val loss = -0.8418\n",
      "\n",
      "--------Epoch 212--------\n",
      "Val loss = -0.8465\n",
      "\n",
      "--------Epoch 213--------\n",
      "Val loss = -0.8490\n",
      "\n",
      "--------Epoch 214--------\n",
      "Val loss = -0.8483\n",
      "\n",
      "--------Epoch 215--------\n",
      "Val loss = -0.8473\n",
      "\n",
      "--------Epoch 216--------\n",
      "Val loss = -0.8510\n",
      "\n",
      "--------Epoch 217--------\n",
      "Val loss = -0.8516\n",
      "\n",
      "--------Epoch 218--------\n",
      "Val loss = -0.8484\n",
      "\n",
      "--------Epoch 219--------\n",
      "Val loss = -0.8473\n",
      "\n",
      "--------Epoch 220--------\n",
      "Val loss = -0.8467\n",
      "\n",
      "--------Epoch 221--------\n",
      "Val loss = -0.8476\n",
      "\n",
      "--------Epoch 222--------\n",
      "Val loss = -0.8464\n",
      "\n",
      "--------Epoch 223--------\n",
      "Val loss = -0.8342\n",
      "\n",
      "--------Epoch 224--------\n",
      "Val loss = -0.8487\n",
      "\n",
      "--------Epoch 225--------\n",
      "Val loss = -0.8483\n",
      "\n",
      "--------Epoch 226--------\n",
      "Val loss = -0.8448\n",
      "\n",
      "--------Epoch 227--------\n",
      "Val loss = -0.8463\n",
      "\n",
      "--------Epoch 228--------\n",
      "Val loss = -0.8456\n",
      "\n",
      "--------Epoch 229--------\n",
      "Val loss = -0.8410\n",
      "\n",
      "--------Epoch 230--------\n",
      "Val loss = -0.8390\n",
      "\n",
      "--------Epoch 231--------\n",
      "Val loss = -0.8371\n",
      "\n",
      "--------Epoch 232--------\n",
      "Val loss = -0.8308\n",
      "\n",
      "--------Epoch 233--------\n",
      "Val loss = -0.8485\n",
      "\n",
      "--------Epoch 234--------\n",
      "Val loss = -0.8464\n",
      "\n",
      "--------Epoch 235--------\n",
      "Val loss = -0.8380\n",
      "\n",
      "--------Epoch 236--------\n",
      "Val loss = -0.8385\n",
      "\n",
      "--------Epoch 237--------\n",
      "Val loss = -0.8370\n",
      "\n",
      "--------Epoch 238--------\n",
      "Val loss = -0.8417\n",
      "\n",
      "--------Epoch 239--------\n",
      "Val loss = -0.8436\n",
      "\n",
      "--------Epoch 240--------\n",
      "Val loss = -0.8447\n",
      "\n",
      "--------Epoch 241--------\n",
      "Val loss = -0.8412\n",
      "\n",
      "--------Epoch 242--------\n",
      "Val loss = -0.8410\n",
      "\n",
      "--------Epoch 243--------\n",
      "Val loss = -0.8397\n",
      "\n",
      "--------Epoch 244--------\n",
      "Val loss = -0.8416\n",
      "\n",
      "--------Epoch 245--------\n",
      "Val loss = -0.8418\n",
      "\n",
      "--------Epoch 246--------\n",
      "Val loss = -0.8373\n",
      "\n",
      "--------Epoch 247--------\n",
      "Val loss = -0.8419\n",
      "\n",
      "--------Epoch 248--------\n",
      "Val loss = -0.8444\n",
      "\n",
      "--------Epoch 249--------\n",
      "Val loss = -0.8270\n",
      "\n",
      "--------Epoch 250--------\n",
      "Val loss = -0.8268\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Train jointly\n",
    "gafs.fit(train_dataset,\n",
    "         val_dataset,\n",
    "         mbsize=128,\n",
    "         lr=2e-4,\n",
    "         nepochs=250,\n",
    "         max_features=max_features,\n",
    "         loss_fn=nn.CrossEntropyLoss(),\n",
    "         val_loss_fn=NegAccuracy(),\n",
    "         train_model=True,\n",
    "         argmax=True,\n",
    "         no_repeats=False,\n",
    "         start_temp=10.0,\n",
    "         end_temp=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "nuclear-andrews",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sampling acc = 85.91 +/- 0.43\n",
      "Argmax acc = 86.69\n"
     ]
    }
   ],
   "source": [
    "# Accuracy with no_repeats\n",
    "acc_list = []\n",
    "for _ in range(50):\n",
    "    test_acc = gafs.evaluate(test_dataset, max_features, Accuracy(), 1024, no_repeats=True)\n",
    "    acc_list.append(test_acc)\n",
    "    \n",
    "# Argmax accuracy\n",
    "test_acc = gafs.evaluate(test_dataset, max_features, Accuracy(), 1024, no_repeats=True, argmax=True)\n",
    "\n",
    "print(f'Sampling acc = {100*np.mean(acc_list):.2f} +/- {100*1.96*np.std(acc_list):.2f}')\n",
    "print(f'Argmax acc = {100*test_acc:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "practical-hopkins",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
