{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f78af6b7-569c-4f35-ae99-511121e304e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import pickle\n",
    "import argparse\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "from torchmetrics import AUROC\n",
    "from sklearn.metrics import accuracy_score, roc_auc_score\n",
    "from adaptive import AdaptiveSelection, MaskLayer, MaskingPretrainer\n",
    "\n",
    "import sys\n",
    "sys.path.append('../')\n",
    "from data import DenseDatasetSelected, data_split, get_xy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fcfcf277-648a-45d6-b057-f8d2ece3820b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train samples = 2761, val samples = 920, test samples = 920\n"
     ]
    }
   ],
   "source": [
    "# Load dataset\n",
    "dataset = DenseDatasetSelected('../../datasets/spam.csv')\n",
    "d_in = dataset.X.shape[1]  # 57\n",
    "d_out = len(np.unique(dataset.Y))  # 2\n",
    "\n",
    "# Split dataset\n",
    "train_dataset, val_dataset, test_dataset = data_split(dataset, random_state=0)\n",
    "print(f'Train samples = {len(train_dataset)}, val samples = {len(val_dataset)}, test samples = {len(test_dataset)}')\n",
    "\n",
    "# Find mean/variance for normalizing\n",
    "x, y = get_xy(train_dataset)\n",
    "mean = np.mean(x, axis=0)\n",
    "std = np.std(y, axis=0)\n",
    "\n",
    "# Normalize via the original dataset\n",
    "dataset.X = dataset.X - mean\n",
    "\n",
    "# Setup\n",
    "max_features = 35\n",
    "device = torch.device('cuda', 7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "22a9623a-9d59-449f-beac-2feaa0a79dcf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = 0.7453\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = 0.8113\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = 0.5626\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = 0.5103\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = 0.4666\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = 0.4646\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = 0.4354\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = 0.4472\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = 0.4418\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = 0.4107\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = 0.4089\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = 0.4051\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = 0.3911\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = 0.4072\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = 0.4068\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = 0.4099\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = 0.3949\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = 0.4310\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = 0.3840\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = 0.3799\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = 0.3787\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = 0.3746\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = 0.3764\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = 0.3654\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = 0.3961\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = 0.3820\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = 0.3496\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = 0.3663\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = 0.3475\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = 0.3669\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = 0.3771\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = 0.3494\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = 0.3420\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = 0.4071\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = 0.3509\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = 0.3415\n",
      "\n",
      "--------Epoch 37--------\n",
      "Val loss = 0.3509\n",
      "\n",
      "--------Epoch 38--------\n",
      "Val loss = 0.3366\n",
      "\n",
      "--------Epoch 39--------\n",
      "Val loss = 0.3351\n",
      "\n",
      "--------Epoch 40--------\n",
      "Val loss = 0.3563\n",
      "\n",
      "--------Epoch 41--------\n",
      "Val loss = 0.3763\n",
      "\n",
      "--------Epoch 42--------\n",
      "Val loss = 0.3583\n",
      "\n",
      "--------Epoch 43--------\n",
      "Val loss = 0.3609\n",
      "\n",
      "--------Epoch 44--------\n",
      "Val loss = 0.3364\n",
      "\n",
      "--------Epoch 45--------\n",
      "Val loss = 0.3430\n",
      "\n",
      "Epoch    45: reducing learning rate of group 0 to 2.0000e-04.\n",
      "--------Epoch 46--------\n",
      "Val loss = 0.3489\n",
      "\n",
      "Stopping early at epoch 46\n",
      "Starting training with temp = 1.0000\n",
      "\n",
      "--------Epoch 1 (1 total)--------\n",
      "Val loss = 0.2865, Zero-temp loss = 0.2994\n",
      "\n",
      "--------Epoch 2 (2 total)--------\n",
      "Val loss = 0.2864, Zero-temp loss = 0.3036\n",
      "\n",
      "--------Epoch 3 (3 total)--------\n",
      "Val loss = 0.2610, Zero-temp loss = 0.2752\n",
      "\n",
      "--------Epoch 4 (4 total)--------\n",
      "Val loss = 0.2656, Zero-temp loss = 0.2815\n",
      "\n",
      "--------Epoch 5 (5 total)--------\n",
      "Val loss = 0.2406, Zero-temp loss = 0.2555\n",
      "\n",
      "--------Epoch 6 (6 total)--------\n",
      "Val loss = 0.2353, Zero-temp loss = 0.2486\n",
      "\n",
      "--------Epoch 7 (7 total)--------\n",
      "Val loss = 0.2348, Zero-temp loss = 0.2486\n",
      "\n",
      "--------Epoch 8 (8 total)--------\n",
      "Val loss = 0.2147, Zero-temp loss = 0.2270\n",
      "\n",
      "--------Epoch 9 (9 total)--------\n",
      "Val loss = 0.2304, Zero-temp loss = 0.2456\n",
      "\n",
      "--------Epoch 10 (10 total)--------\n",
      "Val loss = 0.2256, Zero-temp loss = 0.2440\n",
      "\n",
      "--------Epoch 11 (11 total)--------\n",
      "Val loss = 0.2270, Zero-temp loss = 0.2448\n",
      "\n",
      "--------Epoch 12 (12 total)--------\n",
      "Val loss = 0.2209, Zero-temp loss = 0.2375\n",
      "\n",
      "--------Epoch 13 (13 total)--------\n",
      "Val loss = 0.2472, Zero-temp loss = 0.2681\n",
      "\n",
      "--------Epoch 14 (14 total)--------\n",
      "Val loss = 0.2358, Zero-temp loss = 0.2581\n",
      "\n",
      "Epoch    14: reducing learning rate of group 0 to 2.0000e-04.\n",
      "--------Epoch 15 (15 total)--------\n",
      "Val loss = 0.2217, Zero-temp loss = 0.2411\n",
      "\n",
      "Stopping temp = 1.0000 at epoch 15\n",
      "\n",
      "Starting training with temp = 0.5623\n",
      "\n",
      "--------Epoch 1 (16 total)--------\n",
      "Val loss = 0.2406, Zero-temp loss = 0.2504\n",
      "\n",
      "--------Epoch 2 (17 total)--------\n",
      "Val loss = 0.2263, Zero-temp loss = 0.2355\n",
      "\n",
      "--------Epoch 3 (18 total)--------\n",
      "Val loss = 0.2253, Zero-temp loss = 0.2337\n",
      "\n",
      "--------Epoch 4 (19 total)--------\n",
      "Val loss = 0.2198, Zero-temp loss = 0.2280\n",
      "\n",
      "--------Epoch 5 (20 total)--------\n",
      "Val loss = 0.2277, Zero-temp loss = 0.2373\n",
      "\n",
      "--------Epoch 6 (21 total)--------\n",
      "Val loss = 0.2357, Zero-temp loss = 0.2440\n",
      "\n",
      "--------Epoch 7 (22 total)--------\n",
      "Val loss = 0.2225, Zero-temp loss = 0.2319\n",
      "\n",
      "--------Epoch 8 (23 total)--------\n",
      "Val loss = 0.2358, Zero-temp loss = 0.2472\n",
      "\n",
      "--------Epoch 9 (24 total)--------\n",
      "Val loss = 0.2122, Zero-temp loss = 0.2227\n",
      "\n",
      "--------Epoch 10 (25 total)--------\n",
      "Val loss = 0.2281, Zero-temp loss = 0.2386\n",
      "\n",
      "--------Epoch 11 (26 total)--------\n",
      "Val loss = 0.2397, Zero-temp loss = 0.2505\n",
      "\n",
      "--------Epoch 12 (27 total)--------\n",
      "Val loss = 0.2211, Zero-temp loss = 0.2308\n",
      "\n",
      "--------Epoch 13 (28 total)--------\n",
      "Val loss = 0.2224, Zero-temp loss = 0.2326\n",
      "\n",
      "--------Epoch 14 (29 total)--------\n",
      "Val loss = 0.2275, Zero-temp loss = 0.2383\n",
      "\n",
      "--------Epoch 15 (30 total)--------\n",
      "Val loss = 0.2230, Zero-temp loss = 0.2324\n",
      "\n",
      "Epoch    15: reducing learning rate of group 0 to 2.0000e-04.\n",
      "--------Epoch 16 (31 total)--------\n",
      "Val loss = 0.2087, Zero-temp loss = 0.2177\n",
      "\n",
      "--------Epoch 17 (32 total)--------\n",
      "Val loss = 0.2106, Zero-temp loss = 0.2209\n",
      "\n",
      "--------Epoch 18 (33 total)--------\n",
      "Val loss = 0.2170, Zero-temp loss = 0.2281\n",
      "\n",
      "--------Epoch 19 (34 total)--------\n",
      "Val loss = 0.2120, Zero-temp loss = 0.2216\n",
      "\n",
      "--------Epoch 20 (35 total)--------\n",
      "Val loss = 0.2074, Zero-temp loss = 0.2170\n",
      "\n",
      "--------Epoch 21 (36 total)--------\n",
      "Val loss = 0.2104, Zero-temp loss = 0.2201\n",
      "\n",
      "--------Epoch 22 (37 total)--------\n",
      "Val loss = 0.2120, Zero-temp loss = 0.2233\n",
      "\n",
      "--------Epoch 23 (38 total)--------\n",
      "Val loss = 0.2097, Zero-temp loss = 0.2215\n",
      "\n",
      "--------Epoch 24 (39 total)--------\n",
      "Val loss = 0.2082, Zero-temp loss = 0.2187\n",
      "\n",
      "--------Epoch 25 (40 total)--------\n",
      "Val loss = 0.2079, Zero-temp loss = 0.2183\n",
      "\n",
      "--------Epoch 26 (41 total)--------\n",
      "Val loss = 0.2108, Zero-temp loss = 0.2209\n",
      "\n",
      "Epoch    26: reducing learning rate of group 0 to 4.0000e-05.\n",
      "--------Epoch 27 (42 total)--------\n",
      "Val loss = 0.2104, Zero-temp loss = 0.2199\n",
      "\n",
      "Stopping temp = 0.5623 at epoch 27\n",
      "\n",
      "Starting training with temp = 0.3162\n",
      "\n",
      "--------Epoch 1 (43 total)--------\n",
      "Val loss = 0.2334, Zero-temp loss = 0.2389\n",
      "\n",
      "--------Epoch 2 (44 total)--------\n",
      "Val loss = 0.2163, Zero-temp loss = 0.2212\n",
      "\n",
      "--------Epoch 3 (45 total)--------\n",
      "Val loss = 0.2211, Zero-temp loss = 0.2253\n",
      "\n",
      "--------Epoch 4 (46 total)--------\n",
      "Val loss = 0.2112, Zero-temp loss = 0.2151\n",
      "\n",
      "--------Epoch 5 (47 total)--------\n",
      "Val loss = 0.2236, Zero-temp loss = 0.2297\n",
      "\n",
      "--------Epoch 6 (48 total)--------\n",
      "Val loss = 0.2133, Zero-temp loss = 0.2182\n",
      "\n",
      "--------Epoch 7 (49 total)--------\n",
      "Val loss = 0.2154, Zero-temp loss = 0.2195\n",
      "\n",
      "--------Epoch 8 (50 total)--------\n",
      "Val loss = 0.2210, Zero-temp loss = 0.2248\n",
      "\n",
      "--------Epoch 9 (51 total)--------\n",
      "Val loss = 0.2230, Zero-temp loss = 0.2269\n",
      "\n",
      "--------Epoch 10 (52 total)--------\n",
      "Val loss = 0.2154, Zero-temp loss = 0.2196\n",
      "\n",
      "Epoch    10: reducing learning rate of group 0 to 2.0000e-04.\n",
      "--------Epoch 11 (53 total)--------\n",
      "Val loss = 0.2178, Zero-temp loss = 0.2230\n",
      "\n",
      "Stopping temp = 0.3162 at epoch 11\n",
      "\n",
      "Starting training with temp = 0.1778\n",
      "\n",
      "--------Epoch 1 (54 total)--------\n",
      "Val loss = 0.2129, Zero-temp loss = 0.2147\n",
      "\n",
      "--------Epoch 2 (55 total)--------\n",
      "Val loss = 0.2222, Zero-temp loss = 0.2243\n",
      "\n",
      "--------Epoch 3 (56 total)--------\n",
      "Val loss = 0.2226, Zero-temp loss = 0.2248\n",
      "\n",
      "--------Epoch 4 (57 total)--------\n",
      "Val loss = 0.2280, Zero-temp loss = 0.2299\n",
      "\n",
      "--------Epoch 5 (58 total)--------\n",
      "Val loss = 0.2236, Zero-temp loss = 0.2248\n",
      "\n",
      "--------Epoch 6 (59 total)--------\n",
      "Val loss = 0.2220, Zero-temp loss = 0.2242\n",
      "\n",
      "--------Epoch 7 (60 total)--------\n",
      "Val loss = 0.2361, Zero-temp loss = 0.2383\n",
      "\n",
      "Epoch     7: reducing learning rate of group 0 to 2.0000e-04.\n",
      "--------Epoch 8 (61 total)--------\n",
      "Val loss = 0.2141, Zero-temp loss = 0.2162\n",
      "\n",
      "Stopping temp = 0.1778 at epoch 8\n",
      "\n",
      "Starting training with temp = 0.1000\n",
      "\n",
      "--------Epoch 1 (62 total)--------\n",
      "Val loss = 0.2270, Zero-temp loss = 0.2282\n",
      "\n",
      "--------Epoch 2 (63 total)--------\n",
      "Val loss = 0.2223, Zero-temp loss = 0.2233\n",
      "\n",
      "--------Epoch 3 (64 total)--------\n",
      "Val loss = 0.2352, Zero-temp loss = 0.2362\n",
      "\n",
      "--------Epoch 4 (65 total)--------\n",
      "Val loss = 0.2335, Zero-temp loss = 0.2343\n",
      "\n",
      "--------Epoch 5 (66 total)--------\n",
      "Val loss = 0.2375, Zero-temp loss = 0.2385\n",
      "\n",
      "--------Epoch 6 (67 total)--------\n",
      "Val loss = 0.2330, Zero-temp loss = 0.2337\n",
      "\n",
      "--------Epoch 7 (68 total)--------\n",
      "Val loss = 0.2239, Zero-temp loss = 0.2248\n",
      "\n",
      "--------Epoch 8 (69 total)--------\n",
      "Val loss = 0.2210, Zero-temp loss = 0.2213\n",
      "\n",
      "--------Epoch 9 (70 total)--------\n",
      "Val loss = 0.2319, Zero-temp loss = 0.2326\n",
      "\n",
      "--------Epoch 10 (71 total)--------\n",
      "Val loss = 0.2328, Zero-temp loss = 0.2334\n",
      "\n",
      "--------Epoch 11 (72 total)--------\n",
      "Val loss = 0.2213, Zero-temp loss = 0.2216\n",
      "\n",
      "--------Epoch 12 (73 total)--------\n",
      "Val loss = 0.2144, Zero-temp loss = 0.2150\n",
      "\n",
      "--------Epoch 13 (74 total)--------\n",
      "Val loss = 0.2260, Zero-temp loss = 0.2266\n",
      "\n",
      "--------Epoch 14 (75 total)--------\n",
      "Val loss = 0.2250, Zero-temp loss = 0.2256\n",
      "\n",
      "--------Epoch 15 (76 total)--------\n",
      "Val loss = 0.2303, Zero-temp loss = 0.2306\n",
      "\n",
      "--------Epoch 16 (77 total)--------\n",
      "Val loss = 0.2318, Zero-temp loss = 0.2323\n",
      "\n",
      "--------Epoch 17 (78 total)--------\n",
      "Val loss = 0.2366, Zero-temp loss = 0.2374\n",
      "\n",
      "--------Epoch 18 (79 total)--------\n",
      "Val loss = 0.2116, Zero-temp loss = 0.2122\n",
      "\n",
      "--------Epoch 19 (80 total)--------\n",
      "Val loss = 0.2246, Zero-temp loss = 0.2253\n",
      "\n",
      "--------Epoch 20 (81 total)--------\n",
      "Val loss = 0.2249, Zero-temp loss = 0.2258\n",
      "\n",
      "--------Epoch 21 (82 total)--------\n",
      "Val loss = 0.2167, Zero-temp loss = 0.2175\n",
      "\n",
      "--------Epoch 22 (83 total)--------\n",
      "Val loss = 0.2426, Zero-temp loss = 0.2431\n",
      "\n",
      "--------Epoch 23 (84 total)--------\n",
      "Val loss = 0.2218, Zero-temp loss = 0.2226\n",
      "\n",
      "--------Epoch 24 (85 total)--------\n",
      "Val loss = 0.2158, Zero-temp loss = 0.2167\n",
      "\n",
      "Epoch    24: reducing learning rate of group 0 to 2.0000e-04.\n",
      "--------Epoch 25 (86 total)--------\n",
      "Val loss = 0.2129, Zero-temp loss = 0.2136\n",
      "\n",
      "Stopping temp = 0.1000 at epoch 25\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up architecture\n",
    "hidden = 128\n",
    "dropout = 0.3\n",
    "\n",
    "# Predictor\n",
    "predictor = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, hidden),\n",
    "    nn.ReLU(),\n",
    "    nn.Dropout(dropout),\n",
    "    nn.Linear(hidden, hidden),\n",
    "    nn.ReLU(),\n",
    "    nn.Dropout(dropout),\n",
    "    nn.Linear(hidden, d_out))\n",
    "\n",
    "# Selector\n",
    "selector = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, hidden),\n",
    "    nn.ReLU(),\n",
    "    nn.Dropout(dropout),\n",
    "    nn.Linear(hidden, hidden),\n",
    "    nn.ReLU(),\n",
    "    nn.Dropout(dropout),\n",
    "    nn.Linear(hidden, d_in))\n",
    "\n",
    "# Tie weights\n",
    "# selector[0] = predictor[0]\n",
    "# selector[3] = predictor[3]\n",
    "\n",
    "# Pretrain predictor\n",
    "mask_layer = MaskLayer(append=True)\n",
    "pretrain = MaskingPretrainer(predictor, mask_layer).to(device)\n",
    "pretrain.fit(train_dataset,\n",
    "             val_dataset,\n",
    "             mbsize=32,\n",
    "             lr=1e-3,\n",
    "             nepochs=100,\n",
    "             loss_fn=nn.CrossEntropyLoss(),\n",
    "             # val_loss_fn=AUROC(num_classes=2),\n",
    "             # val_loss_mode='max',\n",
    "             patience=5,\n",
    "             verbose=True)\n",
    "\n",
    "# Train adaptive selection\n",
    "gafs = AdaptiveSelection(selector, predictor, mask_layer).to(device)\n",
    "gafs.fit(train_dataset,\n",
    "         val_dataset,\n",
    "         mbsize=32,\n",
    "         lr=1e-3,\n",
    "         nepochs=250,\n",
    "         max_features=max_features,\n",
    "         loss_fn=nn.CrossEntropyLoss(),\n",
    "         # val_loss_fn=AUROC(num_classes=2),\n",
    "         # val_loss_mode='max',\n",
    "         patience=5,\n",
    "         verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "13661ab0-a357-463c-ba1e-b24dfe41b211",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "val\n",
      "Num = 1, AUROC = 80.57, Acc = 76.74\n",
      "Num = 2, AUROC = 87.13, Acc = 82.28\n",
      "Num = 3, AUROC = 91.25, Acc = 88.04\n",
      "Num = 4, AUROC = 93.94, Acc = 89.24\n",
      "Num = 5, AUROC = 95.02, Acc = 90.22\n",
      "Num = 6, AUROC = 95.87, Acc = 91.09\n",
      "Num = 7, AUROC = 96.29, Acc = 91.96\n",
      "Num = 8, AUROC = 96.84, Acc = 92.93\n",
      "Num = 9, AUROC = 97.09, Acc = 93.48\n",
      "Num = 10, AUROC = 97.19, Acc = 93.15\n",
      "Num = 15, AUROC = 97.80, Acc = 93.48\n",
      "Num = 20, AUROC = 97.82, Acc = 93.70\n",
      "Num = 25, AUROC = 97.72, Acc = 92.83\n",
      "test\n",
      "Num = 1, AUROC = 82.87, Acc = 78.26\n",
      "Num = 2, AUROC = 88.39, Acc = 82.28\n",
      "Num = 3, AUROC = 93.29, Acc = 88.80\n",
      "Num = 4, AUROC = 95.47, Acc = 90.43\n",
      "Num = 5, AUROC = 96.35, Acc = 91.41\n",
      "Num = 6, AUROC = 96.50, Acc = 92.50\n",
      "Num = 7, AUROC = 96.88, Acc = 91.85\n",
      "Num = 8, AUROC = 97.25, Acc = 92.61\n",
      "Num = 9, AUROC = 97.06, Acc = 93.15\n",
      "Num = 10, AUROC = 97.07, Acc = 93.48\n",
      "Num = 15, AUROC = 97.44, Acc = 93.91\n",
      "Num = 20, AUROC = 97.64, Acc = 93.80\n",
      "Num = 25, AUROC = 97.72, Acc = 93.15\n"
     ]
    }
   ],
   "source": [
    "# TODO: DELETE THIS BLOCK LATER\n",
    "num_features = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 25]\n",
    "\n",
    "# Val\n",
    "print('val')\n",
    "x, y = get_xy(val_dataset)\n",
    "for num in num_features:\n",
    "    pred = gafs(torch.tensor(x, device=device), max_features=num).softmax(dim=1).cpu().data.numpy()\n",
    "    auroc = roc_auc_score(y, pred[:,1])\n",
    "    acc = accuracy_score(y, pred.argmax(axis=1))\n",
    "    print(f'Num = {num}, AUROC = {100*auroc:.2f}, Acc = {100*acc:.2f}')\n",
    "\n",
    "# Test\n",
    "print('test')\n",
    "x, y = get_xy(test_dataset)\n",
    "for num in num_features:\n",
    "    pred = gafs(torch.tensor(x, device=device), max_features=num).softmax(dim=1).cpu().data.numpy()\n",
    "    auroc = roc_auc_score(y, pred[:,1])\n",
    "    acc = accuracy_score(y, pred.argmax(axis=1))\n",
    "    print(f'Num = {num}, AUROC = {100*auroc:.2f}, Acc = {100*acc:.2f}')\n",
    "# TODO DELETE THIS BLOCK LATER"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ff6fe3c8-37ec-436e-8055-ef6865f5255f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = 0.5074\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = 0.3388\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = 0.2939\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = 0.2745\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = 0.2804\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = 0.2578\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = 0.2715\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = 0.2455\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = 0.2455\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = 0.2504\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = 0.2358\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = 0.2294\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = 0.2354\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = 0.2286\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = 0.2374\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = 0.2304\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = 0.2254\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = 0.2376\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = 0.2224\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = 0.2255\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = 0.2391\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = 0.2372\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = 0.2233\n",
      "\n",
      "Epoch    23: reducing learning rate of group 0 to 2.0000e-04.\n",
      "--------Epoch 24--------\n",
      "Val loss = 0.2262\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Reset predictor and train with frozen selector.\n",
    "predictor = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, hidden),\n",
    "    nn.ReLU(),\n",
    "    nn.Dropout(dropout),\n",
    "    nn.Linear(hidden, hidden),\n",
    "    nn.ReLU(),\n",
    "    nn.Dropout(dropout),\n",
    "    nn.Linear(hidden, d_out))\n",
    "gafs = AdaptiveSelection(selector, predictor, mask_layer).to(device)\n",
    "gafs.fit_predictor(train_dataset,\n",
    "                   val_dataset,\n",
    "                   mbsize=128,\n",
    "                   lr=1e-3,\n",
    "                   nepochs=250,\n",
    "                   max_features=max_features,\n",
    "                   loss_fn=nn.CrossEntropyLoss(),\n",
    "                   # val_loss_fn=AUROC(num_classes=2),\n",
    "                   # val_loss_mode='max',\n",
    "                   patience=3,\n",
    "                   verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d3f05c74-6830-4d6b-ba77-d1a681c01863",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "val\n",
      "Num = 1, AUROC = 80.76, Acc = 77.50\n",
      "Num = 2, AUROC = 87.05, Acc = 81.96\n",
      "Num = 3, AUROC = 90.86, Acc = 85.98\n",
      "Num = 4, AUROC = 93.62, Acc = 88.59\n",
      "Num = 5, AUROC = 94.83, Acc = 89.57\n",
      "Num = 6, AUROC = 95.49, Acc = 90.22\n",
      "Num = 7, AUROC = 95.90, Acc = 91.52\n",
      "Num = 8, AUROC = 96.56, Acc = 91.96\n",
      "Num = 9, AUROC = 96.97, Acc = 92.72\n",
      "Num = 10, AUROC = 97.03, Acc = 92.93\n",
      "Num = 15, AUROC = 97.32, Acc = 92.93\n",
      "Num = 20, AUROC = 97.52, Acc = 92.83\n",
      "Num = 25, AUROC = 97.20, Acc = 91.52\n",
      "test\n",
      "Num = 1, AUROC = 82.87, Acc = 78.48\n",
      "Num = 2, AUROC = 88.14, Acc = 81.52\n",
      "Num = 3, AUROC = 92.45, Acc = 86.30\n",
      "Num = 4, AUROC = 94.97, Acc = 88.80\n",
      "Num = 5, AUROC = 95.85, Acc = 90.87\n",
      "Num = 6, AUROC = 95.84, Acc = 90.11\n",
      "Num = 7, AUROC = 96.56, Acc = 91.09\n",
      "Num = 8, AUROC = 96.91, Acc = 90.98\n",
      "Num = 9, AUROC = 96.75, Acc = 91.85\n",
      "Num = 10, AUROC = 96.77, Acc = 92.83\n",
      "Num = 15, AUROC = 97.42, Acc = 93.26\n",
      "Num = 20, AUROC = 97.53, Acc = 93.15\n",
      "Num = 25, AUROC = 97.18, Acc = 91.74\n"
     ]
    }
   ],
   "source": [
    "# TODO: DELETE THIS BLOCK LATER\n",
    "num_features = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 25]\n",
    "\n",
    "# Val\n",
    "print('val')\n",
    "x, y = get_xy(val_dataset)\n",
    "for num in num_features:\n",
    "    pred = gafs(torch.tensor(x, device=device), max_features=num).softmax(dim=1).cpu().data.numpy()\n",
    "    auroc = roc_auc_score(y, pred[:,1])\n",
    "    acc = accuracy_score(y, pred.argmax(axis=1))\n",
    "    print(f'Num = {num}, AUROC = {100*auroc:.2f}, Acc = {100*acc:.2f}')\n",
    "\n",
    "# Test\n",
    "print('test')\n",
    "x, y = get_xy(test_dataset)\n",
    "for num in num_features:\n",
    "    pred = gafs(torch.tensor(x, device=device), max_features=num).softmax(dim=1).cpu().data.numpy()\n",
    "    auroc = roc_auc_score(y, pred[:,1])\n",
    "    acc = accuracy_score(y, pred.argmax(axis=1))\n",
    "    print(f'Num = {num}, AUROC = {100*auroc:.2f}, Acc = {100*acc:.2f}')\n",
    "# TODO DELETE THIS BLOCK LATER"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7a98ee3-617c-410f-9478-1482dfa89ee8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
