{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm.notebook import tqdm\n",
    "from utils import *\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import torchbnn as bnn\n",
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Epistemic Problem\n",
    "Train a 4 Bayesian neural networks (BNNs). Compute the epistemic uncertainty (i.e., the variance over 10000 samples)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"SEED:\", torch.seed())\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# Experiment settings\n",
    "B = 32\n",
    "T = 100\n",
    "embed_dim = 5\n",
    "num_epochs = 500\n",
    "sigma = 0.04\n",
    "Ns = [100, 200, 400, 800]\n",
    "\n",
    "kl_loss = bnn.BKLLoss(reduction='mean', last_layer_only=False)\n",
    "kl_weight = 0.01\n",
    "\n",
    "bnns = []\n",
    "for N in Ns:\n",
    "    dset = ToyDataset(N, sigma)\n",
    "    loader = torch.utils.data.DataLoader(dset, batch_size=B, shuffle=True)\n",
    "\n",
    "    bnn_model = nn.Sequential(\n",
    "        bnn.BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=1, out_features=16),\n",
    "        nn.ReLU(),\n",
    "        bnn.BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=16, out_features=32),\n",
    "        nn.ReLU(),\n",
    "        bnn.BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=32, out_features=64),\n",
    "        nn.ReLU(),\n",
    "        bnn.BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=64, out_features=32),\n",
    "        nn.ReLU(),\n",
    "        bnn.BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=32, out_features=16),\n",
    "        nn.ReLU(),\n",
    "        bnn.BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=16, out_features=1),\n",
    "    ).to(device)\n",
    "    bnn_model.train()\n",
    "    bnn_optim = torch.optim.Adam(bnn_model.parameters(), lr=1e-4)\n",
    "\n",
    "    pbar = tqdm(range(num_epochs))\n",
    "    for epoch in pbar:\n",
    "        for x, y in loader:\n",
    "            x = x.to(device)\n",
    "            y = y.to(device)\n",
    "\n",
    "            pre = bnn_model(y.unsqueeze(-1))\n",
    "            mse = torch.nn.functional.mse_loss(pre, x.unsqueeze(-1))\n",
    "            kl = kl_loss(bnn_model)\n",
    "            bnn_loss = mse + kl_weight*kl\n",
    "            \n",
    "            bnn_optim.zero_grad()\n",
    "            bnn_loss.backward()\n",
    "            bnn_optim.step()\n",
    "\n",
    "    bnns.append(bnn_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_bnn_results(model, y, N=10000):\n",
    "    model.eval()\n",
    "    cond = y.to(device).unsqueeze(-1)\n",
    "    full_results = []\n",
    "    for _ in np.arange(N):\n",
    "        result = model(cond)\n",
    "        full_results.append(result)\n",
    "    full_results = torch.stack(full_results).flatten()\n",
    "    return full_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 10000\n",
    "\n",
    "vars = []\n",
    "for i, model in enumerate(bnns):\n",
    "    preds = get_bnn_results(model, y, n)\n",
    "    mean_pred = preds.mean()\n",
    "    var = preds.var()\n",
    "    vars.append(var)\n",
    "vars = torch.stack(vars)\n",
    "print(vars)"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
