{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "93e7317b",
   "metadata": {},
   "source": [
    "## train VGG model with PAC-Bayes information bottleneck."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "56fdf3eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import os\n",
    "import pdb\n",
    "\n",
    "from src.dataset import load_data\n",
    "from src.utils import img_preprocess, setup_seed, predict, eval_metric, feature_map_size\n",
    "from src.utils import train\n",
    "from src.models import VGG\n",
    "from src.pib_utils import train_pib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "91daa8d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "__data_set__ = 'cifar10'\n",
    "\n",
    "__prior_ckpt__ = './checkpoints/{}/vgg_prior.pt'.format(__data_set__)\n",
    "__save_ckpt__ = './checkpoints/{}/vgg_pib.pt'.format(__data_set__)\n",
    "\n",
    "opt = {\n",
    "    'num_epoch':100,\n",
    "    'batch_size':32,\n",
    "    'lr':1e-4, \n",
    "    'weight_decay':0,\n",
    "    'beta':1e-1,\n",
    "    'noise_scale':1e-10,\n",
    "    'schedule': [50, 80],\n",
    "    'early_stop': 10,\n",
    "}\n",
    "if not os.path.exists('./checkpoints/{}'.format(__data_set__)):\n",
    "    os.makedirs('./checkpoints/{}'.format(__data_set__))\n",
    "\n",
    "# set random seed\n",
    "setup_seed(2020)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "888592b5",
   "metadata": {},
   "source": [
    "## load data & preprocess"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9874ee16",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load from CIFAR-10.\n"
     ]
    }
   ],
   "source": [
    "x_tr, y_tr, x_va, y_va, x_te, y_te = load_data(__data_set__)\n",
    "\n",
    "all_tr_idx = np.arange(len(x_tr))\n",
    "num_class = np.unique(y_va).shape[0]\n",
    "\n",
    "x_tr, y_tr = img_preprocess(x_tr, y_tr,)\n",
    "x_va, y_va = img_preprocess(x_va, y_va,)\n",
    "x_te, y_te = img_preprocess(x_te, y_te,)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c14509b",
   "metadata": {},
   "source": [
    "## train PIB-based VGG model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1489af66",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "VGG(\n",
       "  (extract_feature): Sequential(\n",
       "    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (1): ReLU()\n",
       "    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (3): ReLU()\n",
       "    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (6): ReLU()\n",
       "    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (8): ReLU()\n",
       "    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (11): ReLU()\n",
       "    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (13): ReLU()\n",
       "    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (15): ReLU()\n",
       "    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (classifier): Sequential(\n",
       "    (0): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "    (1): ReLU()\n",
       "    (2): Dropout(p=0.0, inplace=False)\n",
       "    (3): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "    (4): ReLU()\n",
       "    (5): Dropout(p=0.0, inplace=False)\n",
       "    (6): Linear(in_features=1024, out_features=10, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# load model\n",
    "model = VGG(num_classes=num_class, dropout_rate=0.0, last_feature_map_size=feature_map_size(__data_set__))\n",
    "model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "3147464c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load prior.\n",
      "initialize model weights.\n",
      "done get prior weights\n"
     ]
    }
   ],
   "source": [
    "# get prior on the validation set\n",
    "if os.path.exists(__prior_ckpt__):\n",
    "    print(\"load prior.\")\n",
    "    model.load_state_dict(torch.load(__prior_ckpt__))\n",
    "else:\n",
    "    train(model, np.arange(len(y_va)), x_va, y_va, x_va, y_va, 10, 32, 5e-5, 0, __prior_ckpt__, 5)\n",
    "w0_dict = dict()\n",
    "for param in model.named_parameters():\n",
    "    w0_dict[param[0]] = param[1].clone().detach() # detach but still on gpu\n",
    "model.w0_dict = w0_dict\n",
    "model._initialize_weights()\n",
    "print(\"done get prior weights\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0954bb87",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                                  | 0/100 [00:00<?, ?it/s]D:\\python\\lib\\site-packages\\torch\\nn\\functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  ..\\c10/core/TensorImpl.h:1156.)\n",
      "  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)\n",
      "D:\\PAC-Bayes IB\\PAC-Bayes-IB\\src\\pib_utils.py:115: UserWarning: This overload of add_ is deprecated:\n",
      "\tadd_(Number alpha, Tensor other)\n",
      "Consider using one of the following signatures instead:\n",
      "\tadd_(Tensor other, *, Number alpha) (Triggered internally at  ..\\torch\\csrc\\utils\\python_arg_parser.cpp:1025.)\n",
      "  p.data.add_(-group['lr'], d_p)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, info: {'extract_feature.0.weight': tensor(3.6269e-06, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.2.weight': tensor(2.5803e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.5.weight': tensor(9.1582e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.7.weight': tensor(5.9072e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.10.weight': tensor(7.7142e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.12.weight': tensor(7.5887e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.14.weight': tensor(9.1890e-07, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.0.weight': tensor(1.3563e-06, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.3.weight': tensor(9.2734e-07, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.6.weight': tensor(9.8919e-07, device='cuda:0', grad_fn=<PowBackward0>)}\n",
      "epoch: 0, tr loss: 0.027935729106155428, lr: 0.0001, e_decay: 1.1113430446130224e-05\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  1%|▉                                                                                         | 1/100 [00:18<29:47, 18.06s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, va acc: 0.8075999617576599\n",
      "epoch: 1, info: {'extract_feature.0.weight': tensor(1.0749e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.2.weight': tensor(2.6885e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.5.weight': tensor(6.0926e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.7.weight': tensor(6.9039e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.10.weight': tensor(4.3970e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.12.weight': tensor(5.7491e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.14.weight': tensor(8.2988e-07, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.0.weight': tensor(8.3694e-07, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.3.weight': tensor(6.8296e-07, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.6.weight': tensor(8.9229e-07, device='cuda:0', grad_fn=<PowBackward0>)}\n",
      "epoch: 1, tr loss: 0.016418349326268174, lr: 0.0001, e_decay: 5.9326757764210925e-06\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  2%|█▊                                                                                        | 2/100 [00:36<29:35, 18.12s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1, va acc: 0.8077999949455261\n",
      "epoch: 2, info: {'extract_feature.0.weight': tensor(1.7195e-06, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.2.weight': tensor(7.2301e-09, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.5.weight': tensor(1.9345e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.7.weight': tensor(1.1030e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.10.weight': tensor(1.7715e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.12.weight': tensor(3.6482e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.14.weight': tensor(7.6794e-07, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.0.weight': tensor(7.6276e-07, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.3.weight': tensor(5.6456e-07, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.6.weight': tensor(4.5528e-07, device='cuda:0', grad_fn=<PowBackward0>)}\n",
      "epoch: 2, tr loss: 0.011348072461196668, lr: 0.0001, e_decay: 5.122946731717093e-06\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  3%|██▋                                                                                       | 3/100 [00:54<29:21, 18.16s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 2, va acc: 0.8079999685287476\n",
      "epoch: 3, info: {'extract_feature.0.weight': tensor(8.0290e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.2.weight': tensor(7.3828e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.5.weight': tensor(7.1980e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.7.weight': tensor(3.6335e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.10.weight': tensor(3.1333e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.12.weight': tensor(4.1059e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.14.weight': tensor(5.3655e-07, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.0.weight': tensor(4.2526e-07, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.3.weight': tensor(3.8709e-07, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.6.weight': tensor(4.7218e-07, device='cuda:0', grad_fn=<PowBackward0>)}\n",
      "epoch: 3, tr loss: 0.008637504745757777, lr: 0.0001, e_decay: 5.169328687770758e-06\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  4%|███▌                                                                                      | 4/100 [01:12<28:51, 18.03s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 3, va acc: 0.8082000017166138\n",
      "epoch: 4, info: {'extract_feature.0.weight': tensor(1.4442e-06, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.2.weight': tensor(5.8814e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.5.weight': tensor(9.8539e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.7.weight': tensor(4.5269e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.10.weight': tensor(2.1504e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.12.weight': tensor(2.6243e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.14.weight': tensor(3.7067e-07, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.0.weight': tensor(2.8635e-07, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.3.weight': tensor(2.7926e-07, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.6.weight': tensor(3.8497e-07, device='cuda:0', grad_fn=<PowBackward0>)}\n",
      "epoch: 4, tr loss: 0.00703760145727012, lr: 0.0001, e_decay: 4.382288352644537e-06\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  5%|████▌                                                                                     | 5/100 [01:29<28:20, 17.90s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 4, va acc: 0.8075999617576599\n",
      "epoch: 5, info: {'extract_feature.0.weight': tensor(1.5320e-06, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.2.weight': tensor(4.1105e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.5.weight': tensor(1.6433e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.7.weight': tensor(3.2256e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.10.weight': tensor(1.9278e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.12.weight': tensor(2.9345e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.14.weight': tensor(2.4463e-07, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.0.weight': tensor(3.2504e-07, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.3.weight': tensor(2.3115e-07, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.6.weight': tensor(3.5107e-07, device='cuda:0', grad_fn=<PowBackward0>)}\n",
      "epoch: 5, tr loss: 0.005617788851103892, lr: 0.0001, e_decay: 4.068099315190921e-06\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  6%|█████▍                                                                                    | 6/100 [01:47<27:53, 17.80s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 5, va acc: 0.8079999685287476\n",
      "epoch: 6, info: {'extract_feature.0.weight': tensor(7.4037e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.2.weight': tensor(3.4044e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.5.weight': tensor(9.5743e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.7.weight': tensor(1.0320e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.10.weight': tensor(1.3083e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.12.weight': tensor(9.7484e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.14.weight': tensor(1.4233e-07, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.0.weight': tensor(2.1150e-07, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.3.weight': tensor(1.5491e-07, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.6.weight': tensor(1.3699e-07, device='cuda:0', grad_fn=<PowBackward0>)}\n",
      "epoch: 6, tr loss: 0.0046768031552514725, lr: 0.0001, e_decay: 1.8473979253030848e-06\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  7%|██████▎                                                                                   | 7/100 [02:05<27:38, 17.83s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 6, va acc: 0.8082000017166138\n",
      "epoch: 7, info: {'extract_feature.0.weight': tensor(1.0159e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.2.weight': tensor(1.6380e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.5.weight': tensor(6.6610e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.7.weight': tensor(7.6061e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.10.weight': tensor(7.1090e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.12.weight': tensor(1.1268e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.14.weight': tensor(1.3695e-07, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.0.weight': tensor(1.3392e-07, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.3.weight': tensor(1.4565e-07, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.6.weight': tensor(1.2916e-07, device='cuda:0', grad_fn=<PowBackward0>)}\n",
      "epoch: 7, tr loss: 0.003995783957970366, lr: 0.0001, e_decay: 9.900943496177206e-07\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  8%|███████▏                                                                                  | 8/100 [02:23<27:36, 18.00s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 7, va acc: 0.8083999752998352\n",
      "epoch: 8, info: {'extract_feature.0.weight': tensor(1.0610e-06, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.2.weight': tensor(1.4661e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.5.weight': tensor(5.3935e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.7.weight': tensor(7.7831e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.10.weight': tensor(5.1174e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.12.weight': tensor(8.9589e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.14.weight': tensor(1.0874e-07, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.0.weight': tensor(1.5383e-07, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.3.weight': tensor(1.4675e-07, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.6.weight': tensor(1.2853e-07, device='cuda:0', grad_fn=<PowBackward0>)}\n",
      "epoch: 8, tr loss: 0.003457858437012232, lr: 0.0001, e_decay: 1.8860222326111398e-06\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  9%|████████                                                                                  | 9/100 [02:42<27:23, 18.06s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 8, va acc: 0.8087999820709229\n",
      "epoch: 9, info: {'extract_feature.0.weight': tensor(3.1340e-11, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.2.weight': tensor(1.2467e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.5.weight': tensor(9.9016e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.7.weight': tensor(6.8468e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.10.weight': tensor(5.3269e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.12.weight': tensor(2.0001e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.14.weight': tensor(1.1074e-07, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.0.weight': tensor(1.4707e-07, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.3.weight': tensor(1.3023e-07, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.6.weight': tensor(1.0829e-07, device='cuda:0', grad_fn=<PowBackward0>)}\n",
      "epoch: 9, tr loss: 0.0030516785294881074, lr: 0.0001, e_decay: 1.041796053868893e-06\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 10%|████████▉                                                                                | 10/100 [02:59<27:00, 18.01s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 9, va acc: 0.8065999746322632\n",
      "epoch: 10, info: {'extract_feature.0.weight': tensor(1.7619e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.2.weight': tensor(7.9625e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.5.weight': tensor(1.2105e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.7.weight': tensor(9.8592e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.10.weight': tensor(9.0488e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.12.weight': tensor(4.7998e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.14.weight': tensor(9.0422e-08, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.0.weight': tensor(9.3600e-08, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.3.weight': tensor(8.1775e-08, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.6.weight': tensor(9.0230e-08, device='cuda:0', grad_fn=<PowBackward0>)}\n",
      "epoch: 10, tr loss: 0.0026741414986563303, lr: 0.0001, e_decay: 8.114002412185073e-07\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 11%|█████████▊                                                                               | 11/100 [03:18<26:46, 18.05s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 10, va acc: 0.8068000078201294\n",
      "epoch: 11, info: {'extract_feature.0.weight': tensor(5.0358e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.2.weight': tensor(1.3053e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.5.weight': tensor(7.6545e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.7.weight': tensor(8.6383e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.10.weight': tensor(6.4748e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.12.weight': tensor(7.4786e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.14.weight': tensor(6.3328e-08, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.0.weight': tensor(9.4890e-08, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.3.weight': tensor(6.9934e-08, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.6.weight': tensor(9.2716e-08, device='cuda:0', grad_fn=<PowBackward0>)}\n",
      "epoch: 11, tr loss: 0.002352243481530861, lr: 0.0001, e_decay: 1.2574425909406273e-06\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 12%|██████████▋                                                                              | 12/100 [03:36<26:26, 18.03s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 11, va acc: 0.8047999739646912\n",
      "epoch: 12, info: {'extract_feature.0.weight': tensor(1.0690e-06, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.2.weight': tensor(1.9843e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.5.weight': tensor(1.1100e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.7.weight': tensor(1.3773e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.10.weight': tensor(7.9274e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.12.weight': tensor(8.1656e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.14.weight': tensor(7.2513e-08, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.0.weight': tensor(1.1701e-07, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.3.weight': tensor(9.6662e-08, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.6.weight': tensor(1.1362e-07, device='cuda:0', grad_fn=<PowBackward0>)}\n",
      "epoch: 12, tr loss: 0.002109021629888209, lr: 0.0001, e_decay: 2.0769430193467997e-06\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 13%|███████████▌                                                                             | 13/100 [03:53<26:05, 18.00s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 12, va acc: 0.8069999814033508\n",
      "epoch: 13, info: {'extract_feature.0.weight': tensor(2.4711e-09, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.2.weight': tensor(3.9272e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.5.weight': tensor(2.5752e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.7.weight': tensor(5.7371e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.10.weight': tensor(3.2324e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.12.weight': tensor(5.5729e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.14.weight': tensor(3.0857e-08, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.0.weight': tensor(7.6772e-08, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.3.weight': tensor(6.1762e-08, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.6.weight': tensor(5.9133e-08, device='cuda:0', grad_fn=<PowBackward0>)}\n",
      "epoch: 13, tr loss: 0.0019043863307165315, lr: 0.0001, e_decay: 4.414425802679034e-07\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 14%|████████████▍                                                                            | 14/100 [04:12<25:54, 18.08s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 13, va acc: 0.8057999610900879\n",
      "epoch: 14, info: {'extract_feature.0.weight': tensor(3.9987e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.2.weight': tensor(1.6150e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.5.weight': tensor(3.1180e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.7.weight': tensor(3.1791e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.10.weight': tensor(4.8599e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.12.weight': tensor(3.9306e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.14.weight': tensor(6.4064e-08, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.0.weight': tensor(5.0017e-08, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.3.weight': tensor(5.0270e-08, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.6.weight': tensor(4.6382e-08, device='cuda:0', grad_fn=<PowBackward0>)}\n",
      "epoch: 14, tr loss: 0.0017193997835093222, lr: 0.0001, e_decay: 7.776333745823649e-07\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 15%|█████████████▎                                                                           | 15/100 [04:30<25:35, 18.07s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 14, va acc: 0.8051999807357788\n",
      "epoch: 15, info: {'extract_feature.0.weight': tensor(1.6645e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.2.weight': tensor(4.0766e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.5.weight': tensor(3.3827e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.7.weight': tensor(5.2978e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.10.weight': tensor(1.5316e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.12.weight': tensor(3.8095e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.14.weight': tensor(6.6978e-08, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.0.weight': tensor(6.1674e-08, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.3.weight': tensor(6.4675e-08, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.6.weight': tensor(4.7283e-08, device='cuda:0', grad_fn=<PowBackward0>)}\n",
      "epoch: 15, tr loss: 0.0015213592719182396, lr: 0.0001, e_decay: 5.880411890757387e-07\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 16%|██████████████▏                                                                          | 16/100 [04:48<25:15, 18.05s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 15, va acc: 0.8069999814033508\n",
      "epoch: 16, info: {'extract_feature.0.weight': tensor(1.0852e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.2.weight': tensor(6.3073e-09, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.5.weight': tensor(1.5155e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.7.weight': tensor(2.3850e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.10.weight': tensor(1.8594e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.12.weight': tensor(2.4377e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.14.weight': tensor(2.1956e-08, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.0.weight': tensor(3.1862e-08, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.3.weight': tensor(2.6284e-08, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.6.weight': tensor(3.2692e-08, device='cuda:0', grad_fn=<PowBackward0>)}\n",
      "epoch: 16, tr loss: 0.0014166381018200862, lr: 0.0001, e_decay: 2.1193011434661457e-07\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 17%|███████████████▏                                                                         | 17/100 [05:06<24:55, 18.02s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 16, va acc: 0.8057999610900879\n",
      "epoch: 17, info: {'extract_feature.0.weight': tensor(1.1708e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.2.weight': tensor(3.3227e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.5.weight': tensor(3.2608e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.7.weight': tensor(2.8221e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.10.weight': tensor(2.4811e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.12.weight': tensor(2.0055e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.14.weight': tensor(2.9486e-08, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.0.weight': tensor(4.0754e-08, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.3.weight': tensor(2.6790e-08, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.6.weight': tensor(3.0805e-08, device='cuda:0', grad_fn=<PowBackward0>)}\n",
      "epoch: 17, tr loss: 0.0012668993128090184, lr: 0.0001, e_decay: 3.8383828382393403e-07\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 18%|████████████████                                                                         | 18/100 [05:24<24:37, 18.02s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 17, va acc: 0.8057999610900879\n",
      "epoch: 18, info: {'extract_feature.0.weight': tensor(1.8303e-07, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.2.weight': tensor(7.4612e-09, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.5.weight': tensor(1.5474e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.7.weight': tensor(1.7633e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.10.weight': tensor(2.4692e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.12.weight': tensor(2.2989e-08, device='cuda:0', grad_fn=<PowBackward0>), 'extract_feature.14.weight': tensor(2.0370e-08, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.0.weight': tensor(5.2487e-08, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.3.weight': tensor(2.9942e-08, device='cuda:0', grad_fn=<PowBackward0>), 'classifier.6.weight': tensor(2.6500e-08, device='cuda:0', grad_fn=<PowBackward0>)}\n",
      "epoch: 18, tr loss: 0.0011806567544521713, lr: 0.0001, e_decay: 4.005814844276756e-07\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 18%|████████████████                                                                         | 18/100 [05:42<25:58, 19.00s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 18, va acc: 0.8068000078201294\n",
      "early stop on epoch 18, val acc 0.8087999820709229\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# start training model\n",
    "info_dict, loss_acc_dict = train_pib(model, all_tr_idx,\n",
    "    x_tr, y_tr, x_va, y_va, \n",
    "    num_epoch=opt['num_epoch'],\n",
    "    batch_size=opt['batch_size'],\n",
    "    lr=opt['lr'],\n",
    "    weight_decay=opt['weight_decay'],\n",
    "    beta=opt['beta'],\n",
    "    early_stop_ckpt_path=__save_ckpt__,\n",
    "    early_stop_tolerance=opt['early_stop'],\n",
    "    noise_scale=opt['noise_scale'],\n",
    "    schedule=opt['schedule'],\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5c07ff4",
   "metadata": {},
   "source": [
    "## bootstrap on test set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "5b249a0a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test acc: tensor(0.8002, device='cuda:0')\n",
      "test acc: tensor(0.8009, device='cuda:0')\n",
      "test acc: tensor(0.8128, device='cuda:0')\n",
      "test acc: tensor(0.8011, device='cuda:0')\n",
      "test acc: tensor(0.8118, device='cuda:0')\n",
      "test acc: tensor(0.8077, device='cuda:0')\n",
      "test acc: tensor(0.8032, device='cuda:0')\n",
      "test acc: tensor(0.7972, device='cuda:0')\n",
      "test acc: tensor(0.8063, device='cuda:0')\n",
      "test acc: tensor(0.8019, device='cuda:0')\n",
      "95.0 confidence interval 79.79% and 81.26%\n",
      "average: tensor(0.8052, device='cuda:0')\n",
      "interval: tensor(0.0073, device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "# do bootstrapping\n",
    "stats = []\n",
    "for i in range(10):\n",
    "    # sample x_te\n",
    "    sub_idx = np.random.choice(np.arange(len(x_te)), len(x_te), replace=True)\n",
    "    x_te_sub, y_te_sub = x_te[sub_idx], y_te[sub_idx]\n",
    "    pred_te = predict(model, x_te_sub)\n",
    "    acc_te = eval_metric(pred_te, y_te_sub, num_class)\n",
    "    stats.append(acc_te)\n",
    "    print(\"test acc:\", acc_te)\n",
    "\n",
    "# compute confidence interveal 95%\n",
    "alpha = 0.95\n",
    "p = ((1-alpha)/2) * 100\n",
    "lower = max(0, np.percentile(stats, p))\n",
    "p = (alpha+((1.0-alpha)/2.0)) * 100\n",
    "upper = min(1.0, np.percentile(stats, p))\n",
    "print('%.1f confidence interval %.2f%% and %.2f%%' % (alpha*100, lower*100, upper*100))\n",
    "print('average:', (upper+lower)/2)\n",
    "print('interval:', (upper-lower)/2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
