{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b1f2fb3a-c221-4578-a73f-5e1865bcdb5e",
   "metadata": {},
   "source": [
    "# Train an RBF network via gradient descent\n",
    "\n",
    "In this notebook, we show how to instantiate and train an RBF network (and an MLP network). We will test the OOD capabilities of the trained deterministic discriminative models by looking at the softmax entropy and some chosen OOD datasets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "08fff8f7-ca4b-4783-91ac-1e03eacca769",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "curr_dir = os.path.basename(os.path.abspath(os.curdir))\n",
    "# See __init__.py in folder \"toy_example\" for an explanation.\n",
    "if curr_dir == 'tutorials' and '..' not in sys.path:\n",
    "    sys.path.insert(0, '..')\n",
    "\n",
    "from hypnettorch.data.mnist_data import MNISTData\n",
    "from hypnettorch.data.fashion_mnist import FashionMNISTData\n",
    "from hypnettorch.mnets import MLP\n",
    "from hypnettorch.utils import misc\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from sklearn.metrics import roc_auc_score\n",
    "from time import time\n",
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "from finite_width.rbf_net import StackedRBFNet\n",
    "\n",
    "from IPython.display import display, Markdown, Latex\n",
    "#display(Markdown('*some markdown* $\\phi$'))\n",
    "\n",
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "device = 'cuda'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "890f7f9e-7b50-4c19-8862-4a7a897db65d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reading MNIST dataset ...\n",
      "Elapsed time to read dataset: 0.165256 sec\n"
     ]
    }
   ],
   "source": [
    "mnist = MNISTData('.', use_one_hot=True)\n",
    "fmnist = FashionMNISTData('.', use_one_hot=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "eb9bfa5c-8d6f-4155-9c89-6bda70795048",
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_net(net, data, use_test=True):\n",
    "    with torch.no_grad():\n",
    "        test_in = data.input_to_torch_tensor( \\\n",
    "            data.get_test_inputs() if use_test else data.get_val_inputs(), \\\n",
    "            device, mode='inference')\n",
    "        test_out = data.input_to_torch_tensor( \\\n",
    "            data.get_test_outputs() if use_test else data.get_val_outputs(),\n",
    "            device, mode='inference')\n",
    "        test_lbls = test_out.max(dim=1)[1]\n",
    "\n",
    "        logits = net(test_in)\n",
    "        pred_lbls = logits.max(dim=1)[1]\n",
    "\n",
    "        acc = torch.sum(test_lbls == pred_lbls) / test_lbls.numel() * 100.\n",
    "    return acc\n",
    "\n",
    "def train_net(net, data, lr=1e-3, nepochs=10):\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    optimizer = torch.optim.Adam(net.internal_params, lr=lr)\n",
    "\n",
    "    for epoch in range(nepochs): \n",
    "\n",
    "        i = 0\n",
    "        for batch_size, x, y in data.train_iterator(32):\n",
    "            i += 1\n",
    "\n",
    "            x_t = data.input_to_torch_tensor(x, device, mode='train')\n",
    "            y_t = data.output_to_torch_tensor(y, device, mode='train')\n",
    "\n",
    "            # zero the parameter gradients\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            # forward + backward + optimize\n",
    "            p_t = net(x_t)\n",
    "            loss = criterion(p_t, y_t.max(dim=1)[1])\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            if i % 500 == 0:            \n",
    "                print('[%d, %5d] loss: %.3f, val-acc: %.2f%%' %\n",
    "                      (epoch + 1, i + 1, loss.item(), \n",
    "                       test_net(net, data, use_test=False)))\n",
    "\n",
    "    print('Training finished with test-acc: %.2f%%' % (test_net(net, mnist)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f2fd82dc-279f-4f80-ac95-435e67ed7b8f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Creating a \"1-layer RBF network\" with 79410 weights\n",
      "[1,   501] loss: 2.587, val-acc: 11.00%\n",
      "[1,  1001] loss: 2.187, val-acc: 28.56%\n",
      "[1,  1501] loss: 2.159, val-acc: 24.38%\n",
      "[2,   501] loss: 1.987, val-acc: 39.70%\n",
      "[2,  1001] loss: 1.888, val-acc: 47.42%\n",
      "[2,  1501] loss: 1.718, val-acc: 39.80%\n",
      "[3,   501] loss: 1.456, val-acc: 51.36%\n",
      "[3,  1001] loss: 1.399, val-acc: 53.96%\n",
      "[3,  1501] loss: 1.538, val-acc: 51.54%\n",
      "[4,   501] loss: 1.544, val-acc: 57.52%\n",
      "[4,  1001] loss: 1.398, val-acc: 59.42%\n",
      "[4,  1501] loss: 1.339, val-acc: 63.46%\n",
      "[5,   501] loss: 1.316, val-acc: 48.32%\n",
      "[5,  1001] loss: 1.395, val-acc: 58.98%\n",
      "[5,  1501] loss: 1.275, val-acc: 68.08%\n",
      "[6,   501] loss: 1.446, val-acc: 65.60%\n",
      "[6,  1001] loss: 1.090, val-acc: 58.60%\n",
      "[6,  1501] loss: 0.989, val-acc: 71.60%\n",
      "[7,   501] loss: 1.137, val-acc: 68.28%\n",
      "[7,  1001] loss: 1.019, val-acc: 72.46%\n",
      "[7,  1501] loss: 1.086, val-acc: 67.06%\n",
      "[8,   501] loss: 1.128, val-acc: 71.94%\n",
      "[8,  1001] loss: 1.096, val-acc: 72.88%\n",
      "[8,  1501] loss: 0.954, val-acc: 70.28%\n",
      "[9,   501] loss: 1.170, val-acc: 73.34%\n",
      "[9,  1001] loss: 1.105, val-acc: 69.54%\n",
      "[9,  1501] loss: 0.905, val-acc: 74.72%\n",
      "[10,   501] loss: 0.714, val-acc: 76.48%\n",
      "[10,  1001] loss: 0.917, val-acc: 71.10%\n",
      "[10,  1501] loss: 0.843, val-acc: 77.12%\n",
      "[11,   501] loss: 1.088, val-acc: 79.94%\n",
      "[11,  1001] loss: 0.871, val-acc: 78.24%\n",
      "[11,  1501] loss: 1.049, val-acc: 79.10%\n",
      "[12,   501] loss: 0.720, val-acc: 80.10%\n",
      "[12,  1001] loss: 0.769, val-acc: 81.68%\n",
      "[12,  1501] loss: 0.621, val-acc: 79.78%\n",
      "[13,   501] loss: 0.881, val-acc: 79.58%\n",
      "[13,  1001] loss: 0.725, val-acc: 79.42%\n",
      "[13,  1501] loss: 0.744, val-acc: 76.04%\n",
      "[14,   501] loss: 0.674, val-acc: 82.52%\n",
      "[14,  1001] loss: 0.663, val-acc: 78.34%\n",
      "[14,  1501] loss: 0.587, val-acc: 82.50%\n",
      "[15,   501] loss: 0.645, val-acc: 78.20%\n",
      "[15,  1001] loss: 0.685, val-acc: 81.84%\n",
      "[15,  1501] loss: 0.706, val-acc: 76.34%\n",
      "[16,   501] loss: 1.004, val-acc: 76.36%\n",
      "[16,  1001] loss: 0.736, val-acc: 82.60%\n",
      "[16,  1501] loss: 0.772, val-acc: 85.40%\n",
      "[17,   501] loss: 0.761, val-acc: 83.62%\n",
      "[17,  1001] loss: 0.770, val-acc: 79.62%\n",
      "[17,  1501] loss: 0.607, val-acc: 82.02%\n",
      "[18,   501] loss: 0.666, val-acc: 81.28%\n",
      "[18,  1001] loss: 0.510, val-acc: 83.58%\n",
      "[18,  1501] loss: 0.700, val-acc: 81.14%\n",
      "[19,   501] loss: 0.934, val-acc: 84.24%\n",
      "[19,  1001] loss: 0.565, val-acc: 81.22%\n",
      "[19,  1501] loss: 0.562, val-acc: 83.08%\n",
      "[20,   501] loss: 0.680, val-acc: 79.96%\n",
      "[20,  1001] loss: 0.688, val-acc: 84.32%\n",
      "[20,  1501] loss: 0.691, val-acc: 84.88%\n",
      "[21,   501] loss: 0.424, val-acc: 81.54%\n",
      "[21,  1001] loss: 0.687, val-acc: 82.84%\n",
      "[21,  1501] loss: 0.684, val-acc: 82.88%\n",
      "[22,   501] loss: 0.492, val-acc: 84.20%\n",
      "[22,  1001] loss: 0.469, val-acc: 84.28%\n",
      "[22,  1501] loss: 0.411, val-acc: 85.02%\n",
      "[23,   501] loss: 0.841, val-acc: 80.74%\n",
      "[23,  1001] loss: 0.718, val-acc: 83.04%\n",
      "[23,  1501] loss: 0.731, val-acc: 85.46%\n",
      "[24,   501] loss: 0.508, val-acc: 85.12%\n",
      "[24,  1001] loss: 0.409, val-acc: 85.74%\n",
      "[24,  1501] loss: 0.349, val-acc: 83.16%\n",
      "[25,   501] loss: 0.710, val-acc: 85.64%\n",
      "[25,  1001] loss: 0.879, val-acc: 86.02%\n",
      "[25,  1501] loss: 0.483, val-acc: 87.44%\n",
      "[26,   501] loss: 0.547, val-acc: 86.46%\n",
      "[26,  1001] loss: 0.502, val-acc: 86.80%\n",
      "[26,  1501] loss: 0.669, val-acc: 86.54%\n",
      "[27,   501] loss: 0.620, val-acc: 83.32%\n",
      "[27,  1001] loss: 0.388, val-acc: 83.74%\n",
      "[27,  1501] loss: 0.616, val-acc: 84.72%\n",
      "[28,   501] loss: 0.386, val-acc: 85.18%\n",
      "[28,  1001] loss: 0.582, val-acc: 86.06%\n",
      "[28,  1501] loss: 0.627, val-acc: 87.18%\n",
      "[29,   501] loss: 0.728, val-acc: 86.88%\n",
      "[29,  1001] loss: 0.544, val-acc: 86.32%\n",
      "[29,  1501] loss: 0.336, val-acc: 87.60%\n",
      "[30,   501] loss: 0.486, val-acc: 87.54%\n",
      "[30,  1001] loss: 0.489, val-acc: 88.02%\n",
      "[30,  1501] loss: 0.388, val-acc: 85.34%\n",
      "Training finished with test-acc: 86.55%\n"
     ]
    }
   ],
   "source": [
    "rbf_net = StackedRBFNet(n_in=np.prod(mnist.in_shape), n_nonlin_units=(100,), \n",
    "                        n_lin_units=(10,), use_bias=True,\n",
    "                        bandwidth=5000).to(device)\n",
    "\n",
    "train_net(rbf_net, mnist, lr=1e-2, nepochs=30)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6e2ca3cb-0b34-4382-b06b-3990b0f2bcd5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Creating an MLP with 478410 weights.\n",
      "[1,   501] loss: 0.027, val-acc: 94.70%\n",
      "[1,  1001] loss: 0.169, val-acc: 95.18%\n",
      "[1,  1501] loss: 0.046, val-acc: 96.66%\n",
      "[2,   501] loss: 0.017, val-acc: 96.38%\n",
      "[2,  1001] loss: 0.204, val-acc: 96.76%\n",
      "[2,  1501] loss: 0.110, val-acc: 97.58%\n",
      "[3,   501] loss: 0.132, val-acc: 97.70%\n",
      "[3,  1001] loss: 0.155, val-acc: 97.64%\n",
      "[3,  1501] loss: 0.002, val-acc: 97.40%\n",
      "[4,   501] loss: 0.120, val-acc: 98.08%\n",
      "[4,  1001] loss: 0.009, val-acc: 97.66%\n",
      "[4,  1501] loss: 0.002, val-acc: 97.50%\n",
      "[5,   501] loss: 0.001, val-acc: 97.28%\n",
      "[5,  1001] loss: 0.002, val-acc: 97.74%\n",
      "[5,  1501] loss: 0.049, val-acc: 97.74%\n",
      "[6,   501] loss: 0.001, val-acc: 98.12%\n",
      "[6,  1001] loss: 0.004, val-acc: 97.16%\n",
      "[6,  1501] loss: 0.016, val-acc: 97.84%\n",
      "[7,   501] loss: 0.117, val-acc: 98.16%\n",
      "[7,  1001] loss: 0.070, val-acc: 97.64%\n",
      "[7,  1501] loss: 0.009, val-acc: 97.84%\n",
      "[8,   501] loss: 0.004, val-acc: 98.08%\n",
      "[8,  1001] loss: 0.023, val-acc: 98.06%\n",
      "[8,  1501] loss: 0.000, val-acc: 97.94%\n",
      "[9,   501] loss: 0.000, val-acc: 97.88%\n",
      "[9,  1001] loss: 0.065, val-acc: 98.12%\n",
      "[9,  1501] loss: 0.000, val-acc: 98.18%\n",
      "[10,   501] loss: 0.133, val-acc: 97.90%\n",
      "[10,  1001] loss: 0.001, val-acc: 97.84%\n",
      "[10,  1501] loss: 0.000, val-acc: 97.94%\n",
      "Training finished with test-acc: 98.00%\n"
     ]
    }
   ],
   "source": [
    "mlp_net = MLP(n_in=np.prod(mnist.in_shape), n_out=10,\n",
    "              hidden_layers=(400,400)).to(device)\n",
    "\n",
    "train_net(mlp_net, mnist, lr=1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a5978166-81ac-4175-8ca2-e44c958fe07f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MLP AUROC: 0.876\n",
      "RBF Net AUROC: 0.792\n"
     ]
    }
   ],
   "source": [
    "def calc_auroc(net, ind_data, ood_data):\n",
    "    with torch.no_grad():\n",
    "        ind_inps = ind_data.input_to_torch_tensor( \\\n",
    "            ind_data.get_test_inputs(), device, mode='inference')\n",
    "        ind_logits = net(ind_inps)\n",
    "        ind_softmax = nn.functional.softmax(ind_logits, dim=1).\\\n",
    "            cpu().detach().numpy()\n",
    "        ind_entropies = - np.sum(ind_softmax * \\\n",
    "                                 np.log(np.maximum(ind_softmax, 1e-5)), axis=1)\n",
    "        \n",
    "        ood_inps = ood_data.input_to_torch_tensor( \\\n",
    "            ood_data.get_test_inputs(), device, mode='inference')\n",
    "        ood_logits = net(ood_inps)\n",
    "        ood_softmax = nn.functional.softmax(ood_logits, dim=1).\\\n",
    "            cpu().detach().numpy()\n",
    "        ood_entropies = - np.sum(ood_softmax * \\\n",
    "                                 np.log(np.maximum(ood_softmax, 1e-5)), axis=1)\n",
    "        \n",
    "        y_true = [0]*len(ind_entropies) + [1]*len(ood_entropies)\n",
    "        y_score = ind_entropies.tolist() + ood_entropies.tolist()\n",
    "        auroc = roc_auc_score(y_true, y_score)\n",
    "        \n",
    "        return auroc\n",
    "\n",
    "print('MLP AUROC: %.3f' % (calc_auroc(mlp_net, mnist, fmnist)))\n",
    "print('RBF Net AUROC: %.3f' % (calc_auroc(rbf_net, mnist, fmnist)))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
