{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3f270402-de95-4602-b9d0-eef0ac035f95",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n",
    "from torchvision import transforms\n",
    "from torchvision.datasets import CIFAR10\n",
    "from copy import deepcopy\n",
    "from torchmetrics import Accuracy\n",
    "from adaptive import AdaptiveSelection, MaskLayer2d, MaskingPretrainer, BaseModel\n",
    "from resnet import ResNet18Backbone, ResNet18ClassifierHead, ResNet18SelectorHead"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f15fc6f9-556e-4564-bb9c-0a457b794048",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda', 7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "dff0fcc2-926d-4da3-83b7-1bd23ceec021",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "# Transformations\n",
    "transform_train = transforms.Compose([\n",
    "    transforms.RandomCrop(32, padding=4),\n",
    "    transforms.RandomHorizontalFlip(),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
    "])\n",
    "transform_test = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
    "])\n",
    "\n",
    "# Determine train/val split\n",
    "np.random.seed(0)\n",
    "val_inds = np.sort(np.random.choice(50000, size=10000, replace=False))\n",
    "train_inds = np.setdiff1d(np.arange(50000), val_inds)\n",
    "\n",
    "# Training dataset\n",
    "dataset = CIFAR10('/tmp/cifar/', download=True, train=True, transform=transform_train)\n",
    "train_dataset = torch.utils.data.Subset(dataset, train_inds)\n",
    "\n",
    "# Validation dataset\n",
    "dataset = CIFAR10('/tmp/cifar/', download=True, train=True, transform=transform_test)\n",
    "val_dataset = torch.utils.data.Subset(dataset, val_inds)\n",
    "\n",
    "# Test dataset\n",
    "test_dataset = CIFAR10('/tmp/cifar/', download=True, train=False, transform=transform_test)\n",
    "\n",
    "# Set input/output dimensions\n",
    "d_in = 32 * 32\n",
    "d_out = 10"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd751395-ef27-4721-ba5a-eb0e161f532c",
   "metadata": {},
   "source": [
    "# Set up architecture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "12de6b30-5086-463d-a2bc-c48304ee5f6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Shared backbone\n",
    "backbone = ResNet18Backbone()\n",
    "\n",
    "# Classifier head\n",
    "classifier_head = ResNet18ClassifierHead()\n",
    "\n",
    "# Selector head\n",
    "selector_head = ResNet18SelectorHead()\n",
    "\n",
    "# Create classifier and selector networks\n",
    "predictor = nn.Sequential(backbone, classifier_head)\n",
    "selector = nn.Sequential(backbone, selector_head)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76c782c2-a15a-4b4b-b973-4bc191f6b76b",
   "metadata": {},
   "source": [
    "# Train predictor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "b6b3c118-8c79-4257-bc20-e4c9192bc89f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Set up predictor pretraining\n",
    "# basemodel = BaseModel(predictor).to(device)\n",
    "\n",
    "# # Pretrain\n",
    "# basemodel.fit(train_dataset,\n",
    "#               val_dataset,\n",
    "#               mbsize=128,\n",
    "#               lr=1e-3,\n",
    "#               nepochs=100,\n",
    "#               loss_fn=nn.CrossEntropyLoss(),\n",
    "#               val_loss_fn=Accuracy(),\n",
    "#               val_loss_mode='max')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f5942351-ef43-4129-b316-5e5fea373e2d",
   "metadata": {},
   "source": [
    "# Pretrain predictor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "cc30ce24-f01d-4a2a-8f55-658ad1c42e96",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = 0.4148\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = 0.4937\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = 0.5105\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = 0.5574\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = 0.5739\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = 0.5850\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = 0.6128\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = 0.6225\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = 0.6337\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = 0.6402\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = 0.6575\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = 0.6699\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = 0.6784\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = 0.6641\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = 0.6840\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = 0.6953\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = 0.6983\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = 0.7004\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = 0.6904\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = 0.7024\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = 0.7059\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = 0.7056\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = 0.7190\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = 0.7165\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = 0.7319\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = 0.7279\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = 0.7176\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = 0.7292\n",
      "\n",
      "Epoch    28: reducing learning rate of group 0 to 2.0000e-04.\n",
      "--------Epoch 29--------\n",
      "Val loss = 0.7560\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = 0.7504\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = 0.7554\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = 0.7597\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = 0.7574\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = 0.7518\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = 0.7568\n",
      "\n",
      "Epoch    35: reducing learning rate of group 0 to 4.0000e-05.\n",
      "--------Epoch 36--------\n",
      "Val loss = 0.7654\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = 0.7634\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = 0.7583\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = 0.7614\n",
      "\n",
      "Epoch    39: reducing learning rate of group 0 to 8.0000e-06.\n",
      "--------Epoch 40--------\n",
      "Val loss = 0.7658\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = 0.7645\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = 0.7672\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = 0.7592\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = 0.7653\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = 0.7684\n",
      "\n",
      "--------Epoch 46--------\n",
      "Val loss = 0.7658\n",
      "\n",
      "--------Epoch 47--------\n",
      "Val loss = 0.7706\n",
      "\n",
      "--------Epoch 48--------\n",
      "Val loss = 0.7639\n",
      "\n",
      "--------Epoch 49--------\n",
      "Val loss = 0.7710\n",
      "\n",
      "--------Epoch 50--------\n",
      "Val loss = 0.7600\n",
      "\n",
      "--------Epoch 51--------\n",
      "Val loss = 0.7610\n",
      "\n",
      "--------Epoch 52--------\n",
      "Val loss = 0.7638\n",
      "\n",
      "Epoch    52: reducing learning rate of group 0 to 1.6000e-06.\n",
      "--------Epoch 53--------\n",
      "Val loss = 0.7657\n",
      "\n",
      "Stopping early at epoch 53\n"
     ]
    }
   ],
   "source": [
    "# Set up predictor pretraining\n",
    "mask_layer = MaskLayer2d(append=False, mask_width=8, patch_size=4)\n",
    "pretrain = MaskingPretrainer(predictor, mask_layer).to(device)\n",
    "\n",
    "# Pretrain\n",
    "pretrain.fit(train_dataset,\n",
    "             val_dataset,\n",
    "             mbsize=128,\n",
    "             lr=1e-3,\n",
    "             nepochs=100,\n",
    "             loss_fn=nn.CrossEntropyLoss(),\n",
    "             val_loss_fn=Accuracy(),\n",
    "             val_loss_mode='max')\n",
    "\n",
    "# Save model\n",
    "predictor.cpu()\n",
    "torch.save(predictor, 'results/cifar_predictor_pretrained_small.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "696776b3-a873-4c61-b07a-725bbebe2521",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load pretrained modules\n",
    "predictor = torch.load('results/cifar_predictor_pretrained_small.pt').to(device)\n",
    "backbone = predictor[0]\n",
    "classifier_head = predictor[1]\n",
    "selector = nn.Sequential(backbone, selector_head)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "551628a0-6d86-4878-9c6f-06b6842f6b1f",
   "metadata": {},
   "source": [
    "# Adaptive selection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "142b3ab5-b4f7-4ebb-b848-bb2b3c8ca8b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup\n",
    "max_features = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "311f4d42-1507-4c02-89c5-a131095c0beb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting training with temp = 1.0000\n",
      "\n",
      "--------Epoch 1 (1 total)--------\n",
      "Val loss = 0.2672, Zero-temp loss = 0.2646\n",
      "\n",
      "--------Epoch 2 (2 total)--------\n",
      "Val loss = 0.2317, Zero-temp loss = 0.2311\n",
      "\n",
      "--------Epoch 3 (3 total)--------\n",
      "Val loss = 0.2629, Zero-temp loss = 0.2629\n",
      "\n",
      "--------Epoch 4 (4 total)--------\n",
      "Val loss = 0.3056, Zero-temp loss = 0.3047\n",
      "\n",
      "--------Epoch 5 (5 total)--------\n",
      "Val loss = 0.3443, Zero-temp loss = 0.3418\n",
      "\n",
      "--------Epoch 6 (6 total)--------\n",
      "Val loss = 0.4066, Zero-temp loss = 0.4046\n",
      "\n",
      "--------Epoch 7 (7 total)--------\n",
      "Val loss = 0.4401, Zero-temp loss = 0.4372\n",
      "\n",
      "--------Epoch 8 (8 total)--------\n",
      "Val loss = 0.4571, Zero-temp loss = 0.4568\n",
      "\n",
      "--------Epoch 9 (9 total)--------\n",
      "Val loss = 0.4638, Zero-temp loss = 0.4624\n",
      "\n",
      "--------Epoch 10 (10 total)--------\n",
      "Val loss = 0.4425, Zero-temp loss = 0.4415\n",
      "\n",
      "--------Epoch 11 (11 total)--------\n",
      "Val loss = 0.4632, Zero-temp loss = 0.4622\n",
      "\n",
      "--------Epoch 12 (12 total)--------\n",
      "Val loss = 0.4976, Zero-temp loss = 0.4961\n",
      "\n",
      "--------Epoch 13 (13 total)--------\n",
      "Val loss = 0.5053, Zero-temp loss = 0.5042\n",
      "\n",
      "--------Epoch 14 (14 total)--------\n",
      "Val loss = 0.4986, Zero-temp loss = 0.4969\n",
      "\n",
      "--------Epoch 15 (15 total)--------\n",
      "Val loss = 0.5034, Zero-temp loss = 0.5026\n",
      "\n",
      "--------Epoch 16 (16 total)--------\n",
      "Val loss = 0.5295, Zero-temp loss = 0.5289\n",
      "\n",
      "--------Epoch 17 (17 total)--------\n",
      "Val loss = 0.5244, Zero-temp loss = 0.5233\n",
      "\n",
      "--------Epoch 18 (18 total)--------\n",
      "Val loss = 0.5246, Zero-temp loss = 0.5235\n",
      "\n",
      "--------Epoch 19 (19 total)--------\n",
      "Val loss = 0.5155, Zero-temp loss = 0.5144\n",
      "\n",
      "Epoch    19: reducing learning rate of group 0 to 2.0000e-04.\n",
      "--------Epoch 20 (20 total)--------\n",
      "Val loss = 0.5292, Zero-temp loss = 0.5280\n",
      "\n",
      "Stopping temp = 1.0000 at epoch 20\n",
      "\n",
      "Starting training with temp = 0.5623\n",
      "\n",
      "--------Epoch 1 (21 total)--------\n",
      "Val loss = 0.4927, Zero-temp loss = 0.4927\n",
      "\n",
      "--------Epoch 2 (22 total)--------\n",
      "Val loss = 0.4866, Zero-temp loss = 0.4866\n",
      "\n",
      "--------Epoch 3 (23 total)--------\n",
      "Val loss = 0.4446, Zero-temp loss = 0.4444\n",
      "\n",
      "--------Epoch 4 (24 total)--------\n",
      "Val loss = 0.5054, Zero-temp loss = 0.5050\n",
      "\n",
      "--------Epoch 5 (25 total)--------\n",
      "Val loss = 0.5159, Zero-temp loss = 0.5145\n",
      "\n",
      "--------Epoch 6 (26 total)--------\n",
      "Val loss = 0.5171, Zero-temp loss = 0.5165\n",
      "\n",
      "--------Epoch 7 (27 total)--------\n",
      "Val loss = 0.5088, Zero-temp loss = 0.5087\n",
      "\n",
      "--------Epoch 8 (28 total)--------\n",
      "Val loss = 0.5054, Zero-temp loss = 0.5049\n",
      "\n",
      "--------Epoch 9 (29 total)--------\n",
      "Val loss = 0.4921, Zero-temp loss = 0.4915\n",
      "\n",
      "Epoch     9: reducing learning rate of group 0 to 2.0000e-04.\n",
      "--------Epoch 10 (30 total)--------\n",
      "Val loss = 0.5090, Zero-temp loss = 0.5084\n",
      "\n",
      "Stopping temp = 0.5623 at epoch 10\n",
      "\n",
      "Starting training with temp = 0.3162\n",
      "\n",
      "--------Epoch 1 (31 total)--------\n",
      "Val loss = 0.4906, Zero-temp loss = 0.4887\n",
      "\n",
      "--------Epoch 2 (32 total)--------\n",
      "Val loss = 0.4962, Zero-temp loss = 0.4961\n",
      "\n",
      "--------Epoch 3 (33 total)--------\n",
      "Val loss = 0.5040, Zero-temp loss = 0.5028\n",
      "\n",
      "--------Epoch 4 (34 total)--------\n",
      "Val loss = 0.5103, Zero-temp loss = 0.5098\n",
      "\n",
      "--------Epoch 5 (35 total)--------\n",
      "Val loss = 0.5078, Zero-temp loss = 0.5076\n",
      "\n",
      "--------Epoch 6 (36 total)--------\n",
      "Val loss = 0.4959, Zero-temp loss = 0.4956\n",
      "\n",
      "--------Epoch 7 (37 total)--------\n",
      "Val loss = 0.4925, Zero-temp loss = 0.4921\n",
      "\n",
      "Epoch     7: reducing learning rate of group 0 to 2.0000e-04.\n",
      "--------Epoch 8 (38 total)--------\n",
      "Val loss = 0.5005, Zero-temp loss = 0.5005\n",
      "\n",
      "Stopping temp = 0.3162 at epoch 8\n",
      "\n",
      "Starting training with temp = 0.1778\n",
      "\n",
      "--------Epoch 1 (39 total)--------\n",
      "Val loss = 0.4760, Zero-temp loss = 0.4759\n",
      "\n",
      "--------Epoch 2 (40 total)--------\n",
      "Val loss = 0.4948, Zero-temp loss = 0.4945\n",
      "\n",
      "--------Epoch 3 (41 total)--------\n",
      "Val loss = 0.4963, Zero-temp loss = 0.4962\n",
      "\n",
      "--------Epoch 4 (42 total)--------\n",
      "Val loss = 0.4933, Zero-temp loss = 0.4934\n",
      "\n",
      "--------Epoch 5 (43 total)--------\n",
      "Val loss = 0.4942, Zero-temp loss = 0.4936\n",
      "\n",
      "--------Epoch 6 (44 total)--------\n",
      "Val loss = 0.4961, Zero-temp loss = 0.4961\n",
      "\n",
      "Epoch     6: reducing learning rate of group 0 to 2.0000e-04.\n",
      "--------Epoch 7 (45 total)--------\n",
      "Val loss = 0.4981, Zero-temp loss = 0.4982\n",
      "\n",
      "--------Epoch 8 (46 total)--------\n",
      "Val loss = 0.5080, Zero-temp loss = 0.5078\n",
      "\n",
      "--------Epoch 9 (47 total)--------\n",
      "Val loss = 0.5072, Zero-temp loss = 0.5071\n",
      "\n",
      "--------Epoch 10 (48 total)--------\n",
      "Val loss = 0.5052, Zero-temp loss = 0.5052\n",
      "\n",
      "--------Epoch 11 (49 total)--------\n",
      "Val loss = 0.5048, Zero-temp loss = 0.5049\n",
      "\n",
      "Epoch    11: reducing learning rate of group 0 to 4.0000e-05.\n",
      "--------Epoch 12 (50 total)--------\n",
      "Val loss = 0.5097, Zero-temp loss = 0.5096\n",
      "\n",
      "--------Epoch 13 (51 total)--------\n",
      "Val loss = 0.5122, Zero-temp loss = 0.5120\n",
      "\n",
      "--------Epoch 14 (52 total)--------\n",
      "Val loss = 0.5159, Zero-temp loss = 0.5158\n",
      "\n",
      "--------Epoch 15 (53 total)--------\n",
      "Val loss = 0.5156, Zero-temp loss = 0.5156\n",
      "\n",
      "--------Epoch 16 (54 total)--------\n",
      "Val loss = 0.5087, Zero-temp loss = 0.5086\n",
      "\n",
      "--------Epoch 17 (55 total)--------\n",
      "Val loss = 0.5121, Zero-temp loss = 0.5120\n",
      "\n",
      "Epoch    17: reducing learning rate of group 0 to 1.0000e-05.\n",
      "--------Epoch 18 (56 total)--------\n",
      "Val loss = 0.5105, Zero-temp loss = 0.5103\n",
      "\n",
      "Stopping temp = 0.1778 at epoch 18\n",
      "\n",
      "Starting training with temp = 0.1000\n",
      "\n",
      "--------Epoch 1 (57 total)--------\n",
      "Val loss = 0.4896, Zero-temp loss = 0.4894\n",
      "\n",
      "--------Epoch 2 (58 total)--------\n",
      "Val loss = 0.4478, Zero-temp loss = 0.4478\n",
      "\n",
      "--------Epoch 3 (59 total)--------\n",
      "Val loss = 0.4850, Zero-temp loss = 0.4849\n",
      "\n",
      "--------Epoch 4 (60 total)--------\n",
      "Val loss = 0.4735, Zero-temp loss = 0.4735\n",
      "\n",
      "Epoch     4: reducing learning rate of group 0 to 2.0000e-04.\n",
      "--------Epoch 5 (61 total)--------\n",
      "Val loss = 0.4815, Zero-temp loss = 0.4814\n",
      "\n",
      "Stopping temp = 0.1000 at epoch 5\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up adaptive selection\n",
    "mask_layer = MaskLayer2d(append=False, mask_width=8, patch_size=4)\n",
    "gafs = AdaptiveSelection(selector, predictor, mask_layer).to(device)\n",
    "\n",
    "# Train\n",
    "gafs.fit(torch.utils.data.Subset(train_dataset, np.arange(128 * 10)),  # train_dataset,\n",
    "         val_dataset,\n",
    "         mbsize=128,\n",
    "         lr=1e-3,\n",
    "         nepochs=250,\n",
    "         max_features=max_features,\n",
    "         loss_fn=nn.CrossEntropyLoss(),\n",
    "         val_loss_fn=Accuracy(),\n",
    "         val_loss_mode='max')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "c48b58fa-56c2-4828-995e-59646ecc0794",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Num = 1, Acc = 22.0700\n",
      "Num = 2, Acc = 32.1900\n",
      "Num = 3, Acc = 42.3300\n",
      "Num = 4, Acc = 46.5400\n",
      "Num = 5, Acc = 49.9900\n",
      "Num = 10, Acc = 61.1300\n"
     ]
    }
   ],
   "source": [
    "# Calculate validation acc\n",
    "for num in (1, 2, 3, 4, 5, 10):\n",
    "    acc = gafs.evaluate(val_dataset, max_features=num, metric=Accuracy(), batch_size=128)\n",
    "    print(f'Num = {num}, Acc = {100*acc:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "579b2615-6d29-46dd-ac76-af8becf84635",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save trained model\n",
    "gafs.cpu()\n",
    "torch.save(gafs, 'results/adaptive_trained.pt')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eed95c0b-9dff-4986-bb73-7350767aceed",
   "metadata": {},
   "source": [
    "# More training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "b2640f7c-9b94-4965-a962-6185074a109c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load model\n",
    "gafs = torch.load('results/adaptive_trained.pt')\n",
    "gafs.selector[0] = gafs.predictor[0]\n",
    "gafs = gafs.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "8752d258-c838-44ec-951c-d50f924254b9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting training with temp = 1.0000\n",
      "\n",
      "--------Epoch 1 (1 total)--------\n",
      "Val loss = 0.5488, Zero-temp loss = 0.5479\n",
      "\n",
      "--------Epoch 2 (2 total)--------\n",
      "Val loss = 0.6359, Zero-temp loss = 0.5331\n",
      "\n",
      "--------Epoch 3 (3 total)--------\n",
      "Val loss = 0.6619, Zero-temp loss = 0.4969\n",
      "\n",
      "--------Epoch 4 (4 total)--------\n",
      "Val loss = 0.6681, Zero-temp loss = 0.4966\n",
      "\n",
      "--------Epoch 5 (5 total)--------\n",
      "Val loss = 0.6756, Zero-temp loss = 0.4777\n",
      "\n",
      "--------Epoch 6 (6 total)--------\n",
      "Val loss = 0.7161, Zero-temp loss = 0.4901\n",
      "\n",
      "--------Epoch 7 (7 total)--------\n",
      "Val loss = 0.7397, Zero-temp loss = 0.4706\n",
      "\n",
      "--------Epoch 8 (8 total)--------\n",
      "Val loss = 0.7392, Zero-temp loss = 0.4046\n",
      "\n",
      "--------Epoch 9 (9 total)--------\n",
      "Val loss = 0.7580, Zero-temp loss = 0.4050\n",
      "\n",
      "--------Epoch 10 (10 total)--------\n",
      "Val loss = 0.7721, Zero-temp loss = 0.4147\n",
      "\n",
      "--------Epoch 11 (11 total)--------\n",
      "Val loss = 0.7735, Zero-temp loss = 0.4081\n",
      "\n",
      "--------Epoch 12 (12 total)--------\n",
      "Val loss = 0.7853, Zero-temp loss = 0.4026\n",
      "\n",
      "--------Epoch 13 (13 total)--------\n",
      "Val loss = 0.7843, Zero-temp loss = 0.4254\n",
      "\n",
      "--------Epoch 14 (14 total)--------\n",
      "Val loss = 0.7243, Zero-temp loss = 0.2970\n",
      "\n",
      "--------Epoch 15 (15 total)--------\n",
      "Val loss = 0.7913, Zero-temp loss = 0.3846\n",
      "\n",
      "--------Epoch 16 (16 total)--------\n",
      "Val loss = 0.8132, Zero-temp loss = 0.4428\n",
      "\n",
      "--------Epoch 17 (17 total)--------\n",
      "Val loss = 0.8055, Zero-temp loss = 0.4266\n",
      "\n",
      "--------Epoch 18 (18 total)--------\n",
      "Val loss = 0.8152, Zero-temp loss = 0.4439\n",
      "\n",
      "--------Epoch 19 (19 total)--------\n",
      "Val loss = 0.8005, Zero-temp loss = 0.3408\n",
      "\n",
      "--------Epoch 20 (20 total)--------\n",
      "Val loss = 0.7488, Zero-temp loss = 0.2251\n",
      "\n",
      "--------Epoch 21 (21 total)--------\n",
      "Val loss = 0.8114, Zero-temp loss = 0.3042\n",
      "\n",
      "Epoch    21: reducing learning rate of group 0 to 2.0000e-04.\n",
      "--------Epoch 22 (22 total)--------\n",
      "Val loss = 0.8273, Zero-temp loss = 0.3399\n",
      "\n",
      "--------Epoch 23 (23 total)--------\n",
      "Val loss = 0.8351, Zero-temp loss = 0.3320\n",
      "\n",
      "--------Epoch 24 (24 total)--------\n",
      "Val loss = 0.8295, Zero-temp loss = 0.3505\n",
      "\n",
      "--------Epoch 25 (25 total)--------\n",
      "Val loss = 0.8375, Zero-temp loss = 0.3803\n",
      "\n",
      "--------Epoch 26 (26 total)--------\n",
      "Val loss = 0.8354, Zero-temp loss = 0.3351\n",
      "\n",
      "--------Epoch 27 (27 total)--------\n",
      "Val loss = 0.8322, Zero-temp loss = 0.2932\n",
      "\n",
      "--------Epoch 28 (28 total)--------\n",
      "Val loss = 0.8334, Zero-temp loss = 0.3609\n",
      "\n",
      "Epoch    28: reducing learning rate of group 0 to 4.0000e-05.\n",
      "--------Epoch 29 (29 total)--------\n",
      "Val loss = 0.8400, Zero-temp loss = 0.3749\n",
      "\n",
      "--------Epoch 30 (30 total)--------\n",
      "Val loss = 0.8427, Zero-temp loss = 0.4232\n",
      "\n",
      "--------Epoch 31 (31 total)--------\n",
      "Val loss = 0.8291, Zero-temp loss = 0.3108\n",
      "\n",
      "--------Epoch 32 (32 total)--------\n",
      "Val loss = 0.8368, Zero-temp loss = 0.3868\n",
      "\n",
      "--------Epoch 33 (33 total)--------\n",
      "Val loss = 0.8400, Zero-temp loss = 0.4013\n",
      "\n",
      "Epoch    33: reducing learning rate of group 0 to 1.0000e-05.\n",
      "--------Epoch 34 (34 total)--------\n",
      "Val loss = 0.8334, Zero-temp loss = 0.3246\n",
      "\n",
      "Stopping temp = 1.0000 at epoch 34\n",
      "\n",
      "Starting training with temp = 0.5623\n",
      "\n",
      "--------Epoch 1 (35 total)--------\n",
      "Val loss = 0.6691, Zero-temp loss = 0.3887\n",
      "\n",
      "--------Epoch 2 (36 total)--------\n",
      "Val loss = 0.7277, Zero-temp loss = 0.4755\n",
      "\n",
      "--------Epoch 3 (37 total)--------\n",
      "Val loss = 0.6335, Zero-temp loss = 0.3184\n",
      "\n",
      "--------Epoch 4 (38 total)--------\n",
      "Val loss = 0.6668, Zero-temp loss = 0.4332\n",
      "\n",
      "--------Epoch 5 (39 total)--------\n",
      "Val loss = 0.6817, Zero-temp loss = 0.3934\n",
      "\n",
      "Epoch     5: reducing learning rate of group 0 to 2.0000e-04.\n",
      "--------Epoch 6 (40 total)--------\n",
      "Val loss = 0.7053, Zero-temp loss = 0.4096\n",
      "\n",
      "Stopping temp = 0.5623 at epoch 6\n",
      "\n",
      "Starting training with temp = 0.3162\n",
      "\n",
      "--------Epoch 1 (41 total)--------\n",
      "Val loss = 0.6085, Zero-temp loss = 0.5078\n",
      "\n",
      "--------Epoch 2 (42 total)--------\n",
      "Val loss = 0.6289, Zero-temp loss = 0.5211\n",
      "\n",
      "--------Epoch 3 (43 total)--------\n",
      "Val loss = 0.6451, Zero-temp loss = 0.5545\n",
      "\n",
      "--------Epoch 4 (44 total)--------\n",
      "Val loss = 0.5882, Zero-temp loss = 0.4679\n",
      "\n",
      "--------Epoch 5 (45 total)--------\n",
      "Val loss = 0.6272, Zero-temp loss = 0.5277\n",
      "\n",
      "--------Epoch 6 (46 total)--------\n",
      "Val loss = 0.5751, Zero-temp loss = 0.4599\n",
      "\n",
      "Epoch     6: reducing learning rate of group 0 to 2.0000e-04.\n",
      "--------Epoch 7 (47 total)--------\n",
      "Val loss = 0.6394, Zero-temp loss = 0.5204\n",
      "\n",
      "Stopping temp = 0.3162 at epoch 7\n",
      "\n",
      "Starting training with temp = 0.1778\n",
      "\n",
      "--------Epoch 1 (48 total)--------\n",
      "Val loss = 0.6351, Zero-temp loss = 0.5976\n",
      "\n",
      "--------Epoch 2 (49 total)--------\n",
      "Val loss = 0.6337, Zero-temp loss = 0.5988\n",
      "\n",
      "--------Epoch 3 (50 total)--------\n",
      "Val loss = 0.6350, Zero-temp loss = 0.5977\n",
      "\n",
      "--------Epoch 4 (51 total)--------\n",
      "Val loss = 0.6268, Zero-temp loss = 0.5931\n",
      "\n",
      "Epoch     4: reducing learning rate of group 0 to 2.0000e-04.\n",
      "--------Epoch 5 (52 total)--------\n",
      "Val loss = 0.6570, Zero-temp loss = 0.6210\n",
      "\n",
      "--------Epoch 6 (53 total)--------\n",
      "Val loss = 0.6475, Zero-temp loss = 0.6078\n",
      "\n",
      "--------Epoch 7 (54 total)--------\n",
      "Val loss = 0.6530, Zero-temp loss = 0.6180\n",
      "\n",
      "--------Epoch 8 (55 total)--------\n",
      "Val loss = 0.6581, Zero-temp loss = 0.6223\n",
      "\n",
      "--------Epoch 9 (56 total)--------\n",
      "Val loss = 0.6543, Zero-temp loss = 0.6151\n",
      "\n",
      "--------Epoch 10 (57 total)--------\n",
      "Val loss = 0.6438, Zero-temp loss = 0.6061\n",
      "\n",
      "--------Epoch 11 (58 total)--------\n",
      "Val loss = 0.6368, Zero-temp loss = 0.5986\n",
      "\n",
      "Epoch    11: reducing learning rate of group 0 to 4.0000e-05.\n",
      "--------Epoch 12 (59 total)--------\n",
      "Val loss = 0.6599, Zero-temp loss = 0.6225\n",
      "\n",
      "--------Epoch 13 (60 total)--------\n",
      "Val loss = 0.6433, Zero-temp loss = 0.6027\n",
      "\n",
      "--------Epoch 14 (61 total)--------\n",
      "Val loss = 0.6322, Zero-temp loss = 0.5911\n",
      "\n",
      "--------Epoch 15 (62 total)--------\n",
      "Val loss = 0.6504, Zero-temp loss = 0.6098\n",
      "\n",
      "Epoch    15: reducing learning rate of group 0 to 1.0000e-05.\n",
      "--------Epoch 16 (63 total)--------\n",
      "Val loss = 0.6377, Zero-temp loss = 0.5985\n",
      "\n",
      "Stopping temp = 0.1778 at epoch 16\n",
      "\n",
      "Starting training with temp = 0.1000\n",
      "\n",
      "--------Epoch 1 (64 total)--------\n",
      "Val loss = 0.6241, Zero-temp loss = 0.6073\n",
      "\n",
      "--------Epoch 2 (65 total)--------\n",
      "Val loss = 0.6415, Zero-temp loss = 0.6263\n",
      "\n",
      "--------Epoch 3 (66 total)--------\n",
      "Val loss = 0.6301, Zero-temp loss = 0.6160\n",
      "\n",
      "--------Epoch 4 (67 total)--------\n",
      "Val loss = 0.6386, Zero-temp loss = 0.6242\n",
      "\n",
      "--------Epoch 5 (68 total)--------\n",
      "Val loss = 0.6344, Zero-temp loss = 0.6228\n",
      "\n",
      "Epoch     5: reducing learning rate of group 0 to 2.0000e-04.\n",
      "--------Epoch 6 (69 total)--------\n",
      "Val loss = 0.6577, Zero-temp loss = 0.6445\n",
      "\n",
      "--------Epoch 7 (70 total)--------\n",
      "Val loss = 0.6507, Zero-temp loss = 0.6358\n",
      "\n",
      "--------Epoch 8 (71 total)--------\n",
      "Val loss = 0.6631, Zero-temp loss = 0.6506\n",
      "\n",
      "--------Epoch 9 (72 total)--------\n",
      "Val loss = 0.6655, Zero-temp loss = 0.6511\n",
      "\n",
      "--------Epoch 10 (73 total)--------\n",
      "Val loss = 0.6636, Zero-temp loss = 0.6509\n",
      "\n",
      "--------Epoch 11 (74 total)--------\n",
      "Val loss = 0.6639, Zero-temp loss = 0.6491\n",
      "\n",
      "--------Epoch 12 (75 total)--------\n",
      "Val loss = 0.6650, Zero-temp loss = 0.6517\n",
      "\n",
      "Epoch    12: reducing learning rate of group 0 to 4.0000e-05.\n",
      "--------Epoch 13 (76 total)--------\n",
      "Val loss = 0.6711, Zero-temp loss = 0.6558\n",
      "\n",
      "--------Epoch 14 (77 total)--------\n",
      "Val loss = 0.6674, Zero-temp loss = 0.6539\n",
      "\n",
      "--------Epoch 15 (78 total)--------\n",
      "Val loss = 0.6743, Zero-temp loss = 0.6618\n",
      "\n",
      "--------Epoch 16 (79 total)--------\n",
      "Val loss = 0.6712, Zero-temp loss = 0.6574\n",
      "\n",
      "--------Epoch 17 (80 total)--------\n",
      "Val loss = 0.6720, Zero-temp loss = 0.6581\n",
      "\n",
      "--------Epoch 18 (81 total)--------\n",
      "Val loss = 0.6687, Zero-temp loss = 0.6561\n",
      "\n",
      "Epoch    18: reducing learning rate of group 0 to 1.0000e-05.\n",
      "--------Epoch 19 (82 total)--------\n",
      "Val loss = 0.6704, Zero-temp loss = 0.6570\n",
      "\n",
      "Stopping temp = 0.1000 at epoch 19\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Train with whole dataset\n",
    "gafs.fit(train_dataset,\n",
    "         val_dataset,\n",
    "         mbsize=128,\n",
    "         lr=1e-3,\n",
    "         nepochs=250,\n",
    "         max_features=max_features,\n",
    "         loss_fn=nn.CrossEntropyLoss(),\n",
    "         val_loss_fn=Accuracy(),\n",
    "         val_loss_mode='max')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "2aa08f9b-eb46-4a80-8a1a-2258d3bfb6a5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Num = 1, Acc = 33.89\n",
      "Num = 2, Acc = 50.32\n",
      "Num = 3, Acc = 58.60\n",
      "Num = 4, Acc = 65.81\n",
      "Num = 5, Acc = 70.83\n",
      "Num = 6, Acc = 74.05\n",
      "Num = 7, Acc = 76.66\n",
      "Num = 8, Acc = 77.57\n",
      "Num = 9, Acc = 78.57\n",
      "Num = 10, Acc = 79.33\n",
      "Num = 15, Acc = 82.42\n",
      "Num = 20, Acc = 83.14\n",
      "Num = 25, Acc = 84.22\n",
      "Num = 30, Acc = 84.89\n"
     ]
    }
   ],
   "source": [
    "# Calculate validation acc\n",
    "for num in (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 25, 30):\n",
    "    acc = gafs.evaluate(val_dataset, max_features=num, metric=Accuracy(), batch_size=128)\n",
    "    print(f'Num = {num}, Acc = {100*acc:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "5e8c5479-f5b1-4135-ac67-611711566632",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save trained model\n",
    "gafs.cpu()\n",
    "torch.save(gafs, 'results/adaptive_trained2.pt')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7783d3b4-89d0-47b9-b1bf-fdba990ef005",
   "metadata": {},
   "source": [
    "# Even more training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "36e7c390-cb3d-436d-89fd-1c7dd7104766",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load model\n",
    "gafs = torch.load('results/adaptive_trained2.pt')\n",
    "gafs.selector[0] = gafs.predictor[0]\n",
    "gafs = gafs.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e07dbc5-ce99-4f22-83be-f7742faca50c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting training with temp = 1.0000\n",
      "\n",
      "--------Epoch 1 (1 total)--------\n",
      "Val loss = 0.7939, Zero-temp loss = 0.7144\n",
      "\n",
      "--------Epoch 2 (2 total)--------\n",
      "Val loss = 0.8170, Zero-temp loss = 0.7286\n",
      "\n",
      "--------Epoch 3 (3 total)--------\n",
      "Val loss = 0.8089, Zero-temp loss = 0.7240\n",
      "\n",
      "--------Epoch 4 (4 total)--------\n",
      "Val loss = 0.8187, Zero-temp loss = 0.7310\n",
      "\n",
      "--------Epoch 5 (5 total)--------\n",
      "Val loss = 0.8161, Zero-temp loss = 0.7191\n",
      "\n",
      "--------Epoch 6 (6 total)--------\n",
      "Val loss = 0.8226, Zero-temp loss = 0.7287\n",
      "\n",
      "--------Epoch 7 (7 total)--------\n",
      "Val loss = 0.8252, Zero-temp loss = 0.7277\n",
      "\n",
      "--------Epoch 8 (8 total)--------\n",
      "Val loss = 0.8209, Zero-temp loss = 0.7200\n",
      "\n",
      "--------Epoch 9 (9 total)--------\n",
      "Val loss = 0.8219, Zero-temp loss = 0.7177\n",
      "\n",
      "--------Epoch 10 (10 total)--------\n",
      "Val loss = 0.8282, Zero-temp loss = 0.7308\n",
      "\n",
      "--------Epoch 11 (11 total)--------\n",
      "Val loss = 0.8168, Zero-temp loss = 0.7137\n",
      "\n",
      "--------Epoch 12 (12 total)--------\n",
      "Val loss = 0.8355, Zero-temp loss = 0.7289\n",
      "\n",
      "--------Epoch 13 (13 total)--------\n",
      "Val loss = 0.8357, Zero-temp loss = 0.7366\n",
      "\n",
      "--------Epoch 14 (14 total)--------\n",
      "Val loss = 0.8285, Zero-temp loss = 0.7256\n",
      "\n",
      "--------Epoch 15 (15 total)--------\n",
      "Val loss = 0.8374, Zero-temp loss = 0.7356\n",
      "\n",
      "--------Epoch 16 (16 total)--------\n",
      "Val loss = 0.8310, Zero-temp loss = 0.7245\n",
      "\n",
      "--------Epoch 17 (17 total)--------\n",
      "Val loss = 0.8331, Zero-temp loss = 0.7231\n",
      "\n",
      "--------Epoch 18 (18 total)--------\n",
      "Val loss = 0.8373, Zero-temp loss = 0.7248\n",
      "\n",
      "Epoch    18: reducing learning rate of group 0 to 2.0000e-04.\n",
      "--------Epoch 19 (19 total)--------\n",
      "Val loss = 0.8575, Zero-temp loss = 0.7525\n",
      "\n",
      "--------Epoch 20 (20 total)--------\n",
      "Val loss = 0.8535, Zero-temp loss = 0.7473\n",
      "\n",
      "--------Epoch 21 (21 total)--------\n",
      "Val loss = 0.8481, Zero-temp loss = 0.7391\n",
      "\n",
      "--------Epoch 22 (22 total)--------\n",
      "Val loss = 0.8535, Zero-temp loss = 0.7449\n",
      "\n",
      "Epoch    22: reducing learning rate of group 0 to 4.0000e-05.\n",
      "--------Epoch 23 (23 total)--------\n",
      "Val loss = 0.8473, Zero-temp loss = 0.7245\n",
      "\n",
      "Stopping temp = 1.0000 at epoch 23\n",
      "\n",
      "Starting training with temp = 0.5623\n",
      "\n",
      "--------Epoch 1 (24 total)--------\n",
      "Val loss = 0.7891, Zero-temp loss = 0.7360\n",
      "\n",
      "--------Epoch 2 (25 total)--------\n",
      "Val loss = 0.7940, Zero-temp loss = 0.7420\n",
      "\n",
      "--------Epoch 3 (26 total)--------\n",
      "Val loss = 0.7908, Zero-temp loss = 0.7407\n",
      "\n",
      "--------Epoch 4 (27 total)--------\n",
      "Val loss = 0.7916, Zero-temp loss = 0.7439\n",
      "\n",
      "--------Epoch 5 (28 total)--------\n",
      "Val loss = 0.7963, Zero-temp loss = 0.7484\n",
      "\n",
      "--------Epoch 6 (29 total)--------\n",
      "Val loss = 0.7905, Zero-temp loss = 0.7404\n",
      "\n",
      "--------Epoch 7 (30 total)--------\n",
      "Val loss = 0.7984, Zero-temp loss = 0.7456\n",
      "\n",
      "--------Epoch 8 (31 total)--------\n",
      "Val loss = 0.7960, Zero-temp loss = 0.7473\n",
      "\n",
      "--------Epoch 9 (32 total)--------\n",
      "Val loss = 0.8043, Zero-temp loss = 0.7497\n",
      "\n",
      "--------Epoch 10 (33 total)--------\n",
      "Val loss = 0.8016, Zero-temp loss = 0.7517\n",
      "\n",
      "--------Epoch 11 (34 total)--------\n",
      "Val loss = 0.8016, Zero-temp loss = 0.7484\n",
      "\n",
      "--------Epoch 12 (35 total)--------\n",
      "Val loss = 0.8033, Zero-temp loss = 0.7496\n",
      "\n",
      "Epoch    12: reducing learning rate of group 0 to 2.0000e-04.\n",
      "--------Epoch 13 (36 total)--------\n",
      "Val loss = 0.8215, Zero-temp loss = 0.7656\n",
      "\n",
      "--------Epoch 14 (37 total)--------\n",
      "Val loss = 0.8219, Zero-temp loss = 0.7663\n",
      "\n",
      "--------Epoch 15 (38 total)--------\n",
      "Val loss = 0.8225, Zero-temp loss = 0.7659\n",
      "\n",
      "--------Epoch 16 (39 total)--------\n",
      "Val loss = 0.8178, Zero-temp loss = 0.7590\n",
      "\n",
      "--------Epoch 17 (40 total)--------\n",
      "Val loss = 0.8225, Zero-temp loss = 0.7672\n",
      "\n",
      "--------Epoch 18 (41 total)--------\n",
      "Val loss = 0.8126, Zero-temp loss = 0.7527\n",
      "\n",
      "Epoch    18: reducing learning rate of group 0 to 4.0000e-05.\n",
      "--------Epoch 19 (42 total)--------\n",
      "Val loss = 0.8274, Zero-temp loss = 0.7686\n",
      "\n",
      "--------Epoch 20 (43 total)--------\n",
      "Val loss = 0.8205, Zero-temp loss = 0.7613\n",
      "\n",
      "--------Epoch 21 (44 total)--------\n",
      "Val loss = 0.8230, Zero-temp loss = 0.7653\n",
      "\n",
      "--------Epoch 22 (45 total)--------\n",
      "Val loss = 0.8308, Zero-temp loss = 0.7734\n",
      "\n",
      "--------Epoch 23 (46 total)--------\n",
      "Val loss = 0.8300, Zero-temp loss = 0.7707\n",
      "\n",
      "--------Epoch 24 (47 total)--------\n",
      "Val loss = 0.8301, Zero-temp loss = 0.7703\n",
      "\n",
      "--------Epoch 25 (48 total)--------\n",
      "Val loss = 0.8269, Zero-temp loss = 0.7696\n",
      "\n",
      "Epoch    25: reducing learning rate of group 0 to 1.0000e-05.\n",
      "--------Epoch 26 (49 total)--------\n",
      "Val loss = 0.8299, Zero-temp loss = 0.7703\n",
      "\n",
      "Stopping temp = 0.5623 at epoch 26\n",
      "\n",
      "Starting training with temp = 0.3162\n",
      "\n",
      "--------Epoch 1 (50 total)--------\n",
      "Val loss = 0.7808, Zero-temp loss = 0.7601\n",
      "\n",
      "--------Epoch 2 (51 total)--------\n",
      "Val loss = 0.7779, Zero-temp loss = 0.7543\n",
      "\n",
      "--------Epoch 3 (52 total)--------\n",
      "Val loss = 0.7790, Zero-temp loss = 0.7566\n",
      "\n",
      "--------Epoch 4 (53 total)--------\n",
      "Val loss = 0.7816, Zero-temp loss = 0.7627\n",
      "\n",
      "--------Epoch 5 (54 total)--------\n",
      "Val loss = 0.7814, Zero-temp loss = 0.7599\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Train with whole dataset\n",
    "gafs.fit(train_dataset,\n",
    "         val_dataset,\n",
    "         mbsize=128,\n",
    "         lr=1e-3,\n",
    "         nepochs=250,\n",
    "         max_features=32,\n",
    "         loss_fn=nn.CrossEntropyLoss(),\n",
    "         val_loss_fn=Accuracy(),\n",
    "         val_loss_mode='max')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58aea653-dcb5-47a5-95fc-bf7bddc03af5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "456d9b6d-75ad-4b47-a445-3156ad4e195c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0fb29ef-e808-4744-b062-4dfcf2045b98",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
