{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "059e963e-9acd-491a-a415-224464dbdef9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "_StoreAction(option_strings=['--gpu'], dest='gpu', nargs=None, const=None, default=0, type=<class 'int'>, choices=None, help=None, metavar=None)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "import pickle\n",
    "import argparse\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "from torchmetrics import AUROC\n",
    "from sklearn.metrics import accuracy_score, roc_auc_score\n",
    "from adaptive import AdaptiveSelection, MaskLayer, MaskingPretrainer\n",
    "\n",
    "import sys\n",
    "sys.path.append('../')\n",
    "from data import DenseDatasetSelected, data_split, get_xy\n",
    "\n",
    "# # Set up command line arguments\n",
    "# parser = argparse.ArgumentParser()\n",
    "# parser.add_argument('--gpu', type=int, default=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4a522fcb-e9b2-49d4-8bc6-21389995f140",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "usage: ipykernel_launcher.py [-h] [--gpu GPU]\n",
      "ipykernel_launcher.py: error: unrecognized arguments: -f /homes/gws/icovert/.local/share/jupyter/runtime/kernel-37b93811-9ab6-4216-8991-f7eae7510bff.json\n"
     ]
    },
    {
     "ename": "SystemExit",
     "evalue": "2",
     "output_type": "error",
     "traceback": [
      "An exception has occurred, use %tb to see the full traceback.\n",
      "\u001b[0;31mSystemExit\u001b[0m\u001b[0;31m:\u001b[0m 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/homes/gws/icovert/anaconda3/envs/numbaenv/lib/python3.6/site-packages/IPython/core/interactiveshell.py:3351: UserWarning: To exit: use 'exit', 'quit', or Ctrl-D.\n",
      "  warn(\"To exit: use 'exit', 'quit', or Ctrl-D.\", stacklevel=1)\n"
     ]
    }
   ],
   "source": [
    "# args = parser.parse_args()\n",
    "# args.gpu = 7\n",
    "\n",
    "# Load dataset\n",
    "dataset = DenseDatasetSelected('../../datasets/miniboone.csv')\n",
    "d_in = dataset.X.shape[1]  # 57\n",
    "d_out = len(np.unique(dataset.Y))  # 2\n",
    "\n",
    "# Split dataset\n",
    "train_dataset, val_dataset, test_dataset = data_split(dataset, random_state=0)\n",
    "print(f'Train samples = {len(train_dataset)}, val samples = {len(val_dataset)}, test samples = {len(test_dataset)}')\n",
    "\n",
    "# Find mean/variance for normalizing\n",
    "x, y = get_xy(train_dataset)\n",
    "mean = np.mean(x, axis=0)\n",
    "std = np.std(y, axis=0)\n",
    "\n",
    "# Normalize via the original dataset\n",
    "dataset.X = dataset.X - mean\n",
    "\n",
    "# Setup\n",
    "max_features = 25\n",
    "device = torch.device('cuda', 7)\n",
    "\n",
    "# Set up architecture\n",
    "hidden = 128\n",
    "dropout = 0.3\n",
    "\n",
    "# Predictor\n",
    "predictor = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, hidden),\n",
    "    nn.ReLU(),\n",
    "    nn.Dropout(dropout),\n",
    "    nn.Linear(hidden, hidden),\n",
    "    nn.ReLU(),\n",
    "    nn.Dropout(dropout),\n",
    "    nn.Linear(hidden, d_out))\n",
    "\n",
    "# Selector\n",
    "selector = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, hidden),\n",
    "    nn.ReLU(),\n",
    "    nn.Dropout(dropout),\n",
    "    nn.Linear(hidden, hidden),\n",
    "    nn.ReLU(),\n",
    "    nn.Dropout(dropout),\n",
    "    nn.Linear(hidden, d_in))\n",
    "\n",
    "# Tie weights\n",
    "# selector[0] = predictor[0]\n",
    "# selector[3] = predictor[3]\n",
    "\n",
    "# Pretrain predictor\n",
    "mask_layer = MaskLayer(append=True)\n",
    "pretrain = MaskingPretrainer(predictor, mask_layer).to(device)\n",
    "pretrain.fit(train_dataset,\n",
    "             val_dataset,\n",
    "             mbsize=128,\n",
    "             lr=1e-3,\n",
    "             nepochs=100,\n",
    "             loss_fn=nn.CrossEntropyLoss(),\n",
    "             # val_loss_fn=AUROC(num_classes=2),\n",
    "             # val_loss_mode='max',\n",
    "             patience=3,\n",
    "             verbose=True)\n",
    "\n",
    "# Train adaptive selection\n",
    "gafs = AdaptiveSelection(selector, predictor, mask_layer).to(device)\n",
    "gafs.fit(train_dataset,\n",
    "         val_dataset,\n",
    "         mbsize=128,\n",
    "         lr=1e-3,\n",
    "         nepochs=250,\n",
    "         max_features=max_features,\n",
    "         loss_fn=nn.CrossEntropyLoss(),\n",
    "         # val_loss_fn=AUROC(num_classes=2),\n",
    "         # val_loss_mode='max',\n",
    "         patience=3,\n",
    "         verbose=True)\n",
    "\n",
    "# TODO: DELETE THIS BLOCK LATER\n",
    "print('evaluating after first training run')\n",
    "results = {\n",
    "    'auroc': {},\n",
    "    'acc': {}\n",
    "}\n",
    "num_features = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 25]\n",
    "# Val\n",
    "print('val')\n",
    "x, y = get_xy(val_dataset)\n",
    "for num in num_features:\n",
    "    pred = gafs(torch.tensor(x, device=device), max_features=num).softmax(dim=1).cpu().data.numpy()\n",
    "    auroc = roc_auc_score(y, pred[:,1])\n",
    "    acc = accuracy_score(y, pred.argmax(axis=1))\n",
    "    print(f'Num = {num}, AUROC = {100*auroc:.2f}, Acc = {100*acc:.2f}')\n",
    "    results['auroc'][num] = auroc\n",
    "    results['acc'][num] = acc\n",
    "\n",
    "# Test\n",
    "print('test')\n",
    "x, y = get_xy(test_dataset)\n",
    "for num in num_features:\n",
    "    pred = gafs(torch.tensor(x, device=device), max_features=num).softmax(dim=1).cpu().data.numpy()\n",
    "    auroc = roc_auc_score(y, pred[:,1])\n",
    "    acc = accuracy_score(y, pred.argmax(axis=1))\n",
    "    print(f'Num = {num}, AUROC = {100*auroc:.2f}, Acc = {100*acc:.2f}')\n",
    "    results['auroc'][num] = auroc\n",
    "    results['acc'][num] = acc\n",
    "# TODO DELETE THIS BLOCK LATER\n",
    "\n",
    "# Reset predictor and train with frozen selector.\n",
    "predictor = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, hidden),\n",
    "    nn.ReLU(),\n",
    "    nn.Dropout(dropout),\n",
    "    nn.Linear(hidden, hidden),\n",
    "    nn.ReLU(),\n",
    "    nn.Dropout(dropout),\n",
    "    nn.Linear(hidden, d_out))\n",
    "gafs = AdaptiveSelection(selector, predictor, mask_layer).to(device)\n",
    "gafs.fit_predictor(train_dataset,\n",
    "                   val_dataset,\n",
    "                   mbsize=128,\n",
    "                   lr=1e-3,\n",
    "                   nepochs=250,\n",
    "                   max_features=max_features,\n",
    "                   loss_fn=nn.CrossEntropyLoss(),\n",
    "                   # val_loss_fn=AUROC(num_classes=2),\n",
    "                   # val_loss_mode='max',\n",
    "                   patience=3,\n",
    "                   verbose=True)\n",
    "\n",
    "# TODO: DELETE THIS BLOCK LATER\n",
    "print('evaluating after predictor re-training')\n",
    "results = {\n",
    "    'auroc': {},\n",
    "    'acc': {}\n",
    "}\n",
    "num_features = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 25]\n",
    "# Val\n",
    "print('val')\n",
    "x, y = get_xy(val_dataset)\n",
    "for num in num_features:\n",
    "    pred = gafs(torch.tensor(x, device=device), max_features=num).softmax(dim=1).cpu().data.numpy()\n",
    "    auroc = roc_auc_score(y, pred[:,1])\n",
    "    acc = accuracy_score(y, pred.argmax(axis=1))\n",
    "    print(f'Num = {num}, AUROC = {100*auroc:.2f}, Acc = {100*acc:.2f}')\n",
    "    results['auroc'][num] = auroc\n",
    "    results['acc'][num] = acc\n",
    "\n",
    "# Test\n",
    "print('test')\n",
    "x, y = get_xy(test_dataset)\n",
    "for num in num_features:\n",
    "    pred = gafs(torch.tensor(x, device=device), max_features=num).softmax(dim=1).cpu().data.numpy()\n",
    "    auroc = roc_auc_score(y, pred[:,1])\n",
    "    acc = accuracy_score(y, pred.argmax(axis=1))\n",
    "    print(f'Num = {num}, AUROC = {100*auroc:.2f}, Acc = {100*acc:.2f}')\n",
    "    results['auroc'][num] = auroc\n",
    "    results['acc'][num] = acc\n",
    "# TODO DELETE THIS BLOCK LATER\n",
    "\n",
    "# # For saving results\n",
    "# results = {\n",
    "#     'auroc': {},\n",
    "#     'acc': {}\n",
    "# }\n",
    "\n",
    "# # Generate results\n",
    "# num_features = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 25]\n",
    "# x, y = get_xy(test_dataset)\n",
    "# for num in num_features:\n",
    "#     pred = gafs(torch.tensor(x, device=device), max_features=num).softmax(dim=1).cpu().data.numpy()\n",
    "#     auroc = roc_auc_score(y, pred[:,1])\n",
    "#     acc = accuracy_score(y, pred.argmax(axis=1))\n",
    "#     print(f'Num = {num}, AUROC = {100*auroc:.2f}, Acc = {100*acc:.2f}')\n",
    "#     results['auroc'][num] = auroc\n",
    "#     results['acc'][num] = acc\n",
    "\n",
    "# with open('results/adaptive_results.pkl', 'wb') as f:\n",
    "#     pickle.dump(results, f)\n",
    "\n",
    "# # Find most common selections at each step\n",
    "# top_list = [[]]\n",
    "# for num in range(d_in):\n",
    "#     x, y = get_xy(test_dataset)\n",
    "#     x, m = gafs.select_features(torch.tensor(x, device=device), max_features=num + 1)\n",
    "#     p = m.mean(dim=0).cpu().data.numpy()\n",
    "#     top_list.append(np.sort(np.argsort(p)[-(num + 1):]))\n",
    "\n",
    "# # Save results\n",
    "# with open('results/adaptive_frequent_selections.pkl', 'wb') as f:\n",
    "#     pickle.dump(top_list, f)\n",
    "\n",
    "# # Save model\n",
    "# gafs.cpu()\n",
    "# torch.save(gafs, 'results/adaptive_trained.pt')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4244a8af-360f-4f2e-9570-5318cc284550",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc38297a-f10a-47df-b7db-b9fad7b92367",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab7e235b-347f-4ac0-9b02-1d52e9dfff08",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
