{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import jax.numpy as jnp\n",
    "import jax.random as random\n",
    "from scipy.stats import gaussian_kde\n",
    "from rsnl.examples.toad import calculate_summary_statistics, dgp\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle as pkl\n",
    "import arviz as az\n",
    "import matplotlib.colors as mcolors\n",
    "import pickle as pkl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set the default font to Times New Roman\n",
    "plt.rcParams['font.family'] = 'serif'\n",
    "plt.rcParams['font.serif'] = ['Times New Roman']\n",
    "plt.rcParams['mathtext.fontset'] = 'cm'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"../res/toad/rsnl/seed_0/thetas.pkl\", \"rb\") as f:\n",
    "    theta_draws_rsnl = jnp.array(pkl.load(f))\n",
    "\n",
    "thetas_rsnl = jnp.concatenate(theta_draws_rsnl, axis=0)\n",
    "thetas_rsnl = jnp.squeeze(thetas_rsnl)\n",
    "\n",
    "with open(\"../res/toad/rsnl/seed_0/adj_params.pkl\", \"rb\") as f:\n",
    "    adj_params = jnp.array(pkl.load(f))\n",
    "\n",
    "adj_params = jnp.concatenate(adj_params, axis=0)\n",
    "\n",
    "with open(\"../res/toad/snl/seed_0/thetas.pkl\", \"rb\") as f:\n",
    "    theta_draws_snl = jnp.array(pkl.load(f))\n",
    "\n",
    "thetas_snl = jnp.concatenate(theta_draws_snl, axis=0)\n",
    "thetas_snl = jnp.squeeze(thetas_snl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "true_params = jnp.array([1.8, 45.0, 0.8])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rsnl_theta_plot = {}\n",
    "snl_theta_plot = {}\n",
    "\n",
    "for i in range(3):\n",
    "    rsnl_theta_plot['theta' + str(i+1)] = thetas_rsnl[ :, i]\n",
    "    snl_theta_plot['theta' + str(i+1)] = thetas_snl[:, i]\n",
    "\n",
    "\n",
    "var_name_map = {}\n",
    "reference_values = {}\n",
    "labels = [r'$\\alpha_{\\mathrm{toad}}$', r'$\\delta$', r'$p_0$']\n",
    "for ii, k in enumerate(rsnl_theta_plot):\n",
    "    var_name_map[k] = labels[ii]\n",
    "    reference_values[var_name_map[k]] = true_params[ii]  # why does ref_vals match labels and not data? ah well"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "az.plot_pair(rsnl_theta_plot,\n",
    "             kind='kde',\n",
    "            #  reference_values=reference_values,\n",
    "             marginals=True,\n",
    "             labeller=az.labels.MapLabeller(var_name_map=var_name_map),\n",
    "            #  reference_values_kwargs={'color': 'red', 'marker': 'X', 'markersize': 12},\n",
    "             kde_kwargs={'hdi_probs': [0.05, 0.25, 0.5, 0.75, 0.95],\n",
    "                         'contour_kwargs': {\"colors\":None, \"cmap\":plt.cm.viridis},\n",
    "                         'contourf_kwargs': {\"alpha\":0}},\n",
    "             textsize=24,\n",
    "            )\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"plots/toad_theta_posterior.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "az.plot_pair(snl_theta_plot,\n",
    "             kind='kde',\n",
    "            #  reference_values=reference_values,\n",
    "             marginals=True,\n",
    "             labeller=az.labels.MapLabeller(var_name_map=var_name_map),\n",
    "            #  reference_values_kwargs={'color': 'red', 'marker': 'X', 'markersize': 12},\n",
    "             kde_kwargs={'hdi_probs': [0.05, 0.25, 0.5, 0.75, 0.95],\n",
    "                         'contour_kwargs': {\"colors\":None, \"cmap\":plt.cm.cividis},\n",
    "                         'contourf_kwargs': {\"alpha\":0}},\n",
    "             marginal_kwargs={'color': 'orange'},\n",
    "             textsize=24,\n",
    "            )\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"plots/toad_snl_theta_posterior.pdf\", bbox_inches='tight')\n",
    "# plt.xlabel(rf\"$\\theta_1$\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../res/external_res/rnpe_toad_res.pkl', 'rb') as f:\n",
    "    res = pkl.load(f)\n",
    "\n",
    "import scipy.io\n",
    "thetas_rbsl = scipy.io.loadmat('../res/external_res/results_bsl_model2_realdata_mean_n500.mat')['theta']\n",
    "\n",
    "thetas_rnpe = res['posterior_samples']['RNPE']\n",
    "\n",
    "rnpe_theta_plot = {}\n",
    "rbsl_theta_plot = {}\n",
    "for i in range(3):\n",
    "    rnpe_theta_plot['theta' + str(i+1)] = thetas_rnpe[ :, 0, i]\n",
    "    rbsl_theta_plot['theta' + str(i+1)] = thetas_rbsl[ :, i]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "az.plot_pair(rnpe_theta_plot,\n",
    "             kind='kde',\n",
    "            #  reference_values=reference_values,\n",
    "             marginals=True,\n",
    "             labeller=az.labels.MapLabeller(var_name_map=var_name_map),\n",
    "            #  reference_values_kwargs={'color': 'red', 'marker': 'X', 'markersize': 12},\n",
    "             kde_kwargs={'hdi_probs': [0.05, 0.25, 0.5, 0.75, 0.95],\n",
    "                         'contour_kwargs': {\"colors\":None, \"cmap\":plt.cm.viridis},\n",
    "                         'contourf_kwargs': {\"alpha\":0}},\n",
    "             textsize=24,\n",
    "            )\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"plots/toad_rnpe_theta_posterior.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "az.plot_pair(rbsl_theta_plot,\n",
    "             kind='kde',\n",
    "            #  reference_values=reference_values,\n",
    "             marginals=True,\n",
    "             labeller=az.labels.MapLabeller(var_name_map=var_name_map),\n",
    "            #  reference_values_kwargs={'color': 'red', 'marker': 'X', 'markersize': 12},\n",
    "             kde_kwargs={'hdi_probs': [0.05, 0.25, 0.5, 0.75, 0.95],\n",
    "                         'contour_kwargs': {\"colors\":None, \"cmap\":plt.cm.viridis},\n",
    "                         'contourf_kwargs': {\"alpha\":0}},\n",
    "             textsize=24,\n",
    "            )\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"plots/toad_rbsl_theta_posterior.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "az.plot_pair(snl_theta_plot,\n",
    "             kind='kde',\n",
    "             reference_values=reference_values,\n",
    "             marginals=True,\n",
    "             labeller=az.labels.MapLabeller(var_name_map=var_name_map),\n",
    "             reference_values_kwargs={'color': 'red', 'marker': 'X', 'markersize': 12},\n",
    "             kde_kwargs={'hdi_probs': [0.05, 0.25, 0.5, 0.75, 0.95],\n",
    "                         'contour_kwargs': {\"colors\":None, \"cmap\":plt.cm.cividis},\n",
    "                         'contourf_kwargs': {\"alpha\":0}},\n",
    "             marginal_kwargs={'color': 'orange'},\n",
    "             textsize=24,\n",
    "            )\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"plots/toad_snl_theta_posterior.pdf\", bbox_inches='tight')\n",
    "# plt.xlabel(rf\"$\\theta_1$\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set the default font to Times New Roman\n",
    "plt.rcParams['font.family'] = 'serif'\n",
    "plt.rcParams['font.serif'] = ['Times New Roman']\n",
    "plt.rcParams['mathtext.fontset'] = 'cm'\n",
    "\n",
    "\n",
    "rng_key = random.PRNGKey(48)\n",
    "prior_samples = random.laplace(rng_key, shape=(10000, 2))\n",
    "\n",
    "for i in range(48):\n",
    "    az.plot_dist(adj_params[:, i],\n",
    "                 label='Posterior',\n",
    "                 color='black')\n",
    "    az.plot_dist(prior_samples[:, i],\n",
    "                 color=mcolors.CSS4_COLORS['limegreen'],\n",
    "                 plot_kwargs={'linestyle': 'dashed'},\n",
    "                 label='Prior')\n",
    "\n",
    "    plt.xlabel(\"$\\gamma_{%s}$\" % (i+1), fontsize=25)\n",
    "    plt.ylabel(\"Density\", fontsize=25)\n",
    "    plt.xlim([-10, 10])\n",
    "    plt.xticks([-10, -5, 0, 5, 10], fontsize=25)\n",
    "    plt.yticks(fontsize=25)\n",
    "    plt.ylim(bottom=0)\n",
    "    plt.legend(fontsize=25,\n",
    "               loc='upper left',\n",
    "               borderpad=0.1, labelspacing=0.1, handletextpad=0.1)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f'plots/toad_adj_param_{i+1}.pdf', bbox_inches='tight')\n",
    "    plt.clf()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams['font.family'] = 'serif'\n",
    "plt.rcParams['font.serif'] = ['Times New Roman']\n",
    "plt.rcParams.update({'font.size': 25})\n",
    "plt.rcParams['xtick.labelsize'] = 25\n",
    "plt.rcParams['axes.labelsize'] = 25\n",
    "plt.rcParams[\"axes.unicode_minus\"] = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(3, 3, sharey=False, figsize=(16, 16))\n",
    "\n",
    "old_axes = np.empty((3, 3), dtype='object')\n",
    "\n",
    "lower_bounds = [1.0, 20.0, 0.4]\n",
    "upper_bounds = [2.0, 70.0, 0.9]\n",
    "\n",
    "for i in range(3):\n",
    "    for j in range(3):\n",
    "        axes[i][j].set_xlim([lower_bounds[j]-0.1, upper_bounds[j]+0.1])\n",
    "        old_axes[i][j] = axes[i][j].axes\n",
    "\n",
    "axes = az.plot_pair(rsnl_theta_plot,\n",
    "                    kind='kde',\n",
    "                    kde_kwargs={'hdi_probs': [0.05, 0.25, 0.5, 0.75, 0.95],\n",
    "                                'contour_kwargs': {\"colors\":None, \"cmap\":plt.cm.viridis},\n",
    "                                'contourf_kwargs': {\"alpha\":0}},\n",
    "                    ax=axes,\n",
    "                    labeller=az.labels.MapLabeller(var_name_map=var_name_map,),\n",
    "                    marginals=True,\n",
    "                    marginal_kwargs={'label': 'RSNL'},\n",
    "                    )\n",
    "\n",
    "curr_fig = plt.gcf()\n",
    "for ii, ax_ii in enumerate(axes):\n",
    "    for jj, ax_jj in enumerate(ax_ii):\n",
    "        if ii == jj:  # only marginal for now\n",
    "            az.plot_dist(snl_theta_plot['theta' + str(ii+1)],\n",
    "                         ax=ax_jj,\n",
    "                         color='orange',\n",
    "                         plot_kwargs={'linestyle': 'dashed'},\n",
    "                         label='SNL'\n",
    "                         )\n",
    "            az.plot_dist(rbsl_theta_plot['theta' + str(ii+1)],\n",
    "                         ax=ax_jj,\n",
    "                         bw='silverman',\n",
    "                         color='black',\n",
    "                         plot_kwargs={'linestyle': 'dotted'},\n",
    "                         label='RBSL')\n",
    "            if ii != 0:\n",
    "                ax_jj.get_legend().remove()\n",
    "            else:\n",
    "                ax_jj.legend(bbox_to_anchor=(2.75, 1.25), ncol=3)\n",
    "        if ii < jj:\n",
    "            ax_jj._remove_method = None\n",
    "            ax_jj.figure = curr_fig\n",
    "            ax_jj.set_visible(True)\n",
    "            ax_jj = az.plot_kde(snl_theta_plot['theta' + str(jj+1)],\n",
    "                                snl_theta_plot['theta' + str(ii+1)],\n",
    "                                ax=ax_jj,\n",
    "                                hdi_probs=[0.05, 0.25, 0.5, 0.75, 0.95],\n",
    "                                contour_kwargs={\"colors\": None, \"cmap\":plt.cm.cividis, 'linestyles': 'dashed'},\n",
    "                                contourf_kwargs={\"alpha\":0}, show=False)\n",
    "            ax_jj.axes = old_axes[ii, jj]\n",
    "            ax_jj.axes.get_xaxis().set_visible(False)\n",
    "            ax_jj.axes.get_yaxis().set_visible(False)\n",
    "            fig.add_subplot(ax_jj)\n",
    "for i in range(3):\n",
    "    for j in range(3):\n",
    "        if i != j:\n",
    "            axes[i][j].set_xticks([lower_bounds[j], (lower_bounds[j] + upper_bounds[j])/2, upper_bounds[j]])\n",
    "            axes[i][j].set_yticks([lower_bounds[i], (lower_bounds[i] + upper_bounds[i])/2, upper_bounds[i]])\n",
    "        if i == 0 and j == 1:\n",
    "            axes[i][j].get_yaxis().set_visible(True)\n",
    "            axes[i][j].tick_params(left=True, labelleft=True)\n",
    "        else:\n",
    "            if i == 0:\n",
    "                axes[i][j].set_yticks([])\n",
    "                axes[i][j].tick_params(left=False, labelleft=False)\n",
    "            if i == 2:\n",
    "                axes[i][j].set_xticks([lower_bounds[j], (lower_bounds[j] + upper_bounds[j])/2, upper_bounds[j]])\n",
    "\n",
    "plt.subplots_adjust(wspace=0.25)  # adjust the space between the subplots\n",
    "plt.savefig(\"plots/toad_joint_real_data.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up sim data\n",
    "with open(\"../res/toad/rsnl_sim/seed_0/thetas.pkl\", \"rb\") as f:\n",
    "    theta_sim_draws_rsnl = jnp.array(pkl.load(f))\n",
    "\n",
    "thetas_sim_rsnl = jnp.concatenate(theta_sim_draws_rsnl, axis=0)\n",
    "thetas_sim_rsnl = jnp.squeeze(thetas_sim_rsnl)\n",
    "\n",
    "with open(\"../res/toad/rsnl_sim/seed_0/adj_params.pkl\", \"rb\") as f:\n",
    "    adj_params_sim = jnp.array(pkl.load(f))\n",
    "\n",
    "adj_params_sim = jnp.concatenate(adj_params_sim, axis=0)\n",
    "\n",
    "with open(\"../res/toad/snl_sim/seed_0/thetas.pkl\", \"rb\") as f:\n",
    "    theta_sim_draws_snl = jnp.array(pkl.load(f))\n",
    "\n",
    "thetas_sim_snl = jnp.concatenate(theta_sim_draws_snl, axis=0)\n",
    "thetas_sim_snl = jnp.squeeze(thetas_sim_snl)\n",
    "\n",
    "thetas_sim_rbsl = scipy.io.loadmat('../res/external_res/results_bsl_model2_simdata_variance_n300.mat')['theta']\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rsnl_sim_theta_plot = {}\n",
    "snl_sim_theta_plot = {}\n",
    "rbsl_sim_theta_plot = {}\n",
    "\n",
    "for i in range(3):\n",
    "    rsnl_sim_theta_plot['theta' + str(i+1)] = thetas_sim_rsnl[ :, i]\n",
    "    snl_sim_theta_plot['theta' + str(i+1)] = thetas_sim_snl[:, i]\n",
    "    rbsl_sim_theta_plot['theta' + str(i+1)] = thetas_sim_rbsl[:, i]\n",
    "\n",
    "var_name_map = {}\n",
    "reference_values = {}\n",
    "labels = [r'$\\alpha_{\\mathrm{toad}}$', r'$\\delta$', r'$p_0$']\n",
    "for ii, k in enumerate(rsnl_sim_theta_plot):\n",
    "    var_name_map[k] = labels[ii]\n",
    "    reference_values[var_name_map[k]] = true_params[ii]  # why does ref_vals match labels and not data? ah well"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sim posterior plot\n",
    "fig, axes = plt.subplots(3, 3, sharey=False, figsize=(16, 16))\n",
    "\n",
    "old_axes = np.empty((3, 3), dtype='object')\n",
    "lower_bounds = [1.6, 35.0, 0.5]\n",
    "upper_bounds = [2.0, 55.0, 0.7]\n",
    "\n",
    "for i in range(3):\n",
    "    for j in range(3):\n",
    "        axes[i][j].set_xlim([lower_bounds[j]-0.05, upper_bounds[j]+0.05])\n",
    "        old_axes[i][j] = axes[i][j].axes\n",
    "\n",
    "\n",
    "axes = az.plot_pair(rsnl_sim_theta_plot,\n",
    "                    kind='kde',\n",
    "                    reference_values=reference_values,\n",
    "                    reference_values_kwargs={'color': 'red', 'marker': 'X', 'markersize': 12},\n",
    "                    kde_kwargs={'hdi_probs': [0.05, 0.25, 0.5, 0.75, 0.95],\n",
    "                                'contour_kwargs': {\"colors\":None, \"cmap\":plt.cm.viridis},\n",
    "                                'contourf_kwargs': {\"alpha\":0}},\n",
    "                    ax=axes,\n",
    "                    labeller=az.labels.MapLabeller(var_name_map=var_name_map,),\n",
    "                    marginals=True,\n",
    "                    marginal_kwargs={'label': 'RSNL'},\n",
    "                    )\n",
    "\n",
    "curr_fig = plt.gcf()\n",
    "for ii, ax_ii in enumerate(axes):\n",
    "    for jj, ax_jj in enumerate(ax_ii):\n",
    "        if ii == jj:  # only marginal for now\n",
    "            az.plot_dist(snl_sim_theta_plot['theta' + str(ii+1)],ax=ax_jj, color='orange', plot_kwargs={'linestyle': 'dashed'}, label='SNL')\n",
    "            az.plot_dist(rbsl_sim_theta_plot['theta' + str(ii+1)],ax=ax_jj, bw='silverman', color='black', plot_kwargs={'linestyle': 'dotted'}, label='RBSL')\n",
    "            if ii != 0:\n",
    "                ax_jj.get_legend().remove()\n",
    "            else:\n",
    "                ax_jj.legend(bbox_to_anchor=(2.75, 1.25), ncol=3)\n",
    "            ax_jj.axvline(x=true_params[jj], color='red', linestyle='dashed')\n",
    "        if ii < jj:\n",
    "            print('jj')\n",
    "            ax_jj._remove_method = None\n",
    "            ax_jj.figure = curr_fig\n",
    "            ax_jj.set_visible(True)\n",
    "            ax_jj = az.plot_kde(snl_sim_theta_plot['theta' + str(jj+1)],\n",
    "                                snl_sim_theta_plot['theta' + str(ii+1)],\n",
    "                                ax=ax_jj,\n",
    "                                hdi_probs=[0.05, 0.25, 0.5, 0.75, 0.95],\n",
    "                                contour_kwargs={\"colors\": None, \"cmap\":plt.cm.cividis, 'linestyles': 'dashed'},\n",
    "                                contourf_kwargs={\"alpha\":0}, show=False)\n",
    "            ax_jj.plot(true_params[jj], true_params[ii], color='red', marker= 'X', markersize=12, markeredgecolor='k')\n",
    "            ax_jj.axes = old_axes[ii, jj]\n",
    "            ax_jj.axes.get_xaxis().set_visible(False)\n",
    "            ax_jj.axes.get_yaxis().set_visible(False)\n",
    "            fig.add_subplot(ax_jj)\n",
    "for i in range(3):\n",
    "    for j in range(3):\n",
    "        if i != j:\n",
    "            print('i: ', i, 'j: ', j)\n",
    "            axes[i][j].set_xticks([lower_bounds[j], (lower_bounds[j] + upper_bounds[j])/2, upper_bounds[j]])\n",
    "            axes[i][j].set_yticks([lower_bounds[i], (lower_bounds[i] + upper_bounds[i])/2, upper_bounds[i]])\n",
    "        if i == 0 and j == 1:\n",
    "            axes[i][j].get_yaxis().set_visible(True)\n",
    "            axes[i][j].tick_params(left=True, labelleft=True)\n",
    "        else:\n",
    "            if i == 0:\n",
    "                axes[i][j].set_yticks([])\n",
    "                axes[i][j].tick_params(left=False, labelleft=False)\n",
    "            if i == 2:\n",
    "                axes[i][j].set_xticks([lower_bounds[j], (lower_bounds[j] + upper_bounds[j])/2, upper_bounds[j]])\n",
    "\n",
    "plt.subplots_adjust(wspace=0.25)  # adjust the space between the subplots\n",
    "plt.savefig(\"plots/toad_joint_sim_data.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Posterior Predictive Plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy.io\n",
    "from functools import partial"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "thetas_rbsl = scipy.io.loadmat('../res/external_res/results_bsl_model2_realdata_mean_n500.mat')['theta']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_real_xobs():\n",
    "    df = scipy.io.loadmat('../rsnl/examples/data/radio_converted.mat')['Y']\n",
    "    nan_idx = jnp.isnan(df)\n",
    "    df = jnp.array(df)\n",
    "\n",
    "    x_obs = calculate_summary_statistics(df, real_data=True, nan_idx=nan_idx)\n",
    "\n",
    "    sum_fn = partial(calculate_summary_statistics, real_data=True,\n",
    "                     nan_idx=nan_idx)\n",
    "    return x_obs, sum_fn\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rng_key = random.PRNGKey(0)\n",
    "rng_key, sub_key1, sub_key2 = random.split(rng_key, 3)\n",
    "sum_fn = calculate_summary_statistics\n",
    "x_obs, sum_fn = get_real_xobs()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # first just do first summary\n",
    "# NOTE: CAN COMMENT THIS OUT IF YOU HAVE ALREADY RUN IT\n",
    "thetas_rnpe = thetas_rnpe.reshape((10000, -1))\n",
    "\n",
    "n_sims = 1_000\n",
    "rsnl_ssx = np.zeros((48, n_sims))\n",
    "rnpe_ssx = np.zeros((48, n_sims))\n",
    "snl_ssx = np.zeros((48, n_sims))\n",
    "rbsl_ssx = np.zeros((48, n_sims))\n",
    "\n",
    "stride = 10\n",
    "\n",
    "for i in range(n_sims):\n",
    "    if i % 100 == 0:\n",
    "        print('i: ', i)\n",
    "    key, sub_key1, sub_key2, sub_key3, sub_key4 = random.split(rng_key, 5)\n",
    "    idx = i * stride\n",
    "    x_rsnl = dgp(sub_key1, thetas_rsnl[idx, 0], thetas_rsnl[idx, 1], thetas_rsnl[idx,2], model=2)\n",
    "    rsnl_ssx[:, i] = sum_fn(x_rsnl)\n",
    "\n",
    "    x_rnpe = dgp(sub_key2, thetas_rnpe[idx, 0], thetas_rnpe[idx, 1], thetas_rnpe[idx,2], model=2)\n",
    "    rnpe_ssx[:, i] = sum_fn(x_rnpe)\n",
    "\n",
    "    x_snl = dgp(sub_key3, thetas_snl[idx, 0], thetas_snl[idx, 1], thetas_snl[idx,2], model=2)\n",
    "    snl_ssx[:, i] = sum_fn(x_snl)\n",
    "\n",
    "    x_rbsl = dgp(sub_key4, thetas_rbsl[idx, 0], thetas_rbsl[idx, 1], thetas_rbsl[idx,2], model=2)\n",
    "    rbsl_ssx[:, i] = sum_fn(x_rbsl)\n",
    "\n",
    "\n",
    "pkl.dump(rsnl_ssx, open('rsnl_ssx.pkl', 'wb'))\n",
    "pkl.dump(rnpe_ssx, open('rnpe_ssx.pkl', 'wb'))\n",
    "pkl.dump(snl_ssx, open('snl_ssx.pkl', 'wb'))\n",
    "pkl.dump(rbsl_ssx, open('rbsl_ssx.pkl', 'wb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rsnl_ssx = pkl.load(open('rsnl_ssx.pkl', 'rb'))\n",
    "rnpe_ssx = pkl.load(open('rnpe_ssx.pkl', 'rb'))\n",
    "snl_ssx = pkl.load(open('snl_ssx.pkl', 'rb'))\n",
    "rbsl_ssx = pkl.load(open('rbsl_ssx.pkl', 'rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(48):\n",
    "    rsnl_first_summary = rsnl_ssx[i, :].flatten()\n",
    "    rsnl_first_summary = rsnl_first_summary[np.where(np.abs(rsnl_first_summary) < 10000)]\n",
    "    rsnl_xs = np.linspace(np.min(rsnl_first_summary), np.max(rsnl_first_summary), 100)\n",
    "    kde = gaussian_kde(rsnl_first_summary)\n",
    "    plt.plot(rsnl_xs, kde(rsnl_xs), label='RSNL')\n",
    "\n",
    "    rnpe_first_summary = rnpe_ssx[i, :].flatten()\n",
    "    rnpe_first_summary = rnpe_first_summary[np.where(np.abs(rnpe_first_summary) < 10000)]\n",
    "    rnpe_xs = np.linspace(np.min(rnpe_first_summary), np.max(rnpe_first_summary), 100)\n",
    "    kde = gaussian_kde(rnpe_first_summary)\n",
    "    plt.plot(rnpe_xs, kde(rnpe_xs), label='RNPE')\n",
    "\n",
    "    snl_first_summary = snl_ssx[i, :].flatten()\n",
    "    snl_first_summary = snl_first_summary[np.where(np.abs(snl_first_summary) < 10000)]\n",
    "    snl_xs = np.linspace(np.min(snl_first_summary), np.max(snl_first_summary), 100)\n",
    "    kde = gaussian_kde(snl_first_summary)\n",
    "    plt.plot(snl_xs, kde(snl_xs), label='SNL')\n",
    "\n",
    "    rbsl_summary = rbsl_ssx[i, :].flatten()\n",
    "    rbsl_summary = rbsl_summary[np.where(np.abs(rbsl_summary) < 10000)]\n",
    "    rbsl_xs = np.linspace(np.min(rbsl_summary), np.max(rbsl_summary), 100)\n",
    "    kde = gaussian_kde(rbsl_summary)\n",
    "    plt.plot(rbsl_xs, kde(rbsl_xs), label='RBSL')\n",
    "\n",
    "    plt.plot(x_obs[i], 0, 'x', color='red')\n",
    "    plt.legend()\n",
    "    plt.savefig(f\"plots/ppc_{str(i)}.pdf\", bbox_inches=\"tight\")\n",
    "    plt.clf()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# RBF kernel function\n",
    "def rbf_kernel(x, y, lengthscale=1.0):\n",
    "    return jnp.exp(-jnp.sum((x - y) ** 2) / (2 * lengthscale ** 2))\n",
    "\n",
    "# Median heuristic for lengthscale\n",
    "def median_heuristic(x):\n",
    "    pairwise_dists = jnp.sqrt(jnp.sum((x[:, :, None] - x[:, None, :]) ** 2, axis=0))\n",
    "    return jnp.sqrt(jnp.median(pairwise_dists) / 2)\n",
    "\n",
    "# MMD calculation\n",
    "def MMD_unweighted(simulated, observed, lengthscale=1):\n",
    "    l = 1000  # Number of simulated statistics to use\n",
    "    simulated = simulated[:, :l]  # Take only the first l samples\n",
    "\n",
    "    # Compute pairwise kernel values for simulated statistics\n",
    "    k_simulated = jnp.array([[rbf_kernel(x, y, lengthscale) for x in simulated.T] for y in simulated.T])\n",
    "\n",
    "    # Compute kernel values between simulated and observed statistics\n",
    "    k_sim_obs = jnp.array([rbf_kernel(x, observed, lengthscale) for x in simulated.T])\n",
    "\n",
    "    # Calculate MMD\n",
    "    mmd_value = (jnp.sum(k_simulated) / (l ** 2)) - (2 * jnp.sum(k_sim_obs) / l)\n",
    "\n",
    "    return mmd_value\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lengthscale_beta = median_heuristic(rsnl_ssx)\n",
    "\n",
    "print('lengthscale_beta: ', lengthscale_beta)\n",
    "\n",
    "rsnl_mmd = MMD_unweighted(rsnl_ssx, x_obs.reshape((-1, 1)), lengthscale_beta)\n",
    "print(\"RSNL: \", rsnl_mmd)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lengthscale_beta = median_heuristic(snl_ssx)\n",
    "MMD_unweighted(snl_ssx, x_obs.reshape((-1, 1)), lengthscale_beta)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lengthscale_beta = median_heuristic(rnpe_ssx)\n",
    "MMD_unweighted(rnpe_ssx, x_obs.reshape((-1, 1)), lengthscale_beta)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rbsl_ssx = scipy.io.loadmat('../res/external_res/ssx_rbslm.mat')['ssx_all'].T\n",
    "# ssx_bsl.shape\n",
    "rbsl_ssx = rbsl_ssx[0:10000:10]\n",
    "# lengthscale_beta = median_heuristic(rbsl_ssx)\n",
    "# MMD_unweighted(rbsl_ssx, x_obs.reshape((-1, 1)), lengthscale_beta)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# PPC for log non-ret\n",
    "# Start with observed data\n",
    "observed_df = scipy.io.loadmat('../rsnl/examples/data/radio_converted.mat')['Y']\n",
    "nan_idx = jnp.isnan(observed_df)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import cm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lags = [1, 2, 4, 8]\n",
    "X = observed_df\n",
    "x_obs = calculate_summary_statistics(observed_df)\n",
    "non_ret = {}\n",
    "thd = 10\n",
    "n_lag_sims = 100\n",
    "key = random.PRNGKey(0)\n",
    "colormap = cm.get_cmap('twilight', n_lag_sims)  # Here, n_lag_sims specifies the number of colours needed\n",
    "\n",
    "num_ret_rsnl = np.zeros((len(lags), n_lag_sims))\n",
    "for ii, lag in enumerate(lags):\n",
    "    disp = X[lag:, :] - X[:-lag, :]\n",
    "    disp = jnp.abs(disp)\n",
    "    non_ret[f'lag_{str(lag)}'] = sorted(np.log([step for step in disp.flatten() if step > thd]))\n",
    "    print(len(non_ret[f'lag_{str(lag)}']))\n",
    "    buffer = 1.0\n",
    "    if lag == 8:\n",
    "        buffer = 2.0\n",
    "    xs_obs = np.linspace(np.min(non_ret[f'lag_{str(lag)}'])-0.5, np.max(non_ret[f'lag_{str(lag)}'])+buffer, 30)\n",
    "    kde_obs = gaussian_kde(non_ret[f'lag_{str(lag)}'], bw_method='silverman')\n",
    "    for j in range(n_lag_sims):\n",
    "        key, sub_key = random.split(key)\n",
    "        X_i = dgp(sub_key, thetas_rsnl[j, 0], thetas_rsnl[j, 1], thetas_rsnl[j, 2], model=2)\n",
    "        X_i = X_i.at[nan_idx].set(jnp.nan)\n",
    "        disp_i = X_i[lag:, :] - X_i[:-lag, :]\n",
    "        disp_i = jnp.abs(disp_i)\n",
    "        ret = disp_i < thd\n",
    "        num_ret_j = np.sum(ret)\n",
    "        num_ret_rsnl[ii, j] = num_ret_j\n",
    "        non_ret_i = sorted(np.log([step for step in jnp.abs(disp_i).flatten() if step > thd]))\n",
    "        xs = np.linspace(np.min(non_ret_i)-0.5, np.max(non_ret_i)+0.5, 30)\n",
    "        kde = gaussian_kde(non_ret_i, bw_method='silverman')\n",
    "        color = colormap(j)\n",
    "        plt.plot(xs, kde(xs), label=f'lag {lag}', alpha=0.1, color=color)\n",
    "    plt.xlim([0, 10])\n",
    "    plt.xticks([0, 5, 10])\n",
    "    plt.ylim([0, 0.7])\n",
    "    plt.yticks([0, 0.2, 0.4, 0.6])\n",
    "    plt.ylabel(\"Density\")\n",
    "    plt.xlabel(\"Log distance\")\n",
    "    plt.title(f\"Lag {str(lag)}\")\n",
    "    plt.plot(xs_obs, kde(xs_obs), label=f'lag {lag}', color='black', linewidth=3)\n",
    "    plt.savefig(f'plots/ppc_lag_{str(lag)}_rsnl.pdf', bbox_inches='tight')\n",
    "    plt.clf()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.boxplot(num_ret_rsnl.T, positions=range(1, 5), widths=0.6)\n",
    "plt.scatter(range(1, 5), x_obs[::12], color='red', zorder=2, label='Observed')\n",
    "plt.xticks(ticks=range(1, 5), labels=[\"1\", \"2\", \"4\", \"8\"])  # Adding custom x-axis labels\n",
    "plt.xlabel(\"Lag\")\n",
    "plt.ylabel(\"Number returned\")\n",
    "plt.legend()\n",
    "plt.savefig(f'plots/ppc_numret_rsnl.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lags = [1, 2, 4, 8]\n",
    "X = observed_df\n",
    "non_ret = {}\n",
    "thd = 10\n",
    "n_lag_sims = 100\n",
    "num_ret_snl = np.zeros((len(lags), n_lag_sims))\n",
    "for ii, lag in enumerate(lags):\n",
    "    disp = X[lag:, :] - X[:-lag, :]\n",
    "    non_ret[f'lag_{str(lag)}'] = sorted(np.log([step for step in jnp.abs(disp).flatten() if step > thd]))\n",
    "    print(len(non_ret[f'lag_{str(lag)}']))\n",
    "    xs = np.linspace(np.min(non_ret[f'lag_{str(lag)}']), np.max(non_ret[f'lag_{str(lag)}']), 30)\n",
    "    kde = gaussian_kde(non_ret[f'lag_{str(lag)}'])\n",
    "    plt.plot(xs, kde(xs), label=f'lag {lag}')\n",
    "    for j in range(n_lag_sims):\n",
    "        key, sub_key = random.split(key)\n",
    "        X_i = dgp(sub_key, thetas_snl[j, 0], thetas_snl[j, 1], thetas_snl[j,2], model=2)\n",
    "        X_i = X_i.at[nan_idx].set(jnp.nan)\n",
    "        disp_i = X_i[lag:, :] - X_i[:-lag, :]\n",
    "        disp_i = jnp.abs(disp_i)\n",
    "        ret = disp_i < thd\n",
    "        num_ret_j = np.sum(ret)\n",
    "        num_ret_snl[ii, j] = num_ret_j\n",
    "        non_ret_i = sorted(np.log([step for step in disp_i.flatten() if step > thd]))\n",
    "        xs = np.linspace(np.min(non_ret_i), np.max(non_ret_i), 30)\n",
    "        kde = gaussian_kde(non_ret_i)\n",
    "        plt.plot(xs, kde(xs), label=f'lag {lag}', alpha=0.1)\n",
    "    plt.savefig(f'plots/ppc_lag_{str(lag)}_snl.pdf', bbox_inches='tight')\n",
    "    plt.clf()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.boxplot(num_ret_snl.T, positions=range(1, 5), widths=0.6)\n",
    "plt.scatter(range(1, 5), x_obs[::12], color='red', zorder=2, label='observed')\n",
    "plt.savefig(f'plots/ppc_numret_snl.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "thetas_rnpe = thetas_rnpe.reshape((10000, -1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lags = [1, 2, 4, 8]\n",
    "X = observed_df\n",
    "non_ret = {}\n",
    "thd = 10\n",
    "n_lag_sims = 100\n",
    "\n",
    "num_ret = np.zeros((len(lags), n_lag_sims))\n",
    "for ii, lag in enumerate(lags):\n",
    "    disp = X[lag:, :] - X[:-lag, :]\n",
    "    non_ret[f'lag_{str(lag)}'] = sorted(np.log([step for step in jnp.abs(disp).flatten() if step > thd]))\n",
    "    print(len(non_ret[f'lag_{str(lag)}']))\n",
    "    xs = np.linspace(np.min(non_ret[f'lag_{str(lag)}']), np.max(non_ret[f'lag_{str(lag)}']), 30)\n",
    "    kde = gaussian_kde(non_ret[f'lag_{str(lag)}'])\n",
    "    plt.plot(xs, kde(xs), label=f'lag {lag}')\n",
    "    for j in range(n_lag_sims):\n",
    "        key, sub_key = random.split(key)\n",
    "        X_i = dgp(sub_key, thetas_rnpe[j, 0], thetas_rnpe[j, 1], thetas_rnpe[j, 2], model=2)\n",
    "        X_i = X_i.at[nan_idx].set(jnp.nan)\n",
    "        disp_i = X_i[lag:, :] - X_i[:-lag, :]\n",
    "        disp_i = jnp.abs(disp_i)\n",
    "        ret = disp_i < thd\n",
    "        num_ret_j = np.sum(ret)\n",
    "        num_ret[ii, j] = num_ret_j\n",
    "        non_ret_i = sorted(np.log([step for step in disp_i.flatten() if step > thd]))\n",
    "        xs = np.linspace(np.min(non_ret_i), np.max(non_ret_i), 30)\n",
    "        kde = gaussian_kde(non_ret_i)\n",
    "        plt.plot(xs, kde(xs), label=f'lag {lag}', alpha=0.1)\n",
    "    plt.savefig(f'plots/ppc_lag_{str(lag)}_rnpe.pdf', bbox_inches='tight')\n",
    "    plt.clf()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.boxplot(num_ret.T, positions=range(1, 5), widths=0.6)\n",
    "plt.scatter(range(1, 5), x_obs[::12], color='red', zorder=2, label='observed')\n",
    "plt.savefig(f'plots/ppc_numret_rnpe.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "thetas_rbsl  = scipy.io.loadmat('../res/external_res/results_bsl_model2_realdata_mean_n500.mat')['theta']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lags = [1, 2, 4, 8]\n",
    "X = observed_df\n",
    "non_ret = {}\n",
    "thd = 10\n",
    "n_lag_sims = 100\n",
    "key = random.PRNGKey(0)\n",
    "num_ret = np.zeros((len(lags), n_lag_sims))\n",
    "for ii, lag in enumerate(lags):\n",
    "    disp = X[lag:, :] - X[:-lag, :]\n",
    "    non_ret[f'lag_{str(lag)}'] = sorted(np.log([step for step in jnp.abs(disp).flatten() if step > thd]))\n",
    "    print(len(non_ret[f'lag_{str(lag)}']))\n",
    "    xs = np.linspace(np.min(non_ret[f'lag_{str(lag)}']), np.max(non_ret[f'lag_{str(lag)}']), 30)\n",
    "    kde = gaussian_kde(non_ret[f'lag_{str(lag)}'])\n",
    "    plt.plot(xs, kde(xs), label=f'lag {lag}')\n",
    "    for j in range(n_lag_sims):\n",
    "        key, sub_key = random.split(key)\n",
    "        X_i = dgp(sub_key, thetas_rbsl[j, 0], thetas_rbsl[j, 1], thetas_rbsl[j, 2], model=2)\n",
    "        X_i = X_i.at[nan_idx].set(jnp.nan)\n",
    "        disp_i = X_i[lag:, :] - X_i[:-lag, :]\n",
    "        disp_i = jnp.abs(disp_i)\n",
    "        ret = disp_i < thd\n",
    "        num_ret_j = np.sum(ret)\n",
    "        num_ret[ii, j] = num_ret_j\n",
    "        non_ret_i = sorted(np.log([step for step in disp_i.flatten() if step > thd]))\n",
    "        xs = np.linspace(np.min(non_ret_i), np.max(non_ret_i), 30)\n",
    "        kde = gaussian_kde(non_ret_i)\n",
    "        plt.plot(xs, kde(xs), label=f'lag {lag}', alpha=0.1)\n",
    "    plt.savefig(f'plots/ppc_lag_{str(lag)}_rbsl.pdf', bbox_inches='tight')\n",
    "    plt.clf()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.boxplot(num_ret.T, positions=range(1, 5), widths=0.6)\n",
    "plt.scatter(range(1, 5), x_obs[::12], color='red', zorder=2, label='observed')\n",
    "plt.savefig(f'plots/ppc_numret_rbsl.pdf', bbox_inches='tight')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
