{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "from lib import rubin\n",
    "from lib import privacy_accounting\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import arviz as az\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib\n",
    "import seaborn as sns\n",
    "import scipy.stats as stats\n",
    "from tqdm import tqdm\n",
    "import warnings\n",
    "import functools\n",
    "import itertools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.style.use(\"default\")\n",
    "matplotlib.rcParams['pdf.fonttype'] = 42\n",
    "matplotlib.rcParams['ps.fonttype'] = 42"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "figdir = \"figures/gaussian-test/\"\n",
    "\n",
    "analyst_label = \"Analyst's $p(\\mu | X, I_A)$\"\n",
    "data_provider_label = \"Data Provider's $p(\\mu | X, I_S)$\"\n",
    "syn_data_label = \"With syn. data $\\\\bar{p}_n(\\mu)$\"\n",
    "\n",
    "mu_true = 1\n",
    "sigma_true = 2\n",
    "n = 100\n",
    "X = np.random.normal(mu_true, sigma_true, n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def known_variance_inference(X, mu_0, sigma_0, known_sigma):\n",
    "    n = X.size\n",
    "    X_mean = np.mean(X)\n",
    "    posterior_mu = (n * 1 / known_sigma**2 * X_mean + 1 / sigma_0**2 * mu_0) / (n * 1 / known_sigma**2 + 1 / sigma_0**2)\n",
    "    posterior_sigma = np.sqrt(1 / (n * 1 / known_sigma**2 + 1 / sigma_0**2))\n",
    "    posterior = stats.norm(loc=posterior_mu, scale=posterior_sigma)\n",
    "    return posterior\n",
    "\n",
    "def unknown_variance_inference(X, mu_0, sigma_0, nu_0, kappa_0):\n",
    "    n = X.size\n",
    "    X_mean = np.mean(X)\n",
    "    X_var = np.var(X, ddof=1)\n",
    "\n",
    "    mu = kappa_0 / (kappa_0 + n) * mu_0 + n / (kappa_0 + n) * X_mean\n",
    "    kappa = kappa_0 + n\n",
    "    nu = nu_0 + n\n",
    "    sigma2 = (nu_0 + sigma_0**2 + (n - 1) * X_var + kappa_0 * n / (kappa_0 + n) * (X_mean - mu_0)**2) / nu\n",
    "\n",
    "    sigma2_post = stats.invgamma(a=nu / 2, scale=nu * sigma2 / 2)\n",
    "    def mu_post(sigma2):\n",
    "        return stats.norm(loc=mu, scale=np.sqrt(sigma2 / kappa))\n",
    "    mu_post_marginal = stats.t(df=nu, loc=mu, scale=np.sqrt(sigma2 / kappa))\n",
    "    return sigma2_post, mu_post, mu_post_marginal\n",
    "\n",
    "def combined_posterior_pdf(xs, syn_data_posteriors):\n",
    "    return np.stack([posterior.pdf(xs) for posterior in syn_data_posteriors], axis=0).mean(axis=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Known Variance Up- and Downstream"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "down_mu_0 = 0\n",
    "down_sigma_0 = 10\n",
    "down_known_sigma = 1.0 * sigma_true\n",
    "down_known_sigma_diff = 0.5 * sigma_true\n",
    "true_down_posterior = known_variance_inference(X, down_mu_0, down_sigma_0, down_known_sigma)\n",
    "true_down_posterior_diff = known_variance_inference(X, down_mu_0, down_sigma_0, down_known_sigma_diff)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "syn_mu_0 = 0\n",
    "syn_sigma_0 = 10\n",
    "syn_known_sigma = 1.0 * sigma_true\n",
    "\n",
    "syn_posterior = known_variance_inference(\n",
    "    X, syn_mu_0, syn_sigma_0, syn_known_sigma\n",
    ")\n",
    "\n",
    "def sample_syn_post(size):\n",
    "    mus = syn_posterior.rvs(size=size)\n",
    "    return mus"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_syn_datasets = 400\n",
    "n_syn_dataset = 20 * n\n",
    "syn_mu_sample = sample_syn_post(n_syn_datasets)\n",
    "syn_datasets = np.zeros((n_syn_datasets, n_syn_dataset))\n",
    "for i in range(n_syn_datasets):\n",
    "    mu = syn_mu_sample[i]\n",
    "    syn_datasets[i, :] = stats.norm.rvs(loc=mu, scale=syn_known_sigma, size=n_syn_dataset)\n",
    "\n",
    "syn_data_posteriors = [\n",
    "    known_variance_inference(syn_datasets[i, :], down_mu_0, down_sigma_0, down_known_sigma) \n",
    "    for i in range(n_syn_datasets)\n",
    "]\n",
    "syn_data_posteriors_diff = [\n",
    "    known_variance_inference(syn_datasets[i, :], down_mu_0, down_sigma_0, down_known_sigma_diff) \n",
    "    for i in range(n_syn_datasets)\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xs = np.linspace(0, 2, 200)\n",
    "\n",
    "fig, axes = plt.subplots(1, 2, figsize=(10, 2))\n",
    "ax = axes[0]\n",
    "ax.plot(xs, true_down_posterior.pdf(xs), label=analyst_label)\n",
    "ax.plot(xs, syn_posterior.pdf(xs), label=data_provider_label, linestyle=\"dashed\")\n",
    "ax.plot(xs, combined_posterior_pdf(xs, syn_data_posteriors), label=syn_data_label)\n",
    "ax.set_xlabel(\"$\\mu$\")\n",
    "ax.set_title(\"Correct Downstream Variance\")\n",
    "\n",
    "ax = axes[1]\n",
    "ax.plot(xs, true_down_posterior_diff.pdf(xs), label=analyst_label)\n",
    "ax.plot(xs, syn_posterior.pdf(xs), label=data_provider_label, linestyle=\"dashed\")\n",
    "ax.plot(xs, combined_posterior_pdf(xs, syn_data_posteriors_diff), label=syn_data_label)\n",
    "ax.set_xlabel(\"$\\mu$\")\n",
    "ax.set_title(\"Incorrect Downstream Variance\")\n",
    "\n",
    "leg_h, leg_l = axes[0].get_legend_handles_labels()\n",
    "fig.legend(leg_h, leg_l, loc=\"upper center\", ncol=3, bbox_to_anchor=(0.5, -0.1))\n",
    "plt.savefig(figdir + \"known-known-variance-results.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_syn_datasets_vals = [25, 50, 100, 200, 400]\n",
    "n_syn_dataset_mul_vals = [1, 2, 5, 10, 20]\n",
    "\n",
    "def run_synthetic_data_inference(n_syn_dataset, n_syn_datasets):\n",
    "    syn_mu_sample = sample_syn_post(n_syn_datasets)\n",
    "    syn_datasets = np.zeros((n_syn_datasets, n_syn_dataset))\n",
    "    for i in range(n_syn_datasets):\n",
    "        mu = syn_mu_sample[i]\n",
    "        syn_datasets[i, :] = stats.norm.rvs(loc=mu, scale=syn_known_sigma, size=n_syn_dataset)\n",
    "\n",
    "    syn_data_posteriors = [\n",
    "        known_variance_inference(syn_datasets[i, :], down_mu_0, down_sigma_0, down_known_sigma) \n",
    "        for i in range(n_syn_datasets)\n",
    "    ]\n",
    "    return syn_data_posteriors\n",
    "\n",
    "syn_data_posteriors_all = {}\n",
    "for i, n_syn_datasets in enumerate(n_syn_datasets_vals):\n",
    "    for j, n_syn_dataset_mul in enumerate(n_syn_dataset_mul_vals):\n",
    "        syn_data_posteriors_all[n_syn_datasets, n_syn_dataset_mul] = run_synthetic_data_inference(n * n_syn_dataset_mul, n_syn_datasets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "fig, axes = plt.subplots(\n",
    "    len(n_syn_datasets_vals), len(n_syn_dataset_mul_vals), \n",
    "    figsize=(2 * len(n_syn_dataset_mul_vals), 2 * len(n_syn_datasets_vals))\n",
    ")\n",
    "for i, n_syn_datasets in enumerate(n_syn_datasets_vals):\n",
    "    for j, n_syn_dataset_mul in enumerate(n_syn_dataset_mul_vals):\n",
    "        ax = axes[i, j]\n",
    "        # syn_data_posteriors = run_synthetic_data_inference(n * n_syn_dataset_mul, n_syn_datasets)\n",
    "        syn_data_posteriors = syn_data_posteriors_all[n_syn_datasets, n_syn_dataset_mul]\n",
    "\n",
    "        ax.plot(xs, true_down_posterior.pdf(xs), label=analyst_label)\n",
    "        # ax.plot(xs, syn_posterior.pdf(xs), label=\"Data Provider's $p(\\mu | X)$\")\n",
    "        ax.plot(xs, combined_posterior_pdf(xs, syn_data_posteriors), label=syn_data_label)\n",
    "        ax.set_xlabel(\"$\\mu$\")\n",
    "        ax.set_title(\"$n_{{X^*}} / n_X = {}, m = {}$\".format(n_syn_dataset_mul, n_syn_datasets))\n",
    "\n",
    "leg_h, leg_l = axes[0, 0].get_legend_handles_labels()\n",
    "fig.legend(leg_h, leg_l, loc=\"upper center\", ncol=3, bbox_to_anchor=(0.5, 0.01))\n",
    "plt.tight_layout()\n",
    "plt.savefig(figdir + \"known-variance-hyperparameter-results.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_diagonal_indices = [0, 2, 4]\n",
    "fig, axes = plt.subplots(1, len(plot_diagonal_indices), figsize=(2.5 * len(plot_diagonal_indices), 2)\n",
    ")\n",
    "for i, ind in enumerate(plot_diagonal_indices):\n",
    "# for i, n_syn_datasets in enumerate(n_syn_datasets_vals):\n",
    "#     for j, n_syn_dataset_mul in enumerate(n_syn_dataset_mul_vals):\n",
    "    ax = axes[i]\n",
    "    # n_syn_datasets = n_syn_datasets_vals[ind]\n",
    "    n_syn_datasets = n_syn_datasets_vals[-1]\n",
    "    n_syn_dataset_mul = n_syn_dataset_mul_vals[ind]\n",
    "    # syn_data_posteriors = run_synthetic_data_inference(n * n_syn_dataset_mul, n_syn_datasets)\n",
    "    syn_data_posteriors = syn_data_posteriors_all[n_syn_datasets, n_syn_dataset_mul]\n",
    "\n",
    "    ax.plot(xs, true_down_posterior.pdf(xs), label=analyst_label)\n",
    "    # ax.plot(xs, syn_posterior.pdf(xs), label=\"Data Provider's $p(\\mu | X)$\")\n",
    "    ax.plot(xs, combined_posterior_pdf(xs, syn_data_posteriors), label=syn_data_label)\n",
    "    ax.set_xlabel(\"$\\mu$\")\n",
    "    ax.set_title(\"$n_{{X^*}} / n_X = {}, m = {}$\".format(n_syn_dataset_mul, n_syn_datasets))\n",
    "\n",
    "leg_h, leg_l = axes[0].get_legend_handles_labels()\n",
    "fig.legend(leg_h, leg_l, loc=\"upper center\", ncol=3, bbox_to_anchor=(0.5, 0.09))\n",
    "plt.tight_layout()\n",
    "plt.savefig(figdir + \"known-variance-hyperparameter-results-small.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Unknown Variance Up- and Downstream"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "down_mu_0 = 0\n",
    "down_nu_0 = 2\n",
    "down_sigma_0 = 1\n",
    "down_kappa_0 = 2\n",
    "\n",
    "true_sigma2_post, true_mu_post, true_mu_post_marginal = unknown_variance_inference(X, down_mu_0, down_sigma_0, down_nu_0, down_kappa_0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "syn_mu_0 = 0\n",
    "syn_nu_0 = 2\n",
    "syn_sigma_0 = 1\n",
    "syn_kappa_0 = 2\n",
    "\n",
    "syn_sigma2_post, syn_mu_post, syn_mu_post_marginal = unknown_variance_inference(\n",
    "    X, syn_mu_0, syn_sigma_0, syn_nu_0, syn_kappa_0\n",
    ")\n",
    "\n",
    "def sample_syn_post(size):\n",
    "    sigma2s = syn_sigma2_post.rvs(size=size)\n",
    "    mus = np.zeros(size)\n",
    "    # for i in range(size):\n",
    "    #     mus[i] = syn_mu_post(sigma2s[i]).rvs(size=1)\n",
    "    mus = syn_mu_post(sigma2s).rvs(size=size)\n",
    "    return mus, sigma2s"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_syn_datasets = 1000\n",
    "n_syn_dataset = 100 * n\n",
    "syn_mu_sample, syn_sigma2_sample = sample_syn_post(n_syn_datasets)\n",
    "syn_datasets = np.zeros((n_syn_datasets, n_syn_dataset))\n",
    "for i in range(n_syn_datasets):\n",
    "    mu = syn_mu_sample[i]\n",
    "    sigma2 = syn_sigma2_sample[i]\n",
    "    syn_datasets[i, :] = stats.norm.rvs(loc=mu, scale=np.sqrt(sigma2), size=n_syn_dataset)\n",
    "\n",
    "syn_data_posteriors = [\n",
    "    unknown_variance_inference(syn_datasets[i, :], down_mu_0, down_sigma_0, down_nu_0, down_kappa_0)[2]\n",
    "    for i in range(n_syn_datasets)\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xs = np.linspace(-1, 3, 200)\n",
    "syn_mu_sample, syn_sigma2_sample = sample_syn_post(10000)\n",
    "plt.figure(figsize=(7, 7))\n",
    "plt.plot(xs, true_mu_post_marginal.pdf(xs), label=analyst_label)\n",
    "# plt.hist(syn_mu_sample, density=True, bins=50, label=\"Synthetic Posterior\")\n",
    "plt.plot(xs, syn_mu_post_marginal.pdf(xs), label=data_provider_label)\n",
    "plt.plot(xs, combined_posterior_pdf(xs, syn_data_posteriors), label=syn_data_label)\n",
    "plt.legend()\n",
    "plt.savefig(figdir + \"unknown-unknown-variance-results.pdf\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Unknown Variance Upstream, Known Variance Downstream"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "down_mu_0 = 0\n",
    "down_sigma_0 = 10\n",
    "down_known_sigma = 1.0 * sigma_true\n",
    "down_known_sigma_diff = 0.5 * sigma_true\n",
    "true_down_posterior = known_variance_inference(X, down_mu_0, down_sigma_0, down_known_sigma)\n",
    "true_down_posterior_diff = known_variance_inference(X, down_mu_0, down_sigma_0, down_known_sigma_diff)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "syn_mu_0 = 0\n",
    "syn_nu_0 = 2\n",
    "syn_sigma_0 = 1\n",
    "syn_kappa_0 = 2\n",
    "\n",
    "syn_sigma2_post, syn_mu_post, syn_mu_post_marginal = unknown_variance_inference(\n",
    "    X, syn_mu_0, syn_sigma_0, syn_nu_0, syn_kappa_0\n",
    ")\n",
    "\n",
    "def sample_syn_post(size):\n",
    "    sigma2s = syn_sigma2_post.rvs(size=size)\n",
    "    mus = np.zeros(size)\n",
    "    mus = syn_mu_post(sigma2s).rvs(size=size)\n",
    "    return mus, sigma2s"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_syn_datasets = 400\n",
    "n_syn_dataset = 20 * n\n",
    "syn_mu_sample, syn_sigma2_sample = sample_syn_post(n_syn_datasets)\n",
    "syn_datasets = np.zeros((n_syn_datasets, n_syn_dataset))\n",
    "for i in range(n_syn_datasets):\n",
    "    mu = syn_mu_sample[i]\n",
    "    sigma2 = syn_sigma2_sample[i]\n",
    "    syn_datasets[i, :] = stats.norm.rvs(loc=mu, scale=np.sqrt(sigma2), size=n_syn_dataset)\n",
    "\n",
    "syn_data_posteriors = [\n",
    "    known_variance_inference(syn_datasets[i, :], down_mu_0, down_sigma_0, down_known_sigma) \n",
    "    for i in range(n_syn_datasets)\n",
    "]\n",
    "syn_data_posteriors_diff = [\n",
    "    known_variance_inference(syn_datasets[i, :], down_mu_0, down_sigma_0, down_known_sigma_diff) \n",
    "    for i in range(n_syn_datasets)\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xs = np.linspace(0, 2, 200)\n",
    "\n",
    "fig, axes = plt.subplots(1, 2, figsize=(10, 3))\n",
    "ax = axes[0]\n",
    "ax.plot(xs, true_down_posterior.pdf(xs), label=analyst_label)\n",
    "ax.plot(xs, syn_mu_post_marginal.pdf(xs), label=data_provider_label)\n",
    "ax.plot(xs, combined_posterior_pdf(xs, syn_data_posteriors), label=syn_data_label)\n",
    "ax.set_xlabel(\"$\\mu$\")\n",
    "ax.set_title(\"Correct Downstream Variance\")\n",
    "\n",
    "ax = axes[1]\n",
    "ax.plot(xs, true_down_posterior_diff.pdf(xs), label=analyst_label)\n",
    "ax.plot(xs, syn_mu_post_marginal.pdf(xs), label=data_provider_label)\n",
    "ax.plot(xs, combined_posterior_pdf(xs, syn_data_posteriors_diff), label=syn_data_label)\n",
    "ax.set_xlabel(\"$\\mu$\")\n",
    "ax.set_title(\"Incorrect Downstream Variance\")\n",
    "\n",
    "leg_h, leg_l = axes[0].get_legend_handles_labels()\n",
    "fig.legend(leg_h, leg_l, loc=\"upper center\", ncol=3, bbox_to_anchor=(0.5, -0.03))\n",
    "plt.savefig(figdir + \"unknown-known-variance-results.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Known Variance Upstream, Unknown Variance Downstream"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "down_mu_0 = 0\n",
    "down_nu_0 = 2\n",
    "down_sigma_0 = 1\n",
    "down_kappa_0 = 2\n",
    "\n",
    "true_sigma2_post, true_mu_post, true_mu_post_marginal = unknown_variance_inference(X, down_mu_0, down_sigma_0, down_nu_0, down_kappa_0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "syn_mu_0 = 0\n",
    "syn_sigma_0 = 10\n",
    "syn_known_sigma = 1.0 * sigma_true\n",
    "\n",
    "syn_posterior = known_variance_inference(\n",
    "    X, syn_mu_0, syn_sigma_0, syn_known_sigma\n",
    ")\n",
    "\n",
    "def sample_syn_post(size):\n",
    "    mus = syn_posterior.rvs(size=size)\n",
    "    return mus"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_syn_datasets = 1000\n",
    "n_syn_dataset = 100 * n\n",
    "syn_mu_sample = sample_syn_post(n_syn_datasets)\n",
    "syn_datasets = np.zeros((n_syn_datasets, n_syn_dataset))\n",
    "for i in range(n_syn_datasets):\n",
    "    mu = syn_mu_sample[i]\n",
    "    syn_datasets[i, :] = stats.norm.rvs(loc=mu, scale=syn_known_sigma, size=n_syn_dataset)\n",
    "\n",
    "syn_data_posteriors = [\n",
    "    unknown_variance_inference(syn_datasets[i, :], down_mu_0, down_sigma_0, down_nu_0, down_kappa_0)[2]\n",
    "    for i in range(n_syn_datasets)\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xs = np.linspace(-0, 2, 200)\n",
    "syn_mu_sample = sample_syn_post(10000)\n",
    "plt.figure(figsize=(7, 7))\n",
    "plt.plot(xs, true_mu_post_marginal.pdf(xs), label=analyst_label)\n",
    "# plt.hist(syn_mu_sample, density=True, bins=50, label=\"Synthetic Posterior\")\n",
    "plt.plot(xs, syn_posterior.pdf(xs), label=data_provider_label)\n",
    "plt.plot(xs, combined_posterior_pdf(xs, syn_data_posteriors), label=syn_data_label)\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Testing Variance Approximation Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def variance_correction(post_means, post_variances, n_syn, n_real):\n",
    "    total_mean = post_means.mean()\n",
    "    combined_var = post_means.var() + post_variances.mean()\n",
    "    c = n_syn / n_real\n",
    "    corrected_var = (1 + 1 / c)**(-1) * (combined_var - post_variances.mean())\n",
    "    return stats.norm(loc=total_mean, scale=np.sqrt(corrected_var))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "down_mu_0 = 0\n",
    "down_sigma_0 = 10\n",
    "down_known_sigma = 1.0 * sigma_true\n",
    "true_down_posterior = known_variance_inference(X, down_mu_0, down_sigma_0, down_known_sigma)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Unknown Variance Synthetic Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "syn_mu_0 = 0\n",
    "syn_nu_0 = 2\n",
    "syn_sigma_0 = 1\n",
    "syn_kappa_0 = 2\n",
    "\n",
    "syn_sigma2_post, syn_mu_post, syn_mu_post_marginal = unknown_variance_inference(\n",
    "    X, syn_mu_0, syn_sigma_0, syn_nu_0, syn_kappa_0\n",
    ")\n",
    "\n",
    "def sample_syn_post(size):\n",
    "    sigma2s = syn_sigma2_post.rvs(size=size)\n",
    "    mus = np.zeros(size)\n",
    "    mus = syn_mu_post(sigma2s).rvs(size=size)\n",
    "    return mus, sigma2s"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_syn_datasets = 400\n",
    "n_syn_dataset = 1 * n\n",
    "\n",
    "syn_mu_sample, syn_sigma2_sample = sample_syn_post(n_syn_datasets)\n",
    "syn_datasets = np.zeros((n_syn_datasets, n_syn_dataset))\n",
    "for i in range(n_syn_datasets):\n",
    "    mu = syn_mu_sample[i]\n",
    "    sigma2 = syn_sigma2_sample[i]\n",
    "    syn_datasets[i, :] = stats.norm.rvs(loc=mu, scale=np.sqrt(sigma2), size=n_syn_dataset)\n",
    "\n",
    "syn_data_posteriors_unknown_var = [\n",
    "    known_variance_inference(syn_datasets[i, :], down_mu_0, down_sigma_0, down_known_sigma) \n",
    "    for i in range(n_syn_datasets)\n",
    "]\n",
    "\n",
    "syn_data_post_means = np.array([post.mean() for post in syn_data_posteriors_unknown_var])\n",
    "syn_data_post_vars = np.array([post.var() for post in syn_data_posteriors_unknown_var])\n",
    "combined_approximation_unknown_var = variance_correction(syn_data_post_means, syn_data_post_vars, n_syn_dataset, n)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Known Variance Synthetic Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "syn_mu_0 = 0\n",
    "syn_sigma_0 = 10\n",
    "syn_known_sigma = 1.0 * sigma_true\n",
    "\n",
    "syn_posterior = known_variance_inference(\n",
    "    X, syn_mu_0, syn_sigma_0, syn_known_sigma\n",
    ")\n",
    "\n",
    "def sample_syn_post(size):\n",
    "    mus = syn_posterior.rvs(size=size)\n",
    "    return mus"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "syn_mu_sample = sample_syn_post(n_syn_datasets)\n",
    "syn_datasets = np.zeros((n_syn_datasets, n_syn_dataset))\n",
    "for i in range(n_syn_datasets):\n",
    "    mu = syn_mu_sample[i]\n",
    "    syn_datasets[i, :] = stats.norm.rvs(loc=mu, scale=syn_known_sigma, size=n_syn_dataset)\n",
    "\n",
    "syn_data_posteriors_known_var = [\n",
    "    known_variance_inference(syn_datasets[i, :], down_mu_0, down_sigma_0, down_known_sigma) \n",
    "    for i in range(n_syn_datasets)\n",
    "]\n",
    "\n",
    "syn_data_post_means = np.array([post.mean() for post in syn_data_posteriors_known_var])\n",
    "syn_data_post_vars = np.array([post.var() for post in syn_data_posteriors_known_var])\n",
    "combined_approximation_known_var = variance_correction(syn_data_post_means, syn_data_post_vars, n_syn_dataset, n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xs = np.linspace(0, 2, 200)\n",
    "\n",
    "fig, axes = plt.subplots(1, 2, figsize=(10, 2))\n",
    "ax = axes[0]\n",
    "ax.plot(xs, true_down_posterior.pdf(xs), label=analyst_label)\n",
    "ax.plot(xs, syn_posterior.pdf(xs), label=data_provider_label, linestyle=\"dashed\")\n",
    "ax.plot(xs, combined_posterior_pdf(xs, syn_data_posteriors_known_var), label=syn_data_label)\n",
    "ax.plot(xs, combined_approximation_known_var.pdf(xs), label=\"Gaussian approximation\", linestyle=\"dashed\")\n",
    "ax.set_xlabel(\"$\\mu$\")\n",
    "ax.set_title(\"Known Variance $p(X^* | X)$\")\n",
    "\n",
    "ax = axes[1]\n",
    "ax.plot(xs, true_down_posterior.pdf(xs), label=analyst_label)\n",
    "ax.plot(xs, syn_mu_post_marginal.pdf(xs), label=data_provider_label, linestyle=\"dashed\")\n",
    "ax.plot(xs, combined_posterior_pdf(xs, syn_data_posteriors_unknown_var), label=syn_data_label)\n",
    "ax.plot(xs, combined_approximation_unknown_var.pdf(xs), label=\"Gaussian approximation\", linestyle=\"dashed\")\n",
    "ax.set_xlabel(\"$\\mu$\")\n",
    "ax.set_title(\"Unknown Variance $p(X^* | X)$\")\n",
    "\n",
    "leg_h, leg_l = axes[0].get_legend_handles_labels()\n",
    "fig.legend(leg_h, leg_l, loc=\"upper center\", ncol=2, bbox_to_anchor=(0.5, -0.1))\n",
    "plt.savefig(figdir + \"gaussian-approximation-results.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.12 ('max-ent-env2')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "fa1919baa51ca66e7ec489754979eb11625f9177f6ce3c8d3d9d35570b67b133"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
