{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import jax.numpy as jnp\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle as pkl\n",
    "from scipy.stats import gaussian_kde"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_true_log_probs(directory, true_param):\n",
    "    sub_dirs = [x[0] for x in os.walk(directory)]\n",
    "    sub_dirs = sub_dirs[1:]\n",
    "    log_probs = []\n",
    "    for sub_dir in sub_dirs:\n",
    "        try:\n",
    "            with open(f'{sub_dir}/thetas.pkl', 'rb') as f:\n",
    "                thetas = jnp.array(pkl.load(f))\n",
    "                thetas = jnp.concatenate(thetas, axis=0)\n",
    "                thetas = jnp.transpose(thetas)\n",
    "                kde = gaussian_kde(thetas, bw_method='silverman')\n",
    "                log_prob = kde.logpdf(true_param)\n",
    "                log_probs.append(log_prob)\n",
    "        except Exception as e:\n",
    "            print(e)\n",
    "            continue\n",
    "    return log_probs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_true_log_probs_slcp(directory, true_param):\n",
    "    # same but with theta.pkl ... my bad\n",
    "    sub_dirs = [x[0] for x in os.walk(directory)]\n",
    "    sub_dirs = sub_dirs[1:]\n",
    "    log_probs = []\n",
    "    for sub_dir in sub_dirs:\n",
    "        try:\n",
    "            with open(f'{sub_dir}/theta.pkl', 'rb') as f:\n",
    "                thetas = jnp.array(pkl.load(f))\n",
    "                thetas = jnp.concatenate(thetas, axis=0)\n",
    "                thetas = jnp.transpose(thetas)\n",
    "                kde = gaussian_kde(thetas, bw_method='silverman')\n",
    "                log_prob = kde.logpdf(true_param)\n",
    "                log_probs.append(log_prob)\n",
    "        except Exception as e:\n",
    "            print(e)\n",
    "            continue\n",
    "    return log_probs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "true_param = jnp.array([1.0])\n",
    "\n",
    "directory = '../res/contaminated_normal/rsnl/'\n",
    "true_log_probs_rsnl = get_true_log_probs(directory, true_param)\n",
    "true_log_probs_rsnl = np.squeeze(np.array(true_log_probs_rsnl))\n",
    "true_log_probs_rsnl_normal = true_log_probs_rsnl\n",
    "\n",
    "directory = '../res/contaminated_normal/snl/'\n",
    "true_log_probs_snl = get_true_log_probs(directory, true_param)\n",
    "true_log_probs_snl = np.squeeze(np.array(true_log_probs_snl))\n",
    "true_log_probs_snl_normal = true_log_probs_snl\n",
    "# plot_and_save_coverages(empirical_coverage_rsnl, empirical_coverage_snl,\n",
    "#                         title=\"Contaminated normal\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "\n",
    "df = pd.DataFrame()\n",
    "df['logprob'] = np.concatenate([true_log_probs_rsnl, true_log_probs_snl])\n",
    "df['method'] = ['RSNL'] * len(true_log_probs_rsnl) + ['SNL'] * len(true_log_probs_snl)\n",
    "# with sns.plotting_context({'font.size': 20}):\n",
    "plt.figure(figsize=(6, 9))  # Adjust the figure size as needed\n",
    "plt.rcParams[\"axes.labelsize\"] = 15\n",
    "\n",
    "sns.set(font_scale=3.5, font='Times New Roman', style='white')\n",
    "fig, ax = plt.subplots()\n",
    "ax = sns.boxplot(x='method', y='logprob', data=df, showfliers = False)\n",
    "ax.set(xlabel='', ylabel='Log density')\n",
    "# ax.legend(loc='upper left', bbox_to_anchor=(0, 1.1))\n",
    "plt.plot()\n",
    "plt.savefig('contaminated_normal_boxplots.pdf', bbox_inches='tight')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "true_param = jnp.array([0.0])\n",
    "\n",
    "directory = '../res/misspec_ma1/rsnl/'\n",
    "true_log_probs_rsnl = get_true_log_probs(directory, true_param)\n",
    "true_log_probs_rsnl = np.squeeze(np.array(true_log_probs_rsnl))\n",
    "true_log_probs_rsnl_ma1 = true_log_probs_rsnl\n",
    "\n",
    "directory = '../res/misspec_ma1/snl/'\n",
    "true_log_probs_snl = get_true_log_probs(directory, true_param)\n",
    "true_log_probs_snl = np.squeeze(np.array(true_log_probs_snl))\n",
    "true_log_probs_snl_ma1 = true_log_probs_snl\n",
    "# plt.boxplot([true_log_probs_rsnl, true_log_probs_snl], labels=['RSNL', 'SNL'],\n",
    "#             showfliers=False)\n",
    "\n",
    "# plot_and_save_coverages(empirical_coverage_rsnl, empirical_coverage_snl,\n",
    "#                         title=\"Contaminated normal\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "\n",
    "df = pd.DataFrame()\n",
    "df['logprob'] = np.concatenate([true_log_probs_rsnl, true_log_probs_snl])\n",
    "df['method'] = ['RSNL'] * len(true_log_probs_rsnl) + ['SNL'] * len(true_log_probs_snl)\n",
    "# with sns.plotting_context({'font.size': 20}):\n",
    "plt.rcParams[\"axes.labelsize\"] = 15\n",
    "plt.figure(figsize=(6, 9))  # Adjust the figure size as needed\n",
    "\n",
    "sns.set(font_scale=3.5, font='Times New Roman', style='white')\n",
    "fig, ax = plt.subplots()\n",
    "ax = sns.boxplot(x='method', y='logprob', data=df, showfliers = False)\n",
    "ax.set(xlabel='', ylabel='Log density')\n",
    "# ax.legend(loc='upper left', bbox_to_anchor=(0, 1.1))\n",
    "plt.plot()\n",
    "plt.savefig('misspec_ma1_boxplots.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "true_param = jnp.array([0.7, -2.9, -1.0, -0.9, 0.6])\n",
    "\n",
    "directory = '../res/contaminated_slcp/rsnl/'\n",
    "true_log_probs_rsnl = get_true_log_probs_slcp(directory, true_param)\n",
    "true_log_probs_rsnl = np.squeeze(np.array(true_log_probs_rsnl))\n",
    "true_log_probs_rsnl_slcp = true_log_probs_rsnl\n",
    "\n",
    "directory = '../res/contaminated_slcp/snl/'\n",
    "true_log_probs_snl = get_true_log_probs_slcp(directory, true_param)\n",
    "true_log_probs_snl = np.squeeze(np.array(true_log_probs_snl))\n",
    "true_log_probs_snl_slcp = true_log_probs_snl\n",
    "# plt.boxplot([true_log_probs_rsnl, true_log_probs_snl], labels=['RSNL', 'SNL'],\n",
    "#             showfliers=False)\n",
    "\n",
    "# plot_and_save_coverages(empirical_coverage_rsnl, empirical_coverage_snl,\n",
    "#                         title=\"Contaminated normal\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "\n",
    "df = pd.DataFrame()\n",
    "df['logprob'] = np.concatenate([true_log_probs_rsnl, true_log_probs_snl])\n",
    "df['method'] = ['RSNL'] * len(true_log_probs_rsnl) + ['SNL'] * len(true_log_probs_snl)\n",
    "# with sns.plotting_context({'font.size': 20}):\n",
    "plt.rcParams[\"axes.labelsize\"] = 15\n",
    "plt.figure(figsize=(6, 9))  # Adjust the figure size as needed\n",
    "\n",
    "sns.set(font_scale=3.5, font='Times New Roman', style='white')\n",
    "fig, ax = plt.subplots()\n",
    "ax = sns.boxplot(x='method', y='logprob', data=df, showfliers = False)\n",
    "ax.set(xlabel='', ylabel='Log density')\n",
    "# ax.legend(loc='upper left', bbox_to_anchor=(0, 1.1))\n",
    "plt.plot()\n",
    "plt.savefig('contaminated_slcp_boxplots.pdf', bbox_inches='tight')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "true_param = jnp.array([0.1, 0.15])\n",
    "\n",
    "directory = '../res/sir/rsnl/'\n",
    "true_log_probs_rsnl = get_true_log_probs(directory, true_param)\n",
    "true_log_probs_rsnl = np.squeeze(np.array(true_log_probs_rsnl))\n",
    "true_log_probs_rsnl_sir = true_log_probs_rsnl\n",
    "\n",
    "directory = '../res/sir/snl/'\n",
    "true_log_probs_snl = get_true_log_probs(directory, true_param)\n",
    "true_log_probs_snl = np.squeeze(np.array(true_log_probs_snl))\n",
    "true_log_probs_snl_sir = true_log_probs_snl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "\n",
    "df = pd.DataFrame()\n",
    "df['logprob'] = np.concatenate([true_log_probs_rsnl, true_log_probs_snl])\n",
    "df['method'] = ['RSNL'] * len(true_log_probs_rsnl) + ['SNL'] * len(true_log_probs_snl)\n",
    "# with sns.plotting_context({'font.size': 20}):\n",
    "plt.figure(figsize=(6, 9))  # Adjust the figure size as needed\n",
    "plt.rcParams[\"axes.labelsize\"] = 15\n",
    "\n",
    "sns.set(font_scale=3.5, font='Times New Roman', style='white')\n",
    "fig, ax = plt.subplots()\n",
    "ax = sns.boxplot(x='method', y='logprob', data=df, showfliers = False)\n",
    "ax.set(xlabel='', ylabel='Log density')\n",
    "# ax.legend(loc='upper left', bbox_to_anchor=(0, 1.1))\n",
    "plt.plot()\n",
    "plt.savefig('sir_boxplots.pdf', bbox_inches='tight')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 4, figsize=(48, 8), sharey=False)  #  6 .. 1\n",
    "plt.rcParams['font.family'] = 'Times New Roman'\n",
    "plt.rcParams.update({'font.size': 35})\n",
    "plt.rcParams['xtick.labelsize'] = 35\n",
    "\n",
    "\n",
    "true_log_probs_rsnl_all = [true_log_probs_rsnl_normal, true_log_probs_rsnl_ma1, true_log_probs_rsnl_slcp, true_log_probs_rsnl_sir]\n",
    "true_log_probs_snl_all = [true_log_probs_snl_normal, true_log_probs_snl_ma1, true_log_probs_snl_slcp, true_log_probs_snl_sir]\n",
    "\n",
    "for i in range(4):\n",
    "    df = pd.DataFrame()\n",
    "    df['logprob'] = np.concatenate([true_log_probs_rsnl_all[i], true_log_probs_snl_all[i]])\n",
    "    df['method'] = ['RSNL'] * len(true_log_probs_rsnl_all[i]) + ['SNL'] * len(true_log_probs_snl_all[i])\n",
    "    # with sns.plotting_context({'font.size': 20}):\n",
    "    plt.rcParams[\"axes.labelsize\"] = 35\n",
    "\n",
    "    sns.set(font_scale=5, font='Times New Roman', style='white')\n",
    "    # fig, ax = plt.subplots()\n",
    "    ax = sns.boxplot(x='method', y='logprob', data=df, showfliers = False, ax=axs[i])\n",
    "    ax.set(xlabel='', ylabel='Log density')\n",
    "    # plt.savefig(\"boxplots_all.pdf\", bbox_inches='tight')\n",
    "    # axs[i] = ax\n",
    "plt.subplots_adjust(wspace=0.5)  # adjust the space between the subplots\n",
    "plt.savefig(\"boxplots_all.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "true_param = jnp.array([1.0])\n",
    "\n",
    "directory = '../res/contaminated_normal/rsnl_well_specified/'\n",
    "true_log_probs_rsnl = get_true_log_probs(directory, true_param)\n",
    "true_log_probs_rsnl = np.squeeze(np.array(true_log_probs_rsnl))\n",
    "true_log_probs_rsnl_sir = true_log_probs_rsnl\n",
    "\n",
    "directory = '../res/contaminated_normal/snl_well_specified/'\n",
    "true_log_probs_snl = get_true_log_probs(directory, true_param)\n",
    "true_log_probs_snl = np.squeeze(np.array(true_log_probs_snl))\n",
    "true_log_probs_snl_sir = true_log_probs_snl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "\n",
    "df = pd.DataFrame()\n",
    "df['logprob'] = np.concatenate([true_log_probs_rsnl, true_log_probs_snl])\n",
    "df['method'] = ['RSNL'] * len(true_log_probs_rsnl) + ['SNL'] * len(true_log_probs_snl)\n",
    "# with sns.plotting_context({'font.size': 20}):\n",
    "plt.figure(figsize=(6, 9))  # Adjust the figure size as needed\n",
    "plt.rcParams[\"axes.labelsize\"] = 15\n",
    "\n",
    "sns.set(font_scale=3.5, font='Times New Roman', style='white')\n",
    "fig, ax = plt.subplots()\n",
    "ax = sns.boxplot(x='method', y='logprob', data=df, showfliers = False)\n",
    "ax.set(xlabel='', ylabel='Log density')\n",
    "# ax.legend(loc='upper left', bbox_to_anchor=(0, 1.1))\n",
    "plt.plot()\n",
    "plt.savefig('well_specified_normal_boxplots.pdf', bbox_inches='tight')\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
