{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "import functools\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib\n",
    "import seaborn as sns\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "from jax.scipy import stats\n",
    "from jax.config import config\n",
    "config.update('jax_platform_name', 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.style.use(\"default\")\n",
    "matplotlib.rcParams['pdf.fonttype'] = 42\n",
    "matplotlib.rcParams['ps.fonttype'] = 42"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "downstream_posteriors_dirname = \"../../../results/toy-data/downstream-posteriors/\"\n",
    "# dp_glm_posteriors_dirname = \"../../../results/toy-data/dp-glm-posteriors/\"\n",
    "real_data_results_dirname = \"../../../results/toy-data/real-data-results/\"\n",
    "figdir = \"../../../figures/toy-data/\"\n",
    "\n",
    "\n",
    "epsilon = 1.0\n",
    "repeat = 8\n",
    "\n",
    "filename = downstream_posteriors_dirname + \"{}_{}.p\".format(repeat, epsilon)\n",
    "with open(filename, \"rb\") as file:\n",
    "    posterior_obj = pickle.load(file)\n",
    "\n",
    "filename = real_data_results_dirname + \"{}.p\".format(repeat)\n",
    "with open(filename, \"rb\") as file:\n",
    "    real_data_obj = pickle.load(file)\n",
    "\n",
    "# filename = dp_glm_posteriors_dirname + \"{}_{}.p\".format(repeat, epsilon)\n",
    "# with open(filename, \"rb\") as file:\n",
    "#     dp_glm_obj = pickle.load(file)\n",
    "\n",
    "marginalised_laplace_approxes = posterior_obj[\"marginalised_laplace_approxes\"]\n",
    "s_posteriors = posterior_obj[\"s_posteriors\"]\n",
    "\n",
    "true_params = np.array(real_data_obj[\"true_params\"])\n",
    "nondp_post = real_data_obj[\"nondp_post\"]\n",
    "\n",
    "# dp_glm_inf_data = dp_glm_obj[\"dp_glm_inf_data\"]\n",
    "# dp_glm_posterior = dp_glm_inf_data.posterior.stack(draws=(\"chain\", \"draw\"))\n",
    "# dp_glm_posterior = dp_glm_posterior.theta_DP_scaled.values.transpose()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "syn_data_name = \"Syn. data $\\\\bar{p}_n(Q)$\"\n",
    "s_posterior_name = \"Syn. data $p(X | \\\\tilde{s})$\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@jax.jit\n",
    "def laplace_approx_mix_pdf(xs, laplace_approxes, i):\n",
    "    means, covs = laplace_approxes\n",
    "\n",
    "    if len(means.shape) == 1:\n",
    "        return stats.norm.pdf(xs, loc=means[i], scale=jnp.sqrt(covs[i, i]))\n",
    "\n",
    "    variances = covs.diagonal(axis1=1, axis2=2)\n",
    "    ys = jax.vmap(\n",
    "        lambda mean, var: stats.norm.pdf(xs, loc=mean, scale=var**0.5),\n",
    "        (0, 0), 0\n",
    "    )(means[:, i], variances[:, i])\n",
    "    return ys.mean(axis=0)\n",
    "\n",
    "def sample_laplace_approx(rng, size, laplace_approxes):\n",
    "    means, covs = laplace_approxes\n",
    "    d = means[0].shape[0]\n",
    "    rng, ind_key = jax.random.split(rng)\n",
    "    rng, sample_key = jax.random.split(rng, 2)\n",
    "    inds = jax.random.choice(ind_key, len(means), (size,), replace=True)\n",
    "    sample_means = means[inds]\n",
    "    sample_covs = covs[inds]\n",
    "    samples = jax.random.multivariate_normal(sample_key, sample_means, sample_covs)\n",
    "    return samples\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_syn_datasets = 400\n",
    "n_syn_dataset_mul = 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(1, 2, figsize=(10, 2))\n",
    "for i in range(2):\n",
    "    ax = axes[i]\n",
    "    xlim = np.array((-1.5, 1.5)) + true_params[i]\n",
    "    ax.set_xlim(xlim)\n",
    "    xs = np.linspace(*xlim, 200)\n",
    "\n",
    "    laplace_approx = marginalised_laplace_approxes[n_syn_datasets, n_syn_dataset_mul]\n",
    "    ax.plot(xs, laplace_approx_mix_pdf(xs, laplace_approx, i), label=syn_data_name)\n",
    "\n",
    "    ax.plot(xs, laplace_approx_mix_pdf(xs, s_posteriors, i), label=s_posterior_name)\n",
    "\n",
    "    ax.plot(xs, laplace_approx_mix_pdf(xs, nondp_post, i), label=\"Non-DP\")\n",
    "\n",
    "    # sns.kdeplot(dp_glm_posterior[:, i], ax=ax, label=\"DP-GLM\")\n",
    "\n",
    "    ax.axvline(true_params[i], color=\"grey\", linestyle=\"dashed\")\n",
    "    \n",
    "    ax.set_title(\"Coefficient: {}\".format(true_params[i]))\n",
    "    ymin, ymax = ax.get_ylim()\n",
    "    ax.set_ylim((0, ymax))\n",
    "    ax.set_xlabel(\"Parameter\")\n",
    "\n",
    "leg_h, leg_l = axes[0].get_legend_handles_labels()\n",
    "# plt.tight_layout()\n",
    "fig.legend(leg_h, leg_l, loc=\"upper center\", bbox_to_anchor=(0.5, -0.1), ncol=4)\n",
    "plt.savefig(figdir + \"toy_data_results_eps_{}.pdf\".format(epsilon), bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_syn_datasets_vals = [25, 50, 100, 200, 400]\n",
    "n_syn_dataset_mul_vals = [1, 2, 5, 10, 20]\n",
    "\n",
    "fig, axes = plt.subplots(\n",
    "    len(n_syn_datasets_vals), len(n_syn_dataset_mul_vals), \n",
    "    figsize=(2 * len(n_syn_dataset_mul_vals), 2 * len(n_syn_datasets_vals))\n",
    ")\n",
    "for i, n_syn_datasets_val in enumerate(n_syn_datasets_vals):\n",
    "    for j, n_syn_dataset_mul_val in enumerate(n_syn_dataset_mul_vals):\n",
    "        for k in range(2):\n",
    "            ax = axes[i, j]\n",
    "            xlim = (-0.5, 1.5)\n",
    "            ax.set_xlim(xlim)\n",
    "            xs = np.linspace(*xlim, 200)\n",
    "\n",
    "            laplace_approx = marginalised_laplace_approxes[n_syn_datasets_val, n_syn_dataset_mul_val]\n",
    "            label = syn_data_name if k == 0 else None\n",
    "            ax.plot(xs, laplace_approx_mix_pdf(xs, laplace_approx, k), label=label, color=\"C0\")\n",
    "\n",
    "            label = s_posterior_name if k == 0 else None\n",
    "            ax.plot(xs, laplace_approx_mix_pdf(xs, s_posteriors, k), label=label, color=\"C1\")\n",
    "\n",
    "            ax.axvline(true_params[k], color=\"grey\", linestyle=\"dashed\")\n",
    "\n",
    "        ax.set_title(\"$n_{{X^*}} / n_X$: {}, $m$: {}\".format(n_syn_dataset_mul_val, n_syn_datasets_val))\n",
    "        ymin, ymax = ax.get_ylim()\n",
    "        ax.set_ylim((0, ymax))\n",
    "\n",
    "leg_h, leg_l = axes[0, 0].get_legend_handles_labels()\n",
    "plt.tight_layout()\n",
    "fig.legend(leg_h, leg_l, loc=\"upper center\", bbox_to_anchor=(0.5, -0.00), ncol=2)\n",
    "plt.savefig(figdir + \"hyperparameter_comparison.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@jax.jit\n",
    "def monte_carlo_posterior_intervals(rng, laplace_approximations, conf_levels, mc_samples=1000):\n",
    "    samples = sample_laplace_approx(rng, mc_samples, laplace_approximations)\n",
    "    lbs = (1 - conf_levels) / 2\n",
    "    ubs = 1 - lbs\n",
    "    bounds = jnp.stack((lbs, ubs), axis=1)\n",
    "    quantiles = jax.vmap(\n",
    "        lambda bounds: jnp.quantile(samples, bounds, axis=0)\n",
    "        , 0, 0\n",
    "    )(bounds) # (conf_level, lower/upper, dimension)\n",
    "    return quantiles\n",
    "\n",
    "@jax.jit\n",
    "def has_coverage_width(rng, laplace_approximations, conf_levels, true_params, mc_samples=1000):\n",
    "    intervals = monte_carlo_posterior_intervals(rng, laplace_approximations, conf_levels)\n",
    "    has_coverage = (intervals[:, 0, :] <= true_params) & (intervals[:, 1, :] >= true_params).astype(int)\n",
    "    width = intervals[:, 1, :] - intervals[:, 0, :]\n",
    "    return has_coverage, width\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "\n",
    "conf_levels = jnp.linspace(0.05, 0.95, 19)\n",
    "rng = jax.random.PRNGKey(4628368)\n",
    "all_records = []\n",
    "def conf_int_record(obj, conf_level, type, width, coverage, dim):\n",
    "    return {\n",
    "        \"epsilon\": obj[\"epsilon\"],\n",
    "        \"delta\": obj[\"delta\"],\n",
    "        \"repeat\": obj[\"repeat\"],\n",
    "        \"width\": float(width),\n",
    "        \"has_coverage\": int(coverage),\n",
    "        \"dim\": dim,\n",
    "        \"conf_level\": float(conf_level),\n",
    "        \"type\": type,\n",
    "    }\n",
    "\n",
    "for path in Path(downstream_posteriors_dirname).glob(\"*.p\"):\n",
    "    with open(path, \"rb\") as file:\n",
    "        obj = pickle.load(file)\n",
    "    marginalised_posteriors = obj[\"marginalised_laplace_approxes\"][n_syn_datasets, n_syn_dataset_mul]\n",
    "    s_posteriors = obj[\"s_posteriors\"]\n",
    "\n",
    "    rng, key = jax.random.split(rng)\n",
    "    marginalised_coverage, marginalised_width = has_coverage_width(key, marginalised_posteriors, conf_levels, true_params)\n",
    "    rng, key = jax.random.split(rng)\n",
    "    s_coverage, s_width = has_coverage_width(key, s_posteriors, conf_levels, true_params)\n",
    "\n",
    "    for j, (coverage, width) in enumerate(zip(marginalised_coverage, marginalised_width)):\n",
    "        for i in range(2):\n",
    "            all_records.append(conf_int_record(obj, conf_levels[j], syn_data_name, width[i], coverage[i], i))\n",
    "    for j, (coverage, width) in enumerate(zip(s_coverage, s_width)):\n",
    "        for i in range(2):\n",
    "            all_records.append(conf_int_record(obj, conf_levels[j], s_posterior_name, width[i], coverage[i], i))\n",
    "\n",
    "\n",
    "for path in Path(dp_glm_posteriors_dirname).glob(\"*.p\"):\n",
    "    with open(path, \"rb\") as file:\n",
    "        obj = pickle.load(file)\n",
    "    dp_glm_inf_data = obj[\"dp_glm_inf_data\"]\n",
    "    dp_glm_posterior = dp_glm_inf_data.posterior.stack(draws=(\"chain\", \"draw\"))\n",
    "    dp_glm_posterior = dp_glm_posterior.theta_DP_scaled.values.transpose()\n",
    "\n",
    "    for conf_level in conf_levels:\n",
    "        lb = (1 - conf_level) / 2\n",
    "        ub = 1 - lb\n",
    "        interval = np.quantile(dp_glm_posterior, [lb, ub], axis=0)\n",
    "        coverage = ((interval[0, :] <= true_params) & (interval[1, :] >= true_params)).astype(int)\n",
    "        width = interval[1, :] - interval[0, :]\n",
    "\n",
    "        for i in range(2):\n",
    "            all_records.append(conf_int_record(obj, conf_level, \"DP-GLM\", width[i], coverage[i], i))\n",
    "\n",
    "df = pd.DataFrame.from_records(all_records)\n",
    "df[\"coefficient\"] = true_params[df.dim]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cdf = df.copy()\n",
    "g = sns.FacetGrid(cdf, row=\"epsilon\", col=\"coefficient\", hue=\"type\", col_order=true_params, xlim=(0, 1.01), ylim=(0, 1.01), height=2.25, aspect=1)\n",
    "g.map_dataframe(sns.lineplot, x=\"conf_level\", y=\"has_coverage\")\n",
    "diag = (0.0, 1.01)\n",
    "g.map(lambda **kws: plt.gca().plot(diag, diag, linestyle=\"dashed\", color=\"black\"))\n",
    "g.add_legend(title=\"\", loc=\"upper center\", bbox_to_anchor=(0.3, -0.00), ncol=2)\n",
    "g.set_titles(template=\"$\\epsilon$: {row_name}, Coefficient: {col_name}\")\n",
    "g.set_xlabels(\"Conf Level\")\n",
    "g.set_ylabels(\"Coverage\")\n",
    "plt.savefig(figdir + \"coverages.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = sns.FacetGrid(cdf, row=\"epsilon\", col=\"coefficient\", hue=\"type\", col_order=true_params, sharey=False, height=2.25, aspect=1)\n",
    "g.map_dataframe(sns.lineplot, x=\"conf_level\", y=\"width\")\n",
    "g.add_legend(title=\"\", loc=\"upper center\", bbox_to_anchor=(0.3, -0.00), ncol=2)\n",
    "g.set_titles(template=\"$\\epsilon$: {row_name}, Coefficient: {col_name}\")\n",
    "g.set_xlabels(\"Conf Level\")\n",
    "g.set_ylabels(\"Width\")\n",
    "plt.savefig(figdir + \"widths.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  },
  "vscode": {
   "interpreter": {
    "hash": "fa1919baa51ca66e7ec489754979eb11625f9177f6ce3c8d3d9d35570b67b133"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
