{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm.notebook import tqdm\n",
    "import torch.nn.functional as F\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "from utils import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Epistemic Problem\n",
    "Train a 4 diffusion models with MC Dropout. Compute the aleatoric uncertainty $\\mathbb{E}(\\sigma^2(x))$ and epistemic uncertainty $Var(\\mathbb{E}(x))$ over 1000 samples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "init_seed = torch.random.initial_seed()\n",
    "print(\"SEED:\", init_seed)\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "p = 0.1\n",
    "\n",
    "# Experiment settings\n",
    "Ns = [100, 200, 400, 800]\n",
    "sigma = 0.1\n",
    "B = 32\n",
    "T = 100\n",
    "num_epochs = 1000\n",
    "\n",
    "models = []\n",
    "for N in Ns:\n",
    "    dset = ToyDataset(N, sigma)\n",
    "    loader = torch.utils.data.DataLoader(dset, batch_size=B, shuffle=True)\n",
    "\n",
    "    model = DropoutDDPM(T, p, device).to(device)\n",
    "    model.train()\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
    "\n",
    "    pbar = tqdm(range(num_epochs))\n",
    "    for epoch in pbar:\n",
    "        for x, y in loader:\n",
    "            x = x.to(device)\n",
    "            y = y.to(device)\n",
    "\n",
    "            t = torch.randint(0, T, (len(x),)).to(device)\n",
    "            loss = model.p_loss(x, t, y)\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            pbar.set_description(f\"Loss: {loss.item():.3f}\")\n",
    "    models.append(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_results(model, y, M=10, N=10000):\n",
    "    cond = y.repeat(N).to(device)\n",
    "    means = []\n",
    "    vars = []\n",
    "    seeds = torch.randint(0, 1000, (M,), device=device).tolist()\n",
    "    for seed in seeds:\n",
    "        full_results = model.p_sample_loop(N, cond, seed)\n",
    "        means.append(full_results.mean())\n",
    "        vars.append(full_results.var())\n",
    "    means = torch.stack(means)\n",
    "    vars = torch.stack(vars)\n",
    "    return means, vars"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Aleatoric Problem\n",
    "Train an 4 diffusion models with dropout on 4 different noise levels $\\sigma=\\{0.1, 0.2, 0.4,0.8\\}$. Compute the aleatoric uncertainty $\\mathbb{E}(\\sigma^2(x))$ and epistemic uncertainty $Var(\\mathbb{E}(x))$ over 10 weights and 1000 samples per weight."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "init_seed = torch.random.initial_seed()\n",
    "print(\"SEED:\", init_seed)\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "p = 0.1\n",
    "\n",
    "# Experiment settings\n",
    "N = 500\n",
    "B = 32\n",
    "T = 100\n",
    "num_epochs = 1000\n",
    "sigmas = [0.1, 0.2, 0.4, 0.8]\n",
    "\n",
    "models = []\n",
    "for sigma in sigmas:\n",
    "    dset = ToyDataset(N, sigma)\n",
    "    loader = torch.utils.data.DataLoader(dset, batch_size=B, shuffle=True)\n",
    "\n",
    "    model = DropoutDDPM(T, p, device).to(device)\n",
    "    model.train()\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
    "\n",
    "    pbar = tqdm(range(num_epochs))\n",
    "    for epoch in pbar:\n",
    "        for x, y in loader:\n",
    "            x = x.to(device)\n",
    "            y = y.to(device)\n",
    "\n",
    "            t = torch.randint(0, T, (len(x),)).to(device)\n",
    "            loss = model.p_loss(x, t, y)\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            pbar.set_description(f\"Loss: {loss.item():.3f}\")\n",
    "    models.append(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y = torch.tensor([4])\n",
    "M = 100\n",
    "N = 10000\n",
    "\n",
    "means = []\n",
    "vars = []\n",
    "preds = []\n",
    "for i, model in enumerate(models):\n",
    "    mean, var = get_results(model, y, M, N)\n",
    "    means.append(mean - mean.mean())\n",
    "    vars.append(var)\n",
    "    preds.append(mean.mean())\n",
    "means = torch.stack(means)\n",
    "vars = torch.stack(vars)\n",
    "preds = torch.stack(preds)\n",
    "\n",
    "df = pd.DataFrame({\"mean\": means.flatten().detach().cpu().numpy(), \"var\": vars.flatten().detach().cpu().numpy(), \"sigma\": [f\"$\\sigma_\\eta^2$ = {sigma:.2f}\" for sigma in np.repeat(sigmas, M)**2]})\n",
    "print(vars.mean(dim=1).detach().cpu().numpy())\n",
    "sns.set(font=\"serif\", style=\"ticks\")\n",
    "ax = sns.kdeplot(data=df, x=\"var\", hue=\"sigma\", fill=True, palette=\"colorblind\")\n",
    "ax.set(xlabel=\"Variance\", ylabel=\"Density\", title=\"Sample Variance for Weight Distribution $\\\\tilde{\\\\theta}$\")\n",
    "sns.move_legend(\n",
    "    ax,\n",
    "    \"upper right\",\n",
    "    title=\"Legend\",\n",
    ")\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
