{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "olympic-maximum",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n",
    "from torchvision import transforms\n",
    "from torchvision.datasets import MNIST\n",
    "from copy import deepcopy\n",
    "from utils import *\n",
    "from models import *\n",
    "import os\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fixed-decline",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda')\n",
    "# device = torch.device('cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9f1d5fc1-0e03-4bfd-8362-0295643fda3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "class DenseDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, data_dir, split, transform=None):\n",
    "        super(DenseDataset, self).__init__()\n",
    "        self.data_dir = os.path.expanduser(data_dir)\n",
    "        data = pd.read_csv(self.data_dir)\n",
    "        data = data[data['split']==split]\n",
    "        self.Y = np.array(data['outcome']).astype('int64')   \n",
    "        self.X = np.array(data.drop(['split', 'outcome'], axis=1))\n",
    "        \n",
    "    def __len__(self):\n",
    "        return self.X.shape[0]\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        x = self.X[index,:]\n",
    "        y = self.Y[index]\n",
    "        return x, y\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0d88f9d8-2532-4855-9f02-b3b2b36a11f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_name = 'spam_split'\n",
    "data_dir = './UCI_datasets/'+data_name+'.csv'\n",
    "train_dataset = DenseDataset(data_dir, 'train')\n",
    "val_dataset = DenseDataset(data_dir, 'valid')\n",
    "test_dataset = DenseDataset(data_dir, 'test')\n",
    "np.random.seed(7) \n",
    "if len(np.unique(train_dataset.Y)) > 2:\n",
    "    multi_label = 1\n",
    "else:\n",
    "    multi_label = 0\n",
    "d_in = train_dataset.X.shape[1]\n",
    "d_out = len(np.unique(train_dataset.Y))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "earlier-coffee",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Number of features to select\n",
    "max_features = [train_dataset.X.shape[1]]+[i for i in range(15, train_dataset.X.shape[1], 5)][::-1]+[10,9,8,7,6,5,4,3,2,1]\n",
    "max_features = max_features[::-1]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "perceived-blind",
   "metadata": {},
   "source": [
    "# Pretrain model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "separate-going",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.6332\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.6780\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.7052\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.7079\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.7052\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.7215\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.7160\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.7432\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.7582\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.7677\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.8016\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.8003\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.7948\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.8193\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.8111\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.8410\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.8234\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.8125\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.8492\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.8424\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up model, need to change to the best archetecture\n",
    "model = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, d_in * 2),\n",
    "    nn.ReLU(),\n",
    "    nn.Dropout(0.5),\n",
    "    nn.Linear(d_in * 2, d_in),\n",
    "    nn.ReLU(),\n",
    "    nn.Dropout(0.5),\n",
    "    nn.Linear(d_in, d_out))\n",
    "mask_layer = MaskLayer(append=True)\n",
    "pretrain = Pretrainer(model, mask_layer).to(device)\n",
    "\n",
    "# Pretrain\n",
    "pretrain.fit(train_dataset,\n",
    "             val_dataset,\n",
    "             mbsize=128,\n",
    "             lr=1e-3,   ### change to the best one\n",
    "             nepochs=20,  ### need to tune\n",
    "             max_features=max(max_features),\n",
    "             loss_fn=nn.CrossEntropyLoss(),   ### change to the best one\n",
    "             val_loss_fn=NegAccuracy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "c0949379-4aa5-435e-9678-343a31fa6491",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): Linear(in_features=114, out_features=114, bias=True)\n",
       "  (1): ReLU()\n",
       "  (2): Dropout(p=0.5, inplace=False)\n",
       "  (3): Linear(in_features=114, out_features=57, bias=True)\n",
       "  (4): ReLU()\n",
       "  (5): Dropout(p=0.5, inplace=False)\n",
       "  (6): Linear(in_features=57, out_features=2, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gafs.model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ethical-terry",
   "metadata": {},
   "source": [
    "# Weight tying"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "ce7d93fc-e608-49b5-a7a5-09758f32ecf7",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Set up accuracy\n",
    "acc_list = []\n",
    "auroc_list=[]\n",
    "\n",
    "# Set up selector\n",
    "selector = nn.Sequential(\n",
    "    nn.Linear(d_in * 2, d_in * 2),\n",
    "    nn.ReLU(),\n",
    "    nn.Dropout(0.5),    ### the hparam need to tune\n",
    "    nn.Linear(d_in * 2, d_in),\n",
    "    nn.ReLU(),\n",
    "    nn.Dropout(0.5),\n",
    "    nn.Linear(d_in, d_in))\n",
    "selector_layer = ConcreteSelector()\n",
    "gafs = GreedyAdaptiveFS(selector, deepcopy(model), mask_layer, selector_layer).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "7febebad-31a4-41bf-8f0e-7015a5cfa3b2",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------Epoch 1--------\n",
      "Val loss = -0.6250\n",
      "\n",
      "--------Epoch 2--------\n",
      "Val loss = -0.6236\n",
      "\n",
      "--------Epoch 3--------\n",
      "Val loss = -0.6223\n",
      "\n",
      "--------Epoch 4--------\n",
      "Val loss = -0.6087\n",
      "\n",
      "--------Epoch 5--------\n",
      "Val loss = -0.6046\n",
      "\n",
      "--------Epoch 6--------\n",
      "Val loss = -0.6209\n",
      "\n",
      "--------Epoch 7--------\n",
      "Val loss = -0.6223\n",
      "\n",
      "--------Epoch 8--------\n",
      "Val loss = -0.6182\n",
      "\n",
      "--------Epoch 9--------\n",
      "Val loss = -0.6345\n",
      "\n",
      "--------Epoch 10--------\n",
      "Val loss = -0.6372\n",
      "\n",
      "--------Epoch 11--------\n",
      "Val loss = -0.6467\n",
      "\n",
      "--------Epoch 12--------\n",
      "Val loss = -0.6345\n",
      "\n",
      "--------Epoch 13--------\n",
      "Val loss = -0.6576\n",
      "\n",
      "--------Epoch 14--------\n",
      "Val loss = -0.6590\n",
      "\n",
      "--------Epoch 15--------\n",
      "Val loss = -0.6644\n",
      "\n",
      "--------Epoch 16--------\n",
      "Val loss = -0.6807\n",
      "\n",
      "--------Epoch 17--------\n",
      "Val loss = -0.6807\n",
      "\n",
      "--------Epoch 18--------\n",
      "Val loss = -0.6766\n",
      "\n",
      "--------Epoch 19--------\n",
      "Val loss = -0.6902\n",
      "\n",
      "--------Epoch 20--------\n",
      "Val loss = -0.7092\n",
      "\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_40118/1031864555.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     24\u001b[0m              \u001b[0mno_repeats\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     25\u001b[0m              \u001b[0mstart_temp\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m10.0\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m              end_temp=0.01)\n\u001b[0m\u001b[1;32m     27\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     28\u001b[0m     \u001b[0;31m# Get accuracy\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/adaptive_selection/models.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, train, val, mbsize, lr, nepochs, max_features, loss_fn, val_loss_fn, train_model, train_selector, start_temp, end_temp, argmax, no_repeats, validation_mode, verbose)\u001b[0m\n\u001b[1;32m    328\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    329\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnepochs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 330\u001b[0;31m             \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtrain_loader\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    331\u001b[0m                 \u001b[0;31m# Move to device.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    332\u001b[0m                 \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    519\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_sampler_iter\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    520\u001b[0m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_reset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 521\u001b[0;31m             \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    522\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_num_yielded\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    523\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_kind\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0m_DatasetKind\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mIterable\u001b[0m \u001b[0;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1173\u001b[0m                 \u001b[0;31m# no valid `self._rcvd_idx` is found (i.e., didn't break)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1174\u001b[0m                 \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_persistent_workers\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1175\u001b[0;31m                     \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_shutdown_workers\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1176\u001b[0m                 \u001b[0;32mraise\u001b[0m \u001b[0mStopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1177\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_shutdown_workers\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1299\u001b[0m                     \u001b[0;31m# wrong, we set a timeout and if the workers fail to join,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1300\u001b[0m                     \u001b[0;31m# they are killed in the `finally` block.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1301\u001b[0;31m                     \u001b[0mw\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0m_utils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mMP_STATUS_CHECK_INTERVAL\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1302\u001b[0m                 \u001b[0;32mfor\u001b[0m \u001b[0mq\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_index_queues\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1303\u001b[0m                     \u001b[0mq\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcancel_join_thread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/multiprocessing/process.py\u001b[0m in \u001b[0;36mjoin\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m    138\u001b[0m         \u001b[0;32massert\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_parent_pid\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgetpid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'can only join a child process'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    139\u001b[0m         \u001b[0;32massert\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_popen\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'can only join a started process'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 140\u001b[0;31m         \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_popen\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwait\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    141\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mres\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    142\u001b[0m             \u001b[0m_children\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdiscard\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/multiprocessing/popen_fork.py\u001b[0m in \u001b[0;36mwait\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m     43\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mtimeout\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     44\u001b[0m                 \u001b[0;32mfrom\u001b[0m \u001b[0mmultiprocessing\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconnection\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mwait\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 45\u001b[0;31m                 \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mwait\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msentinel\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     46\u001b[0m                     \u001b[0;32mreturn\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     47\u001b[0m             \u001b[0;31m# This shouldn't block if wait() returned successfully.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/multiprocessing/connection.py\u001b[0m in \u001b[0;36mwait\u001b[0;34m(object_list, timeout)\u001b[0m\n\u001b[1;32m    918\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    919\u001b[0m             \u001b[0;32mwhile\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 920\u001b[0;31m                 \u001b[0mready\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mselector\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mselect\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    921\u001b[0m                 \u001b[0;32mif\u001b[0m \u001b[0mready\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    922\u001b[0m                     \u001b[0;32mreturn\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfileobj\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mevents\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mready\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/selectors.py\u001b[0m in \u001b[0;36mselect\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m    413\u001b[0m         \u001b[0mready\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    414\u001b[0m         \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 415\u001b[0;31m             \u001b[0mfd_event_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_selector\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpoll\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    416\u001b[0m         \u001b[0;32mexcept\u001b[0m \u001b[0mInterruptedError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    417\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mready\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# Tie weights\n",
    "selector[0].weight = nn.Parameter(gafs.model[0].weight)   ### may need to change the index\n",
    "selector[3].weight = nn.Parameter(gafs.model[3].weight)   ### may need to change the index\n",
    "\n",
    "res_dict_auroc = {}\n",
    "# left_res_dict = {}\n",
    "# right_res_dict = {}\n",
    "\n",
    "res_dict_acc = {}\n",
    "# left_res_dict = {}\n",
    "# right_res_dict = {}\n",
    "\n",
    "for num in max_features:\n",
    "    # Train\n",
    "    gafs.fit(train_dataset,\n",
    "             val_dataset,\n",
    "             mbsize=512,   ### the hparam need to tune, not that important, can have a try\n",
    "             lr=1e-3,    ### the hparam need to tune\n",
    "             nepochs=50,   ### the hparam need to tune, how many epoch needed to converge\n",
    "             max_features=num,\n",
    "             loss_fn=nn.CrossEntropyLoss(),\n",
    "             val_loss_fn=NegAccuracy(),\n",
    "             argmax=False,\n",
    "             no_repeats=False,\n",
    "             start_temp=10.0,   ### the hparam need to tune, how many epoch needed to converge\n",
    "             end_temp=0.01)\n",
    "    \n",
    "    # Get accuracy\n",
    "    test_acc,y_pre, y_true = gafs.evaluate(test_dataset, num, Accuracy(), 1024)\n",
    "    #acc_list.append(test_acc)\n",
    "    if multi_label == 1:\n",
    "        test_auroc = roc_auc_score(y_true.cpu().numpy(), nn.functional.softmax(y_pre, dim=1).cpu().numpy(), average='macro', multi_class = 'ovo')\n",
    "    else:\n",
    "        test_auroc = roc_auc_score(y_true.cpu().numpy(), nn.functional.softmax(y_pre, dim=1).cpu().numpy()[:,1])\n",
    "    res_dict_auroc[num] = test_auroc\n",
    "    test_acc = accuracy_score(y_true.cpu().numpy(), y_pre.cpu().numpy().argmax(axis=1))\n",
    "    res_dict_acc[num] = test_acc\n",
    "    # acc_list.append(test_acc)\n",
    "    # auroc_list.append(test_auroc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "sharp-layout",
   "metadata": {},
   "outputs": [],
   "source": [
    "result_dict = {}\n",
    "result_dict['res_dict_auroc'] = res_dict_auroc\n",
    "result_dict['res_dict_acc'] = res_dict_acc\n",
    "# result_dict['left_res_dict'] = left_res_dict\n",
    "# result_dict['right_res_dict'] = right_res_dict\n",
    "pickle.dump(result_dict, open('./UCI_datasets/results/'+data_name+'_adaptive_results.pkl', 'wb'))"
   ]
  },
  {
   "cell_type": "raw",
   "id": "c0d64cf9-decb-468a-8e55-2274f312e391",
   "metadata": {},
   "source": [
    "# Plot results"
   ]
  },
  {
   "cell_type": "raw",
   "id": "841df4f5-729d-4f44-bfe1-971b5fb1c3c0",
   "metadata": {},
   "source": [
    "global_acc = [0.5562999975204468, 0.7678000000953674, 0.8604000005722046, 0.9107000004768372, 0.9193000002861023]\n",
    "adaptive_acc = [0.7062999967575073, 0.8753999994277954, 0.9239000005722046, 0.9518999997138977, 0.9589999966621399]"
   ]
  },
  {
   "cell_type": "raw",
   "id": "00298c33-8d8f-4cfe-b658-232f593e99f0",
   "metadata": {},
   "source": [
    "plt.figure(figsize=(9, 6))\n",
    "\n",
    "# Plot\n",
    "plt.plot(max_features, global_acc, color='tab:blue', marker='o', markersize=20, label='Global')\n",
    "plt.plot(max_features, adaptive_acc, color='tab:green', marker='*', markersize=20, label='Ours')\n",
    "\n",
    "# # Plot and scatter\n",
    "# plt.plot(max_features, global_acc, color='tab:blue')\n",
    "# plt.plot(max_features, adaptive_acc, color='tab:green')\n",
    "# plt.scatter(max_features, global_acc, color='tab:blue', marker='o', s=20, label='Global')\n",
    "# plt.scatter(max_features, adaptive_acc, color='tab:green', marker='*', s=20, label='Ours')\n",
    "\n",
    "# Legend\n",
    "plt.legend(loc='lower right', frameon=False)\n",
    "\n",
    "# Labels\n",
    "plt.xlabel('# Features', fontsize=18)\n",
    "plt.ylabel('Accuracy', fontsize=18)\n",
    "plt.tick_params(labelsize=16)\n",
    "plt.title('MNIST Feature Selection')\n",
    "\n",
    "# Axis spines\n",
    "plt.gca().spines['right'].set_visible(False)\n",
    "plt.gca().spines['top'].set_visible(False)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
