{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import jax.numpy as jnp\n",
    "import jax.random as random\n",
    "from scipy.stats import gaussian_kde\n",
    "# from rsnl.metrics import plot_and_save_coverage\n",
    "from rsnl.examples.contaminated_slcp import calculate_summary_statistics, true_dgp\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle as pkl\n",
    "import arviz as az\n",
    "import matplotlib.colors as mcolors\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# check x_obs for seed_9\n",
    "seed = 1\n",
    "rng_key = random.PRNGKey(seed)\n",
    "rng_key, sub_key  = random.split(rng_key)\n",
    "true_params = jnp.array([0.7, -2.9, -1.0, -0.9, 0.6])\n",
    "# true_params = prior.sample(sub_key1)\n",
    "x_obs = calculate_summary_statistics(true_dgp(sub_key, *true_params))\n",
    "x_obs = jnp.around(x_obs, 2)\n",
    "print(f'x_obs: {x_obs}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"../res/contaminated_slcp/rsnl/seed_1/theta.pkl\", \"rb\") as f:\n",
    "    theta_draws_rsnl = jnp.array(pd.read_pickle(f))\n",
    "\n",
    "thetas_rsnl = jnp.concatenate(theta_draws_rsnl, axis=0)\n",
    "\n",
    "with open(\"../res/contaminated_slcp/snl/seed_1/theta.pkl\", \"rb\") as f:\n",
    "    theta_draws_snl = jnp.array(pd.read_pickle(f))\n",
    "\n",
    "thetas_snl = jnp.concatenate(theta_draws_snl, axis=0)\n",
    "\n",
    "with open(\"../res/contaminated_slcp/well_specified_snl/seed_1/theta.pkl\", \"rb\") as f:\n",
    "    theta_draws_well_specified_snl = jnp.array(pd.read_pickle(f))\n",
    "\n",
    "theta_well_specified_snl = jnp.concatenate(theta_draws_well_specified_snl, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rsnl_theta_plot = {}\n",
    "snl_theta_plot = {}\n",
    "snl_well_specified_theta_plot = {}\n",
    "\n",
    "for i in range(5):\n",
    "    rsnl_theta_plot['theta' + str(i+1)] = thetas_rsnl[:, i]\n",
    "    snl_theta_plot['theta' + str(i+1)] = thetas_snl[:, i]\n",
    "    snl_well_specified_theta_plot['theta' + str(i+1)] = theta_well_specified_snl[:, i]\n",
    "\n",
    "var_name_map = {}\n",
    "reference_values = {}\n",
    "for ii, k in enumerate(rsnl_theta_plot):\n",
    "    var_name_map[k] = fr'$\\{k[:-1]}_{k[-1]}$'\n",
    "    reference_values[var_name_map[k]] = true_params[ii]  # why does ref_vals match labels and not data? ah well"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams['font.family'] = 'serif'\n",
    "plt.rcParams['font.serif'] = ['Times New Roman']\n",
    "plt.rcParams.update({'font.size': 25})\n",
    "plt.rcParams['xtick.labelsize'] = 25\n",
    "plt.rcParams['axes.labelsize'] = 25\n",
    "plt.rcParams[\"axes.unicode_minus\"] = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(5, 5, sharey=False, figsize=(16, 16))\n",
    "\n",
    "old_axes = np.empty((5, 5), dtype='object')\n",
    "\n",
    "for i in range(5):\n",
    "    for j in range(5):\n",
    "        old_axes[i][j] = axes[i][j].axes\n",
    "axes = az.plot_pair(rsnl_theta_plot,\n",
    "                    kind='kde',\n",
    "                    reference_values=reference_values,\n",
    "                    reference_values_kwargs={'color': 'red', 'marker': 'X', 'markersize': 12},\n",
    "                    kde_kwargs={'hdi_probs': [0.05, 0.25, 0.5, 0.75, 0.95],\n",
    "                                'contour_kwargs': {\"colors\":None, \"cmap\":plt.cm.viridis},\n",
    "                                'contourf_kwargs': {\"alpha\":0}},\n",
    "                    ax=axes,\n",
    "                    labeller=az.labels.MapLabeller(var_name_map=var_name_map,),\n",
    "                    # textsize=18,\n",
    "                    marginals=True,\n",
    "                    marginal_kwargs={'label': 'RSNL'},\n",
    "                    # show=False\n",
    "                    # figsize=(64, 64)\n",
    "                    )\n",
    "\n",
    "curr_fig = plt.gcf()\n",
    "for ii, ax_ii in enumerate(axes):\n",
    "    for jj, ax_jj in enumerate(ax_ii):\n",
    "        if ii == jj:  # only marginal for now\n",
    "            # az.plot_density(snl_misspec_theta_plot['theta' + str(ii+1)], ax=ax_jj, colors='orange', point_estimate=None)\n",
    "            # az.plot_density(snl_correct_theta_plot['theta' + str(ii+1)], ax=ax_jj, colors='black', point_estimate=None)\n",
    "            # xs = np.arange(-3, 3, 100)\n",
    "            # kde_jj = kde(snl_correct_theta_plot['theta' + str(ii+1)])\n",
    "            # ax_jj.plot(xs, kde_jj(xs), color='black')\n",
    "            az.plot_dist(snl_theta_plot['theta' + str(ii+1)],ax=ax_jj, color='orange', plot_kwargs={'linestyle': 'dashed'}, label='SNL-incompatible')\n",
    "            az.plot_dist(snl_well_specified_theta_plot['theta' + str(ii+1)],ax=ax_jj, color='black', plot_kwargs={'linestyle': 'dotted'}, label='SNL-compatible')\n",
    "            if ii != 0:\n",
    "                ax_jj.get_legend().remove()\n",
    "            else:\n",
    "                ax_jj.legend(bbox_to_anchor=(5.5, 1.5), ncol=3,\n",
    "                            #  mode=\"expand\", borderaxespad=0.\n",
    "                             )\n",
    "            ax_jj.axvline(x=true_params[jj], color='red', linestyle='dashed')\n",
    "        if ii < jj:\n",
    "            print('jj')\n",
    "            # fig.add_subplot()\n",
    "            ax_jj._remove_method = None\n",
    "            ax_jj.figure = curr_fig\n",
    "            ax_jj.set_xlim(-3, 3)\n",
    "            ax_jj.set_ylim(-3, 3)\n",
    "            # print(\"Axis type: \", type(ax_jj))\n",
    "            ax_jj.set_visible(True)\n",
    "            # ax_jj.axis('on')\n",
    "            ax_jj = az.plot_kde(snl_theta_plot['theta' + str(jj+1)],\n",
    "                                snl_theta_plot['theta' + str(ii+1)],\n",
    "                                ax=ax_jj,\n",
    "                                hdi_probs=[0.05, 0.25, 0.5, 0.75, 0.95],\n",
    "                                contour_kwargs={\"colors\": None, \"cmap\":plt.cm.cividis, 'linestyles': 'dashed'},\n",
    "                                contourf_kwargs={\"alpha\":0}, show=False)\n",
    "            ax_jj.plot(true_params[jj], true_params[ii], color='red', marker= 'X', markersize=12, markeredgecolor='k')\n",
    "            ax_jj.axes = old_axes[ii, jj]\n",
    "            ax_jj.axes.get_xaxis().set_visible(False)\n",
    "            ax_jj.axes.get_yaxis().set_visible(False)\n",
    "            # ax_jj.set_visible(False)\n",
    "            fig.add_subplot(ax_jj)\n",
    "            # plt.add_axes(ax)\n",
    "            # axes[ii, jj] = ax\n",
    "        #     axes[ii, jj] = axes2[jj, ii]\n",
    "            # ax_jj.plot(x=0, y=0)\n",
    "# plt.axes().get_xaxis().set_visible(True)\n",
    "# plt.axes().get_yaxis().set_visible(True)\n",
    "for i in range(5):\n",
    "    for j in range(5):\n",
    "        if i != j:\n",
    "            axes[i][j].set_ylim(-3, 3)\n",
    "            axes[i][j].set_yticks([-3, 0, 3])\n",
    "        else:\n",
    "            if i == 0:\n",
    "                axes[i][j].set_yticks([0, 3, 6])\n",
    "                axes[i][j].set_ylim(0, 6)\n",
    "        axes[i][j].set_xlim(-3, 3)\n",
    "        axes[i][j].set_xticks([-3, 0, 3])\n",
    "\n",
    "# plt.show()\n",
    "plt.subplots_adjust(wspace=0.25)  # adjust the space between the subplots\n",
    "plt.savefig(\"slcp_joint_all.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"../res/contaminated_slcp/rsnl/seed_1/adj_params.pkl\", \"rb\") as f:\n",
    "    adj_params = jnp.array(pkl.load(f))\n",
    "\n",
    "adj_params = jnp.concatenate(adj_params, axis=0)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set the default font to Times New Roman\n",
    "plt.rcParams['font.family'] = 'serif'\n",
    "plt.rcParams['font.serif'] = ['Times New Roman']\n",
    "plt.rcParams['mathtext.fontset'] = 'cm'\n",
    "plt.rcParams.update({'font.size': 35})"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Adjustment Plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rng_key = random.PRNGKey(0)\n",
    "prior_samples = random.laplace(rng_key, shape=(10000, 2))\n",
    "\n",
    "for i in range(10):\n",
    "    az.plot_dist(adj_params[:, i],\n",
    "                 label='Posterior',\n",
    "                 color='black')\n",
    "    az.plot_dist(prior_samples[:, i],\n",
    "                 color=mcolors.CSS4_COLORS['limegreen'],\n",
    "                 plot_kwargs={'linestyle': 'dashed'},\n",
    "                 label='Prior')\n",
    "    plt.xlabel(\"$\\gamma_{%s}$\" % (i+1), fontsize=35)\n",
    "    plt.ylabel(\"Density\", fontsize=35)\n",
    "    plt.xlim([-10, 10])\n",
    "    plt.ylim(bottom=0)\n",
    "    plt.xticks([-10, -5, 0, 5, 10], fontsize=35)\n",
    "    plt.yticks(fontsize=35)\n",
    "    plt.legend(fontsize=35,\n",
    "               loc='upper left',\n",
    "               borderpad=0.1, labelspacing=0.1, handletextpad=0.1)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f'contaminated_slcp_adj_param_{i+1}.pdf', bbox_inches='tight')\n",
    "    plt.clf()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
