{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import jax.numpy as jnp\n",
    "import jax.random as random\n",
    "from scipy.stats import gaussian_kde\n",
    "# from rsnl.metrics import plot_and_save_coverage\n",
    "from rsnl.examples.contaminated_normal import assumed_dgp, calculate_summary_statistics, true_dgp\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle as pkl\n",
    "import arviz as az\n",
    "import matplotlib.colors as mcolors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# check x_obs for seed_9\n",
    "seed = 9\n",
    "rng_key = random.PRNGKey(seed)\n",
    "rng_key, sub_key1, sub_key2 = random.split(rng_key, 3)\n",
    "true_params = jnp.array([1.0])\n",
    "# true_params = prior.sample(sub_key1)\n",
    "x_obs = calculate_summary_statistics(true_dgp(sub_key2, true_params))\n",
    "print('x_obs: ', x_obs)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model param posterior plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"../res/contaminated_normal/rsnl/seed_9/thetas.pkl\", \"rb\") as f:\n",
    "    theta_draws_rsnl = jnp.array(pkl.load(f))\n",
    "\n",
    "thetas_rsnl = jnp.concatenate(theta_draws_rsnl, axis=0)\n",
    "thetas_rsnl = jnp.squeeze(thetas_rsnl)\n",
    "\n",
    "with open(\"../res/contaminated_normal/snl/seed_9/thetas.pkl\", \"rb\") as f:\n",
    "    theta_draws_snl = jnp.array(pkl.load(f))\n",
    "\n",
    "thetas_snl = jnp.concatenate(theta_draws_snl, axis=0)\n",
    "thetas_snl = jnp.squeeze(thetas_snl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "thetas_rsnl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plt.rcParams['text.usetex'] = True\n",
    "plt.rcParams['font.family'] = 'Times New Roman'\n",
    "plt.rcParams.update({'font.size': 25})\n",
    "plt.rcParams['xtick.labelsize'] = 25\n",
    "\n",
    "az.plot_dist(thetas_rsnl, label=\"RSNL\")\n",
    "\n",
    "az.plot_dist(thetas_snl,\n",
    "        color=mcolors.CSS4_COLORS['darkorange'],\n",
    "        plot_kwargs={'linestyle': 'dashed'},\n",
    "        label=\"SNL\")\n",
    "\n",
    "plt.xlabel(r\"$\\theta$\", fontsize=25)\n",
    "plt.xlim([-1, 3])\n",
    "plt.ylim(bottom=0)\n",
    "plt.legend(fontsize=25, borderpad=0.1, labelspacing=0.1, handletextpad=0.1)\n",
    "plt.ylabel(\"Density\", fontsize=30)\n",
    "plt.axvline(x=1, color='red', linestyle='dashed')\n",
    "# plt.title(\"$b_0 = 0.01$\")\n",
    "plt.tight_layout()\n",
    "plt.savefig('contaminated_normal_posterior.pdf', bbox_inches='tight')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Adjustment param posterior plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"../res/contaminated_normal/rsnl/seed_9/adj_params.pkl\", \"rb\") as f:\n",
    "    adj_params = jnp.array(pkl.load(f))\n",
    "\n",
    "adj_params = jnp.concatenate(adj_params, axis=0)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plt.rcParams['text.usetex'] = True\n",
    "plt.rcParams['mathtext.fontset'] = 'cm'\n",
    "plt.rcParams['font.family'] = 'serif'\n",
    "plt.rcParams['font.serif'] = ['Times New Roman']\n",
    "plt.rcParams.update({'font.size': 25})\n",
    "\n",
    "rng_key = random.PRNGKey(0)\n",
    "prior_samples = random.laplace(rng_key, shape=(10000, 2))\n",
    "\n",
    "for i in range(2):\n",
    "    az.plot_dist(adj_params[:, i],\n",
    "                 label='Posterior',\n",
    "                 color='black')\n",
    "    az.plot_dist(prior_samples[:, i],\n",
    "                 color=mcolors.CSS4_COLORS['limegreen'],\n",
    "                 plot_kwargs={'linestyle': 'dashed'},\n",
    "                 label='Prior')\n",
    "    plt.xlabel(f\"$\\gamma_{i+1}$\", fontsize=25)\n",
    "    plt.ylabel(\"Density\", fontsize=25)\n",
    "    plt.xlim([-10, 10])\n",
    "    plt.ylim(bottom=0)\n",
    "    plt.xticks(fontsize=25)\n",
    "    plt.legend(fontsize=25,\n",
    "               loc='upper left',\n",
    "               borderpad=0.1, labelspacing=0.1, handletextpad=0.1)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f'contaminated_normal_adj_param_{i+1}.pdf', bbox_inches='tight')\n",
    "    plt.clf()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Posterior Predictive Checks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# RSNL PPC\n",
    "x_rsnl_ppc = np.zeros((40000, 2))\n",
    "theta_draws_rsnl = np.concatenate(theta_draws_rsnl, axis=0)\n",
    "for ii, theta_rsnl in enumerate(theta_draws_rsnl):\n",
    "    if ii % 1000 == 0:\n",
    "        print('ii: ', ii)\n",
    "    rng_key, sub_key = random.split(rng_key)\n",
    "    x_rsnl_ppc[ii, :] = calculate_summary_statistics(assumed_dgp(sub_key, theta_rsnl))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import gaussian_kde\n",
    "\n",
    "summ_labels = ['Sample mean', 'Sample variance']\n",
    "for i in range(2):\n",
    "    xs = np.linspace(min(x_rsnl_ppc[:, i]),\n",
    "                     max(x_rsnl_ppc[:, i]), 200)\n",
    "    kde = gaussian_kde(x_rsnl_ppc[:, i])\n",
    "    plt.plot(xs, kde(xs), label='Posterior')\n",
    "    # plt.hist(x_rsnl_ppc[:, i], bins=50)\n",
    "    plt.axvline(x=x_obs[i], color='red', linestyle='dashed')\n",
    "    plt.xlabel(summ_labels[i])\n",
    "    plt.ylabel(\"Density\")\n",
    "    plt.ylim(bottom=0)\n",
    "    plt.xticks(fontsize=25)\n",
    "    plt.savefig(f'contaminated_normal_rsnl_ppc_{i+1}.pdf', bbox_inches='tight')\n",
    "    plt.clf()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# SNL PPC\n",
    "x_snl_ppc = np.zeros((40000, 2))\n",
    "theta_draws_snl = np.concatenate(theta_draws_snl, axis=0)\n",
    "for ii, theta_snl in enumerate(theta_draws_snl):\n",
    "    if ii % 1000 == 0:\n",
    "        print('ii: ', ii)\n",
    "    rng_key, sub_key = random.split(rng_key)\n",
    "    x_snl_ppc[ii, :] = calculate_summary_statistics(assumed_dgp(sub_key, theta_snl))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from scipy.stats import gaussian_kde\n",
    "\n",
    "# summ_labels = ['Sample mean', 'Sample variance']\n",
    "for i in range(2):\n",
    "    xs = np.linspace(min(x_snl_ppc[:, i]),\n",
    "                     max(x_snl_ppc[:, i]), 200)\n",
    "    kde = gaussian_kde(x_snl_ppc[:, i])\n",
    "    plt.plot(xs, kde(xs), color='orange',\n",
    "             label='Posterior')\n",
    "    # plt.hist(x_rsnl_ppc[:, i], bins=50)\n",
    "    plt.axvline(x=x_obs[i], color='red', linestyle='dashed')\n",
    "    plt.xlabel(summ_labels[i])\n",
    "    plt.ylabel(\"Density\")\n",
    "    plt.ylim(bottom=0)\n",
    "    plt.xticks(fontsize=25)\n",
    "    plt.savefig(f'contaminated_normal_snl_ppc_{i+1}.pdf', bbox_inches='tight')\n",
    "    plt.clf()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle as pkl\n",
    "\n",
    "with open('../res/contaminated_normal/rsnl/v4_seed_0/thetas.pkl', 'rb') as f:\n",
    "    theta_draws_v4 = jnp.array(pkl.load(f))\n",
    "\n",
    "theta_draws_v4 = jnp.concatenate(theta_draws_v4, axis=0)\n",
    "theta_draws_v4 = jnp.squeeze(theta_draws_v4)\n",
    "# comparison between laplace(0, 0.5) and data-driven prior\n",
    "with open('../res/contaminated_normal/rsnl/laplace05_v4_seed_0/thetas.pkl', 'rb') as f:\n",
    "    theta_draws_laplace05_v4 = jnp.array(pkl.load(f))\n",
    "\n",
    "theta_draws_laplace05_v4 = jnp.concatenate(theta_draws_laplace05_v4, axis=0)\n",
    "theta_draws_laplace05_v4 = jnp.squeeze(theta_draws_laplace05_v4)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "theta_draws_v4.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import arviz as az\n",
    "import matplotlib.colors as mcolors\n",
    "\n",
    "az.plot_dist(theta_draws_v4, label=\"RSNL - Data-driven prior\")\n",
    "\n",
    "az.plot_dist(theta_draws_laplace05_v4,\n",
    "        color=mcolors.CSS4_COLORS['darkorange'],\n",
    "        plot_kwargs={'linestyle': 'dashed'},\n",
    "        label=\"RSNL - Laplace(0, 0.5)\")\n",
    "\n",
    "plt.xlabel(r\"$\\theta$\", fontsize=25)\n",
    "plt.xlim([-3, 5])\n",
    "plt.ylim(bottom=0)\n",
    "plt.legend(fontsize=15)#, borderpad=0.1, labelspacing=0.1, handletextpad=0.1)\n",
    "plt.ylabel(\"Density\", fontsize=30)\n",
    "plt.axvline(x=1, color='red', linestyle='dashed')\n",
    "# plt.title(\"$b_0 = 0.01$\")\n",
    "plt.tight_layout()\n",
    "plt.savefig('contaminated_normal_posterior_prior_comparison.pdf', bbox_inches='tight')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
