{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import jax.numpy as jnp\n",
    "import jax.random as random\n",
    "from scipy.stats import gaussian_kde\n",
    "from rsnl.metrics import plot_and_save_coverage\n",
    "from rsnl.examples.sir import calculate_summary_statistics, true_dgp, assumed_dgp\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle as pkl\n",
    "import arviz as az\n",
    "import matplotlib.colors as mcolors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set the default font to Times New Roman\n",
    "plt.rcParams['font.family'] = 'serif'\n",
    "plt.rcParams['font.serif'] = ['Times New Roman']\n",
    "plt.rcParams['mathtext.fontset'] = 'cm'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 4\n",
    "rng_key = random.PRNGKey(seed)\n",
    "rng_key, sub_key1, sub_key2 = random.split(rng_key, 3)\n",
    "true_params = jnp.array([0.1, 0.15])\n",
    "x_obs = true_dgp(sub_key2, *true_params)\n",
    "x_obs = calculate_summary_statistics(x_obs)\n",
    "print('x_obs', x_obs)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(10):\n",
    "    rng_key, sub_key1, sub_key2 = random.split(rng_key, 3)\n",
    "    print(calculate_summary_statistics(assumed_dgp(sub_key1, *true_params)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"../res/sir/tmp_rsnl/rsnl/seed_4/thetas.pkl\", \"rb\") as f:\n",
    "    theta_draws_rsnl = jnp.array(pkl.load(f))\n",
    "\n",
    "thetas_rsnl = jnp.concatenate(theta_draws_rsnl, axis=0)\n",
    "thetas_rsnl = jnp.squeeze(thetas_rsnl)\n",
    "\n",
    "with open(\"../res/sir/tmp_rsnl/rsnl/seed_4/adj_params.pkl\", \"rb\") as f:\n",
    "    adj_params = jnp.array(pkl.load(f))\n",
    "\n",
    "adj_params = jnp.concatenate(adj_params, axis=0)\n",
    "\n",
    "with open(\"../res/sir/snl/seed_4/thetas.pkl\", \"rb\") as f:\n",
    "    theta_draws_snl = jnp.array(pkl.load(f))\n",
    "\n",
    "thetas_snl = jnp.concatenate(theta_draws_snl, axis=0)\n",
    "thetas_snl = jnp.squeeze(thetas_snl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rsnl_theta_plot = {}\n",
    "snl_correct_theta_plot = {}\n",
    "snl_theta_plot = {}\n",
    "\n",
    "for i in range(2):\n",
    "    rsnl_theta_plot['theta' + str(i+1)] = thetas_rsnl[ :, i]\n",
    "    snl_theta_plot['theta' + str(i+1)] = thetas_snl[:, i]\n",
    "\n",
    "var_name_map = {}\n",
    "reference_values = {}\n",
    "labels = [r'$\\eta$', r'$\\beta$']\n",
    "for ii, k in enumerate(rsnl_theta_plot):\n",
    "    var_name_map[k] = labels[ii]\n",
    "    reference_values[var_name_map[k]] = true_params[ii]  # why does ref_vals match labels and not data? ah well"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "az.plot_pair(rsnl_theta_plot,\n",
    "             kind='kde',\n",
    "             reference_values=reference_values,\n",
    "             marginals=True,\n",
    "             labeller=az.labels.MapLabeller(var_name_map=var_name_map),\n",
    "             reference_values_kwargs={'color': 'red', 'marker': 'X', 'markersize': 12},\n",
    "             kde_kwargs={'hdi_probs': [0.05, 0.25, 0.5, 0.75, 0.95],\n",
    "                         'contour_kwargs': {\"colors\":None, \"cmap\":plt.cm.viridis},\n",
    "                         'contourf_kwargs': {\"alpha\":0}},\n",
    "             textsize=24,\n",
    "            )\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"sir_theta_posterior.pdf\", bbox_inches='tight')\n",
    "# plt.xlabel(rf\"$\\theta_1$\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "az.plot_pair(snl_theta_plot,\n",
    "             kind='kde',\n",
    "             reference_values=reference_values,\n",
    "             marginals=True,\n",
    "             labeller=az.labels.MapLabeller(var_name_map=var_name_map),\n",
    "             reference_values_kwargs={'color': 'red', 'marker': 'X', 'markersize': 12},\n",
    "             kde_kwargs={'hdi_probs': [0.05, 0.25, 0.5, 0.75, 0.95],\n",
    "                         'contour_kwargs': {\"colors\":None, \"cmap\":plt.cm.cividis},\n",
    "                         'contourf_kwargs': {\"alpha\":0}},\n",
    "             marginal_kwargs={'color': 'orange'},\n",
    "             textsize=24,\n",
    "            )\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"sir_snl_theta_posterior.pdf\", bbox_inches='tight')\n",
    "# plt.xlabel(rf\"$\\theta_1$\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set the default font to Times New Roman\n",
    "plt.rcParams['font.family'] = 'serif'\n",
    "plt.rcParams['font.serif'] = ['Times New Roman']\n",
    "plt.rcParams['mathtext.fontset'] = 'cm'\n",
    "\n",
    "\n",
    "rng_key = random.PRNGKey(6)\n",
    "prior_samples = random.laplace(rng_key, shape=(10000, 2))\n",
    "\n",
    "for i in range(6):\n",
    "    az.plot_dist(adj_params[:, i],\n",
    "                 label='Posterior',\n",
    "                 color='black')\n",
    "    az.plot_dist(prior_samples[:, i],\n",
    "                 color=mcolors.CSS4_COLORS['limegreen'],\n",
    "                 plot_kwargs={'linestyle': 'dashed'},\n",
    "                 label='Prior')\n",
    "\n",
    "    plt.xlabel(f\"$\\gamma_{i+1}$\", fontsize=25)\n",
    "    plt.ylabel(\"Density\", fontsize=25)\n",
    "    plt.xlim([-10, 10])\n",
    "    plt.xticks([-10, -5, 0, 5, 10], fontsize=25)\n",
    "    plt.yticks(fontsize=25)\n",
    "    plt.ylim(bottom=0)\n",
    "    plt.legend(fontsize=25,\n",
    "               loc='upper left',\n",
    "               borderpad=0.1, labelspacing=0.1, handletextpad=0.1)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f'sir_adj_param_{i+1}.pdf', bbox_inches='tight')\n",
    "    plt.clf()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
