{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm.notebook import tqdm\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from utils import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Epistemic Problem\n",
    "Train a 4 ensembles of diffusion models. Compute the aleatoric uncertainty $\\mathbb{E}(\\sigma^2(x))$ and epistemic uncertainty $Var(\\mathbb{E}(x))$ over 1000 samples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"SEED:\", torch.random.initial_seed())\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# Experiment settings\n",
    "sigma = 0.1\n",
    "M = 10 # ensemble size\n",
    "B = 32\n",
    "T = 100\n",
    "embed_dim = 5\n",
    "num_epochs = 500\n",
    "Ns = [100, 200, 400, 800]\n",
    "\n",
    "ensembles = []\n",
    "for N in Ns:\n",
    "    dset = ToyDataset(N, sigma)\n",
    "    loader = torch.utils.data.DataLoader(dset, batch_size=B, shuffle=True)\n",
    "\n",
    "    ensemble = []\n",
    "    for i in range(M):\n",
    "        model = DDPM(T, device)\n",
    "        model.train()\n",
    "        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
    "\n",
    "        losses = []\n",
    "        pbar = tqdm(range(num_epochs))\n",
    "        for epoch in pbar:\n",
    "            for x, y in loader:\n",
    "                x = x.to(device)\n",
    "                y = y.to(device)\n",
    "\n",
    "                t = torch.randint(0, T, (len(x),)).to(device)\n",
    "                loss = model.p_loss(x, t, y)\n",
    "\n",
    "                optimizer.zero_grad()\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "                pbar.set_description(f\"Loss: {loss.item():.3f}\")\n",
    "            losses.append(loss.item())\n",
    "        ensemble.append(model)\n",
    "    ensembles.append(ensemble)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def get_results(ensemble, y, M=10, N=10000):\n",
    "    cond = y.repeat(N).to(device)\n",
    "    means = []\n",
    "    vars = []\n",
    "    for model in ensemble:\n",
    "        full_results = model.p_sample_loop(N, cond)\n",
    "        means.append(full_results.mean())\n",
    "        vars.append(full_results.var())\n",
    "    means = torch.stack(means)\n",
    "    vars = torch.stack(vars)\n",
    "    return means, vars\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 10000\n",
    "y = torch.tensor([0])\n",
    "\n",
    "means = []\n",
    "vars = []\n",
    "for i, ensemble in enumerate(ensembles):\n",
    "    mean, var = get_results(ensemble, y, M, N)\n",
    "    means.append(mean - mean.mean())\n",
    "    vars.append(var)\n",
    "means = torch.stack(means)\n",
    "vars = torch.stack(vars)\n",
    "df = pd.DataFrame({\"mean\": means.flatten().detach().cpu().numpy(), \"var\": vars.flatten().detach().cpu().numpy(), \"size\": [f\"$|\\mathcal{{D}}|$ = {N}\" for N in np.repeat(Ns, M)]})\n",
    "\n",
    "sns.set(font=\"serif\", style=\"ticks\")\n",
    "print(means.var(dim=1).detach().cpu().numpy())\n",
    "ax = sns.kdeplot(data=df, x=\"mean\", hue=\"size\", fill=True, palette=\"colorblind\")\n",
    "ax.set(xlabel=\"Mean\", ylabel=\"Density\", title=\"Sample Mean for Weight Distribution $\\\\tilde{\\\\theta}$\", xlim=(-0.5, 0.5))\n",
    "sns.move_legend(ax, \"upper right\", title=\"Legend\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Aleatoric Problem\n",
    "Train a 4 ensembles of diffusion models. Compute the aleatoric uncertainty $\\mathbb{E}(\\sigma^2(x))$ and epistemic uncertainty $Var(\\mathbb{E}(x))$ over 1000 samples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"SEED:\", torch.random.initial_seed())\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# Experiment settings\n",
    "N = 500 # dataset size\n",
    "M = 10 # ensemble size\n",
    "B = 32\n",
    "T = 100\n",
    "embed_dim = 5\n",
    "num_epochs = 500\n",
    "sigmas = [0.1, 0.2, 0.4, 0.8]\n",
    "\n",
    "ensembles = []\n",
    "for sigma in sigmas:\n",
    "    dset = ToyDataset(N, sigma)\n",
    "    loader = torch.utils.data.DataLoader(dset, batch_size=B, shuffle=True)\n",
    "\n",
    "    ensemble = []\n",
    "    for i in range(M):\n",
    "        model = DDPM(T, device)\n",
    "        model.train()\n",
    "        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
    "\n",
    "        losses = []\n",
    "        pbar = tqdm(range(num_epochs))\n",
    "        for epoch in pbar:\n",
    "            for x, y in loader:\n",
    "                x = x.to(device)\n",
    "                y = y.to(device)\n",
    "\n",
    "                t = torch.randint(0, T, (len(x),)).to(device)\n",
    "                loss = model.p_loss(x, t, y)\n",
    "\n",
    "                optimizer.zero_grad()\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "                pbar.set_description(f\"Loss: {loss.item():.3f}\")\n",
    "            losses.append(loss.item())\n",
    "        ensemble.append(model)\n",
    "    ensembles.append(ensemble)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 10000\n",
    "y = torch.tensor([4])\n",
    "\n",
    "means = []\n",
    "vars = []\n",
    "for i, ensemble in enumerate(ensembles):\n",
    "    mean, var = get_results(ensemble, y, M, N)\n",
    "    means.append(mean - mean.mean())\n",
    "    vars.append(var)\n",
    "means = torch.stack(means)\n",
    "vars = torch.stack(vars)\n",
    "df = pd.DataFrame({\"mean\": means.flatten().detach().cpu().numpy(), \"var\": vars.flatten().detach().cpu().numpy(), \"sigma\": [f\"$\\sigma_\\eta^2$ = {sigma:.2f}\" for sigma in np.repeat(sigmas, M)**2]})\n",
    "\n",
    "sns.set(font=\"serif\", style=\"ticks\")\n",
    "print(vars.mean(dim=1).detach().cpu().numpy())\n",
    "ax = sns.kdeplot(data=df, x=\"var\", hue=\"sigma\", fill=True, palette=reversed(sns.color_palette('colorblind', len(sigmas))))\n",
    "ax.set(xlabel=\"Variance\", ylabel=\"Density\", title=\"Sample Variance for Weight Distribution $\\\\tilde{\\\\theta}$\", xlim=(-0.05,0.8))\n",
    "sns.move_legend(ax, \"upper right\", title=\"Legend\")\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
