{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "97b0ce48-6442-4850-a85f-7e7ef95feea7",
   "metadata": {},
   "source": [
    "# Automatic Calibration Diagnosis: Interpreting Probability Integral Transform (PIT) Histograms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e2b9cc4-4b87-4333-8c82-d2a03aab5255",
   "metadata": {},
   "outputs": [],
   "source": [
    "import functools\n",
    "import math\n",
    "import random\n",
    "\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import ticker\n",
    "from matplotlib import gridspec\n",
    "import pandas as pd\n",
    "from scipy import stats\n",
    "from sklearn.metrics.pairwise import euclidean_distances\n",
    "import torch\n",
    "import wandb\n",
    "\n",
    "from calibration import data\n",
    "from calibration import dist\n",
    "from calibration import method\n",
    "from calibration import pit\n",
    "from calibration import plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17307a07-131a-4ab1-b8d4-7d0b89c97401",
   "metadata": {},
   "outputs": [],
   "source": [
    "matplotlib.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "matplotlib.rcParams[\"font.size\"] = 8\n",
    "matplotlib.rcParams[\"axes.titlesize\"] = 10\n",
    "matplotlib.rcParams[\"figure.dpi\"] = 300\n",
    "PHI = (1 + math.sqrt(5)) / 2\n",
    "WIDTH = 5.5\n",
    "matplotlib.rcParams['figure.figsize'] = (WIDTH, (PHI - 1) * WIDTH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "465f3a40-2875-407a-84e5-44a2cab8fe53",
   "metadata": {},
   "outputs": [],
   "source": [
    "def param2pdf(weight, mean, variance):\n",
    "    return lambda x: dist.pdf_gaussian_mixture(x, weight, mean, variance)\n",
    "\n",
    "def y2pdf(w, s, v1, v2):\n",
    "    weight = torch.tensor([w, 1 - w])\n",
    "    mean = torch.tensor([-s / 2, s / 2])\n",
    "    variance = torch.tensor([v1, v2])\n",
    "    return param2pdf(weight, mean, variance)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1253c4b4-f6c7-4ed9-a0e3-98ee74dcae03",
   "metadata": {},
   "source": [
    "## PIT histogram"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97c428af-1dc5-448f-9653-552905bcacf1",
   "metadata": {},
   "outputs": [],
   "source": [
    "bias = (torch.tensor(1.0), torch.tensor(1.0), \"biased\")\n",
    "under = (torch.tensor(0.0), torch.tensor(2.0), \"under-dispersed\")\n",
    "over = (torch.tensor(0.0), torch.tensor(0.5), \"over-dispersed\")\n",
    "\n",
    "_, axes = plt.subplots(nrows=3, ncols=2,\n",
    "                       sharex=\"col\", constrained_layout=True)\n",
    "for ax, (mean, variance, label) in zip(axes, [bias, under, over]):\n",
    "    ax[0].set_title(f\"{label} predictive distribution\")\n",
    "    y = mean + torch.sqrt(variance) * torch.randn(pit.SAMPLES)\n",
    "    handle_pred = plot.density(ax[0], dist.pdf_gaussian, color=\"C2\")\n",
    "    dist_obs = functools.partial(dist.pdf_gaussian, mean=mean, variance=variance)\n",
    "    handle_obs = plot.density(ax[0], dist_obs, color=\"C3\", linestyle=\"--\")\n",
    "    handle_pit = plot.pit_hist(ax[1], pit.pit_hist(pit.pit_gaussian(y)))\n",
    "    ax[0].set_ylabel(\"density\")\n",
    "    ax[1].set_ylabel(\"density\")\n",
    "axes[2, 0].set_xlabel(\"y\", style=\"italic\")\n",
    "axes[2, 1].set_xlabel(\"PIT\")\n",
    "axes[0, 1].legend([handle_pit, handle_pred, handle_obs],\n",
    "                  [\"PIT histogram\",\n",
    "                   \"predictive distribution\",\n",
    "                   \"observation-generating\\ndistribution\"])\n",
    "plt.savefig(\"figures/types.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e31d5ac0-7708-4854-a87a-69ab32d6ecc1",
   "metadata": {},
   "source": [
    "## Automatically interpreting PIT histograms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "548c6534-51b8-4ebe-a6d1-b0801c16ce5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(nrows=3, ncols=2,\n",
    "                         constrained_layout=True,\n",
    "                         figsize=(WIDTH, (2 / 3) * WIDTH))\n",
    "[[ax_obs, ax_pre], [ax_nll, ax_pit], [ax_pro, ax_int]] = axes\n",
    "\n",
    "ax_obs.set_title(\"step 1 and 2\")\n",
    "mean = torch.tensor(1.0)\n",
    "variance = torch.tensor(1.0)\n",
    "pdf_obs = functools.partial(dist.pdf_gaussian, mean=mean, variance=variance)\n",
    "plot.density(ax_obs, pdf_obs,\n",
    "             color=\"C3\", linestyle=\"--\",\n",
    "             label=\"observation-\\ngenerating\\ndistribution\")\n",
    "y = mean + torch.sqrt(variance) * torch.randn(100)\n",
    "ax_obs.scatter(y, torch.zeros_like(y),\n",
    "               color=\"C4\", marker=\"|\",\n",
    "               label=\"sample\")\n",
    "ax_obs.legend(loc=\"upper left\")\n",
    "ax_obs.set_xlabel(\"y\", style=\"italic\")\n",
    "ax_obs.set_ylabel(\"density\")\n",
    "\n",
    "ax_pre.set_title(\"step 3\")\n",
    "cdf_pre = dist.cdf_gaussian\n",
    "plot.cumulative_density(ax_pre, cdf_pre,\n",
    "                        color=\"C2\",\n",
    "                        label=\"predictive\\ndistribution\")\n",
    "ax_pre.scatter(y, torch.zeros_like(y),\n",
    "               color=\"C4\", marker=\"|\",\n",
    "               label=\"sample\")\n",
    "ax_pre.legend(loc=\"upper left\")\n",
    "ax_pre.set_xlabel(\"y\", style=\"italic\")\n",
    "ax_pre.set_ylabel(\"cumulative density\")\n",
    "\n",
    "ax_pit.set_title(\"step 4 and 5\")\n",
    "plot.pit_hist(ax_pit, pit.pit_hist(pit.pit_gaussian(y)),\n",
    "              label=\"PIT histogram\")\n",
    "ax_pit.legend(loc=\"upper left\")\n",
    "ax_pit.set_xlabel(\"PIT\")\n",
    "ax_pit.set_ylabel(\"density\")\n",
    "\n",
    "ax_int.set_axis_off()\n",
    "textstyle = dict(va=\"center\", ha=\"center\", fontsize=10)\n",
    "ax_int.annotate(\"interpreter\",\n",
    "                xy=(0.5, 1), xycoords=\"data\",\n",
    "                xytext=(0.5, 0.5), textcoords=\"data\",\n",
    "                arrowprops=dict(arrowstyle=\"<-\"),\n",
    "                **textstyle)\n",
    "ax_int.annotate(\"interpreter\",\n",
    "                xy=(0, 0.5), xycoords=\"data\",\n",
    "                xytext=(0.5, 0.5), textcoords=\"data\",\n",
    "                arrowprops=dict(arrowstyle=\"->\"),\n",
    "                **textstyle, alpha=0.0)\n",
    "\n",
    "plot.density(ax_pro, pdf_obs,\n",
    "             color=\"C3\", linestyle=\"--\",\n",
    "             label=\"predicted\\nobservation-\\ngenerating\\ndistribution\")\n",
    "ax_pro.legend(loc=\"upper left\")\n",
    "ax_pro.set_xlabel(\"y\", style=\"italic\")\n",
    "ax_pro.set_ylabel(\"density\")\n",
    "\n",
    "ax_nll.set_axis_off()\n",
    "ax_nll.annotate(\"mean negative log-likelihood\",\n",
    "                xy=(0.5, 1), xycoords=\"data\",\n",
    "                xytext=(0.5, 0.5), textcoords=\"data\",\n",
    "                arrowprops=dict(arrowstyle=\"->\"),\n",
    "                **textstyle)\n",
    "ax_nll.annotate(\"mean negative log-likelihood\",\n",
    "                xy=(0.5, 0), xycoords=\"data\",\n",
    "                xytext=(0.5, 0.5), textcoords=\"data\",\n",
    "                arrowprops=dict(arrowstyle=\"->\"),\n",
    "                **textstyle, alpha=0.0)\n",
    "\n",
    "plt.savefig(\"figures/concept.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2eda3670-2e75-4025-b5ef-82e56433b7c5",
   "metadata": {},
   "source": [
    "## Interpreter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6abe712-cd0a-4182-ad61-550171064cf9",
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint = torch.load(\"models/eternal-smoke-51.pt\")\n",
    "hyperparams = checkpoint[\"hyperparams\"]\n",
    "interpreter = method.MDN(inputs=hyperparams[\"bins\"],\n",
    "                neurons=hyperparams[\"neurons\"],\n",
    "                components=hyperparams[\"components\"])\n",
    "interpreter.load_state_dict(checkpoint[\"model_state_dict\"])\n",
    "interpreter"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6359ed33-295a-4fd3-9350-d96a53c99b81",
   "metadata": {},
   "source": [
    "## Synthetic data set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22f63ce3-9216-419b-8ba9-cafac1fdf169",
   "metadata": {},
   "outputs": [],
   "source": [
    "TESTS = 1000\n",
    "torch.manual_seed(78)\n",
    "testset = pit.PITDataset(TESTS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a34fb8d-fa51-4796-9740-f5ce3a2aa709",
   "metadata": {},
   "outputs": [],
   "source": [
    "refset = pit.PITReference()\n",
    "len(refset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99b54034-9093-4c89-a598-530c749f160b",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_mdn = method.predict(interpreter, testset.X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a608e375-aaf3-417a-b397-7978d3c0710c",
   "metadata": {},
   "outputs": [],
   "source": [
    "random.seed(54)\n",
    "random.sample(range(len(testset)), k=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07f4d5d3-d362-40a2-983e-ee276914df8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "distances = euclidean_distances(testset.X, refset.X)\n",
    "js = distances.argmin(axis=1)\n",
    "random.seed(54)\n",
    "for i in random.sample(range(len(testset)), k=3):\n",
    "    y = testset.annotation[i]\n",
    "    weight = pred_mdn[0][i]\n",
    "    mean = pred_mdn[1][i]\n",
    "    variance = pred_mdn[2][i]\n",
    "    neighbour = refset.annotation[js][i]\n",
    "    _, ax = plt.subplots()\n",
    "    plot.density(ax, y2pdf(*y), label=\"data generating\")\n",
    "    plot.density(ax, param2pdf(weight, mean, variance), label=\"MDN\")\n",
    "    plot.density(ax, y2pdf(*neighbour), label=\"nearest neighbor\")\n",
    "    ax.legend()\n",
    "    plt.show()\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1ba470f-68c9-4a0e-81f8-5765dca510a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "nll_mdn = dist.nll_gaussian_mixture(testset.y, *pred_mdn).mean()\n",
    "nll_mdn.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "839477cc-cb24-4d28-bcc9-620ffd670490",
   "metadata": {},
   "outputs": [],
   "source": [
    "steps = range(5, 18)\n",
    "nll_neighbour = []\n",
    "for s in steps:\n",
    "    refset = pit.PITReference(steps=s)\n",
    "    distances = euclidean_distances(testset.X, refset.X)\n",
    "    js = distances.argmin(axis=1)\n",
    "    neighbour = refset.annotation[js]\n",
    "    weight = torch.stack([neighbour[:, 0], 1 - neighbour[:, 0]], dim=1)\n",
    "    mean = torch.stack((-neighbour[:, 1] / 2, neighbour[:, 1] / 2), dim=1)\n",
    "    variance = neighbour[:, 2:]\n",
    "    nll = dist.nll_gaussian_mixture(testset.y, weight, mean, variance)\n",
    "    nll_neighbour.append(nll.mean().item())\n",
    "    print(f\"{s:2d} {nll_neighbour[-1]:f} {s ** 4}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5dfe9fc3-e110-4b3e-8ec7-ac4190422c1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "_, ax = plt.subplots(figsize=(5.5, (2 / 3) * (PHI - 1) * 5.5), constrained_layout=True)\n",
    "size = [s ** 4 for s in steps]\n",
    "ax.scatter(size, nll_neighbour, marker=\"+\", label=\"nearest neighbour algorithm\")\n",
    "ax.axhline(round(nll_mdn.item(), 3), ls=\"--\", label=\"our interpreter\")\n",
    "ax.set_xlabel(\"training set size of nearest neighbour algorithm\")\n",
    "ax.set_ylabel(\"negative log-likelihood\")\n",
    "ax.legend()\n",
    "ax.yaxis.set_major_formatter(ticker.StrMethodFormatter(\"{x:.3f}\"))\n",
    "plt.savefig(\"figures/neighbour.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d177e6f-5e2a-4338-a7b7-0f64070165ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame.from_dict({\"steps\": steps,\n",
    "                             \"size\": size,\n",
    "                             \"nll\": nll_neighbour})\n",
    "df = df.set_index(\"steps\")\n",
    "df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee8794f1-152c-4d84-b886-309494d7f059",
   "metadata": {},
   "source": [
    "## Uniform PIT histogram"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ecb48364-7a1d-4568-a14d-1c70d4779a2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "pit_hist_uniform = torch.full((pit.BINS, ), 1.0)\n",
    "pred_uniform = method.predict(interpreter, pit_hist_uniform)\n",
    "_, ax = plt.subplots()\n",
    "plot.density(ax, dist.pdf_gaussian, label=\"data-generating\")\n",
    "plot.density(ax, param2pdf(*pred_uniform), label=\"predicted data-generating\")\n",
    "ax.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb892d67-3b39-43b2-b4a1-9f3921cbab2b",
   "metadata": {},
   "source": [
    "## UCI ML repository data sets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed24af42-6435-4d8a-8478-bf370825c303",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_model(modelfile, Model, keys):\n",
    "    checkpoint = torch.load(modelfile)\n",
    "    hyperparams = checkpoint[\"hyperparams\"]\n",
    "    model = Model(*[hyperparams[k] for k in (\"inputs\", \"neurons\") + keys])\n",
    "    model.load_state_dict(checkpoint[\"model_state_dict\"])\n",
    "    return model\n",
    "\n",
    "def load_mdn(modelfile):\n",
    "    return load_model(modelfile, method.MDN, (\"components\", ))\n",
    "\n",
    "def load_dn(modelfile):\n",
    "    return load_mdn(modelfile)\n",
    "\n",
    "def load_de(modelfile):\n",
    "    return load_model(modelfile, method.DE, (\"members\", ))\n",
    "\n",
    "def plot_interpretation(ax, interpretation):\n",
    "    density_predictive = plot.density(ax, dist.pdf_gaussian, color=\"C2\")\n",
    "    density_interpret = plot.density(ax, param2pdf(*interpretation), color=\"C3\", ls=\"--\")\n",
    "    return density_predictive, density_interpret\n",
    "\n",
    "def pit_hist(model, dataset):\n",
    "    alpha, mu, sigma = method.predict(model, dataset.X)\n",
    "    pit_values = pit.pit_gaussian_mixture(dataset.y, alpha, mu, sigma)\n",
    "    return pit.pit_hist(pit_values)\n",
    "\n",
    "def diagnose(pit_hist, ax_pit, ax_dist, interpreter=interpreter):\n",
    "    interpretation = method.predict(interpreter, pit_hist)\n",
    "    pit_hist_interpreter = pit.pit_hist(pit.pit_gaussian(dist.sample_gaussian_mixture(*interpretation)))\n",
    "    hist_true = plot.pit_hist(ax_pit, pit_hist, fill=True)\n",
    "    hist_pred = plot.pit_hist(ax_pit, pit_hist_interpreter)\n",
    "    density_predictive, density_interpret = plot_interpretation(ax_dist, interpretation)\n",
    "    return hist_true, hist_pred, density_predictive, density_interpret\n",
    "\n",
    "def visualise(pit_hist_dn, pit_hist_de, pit_hist_mdn):\n",
    "    _, axes = plt.subplots(3, 2,\n",
    "                           constrained_layout=True,\n",
    "                           figsize=(WIDTH, (PHI - 1) * WIDTH))\n",
    "    axes[0, 0].set_title(\"density network\")\n",
    "    axes[1, 0].set_title(\"deep ensemble\")\n",
    "    axes[2, 0].set_title(\"mixture density network\")\n",
    "    _, hist_pred, density_pred, density_interpret = diagnose(pit_hist_dn, axes[0, 0], axes[0, 1])\n",
    "    print(density_interpret)\n",
    "    diagnose(pit_hist_de, axes[1, 0], axes[1, 1])\n",
    "    hist_mdn = plot.pit_hist(axes[2, 0], pit_hist_mdn, fill=True)\n",
    "    axes[0, 0].set_xticklabels([])\n",
    "    axes[0, 1].set_xticklabels([])\n",
    "    axes[2, 1].set_axis_off()\n",
    "    axes[2, 1].legend([hist_mdn, hist_pred, density_pred, density_interpret],\n",
    "                      [\"true PIT histogram\",\n",
    "                       \"predicted PIT histogram\",\n",
    "                       \"predictive distribution\",\n",
    "                       \"predicted observation-generating distribution\"],\n",
    "                      loc=\"center\")\n",
    "    axes[1, 0].set_xlabel(\"PIT\")\n",
    "    axes[2, 0].set_xlabel(\"PIT\")\n",
    "    axes[1, 1].set_xlabel(\"y\", style=\"italic\")\n",
    "    for i in range(3):\n",
    "        axes[i, 0].set_ylabel(\"density\")\n",
    "        axes[i, 1].set_ylabel(\"density\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ad565a5-d0c0-44ee-8dd7-263c7bd44e67",
   "metadata": {},
   "source": [
    "### Metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "244c6f87-97ca-4004-afa7-57e4d2dc1f27",
   "metadata": {},
   "outputs": [],
   "source": [
    "api = wandb.Api()\n",
    "runs = api.runs(\"calibration\")\n",
    "\n",
    "keys = [\"dataname\", \"method\", \"seed\", \"neurons\"]\n",
    "dicts, names = [], []\n",
    "for run in runs:\n",
    "    try:\n",
    "        dictionary = {k: run.config[k] for k in keys}\n",
    "        dictionary[\"nll\"] = run.summary[\"test.nll\"]\n",
    "        dictionary[\"crps\"] = run.summary[\"test.crps\"]\n",
    "    except KeyError:\n",
    "        continue\n",
    "    dicts.append(dictionary)\n",
    "    names.append(run.name)\n",
    "\n",
    "df = pd.DataFrame(data=dicts, index=names)\n",
    "gdf = df.groupby([\"dataname\", \"method\"])\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6bc8a08-17be-4060-9492-33ca74b49961",
   "metadata": {},
   "outputs": [],
   "source": [
    "gdf[\"nll\"].agg([\"mean\", \"sem\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98ca5692-47f4-49ae-8507-9db9195fb257",
   "metadata": {},
   "outputs": [],
   "source": [
    "gdf[\"crps\"].agg([\"mean\", \"sem\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ecd0a840-f695-4a25-9df7-079d6bceb05f",
   "metadata": {},
   "source": [
    "### Year"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78fae5e7-43cc-4b86-9d73-9f9f1fcbb003",
   "metadata": {},
   "outputs": [],
   "source": [
    "_, _, yearset = data.split(*data.year(), seed=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d716b2c2-789c-4340-baf7-ef4b64851e28",
   "metadata": {},
   "outputs": [],
   "source": [
    "pit_hist_dn_year = pit_hist(load_dn(\"models/rich-dragon-8.pt\"), yearset)\n",
    "pit_hist_de_year = pit_hist(load_de(\"models/generous-valley-7.pt\"), yearset)\n",
    "pit_hist_mdn_year = pit_hist(load_mdn(\"models/chocolate-sound-9.pt\"), yearset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "974643b6-a363-4b2b-9d5d-b800b928324a",
   "metadata": {},
   "outputs": [],
   "source": [
    "visualise(pit_hist_dn_year, pit_hist_de_year, pit_hist_mdn_year)\n",
    "plt.savefig(\"figures/year.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "446d3265-8705-4828-bb66-fa532770e713",
   "metadata": {},
   "source": [
    "### Protein"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a26adca-69c2-48ef-a7b2-49dbc4952e48",
   "metadata": {},
   "outputs": [],
   "source": [
    "_, _, proteinset = data.split(*data.protein(), seed=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0ace03b-e83e-4584-b32e-8ee6d3540460",
   "metadata": {},
   "outputs": [],
   "source": [
    "pit_hist_dn_protein = pit_hist(load_dn(\"models/super-durian-4.pt\"), proteinset)\n",
    "pit_hist_de_protein = pit_hist(load_de(\"models/golden-snow-6.pt\"), proteinset)\n",
    "pit_hist_mdn_protein = pit_hist(load_mdn(\"models/lucky-moon-5.pt\"), proteinset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48439e3d-6a43-4026-a6f4-4afafb2738cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "visualise(pit_hist_dn_protein, pit_hist_de_protein, pit_hist_mdn_protein)\n",
    "plt.savefig(\"figures/protein.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f982b520-8fc6-47a2-87f9-5088c0879db4",
   "metadata": {},
   "source": [
    "### Power"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc0f0410-bdf4-4fc6-9047-e8f9fe06498e",
   "metadata": {},
   "outputs": [],
   "source": [
    "_, _, powerset = data.split(*data.power(), seed=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a23f344-7a8e-45dd-ae92-adb21b220b47",
   "metadata": {},
   "outputs": [],
   "source": [
    "pit_hist_dn_power = pit_hist(load_dn(\"models/polished-star-1.pt\"), powerset)\n",
    "pit_hist_de_power = pit_hist(load_de(\"models/elated-surf-3.pt\"), powerset)\n",
    "pit_hist_mdn_power = pit_hist(load_mdn(\"models/effortless-firefly-2.pt\"), powerset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3256af68-ff6b-4d89-8088-24f5131c2aa7",
   "metadata": {},
   "outputs": [],
   "source": [
    "visualise(pit_hist_dn_power, pit_hist_de_power, pit_hist_mdn_power)\n",
    "plt.savefig(\"figures/power.pdf\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
