{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import jax.numpy as jnp\n",
    "from scipy.stats import gaussian_kde\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle as pkl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set the default font to Times New Roman\n",
    "plt.rcParams['font.family'] = 'serif'\n",
    "plt.rcParams['font.serif'] = ['Times New Roman']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_coverage(directory, true_param):\n",
    "    # true_param = jnp.array([0.0])\n",
    "    # directory = '../res/misspec_ma1/rsnl/'\n",
    "    sub_dirs = [x[0] for x in os.walk(directory)]\n",
    "    sub_dirs = sub_dirs[1:]\n",
    "    total_coverage = np.zeros(21)\n",
    "    count_successful = 0\n",
    "    for sub_dir in sub_dirs:\n",
    "        try:\n",
    "            with open(f'{sub_dir}/thetas.pkl', 'rb') as f:\n",
    "                print('sub_dir: ', sub_dir)\n",
    "                thetas = jnp.array(pkl.load(f))\n",
    "                thetas = jnp.concatenate(thetas, axis=0)\n",
    "                thetas = jnp.transpose(thetas)\n",
    "                thetas_kernel = thetas[:, ::10]  # THIN BY 10\n",
    "                # plt.hist(thetas_kernel, bins=30)\n",
    "                # plt.show()\n",
    "                thetas_kernel_eval = thetas[:, 1::10]  # THIN BY 10\n",
    "                kde = gaussian_kde(thetas_kernel, bw_method='silverman')\n",
    "                true_theta_pdf = kde.logpdf(true_param)\n",
    "                print('true_theta_pdf: ', true_theta_pdf)\n",
    "                theta_draws_pdf = kde.logpdf(thetas_kernel_eval)\n",
    "                theta_draws_pdf = jnp.sort(theta_draws_pdf)[::-1]\n",
    "                # test_xs = jnp.linspace(0.8, 1.2, 1000)\n",
    "                # test_ys = kde.logpdf(test_xs)\n",
    "                # plt.plot(test_xs, test_ys)\n",
    "                # plt.axvline(x=true_param, color='red')\n",
    "                # plt.show()\n",
    "                N = len(theta_draws_pdf)\n",
    "                print('N: ', N)\n",
    "                x = jnp.linspace(0, 1, 21)\n",
    "                indices = round(x * N) - 1\n",
    "                print('indices: ', indices)\n",
    "                print('theta_draws_pdf[indices]: ', theta_draws_pdf[indices])\n",
    "                coverage = theta_draws_pdf[indices] < true_theta_pdf\n",
    "                coverage = np.array(coverage, dtype=int)\n",
    "                coverage[0] = 0\n",
    "                coverage[-1] = 1\n",
    "                total_coverage += coverage\n",
    "                count_successful += 1\n",
    "                print('coverage: ', coverage)\n",
    "        except Exception as e:\n",
    "            print(e)\n",
    "            continue\n",
    "    print('count_successful: ', count_successful)\n",
    "    total_coverage = total_coverage / count_successful\n",
    "    return total_coverage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_and_save_coverages(empirical_coverage_rsnl=None, empirical_coverage_snl=None,\n",
    "                            title=\"\",\n",
    "                            folder_name=\"\"):\n",
    "    \"\"\"Plot coverage.\"\"\"\n",
    "    plt.clf()\n",
    "    plt.plot([0, 1], [0, 1], color='gray', linestyle='dashed')\n",
    "    if empirical_coverage_rsnl is not None:\n",
    "        plt.plot(np.linspace(0, 1, len(empirical_coverage_rsnl)), empirical_coverage_rsnl,\n",
    "                label='RSNL')\n",
    "    if empirical_coverage_snl is not None:\n",
    "        plt.plot(np.linspace(0, 1, len(empirical_coverage_snl)), empirical_coverage_snl,\n",
    "                label='SNL', linestyle='dashed')\n",
    "    plt.xlim([0, 1])\n",
    "    plt.ylim([0, 1])\n",
    "    plt.xticks([0, 1], fontsize=35)\n",
    "    plt.yticks([0, 1], fontsize=35)\n",
    "    plt.xlabel(\"Credibility level\", fontsize=35)\n",
    "    plt.ylabel(\"Empirical coverage\", fontsize=35)\n",
    "    plt.legend(fontsize=35, borderpad=0.1, labelspacing=0.1, handletextpad=0.1)\n",
    "    plt.title(title)\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f\"{folder_name}empirical_coverage.pdf\", bbox_inches='tight')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "true_param = jnp.array([1.0])\n",
    "\n",
    "directory = '../res/contaminated_normal/rsnl/'\n",
    "empirical_coverage_rsnl = calculate_coverage(directory, true_param)\n",
    "\n",
    "directory = '../res/contaminated_normal/snl/'\n",
    "empirical_coverage_snl = calculate_coverage(directory, true_param)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_and_save_coverages(empirical_coverage_rsnl, empirical_coverage_snl,\n",
    "                        # title=\"Contaminated normal\"\n",
    "                        folder_name='contaminated_normal_'\n",
    "                        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "true_param = jnp.array([0.0])\n",
    "\n",
    "directory = '../res/misspec_ma1/rsnl/'\n",
    "empirical_coverage_rsnl = calculate_coverage(directory, true_param)\n",
    "\n",
    "directory = '../res/misspec_ma1/snl/'\n",
    "empirical_coverage_snl = calculate_coverage(directory, true_param)\n",
    "\n",
    "plot_and_save_coverages(empirical_coverage_rsnl, empirical_coverage_snl,\n",
    "                        folder_name='misspec_ma1_'\n",
    "                        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# true_param = jnp.array([0.7, -2.9, -1.0, -0.9, 0.6])\n",
    "\n",
    "# directory = '../res/contaminated_slcp/rsnl/'\n",
    "# empirical_coverage_rsnl = calculate_coverage(directory, true_param)\n",
    "\n",
    "# directory = '../res/contaminated_slcp/snl/'\n",
    "# empirical_coverage_snl = calculate_coverage(directory, true_param)\n",
    "\n",
    "# plot_and_save_coverages(empirical_coverage_rsnl, empirical_coverage_snl,\n",
    "#                         # title=\"Contaminated SLCP\",\n",
    "#                         folder_name='contaminated_slcp_'\n",
    "#                         )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "directory = '../res/sir/rsnl/'\n",
    "true_param = jnp.array([.1, .15])\n",
    "empirical_coverage_rsnl = calculate_coverage(directory, true_param)\n",
    "\n",
    "directory = '../res/sir/snl/'\n",
    "empirical_coverage_snl = calculate_coverage(directory, true_param)\n",
    "plot_and_save_coverages(empirical_coverage_rsnl, empirical_coverage_snl,\n",
    "                        folder_name='sir_'\n",
    "                        # title=\"SIR\"\n",
    "                        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "directory = \"../res/contaminated_normal/rsnl_well_specified/\"\n",
    "true_param = jnp.array([1.0])\n",
    "empirical_coverage_rsnl_well_specified = calculate_coverage(directory,\n",
    "                                                            true_param)\n",
    "\n",
    "directory = \"../res/contaminated_normal/snl_well_specified/\"\n",
    "empirical_coverage_snl_well_specified = calculate_coverage(directory,\n",
    "                                                           true_param)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_and_save_coverages(empirical_coverage_rsnl_well_specified,\n",
    "                        empirical_coverage_snl_well_specified,\n",
    "                        folder_name='well_specified_normal_'\n",
    "                        # title=\"SIR\"\n",
    "                        )\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
