{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f78af6b7-569c-4f35-ae99-511121e304e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import pickle\n",
    "import argparse\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "from torchmetrics import AUROC\n",
    "from sklearn.metrics import accuracy_score, roc_auc_score\n",
    "from adaptive import AdaptiveSelection, MaskLayer, MaskingPretrainer\n",
    "\n",
    "import sys\n",
    "sys.path.append('../')\n",
    "from data import DenseDatasetSelected, data_split, get_xy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fcfcf277-648a-45d6-b057-f8d2ece3820b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train samples = 2761, val samples = 920, test samples = 920\n"
     ]
    }
   ],
   "source": [
    "# Load dataset\n",
    "dataset = DenseDatasetSelected('../../datasets/spam.csv')\n",
    "d_in = dataset.X.shape[1]  # 57\n",
    "d_out = len(np.unique(dataset.Y))  # 2\n",
    "\n",
    "# Split dataset\n",
    "train_dataset, val_dataset, test_dataset = data_split(dataset, random_state=0)\n",
    "print(f'Train samples = {len(train_dataset)}, val samples = {len(val_dataset)}, test samples = {len(test_dataset)}')\n",
    "\n",
    "# Find mean/variance for normalizing\n",
    "x, y = get_xy(train_dataset)\n",
    "mean = np.mean(x, axis=0)\n",
    "std = np.std(y, axis=0)\n",
    "\n",
    "# Normalize via the original dataset\n",
    "dataset.X = dataset.X - mean\n",
    "\n",
    "# Setup\n",
    "max_features = 25\n",
    "device = torch.device('cuda', 7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "22a9623a-9d59-449f-beac-2feaa0a79dcf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = 0.6375\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = 0.7703\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = 0.6535\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = 0.5549\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = 0.5090\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = 0.4676\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = 0.4583\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = 0.4233\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = 0.4225\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = 0.4156\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = 0.4039\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = 0.4179\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = 0.3943\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = 0.3956\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = 0.3873\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = 0.3784\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = 0.4122\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = 0.3984\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = 0.3725\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = 0.3844\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = 0.3549\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = 0.3975\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = 0.3656\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = 0.3954\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = 0.3481\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = 0.3590\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = 0.3560\n",
      "\n",
      "--------Epoch 28--------\n",
      "Val loss = 0.3734\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = 0.3603\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = 0.3501\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = 0.3574\n",
      "\n",
      "Epoch    31: reducing learning rate of group 0 to 2.0000e-04.\n",
      "--------Epoch 32--------\n",
      "Val loss = 0.3731\n",
      "\n",
      "Stopping early at epoch 32\n",
      "Starting training with temp = 1.0000\n",
      "\n",
      "--------Epoch 1 (1 total)--------\n",
      "Val loss = 0.3499, Zero-temp loss = 0.3695\n",
      "\n",
      "--------Epoch 2 (2 total)--------\n",
      "Val loss = 0.3231, Zero-temp loss = 0.3393\n",
      "\n",
      "--------Epoch 3 (3 total)--------\n",
      "Val loss = 0.3113, Zero-temp loss = 0.3308\n",
      "\n",
      "--------Epoch 4 (4 total)--------\n",
      "Val loss = 0.2966, Zero-temp loss = 0.3136\n",
      "\n",
      "--------Epoch 5 (5 total)--------\n",
      "Val loss = 0.2633, Zero-temp loss = 0.2820\n",
      "\n",
      "--------Epoch 6 (6 total)--------\n",
      "Val loss = 0.3022, Zero-temp loss = 0.3248\n",
      "\n",
      "--------Epoch 7 (7 total)--------\n",
      "Val loss = 0.2731, Zero-temp loss = 0.2943\n",
      "\n",
      "--------Epoch 8 (8 total)--------\n",
      "Val loss = 0.2694, Zero-temp loss = 0.2901\n",
      "\n",
      "--------Epoch 9 (9 total)--------\n",
      "Val loss = 0.2525, Zero-temp loss = 0.2735\n",
      "\n",
      "--------Epoch 10 (10 total)--------\n",
      "Val loss = 0.2497, Zero-temp loss = 0.2706\n",
      "\n",
      "--------Epoch 11 (11 total)--------\n",
      "Val loss = 0.2524, Zero-temp loss = 0.2740\n",
      "\n",
      "--------Epoch 12 (12 total)--------\n",
      "Val loss = 0.2572, Zero-temp loss = 0.2831\n",
      "\n",
      "--------Epoch 13 (13 total)--------\n",
      "Val loss = 0.2436, Zero-temp loss = 0.2673\n",
      "\n",
      "--------Epoch 14 (14 total)--------\n",
      "Val loss = 0.2695, Zero-temp loss = 0.2967\n",
      "\n",
      "--------Epoch 15 (15 total)--------\n",
      "Val loss = 0.2592, Zero-temp loss = 0.2878\n",
      "\n",
      "--------Epoch 16 (16 total)--------\n",
      "Val loss = 0.2503, Zero-temp loss = 0.2768\n",
      "\n",
      "--------Epoch 17 (17 total)--------\n",
      "Val loss = 0.2408, Zero-temp loss = 0.2678\n",
      "\n",
      "--------Epoch 18 (18 total)--------\n",
      "Val loss = 0.2555, Zero-temp loss = 0.2859\n",
      "\n",
      "--------Epoch 19 (19 total)--------\n",
      "Val loss = 0.2495, Zero-temp loss = 0.2772\n",
      "\n",
      "--------Epoch 20 (20 total)--------\n",
      "Val loss = 0.2365, Zero-temp loss = 0.2643\n",
      "\n",
      "--------Epoch 21 (21 total)--------\n",
      "Val loss = 0.2406, Zero-temp loss = 0.2720\n",
      "\n",
      "--------Epoch 22 (22 total)--------\n",
      "Val loss = 0.2375, Zero-temp loss = 0.2661\n",
      "\n",
      "--------Epoch 23 (23 total)--------\n",
      "Val loss = 0.2352, Zero-temp loss = 0.2676\n",
      "\n",
      "--------Epoch 24 (24 total)--------\n",
      "Val loss = 0.2422, Zero-temp loss = 0.2741\n",
      "\n",
      "--------Epoch 25 (25 total)--------\n",
      "Val loss = 0.2341, Zero-temp loss = 0.2656\n",
      "\n",
      "--------Epoch 26 (26 total)--------\n",
      "Val loss = 0.2203, Zero-temp loss = 0.2487\n",
      "\n",
      "--------Epoch 27 (27 total)--------\n",
      "Val loss = 0.2341, Zero-temp loss = 0.2625\n",
      "\n",
      "--------Epoch 28 (28 total)--------\n",
      "Val loss = 0.2526, Zero-temp loss = 0.2850\n",
      "\n",
      "--------Epoch 29 (29 total)--------\n",
      "Val loss = 0.2357, Zero-temp loss = 0.2665\n",
      "\n",
      "--------Epoch 30 (30 total)--------\n",
      "Val loss = 0.2264, Zero-temp loss = 0.2577\n",
      "\n",
      "--------Epoch 31 (31 total)--------\n",
      "Val loss = 0.2229, Zero-temp loss = 0.2553\n",
      "\n",
      "--------Epoch 32 (32 total)--------\n",
      "Val loss = 0.2300, Zero-temp loss = 0.2597\n",
      "\n",
      "Epoch    32: reducing learning rate of group 0 to 2.0000e-04.\n",
      "--------Epoch 33 (33 total)--------\n",
      "Val loss = 0.2242, Zero-temp loss = 0.2589\n",
      "\n",
      "Stopping temp = 1.0000 at epoch 33\n",
      "\n",
      "Starting training with temp = 0.5623\n",
      "\n",
      "--------Epoch 1 (34 total)--------\n",
      "Val loss = 0.2359, Zero-temp loss = 0.2535\n",
      "\n",
      "--------Epoch 2 (35 total)--------\n",
      "Val loss = 0.2417, Zero-temp loss = 0.2586\n",
      "\n",
      "--------Epoch 3 (36 total)--------\n",
      "Val loss = 0.2184, Zero-temp loss = 0.2321\n",
      "\n",
      "--------Epoch 4 (37 total)--------\n",
      "Val loss = 0.2419, Zero-temp loss = 0.2590\n",
      "\n",
      "--------Epoch 5 (38 total)--------\n",
      "Val loss = 0.2445, Zero-temp loss = 0.2629\n",
      "\n",
      "--------Epoch 6 (39 total)--------\n",
      "Val loss = 0.2294, Zero-temp loss = 0.2453\n",
      "\n",
      "--------Epoch 7 (40 total)--------\n",
      "Val loss = 0.2269, Zero-temp loss = 0.2433\n",
      "\n",
      "--------Epoch 8 (41 total)--------\n",
      "Val loss = 0.2228, Zero-temp loss = 0.2386\n",
      "\n",
      "--------Epoch 9 (42 total)--------\n",
      "Val loss = 0.2192, Zero-temp loss = 0.2376\n",
      "\n",
      "Epoch     9: reducing learning rate of group 0 to 2.0000e-04.\n",
      "--------Epoch 10 (43 total)--------\n",
      "Val loss = 0.2191, Zero-temp loss = 0.2375\n",
      "\n",
      "Stopping temp = 0.5623 at epoch 10\n",
      "\n",
      "Starting training with temp = 0.3162\n",
      "\n",
      "--------Epoch 1 (44 total)--------\n",
      "Val loss = 0.2436, Zero-temp loss = 0.2529\n",
      "\n",
      "--------Epoch 2 (45 total)--------\n",
      "Val loss = 0.2405, Zero-temp loss = 0.2499\n",
      "\n",
      "--------Epoch 3 (46 total)--------\n",
      "Val loss = 0.2396, Zero-temp loss = 0.2477\n",
      "\n",
      "--------Epoch 4 (47 total)--------\n",
      "Val loss = 0.2266, Zero-temp loss = 0.2319\n",
      "\n",
      "--------Epoch 5 (48 total)--------\n",
      "Val loss = 0.2316, Zero-temp loss = 0.2380\n",
      "\n",
      "--------Epoch 6 (49 total)--------\n",
      "Val loss = 0.2283, Zero-temp loss = 0.2360\n",
      "\n",
      "--------Epoch 7 (50 total)--------\n",
      "Val loss = 0.2259, Zero-temp loss = 0.2327\n",
      "\n",
      "--------Epoch 8 (51 total)--------\n",
      "Val loss = 0.2210, Zero-temp loss = 0.2277\n",
      "\n",
      "--------Epoch 9 (52 total)--------\n",
      "Val loss = 0.2235, Zero-temp loss = 0.2312\n",
      "\n",
      "--------Epoch 10 (53 total)--------\n",
      "Val loss = 0.2260, Zero-temp loss = 0.2327\n",
      "\n",
      "--------Epoch 11 (54 total)--------\n",
      "Val loss = 0.2223, Zero-temp loss = 0.2304\n",
      "\n",
      "--------Epoch 12 (55 total)--------\n",
      "Val loss = 0.2324, Zero-temp loss = 0.2388\n",
      "\n",
      "--------Epoch 13 (56 total)--------\n",
      "Val loss = 0.2228, Zero-temp loss = 0.2288\n",
      "\n",
      "--------Epoch 14 (57 total)--------\n",
      "Val loss = 0.2251, Zero-temp loss = 0.2318\n",
      "\n",
      "Epoch    14: reducing learning rate of group 0 to 2.0000e-04.\n",
      "--------Epoch 15 (58 total)--------\n",
      "Val loss = 0.2253, Zero-temp loss = 0.2325\n",
      "\n",
      "Stopping temp = 0.3162 at epoch 15\n",
      "\n",
      "Starting training with temp = 0.1778\n",
      "\n",
      "--------Epoch 1 (59 total)--------\n",
      "Val loss = 0.2266, Zero-temp loss = 0.2306\n",
      "\n",
      "--------Epoch 2 (60 total)--------\n",
      "Val loss = 0.2377, Zero-temp loss = 0.2419\n",
      "\n",
      "--------Epoch 3 (61 total)--------\n",
      "Val loss = 0.2443, Zero-temp loss = 0.2478\n",
      "\n",
      "--------Epoch 4 (62 total)--------\n",
      "Val loss = 0.2442, Zero-temp loss = 0.2467\n",
      "\n",
      "--------Epoch 5 (63 total)--------\n",
      "Val loss = 0.2348, Zero-temp loss = 0.2367\n",
      "\n",
      "--------Epoch 6 (64 total)--------\n",
      "Val loss = 0.2363, Zero-temp loss = 0.2387\n",
      "\n",
      "--------Epoch 7 (65 total)--------\n",
      "Val loss = 0.2320, Zero-temp loss = 0.2346\n",
      "\n",
      "Epoch     7: reducing learning rate of group 0 to 2.0000e-04.\n",
      "--------Epoch 8 (66 total)--------\n",
      "Val loss = 0.2232, Zero-temp loss = 0.2261\n",
      "\n",
      "--------Epoch 9 (67 total)--------\n",
      "Val loss = 0.2252, Zero-temp loss = 0.2276\n",
      "\n",
      "--------Epoch 10 (68 total)--------\n",
      "Val loss = 0.2267, Zero-temp loss = 0.2296\n",
      "\n",
      "--------Epoch 11 (69 total)--------\n",
      "Val loss = 0.2230, Zero-temp loss = 0.2256\n",
      "\n",
      "--------Epoch 12 (70 total)--------\n",
      "Val loss = 0.2214, Zero-temp loss = 0.2241\n",
      "\n",
      "--------Epoch 13 (71 total)--------\n",
      "Val loss = 0.2146, Zero-temp loss = 0.2175\n",
      "\n",
      "--------Epoch 14 (72 total)--------\n",
      "Val loss = 0.2227, Zero-temp loss = 0.2257\n",
      "\n",
      "--------Epoch 15 (73 total)--------\n",
      "Val loss = 0.2195, Zero-temp loss = 0.2224\n",
      "\n",
      "--------Epoch 16 (74 total)--------\n",
      "Val loss = 0.2198, Zero-temp loss = 0.2223\n",
      "\n",
      "--------Epoch 17 (75 total)--------\n",
      "Val loss = 0.2276, Zero-temp loss = 0.2296\n",
      "\n",
      "--------Epoch 18 (76 total)--------\n",
      "Val loss = 0.2249, Zero-temp loss = 0.2276\n",
      "\n",
      "--------Epoch 19 (77 total)--------\n",
      "Val loss = 0.2219, Zero-temp loss = 0.2241\n",
      "\n",
      "Epoch    19: reducing learning rate of group 0 to 4.0000e-05.\n",
      "--------Epoch 20 (78 total)--------\n",
      "Val loss = 0.2201, Zero-temp loss = 0.2215\n",
      "\n",
      "Stopping temp = 0.1778 at epoch 20\n",
      "\n",
      "Starting training with temp = 0.1000\n",
      "\n",
      "--------Epoch 1 (79 total)--------\n",
      "Val loss = 0.2310, Zero-temp loss = 0.2328\n",
      "\n",
      "--------Epoch 2 (80 total)--------\n",
      "Val loss = 0.2215, Zero-temp loss = 0.2228\n",
      "\n",
      "--------Epoch 3 (81 total)--------\n",
      "Val loss = 0.2377, Zero-temp loss = 0.2388\n",
      "\n",
      "--------Epoch 4 (82 total)--------\n",
      "Val loss = 0.2317, Zero-temp loss = 0.2339\n",
      "\n",
      "--------Epoch 5 (83 total)--------\n",
      "Val loss = 0.2295, Zero-temp loss = 0.2306\n",
      "\n",
      "--------Epoch 6 (84 total)--------\n",
      "Val loss = 0.2500, Zero-temp loss = 0.2518\n",
      "\n",
      "--------Epoch 7 (85 total)--------\n",
      "Val loss = 0.2231, Zero-temp loss = 0.2242\n",
      "\n",
      "--------Epoch 8 (86 total)--------\n",
      "Val loss = 0.2562, Zero-temp loss = 0.2578\n",
      "\n",
      "Epoch     8: reducing learning rate of group 0 to 2.0000e-04.\n",
      "--------Epoch 9 (87 total)--------\n",
      "Val loss = 0.2311, Zero-temp loss = 0.2322\n",
      "\n",
      "Stopping temp = 0.1000 at epoch 9\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up architecture\n",
    "hidden = 128\n",
    "dropout = 0.3\n",
    "\n",
    "# Predictor\n",
    "predictor = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, hidden),\n",
    "    nn.ReLU(),\n",
    "    nn.Dropout(dropout),\n",
    "    nn.Linear(hidden, hidden),\n",
    "    nn.ReLU(),\n",
    "    nn.Dropout(dropout),\n",
    "    nn.Linear(hidden, d_out))\n",
    "\n",
    "# Selector\n",
    "selector = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, hidden),\n",
    "    nn.ReLU(),\n",
    "    nn.Dropout(dropout),\n",
    "    nn.Linear(hidden, hidden),\n",
    "    nn.ReLU(),\n",
    "    nn.Dropout(dropout),\n",
    "    nn.Linear(hidden, d_in))\n",
    "\n",
    "# Tie weights\n",
    "# selector[0] = predictor[0]\n",
    "# selector[3] = predictor[3]\n",
    "\n",
    "# Pretrain predictor\n",
    "mask_layer = MaskLayer(append=True)\n",
    "pretrain = MaskingPretrainer(predictor, mask_layer).to(device)\n",
    "pretrain.fit(train_dataset,\n",
    "             val_dataset,\n",
    "             mbsize=32,\n",
    "             lr=1e-3,\n",
    "             nepochs=100,\n",
    "             loss_fn=nn.CrossEntropyLoss(),\n",
    "             # val_loss_fn=AUROC(num_classes=2),\n",
    "             # val_loss_mode='max',\n",
    "             patience=5,\n",
    "             verbose=True)\n",
    "\n",
    "# Train adaptive selection\n",
    "gafs = AdaptiveSelection(selector, predictor, mask_layer).to(device)\n",
    "gafs.fit(train_dataset,\n",
    "         val_dataset,\n",
    "         mbsize=32,\n",
    "         lr=1e-3,\n",
    "         nepochs=250,\n",
    "         max_features=max_features,\n",
    "         loss_fn=nn.CrossEntropyLoss(),\n",
    "         # val_loss_fn=AUROC(num_classes=2),\n",
    "         # val_loss_mode='max',\n",
    "         patience=5,\n",
    "         verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "13661ab0-a357-463c-ba1e-b24dfe41b211",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "val\n",
      "Num = 1, AUROC = 80.53, Acc = 76.41\n",
      "Num = 2, AUROC = 86.86, Acc = 83.37\n",
      "Num = 3, AUROC = 90.83, Acc = 87.39\n",
      "Num = 4, AUROC = 93.62, Acc = 90.11\n",
      "Num = 5, AUROC = 94.63, Acc = 90.00\n",
      "Num = 6, AUROC = 95.33, Acc = 90.11\n",
      "Num = 7, AUROC = 96.46, Acc = 91.74\n",
      "Num = 8, AUROC = 96.76, Acc = 92.50\n",
      "Num = 9, AUROC = 97.04, Acc = 93.37\n",
      "Num = 10, AUROC = 97.44, Acc = 93.37\n",
      "Num = 15, AUROC = 97.79, Acc = 93.26\n",
      "Num = 20, AUROC = 97.55, Acc = 92.61\n",
      "Num = 25, AUROC = 96.74, Acc = 89.78\n",
      "test\n",
      "Num = 1, AUROC = 82.87, Acc = 78.04\n",
      "Num = 2, AUROC = 91.43, Acc = 85.00\n",
      "Num = 3, AUROC = 93.81, Acc = 90.22\n",
      "Num = 4, AUROC = 95.39, Acc = 91.09\n",
      "Num = 5, AUROC = 96.31, Acc = 91.63\n",
      "Num = 6, AUROC = 96.95, Acc = 92.39\n",
      "Num = 7, AUROC = 96.94, Acc = 92.28\n",
      "Num = 8, AUROC = 97.05, Acc = 92.39\n",
      "Num = 9, AUROC = 97.31, Acc = 92.39\n",
      "Num = 10, AUROC = 97.37, Acc = 92.72\n",
      "Num = 15, AUROC = 97.45, Acc = 93.48\n",
      "Num = 20, AUROC = 97.76, Acc = 93.15\n",
      "Num = 25, AUROC = 97.29, Acc = 91.52\n"
     ]
    }
   ],
   "source": [
    "# TODO: DELETE THIS BLOCK LATER\n",
    "num_features = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 25]\n",
    "\n",
    "# Val\n",
    "print('val')\n",
    "x, y = get_xy(val_dataset)\n",
    "for num in num_features:\n",
    "    pred = gafs(torch.tensor(x, device=device), max_features=num).softmax(dim=1).cpu().data.numpy()\n",
    "    auroc = roc_auc_score(y, pred[:,1])\n",
    "    acc = accuracy_score(y, pred.argmax(axis=1))\n",
    "    print(f'Num = {num}, AUROC = {100*auroc:.2f}, Acc = {100*acc:.2f}')\n",
    "\n",
    "# Test\n",
    "print('test')\n",
    "x, y = get_xy(test_dataset)\n",
    "for num in num_features:\n",
    "    pred = gafs(torch.tensor(x, device=device), max_features=num).softmax(dim=1).cpu().data.numpy()\n",
    "    auroc = roc_auc_score(y, pred[:,1])\n",
    "    acc = accuracy_score(y, pred.argmax(axis=1))\n",
    "    print(f'Num = {num}, AUROC = {100*auroc:.2f}, Acc = {100*acc:.2f}')\n",
    "# TODO DELETE THIS BLOCK LATER"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ff6fe3c8-37ec-436e-8055-ef6865f5255f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = 0.5005\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = 0.3419\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = 0.3059\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = 0.2920\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = 0.2709\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = 0.2693\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = 0.2497\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = 0.2647\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = 0.2550\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = 0.2419\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = 0.2446\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = 0.2428\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = 0.2388\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = 0.2344\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = 0.2324\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = 0.2443\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = 0.2377\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = 0.2365\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = 0.2414\n",
      "\n",
      "Epoch    19: reducing learning rate of group 0 to 2.0000e-04.\n",
      "--------Epoch 20--------\n",
      "Val loss = 0.2294\n",
      "\n",
      "--------Epoch 21--------\n",
      "Val loss = 0.2362\n",
      "\n",
      "--------Epoch 22--------\n",
      "Val loss = 0.2289\n",
      "\n",
      "--------Epoch 23--------\n",
      "Val loss = 0.2283\n",
      "\n",
      "--------Epoch 24--------\n",
      "Val loss = 0.2295\n",
      "\n",
      "--------Epoch 25--------\n",
      "Val loss = 0.2294\n",
      "\n",
      "--------Epoch 26--------\n",
      "Val loss = 0.2309\n",
      "\n",
      "--------Epoch 27--------\n",
      "Val loss = 0.2305\n",
      "\n",
      "Epoch    27: reducing learning rate of group 0 to 4.0000e-05.\n",
      "--------Epoch 28--------\n",
      "Val loss = 0.2273\n",
      "\n",
      "--------Epoch 29--------\n",
      "Val loss = 0.2306\n",
      "\n",
      "--------Epoch 30--------\n",
      "Val loss = 0.2302\n",
      "\n",
      "--------Epoch 31--------\n",
      "Val loss = 0.2297\n",
      "\n",
      "--------Epoch 32--------\n",
      "Val loss = 0.2271\n",
      "\n",
      "--------Epoch 33--------\n",
      "Val loss = 0.2330\n",
      "\n",
      "--------Epoch 34--------\n",
      "Val loss = 0.2284\n",
      "\n",
      "--------Epoch 35--------\n",
      "Val loss = 0.2385\n",
      "\n",
      "--------Epoch 36--------\n",
      "Val loss = 0.2332\n",
      "\n",
      "Epoch    36: reducing learning rate of group 0 to 1.0000e-05.\n",
      "--------Epoch 37--------\n",
      "Val loss = 0.2338\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Reset predictor and train with frozen selector.\n",
    "predictor = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, hidden),\n",
    "    nn.ReLU(),\n",
    "    nn.Dropout(dropout),\n",
    "    nn.Linear(hidden, hidden),\n",
    "    nn.ReLU(),\n",
    "    nn.Dropout(dropout),\n",
    "    nn.Linear(hidden, d_out))\n",
    "gafs = AdaptiveSelection(selector, predictor, mask_layer).to(device)\n",
    "gafs.fit_predictor(train_dataset,\n",
    "                   val_dataset,\n",
    "                   mbsize=128,\n",
    "                   lr=1e-3,\n",
    "                   nepochs=250,\n",
    "                   max_features=max_features,\n",
    "                   loss_fn=nn.CrossEntropyLoss(),\n",
    "                   # val_loss_fn=AUROC(num_classes=2),\n",
    "                   # val_loss_mode='max',\n",
    "                   patience=3,\n",
    "                   verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d3f05c74-6830-4d6b-ba77-d1a681c01863",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "val\n",
      "Num = 1, AUROC = 80.77, Acc = 77.61\n",
      "Num = 2, AUROC = 86.69, Acc = 83.04\n",
      "Num = 3, AUROC = 90.46, Acc = 86.52\n",
      "Num = 4, AUROC = 93.69, Acc = 88.80\n",
      "Num = 5, AUROC = 94.68, Acc = 89.02\n",
      "Num = 6, AUROC = 95.15, Acc = 89.78\n",
      "Num = 7, AUROC = 96.25, Acc = 91.09\n",
      "Num = 8, AUROC = 96.71, Acc = 91.63\n",
      "Num = 9, AUROC = 96.98, Acc = 92.28\n",
      "Num = 10, AUROC = 97.28, Acc = 92.50\n",
      "Num = 15, AUROC = 97.63, Acc = 92.61\n",
      "Num = 20, AUROC = 97.32, Acc = 92.83\n",
      "Num = 25, AUROC = 94.05, Acc = 87.28\n",
      "test\n",
      "Num = 1, AUROC = 82.85, Acc = 78.26\n",
      "Num = 2, AUROC = 90.52, Acc = 83.59\n",
      "Num = 3, AUROC = 92.86, Acc = 88.48\n",
      "Num = 4, AUROC = 95.10, Acc = 89.78\n",
      "Num = 5, AUROC = 96.19, Acc = 90.33\n",
      "Num = 6, AUROC = 96.65, Acc = 90.87\n",
      "Num = 7, AUROC = 96.66, Acc = 90.98\n",
      "Num = 8, AUROC = 96.76, Acc = 91.30\n",
      "Num = 9, AUROC = 97.07, Acc = 91.85\n",
      "Num = 10, AUROC = 97.18, Acc = 92.39\n",
      "Num = 15, AUROC = 97.53, Acc = 93.04\n",
      "Num = 20, AUROC = 97.49, Acc = 93.15\n",
      "Num = 25, AUROC = 95.16, Acc = 90.54\n"
     ]
    }
   ],
   "source": [
    "# TODO: DELETE THIS BLOCK LATER\n",
    "num_features = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 25]\n",
    "\n",
    "# Val\n",
    "print('val')\n",
    "x, y = get_xy(val_dataset)\n",
    "for num in num_features:\n",
    "    pred = gafs(torch.tensor(x, device=device), max_features=num).softmax(dim=1).cpu().data.numpy()\n",
    "    auroc = roc_auc_score(y, pred[:,1])\n",
    "    acc = accuracy_score(y, pred.argmax(axis=1))\n",
    "    print(f'Num = {num}, AUROC = {100*auroc:.2f}, Acc = {100*acc:.2f}')\n",
    "\n",
    "# Test\n",
    "print('test')\n",
    "x, y = get_xy(test_dataset)\n",
    "for num in num_features:\n",
    "    pred = gafs(torch.tensor(x, device=device), max_features=num).softmax(dim=1).cpu().data.numpy()\n",
    "    auroc = roc_auc_score(y, pred[:,1])\n",
    "    acc = accuracy_score(y, pred.argmax(axis=1))\n",
    "    print(f'Num = {num}, AUROC = {100*auroc:.2f}, Acc = {100*acc:.2f}')\n",
    "# TODO DELETE THIS BLOCK LATER"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7a98ee3-617c-410f-9478-1482dfa89ee8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
