{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Misspec Boxplots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import jax.numpy as jnp\n",
    "import jax.random as random\n",
    "from scipy.stats import gaussian_kde\n",
    "from rsnl.metrics import plot_and_save_coverage\n",
    "from rsnl.examples.misspec_ma1 import calculate_summary_statistics, true_dgp\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle as pkl\n",
    "import arviz as az\n",
    "import matplotlib.colors as mcolors"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model param posterior plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 6\n",
    "rng_key = random.PRNGKey(seed)\n",
    "rng_key, sub_key1, sub_key2 = random.split(rng_key, 3)\n",
    "true_params = jnp.array([0.0])\n",
    "# true_params = prior.sample(sub_key1)\n",
    "x_obs = calculate_summary_statistics(true_dgp(key=sub_key2))\n",
    "print('x_obs: ', x_obs)\n",
    "\n",
    "with open(\"../res/misspec_ma1/rsnl/seed_6/thetas.pkl\", \"rb\") as f:\n",
    "    theta_draws_rsnl = jnp.array(pkl.load(f))\n",
    "\n",
    "thetas_rsnl = jnp.concatenate(theta_draws_rsnl, axis=0)\n",
    "thetas_rsnl = jnp.squeeze(thetas_rsnl)\n",
    "\n",
    "with open(\"../res/misspec_ma1/snl/seed_6/thetas.pkl\", \"rb\") as f:\n",
    "    theta_draws_snl = jnp.array(pkl.load(f))\n",
    "\n",
    "thetas_snl = jnp.concatenate(theta_draws_snl, axis=0)\n",
    "thetas_snl = jnp.squeeze(thetas_snl)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plt.rcParams['text.usetex'] = True\n",
    "plt.rcParams['font.family'] = 'Times New Roman'\n",
    "plt.rcParams.update({'font.size': 25})\n",
    "plt.rcParams['xtick.labelsize'] = 25\n",
    "\n",
    "az.plot_dist(thetas_rsnl, label=\"RSNL\")\n",
    "\n",
    "az.plot_dist(thetas_snl,\n",
    "        color=mcolors.CSS4_COLORS['darkorange'],\n",
    "        plot_kwargs={'linestyle': 'dashed'},\n",
    "        label=\"SNL\")\n",
    "\n",
    "plt.xlabel(r\"$\\theta$\", fontsize=25)\n",
    "plt.xlim([-1, 1])\n",
    "plt.ylim(bottom=0)\n",
    "plt.legend(fontsize=25, borderpad=0.1, labelspacing=0.1, handletextpad=0.1)\n",
    "plt.ylabel(\"Density\", fontsize=30)\n",
    "plt.axvline(x=0, color='red', linestyle='dashed')\n",
    "# plt.title(\"$b_0 = 0.01$\")\n",
    "plt.tight_layout()\n",
    "plt.savefig('misspec_ma1_posterior.pdf', bbox_inches='tight')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Adjustment parameter plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"../res/misspec_ma1/rsnl/seed_6/adj_params.pkl\", \"rb\") as f:\n",
    "    adj_params = jnp.array(pkl.load(f))\n",
    "\n",
    "adj_params = jnp.concatenate(adj_params, axis=0)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plt.rcParams['text.usetex'] = True\n",
    "plt.rcParams['mathtext.fontset'] = 'cm'\n",
    "plt.rcParams['font.family'] = 'serif'\n",
    "plt.rcParams['font.serif'] = ['Times New Roman']\n",
    "plt.rcParams.update({'font.size': 25})\n",
    "\n",
    "rng_key = random.PRNGKey(0)\n",
    "prior_samples = random.laplace(rng_key, shape=(10000, 2))\n",
    "\n",
    "for i in range(2):\n",
    "    az.plot_dist(adj_params[:, i],\n",
    "                 label='Posterior',\n",
    "                 color='black')\n",
    "    az.plot_dist(prior_samples[:, i],\n",
    "                 color=mcolors.CSS4_COLORS['limegreen'],\n",
    "                 plot_kwargs={'linestyle': 'dashed'},\n",
    "                 label='Prior')\n",
    "    plt.xlabel(f\"$\\gamma_{i+1}$\", fontsize=25)\n",
    "    plt.ylabel(\"Density\", fontsize=25)\n",
    "    plt.xlim([-10, 10])\n",
    "    plt.ylim(bottom=0)\n",
    "    plt.xticks(fontsize=25)\n",
    "    plt.legend(fontsize=25,\n",
    "               loc='upper left',\n",
    "               borderpad=0.1, labelspacing=0.1, handletextpad=0.1)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f'misspec_ma1_adj_param_{i+1}.pdf', bbox_inches='tight')\n",
    "    plt.clf()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
