{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm.notebook import tqdm\n",
    "from utils import ToyDataset, HyperDiffusion\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Epistemic Problem\n",
    "Train an 4 hyper-diffusion models models on 4 different dataset sizes $|\\mathcal{D}|=\\{100, 200, 400, 800\\}$. Compute the aleatoric uncertainty $\\mathbb{E}(\\sigma^2(x))$ and epistemic uncertainty $Var(\\mathbb{E}(x))$ over 10 weights and 1000 samples per weight."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"SEED:\", torch.seed())\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# Experiment settings\n",
    "B = 32\n",
    "T = 100\n",
    "embed_dim = 5\n",
    "num_epochs = 500\n",
    "sigma = 0.05\n",
    "Ns = [400, 600, 800, 1000]\n",
    "\n",
    "\n",
    "models = []\n",
    "for N in Ns:\n",
    "    dset = ToyDataset(N, sigma)\n",
    "    loader = torch.utils.data.DataLoader(dset, batch_size=B, shuffle=True)\n",
    "\n",
    "    model = HyperDiffusion(T, embed_dim, device)\n",
    "    model.train()\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)\n",
    "\n",
    "    pbar = tqdm(range(num_epochs))\n",
    "    for epoch in pbar:\n",
    "        for x, y in loader:\n",
    "            x = x.to(device)\n",
    "            y = y.to(device)\n",
    "\n",
    "            t = torch.randint(0, T, (len(x),)).to(device)\n",
    "            embedding = torch.randn(embed_dim).to(device)\n",
    "            loss = model.p_loss(x, t, y, embedding)\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            pbar.set_description(f\"Loss: {loss.item():.3f}\")\n",
    "    models.append(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_results(model, y, M=10, N=10000):\n",
    "    model.eval()\n",
    "    cond = y.repeat(N).to(device)\n",
    "    means = []\n",
    "    vars = []\n",
    "    weights = torch.randn(M, model.embed_dim).to(device)\n",
    "    for weight in weights:\n",
    "        full_results = model.p_sample_loop(N, cond, weight)\n",
    "        means.append(full_results.mean())\n",
    "        vars.append(full_results.var())\n",
    "    means = torch.stack(means)\n",
    "    vars = torch.stack(vars)\n",
    "    return means, vars"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y = torch.tensor([0])\n",
    "M = 10\n",
    "N = 10000\n",
    "\n",
    "means = []\n",
    "vars = []\n",
    "for i, model in enumerate(models):\n",
    "    mean, var = get_results(model, y, M, N)\n",
    "    means.append(mean - mean.mean())\n",
    "    vars.append(var)\n",
    "means=torch.stack(means)\n",
    "vars = torch.stack(vars)\n",
    "df = pd.DataFrame(\n",
    "    {\"N\": [f\"|D|={N}\" for N in np.repeat(Ns, M)], \"mean\": means.flatten().detach().cpu().numpy(), \"var\": vars.flatten().detach().cpu().numpy()}\n",
    ")\n",
    "\n",
    "print(means.var(dim=1).detach().cpu().numpy())\n",
    "sns.set(font=\"serif\", style=\"ticks\")\n",
    "ax = sns.kdeplot(data=df, x=\"mean\", hue=\"N\", fill=True, palette=\"colorblind\")\n",
    "ax.set(xlabel=\"Mean\", ylabel=\"Density\", title=\"Sample Mean for Weight Distribution $\\\\tilde{\\\\theta}$\", xlim=(-0.1, 0.1))\n",
    "sns.move_legend(\n",
    "    ax,\n",
    "    \"upper right\",\n",
    "    title=\"Legend\",\n",
    ")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Aleatoric Problem\n",
    "Train an 4 hyper-diffusion models on 4 different noise levels $\\sigma=\\{0.1, 0.2, 0.4,0.8\\}$. Compute the aleatoric uncertainty $\\mathbb{E}(\\sigma^2(x))$ and epistemic uncertainty $Var(\\mathbb{E}(x))$ over 10 weights and 1000 samples per weight."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"SEED:\", torch.seed())\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# Experiment settings\n",
    "N = 500\n",
    "B = 32\n",
    "T = 100\n",
    "embed_dim = 5\n",
    "num_epochs = 500\n",
    "sigmas = [0.1, 0.2, 0.4, 0.8]\n",
    "\n",
    "models = []\n",
    "for sigma in sigmas:\n",
    "    dset = ToyDataset(N, sigma)\n",
    "    loader = torch.utils.data.DataLoader(dset, batch_size=B, shuffle=True)\n",
    "\n",
    "    model = HyperDiffusion(T, embed_dim, device)\n",
    "    model.train()\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)\n",
    "\n",
    "    pbar = tqdm(range(num_epochs))\n",
    "    for epoch in pbar:\n",
    "        for x, y in loader:\n",
    "            x = x.to(device)\n",
    "            y = y.to(device)\n",
    "\n",
    "            t = torch.randint(0, T, (len(x),)).to(device)\n",
    "            embedding = torch.randn(embed_dim).to(device)\n",
    "            loss = model.p_loss(x, t, y, embedding)\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            pbar.set_description(f\"Loss: {loss.item():.3f}\")\n",
    "    models.append(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "M = 20\n",
    "N = 10000\n",
    "y = torch.tensor([0])\n",
    "\n",
    "means = []\n",
    "vars = []\n",
    "for i, model in enumerate(models):\n",
    "    mean, var = get_results(model, y, M, N)\n",
    "    means.append(mean - mean.mean())\n",
    "    vars.append(var)\n",
    "means = torch.stack(means)\n",
    "vars = torch.stack(vars)\n",
    "\n",
    "df = pd.DataFrame({\"mean\": means.flatten().detach().cpu().numpy(), \"var\": vars.flatten().detach().cpu().numpy(), \"sigma\": [f\"$\\sigma_\\eta^2$ = {sigma:.2f}\" for sigma in np.repeat(sigmas, M)**2]})\n",
    "sns.set(font=\"serif\", style=\"ticks\")\n",
    "print(vars.mean(dim=1).detach().cpu().numpy())\n",
    "ax = sns.kdeplot(data=df, x=\"var\", hue=\"sigma\", fill=True, palette=reversed(sns.color_palette('colorblind', len(sigmas))))\n",
    "ax.set(xlabel=\"Variance\", ylabel=\"Density\", title=\"Sample Variance for Weight Distribution $\\\\tilde{\\\\theta}$\", xlim=(-0.05, 0.8))\n",
    "sns.move_legend(ax, \"upper right\", title=\"Legend\")\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
