{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abdb5474",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "2fa85038",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "os.environ[\"KERAS_BACKEND\"] = \"jax\"\n",
    "import keras\n",
    "import numpy as np\n",
    "from omegaconf import OmegaConf\n",
    "from sbi_mcmc.tasks import *\n",
    "from sbi_mcmc.tasks.tasks_utils import get_task_logp_func\n",
    "from sbi_mcmc.utils.bf_utils import bf_log_prob_posterior\n",
    "from sbi_mcmc.utils.experiment_utils import *\n",
    "from sbi_mcmc.utils.psis_utils import _sir, sampling_importance_resampling\n",
    "from sbi_mcmc.utils.tf_chees_hmc_utils import run_chees_hmc\n",
    "from sbi_mcmc.utils.utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94ff1014",
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "mcmc_method = \"ChEES-HMC\"\n",
    "mcmc_method = \"NUTS\"\n",
    "task_name = \"psychometric_curve_overdispersion\"\n",
    "# task_name=\"CustomDDM(dt=0.0001)\"\n",
    "# task_name = \"GEV\"\n",
    "# task_name = \"BernoulliGLM\"\n",
    "processed = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd8ccd8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import arviz as az\n",
    "import jax\n",
    "from sbi_mcmc.tasks.tasks import ndarray_values_as_dict\n",
    "\n",
    "num_runs = 20\n",
    "num_warmup_values = [10, 50, 100, 200, 300, 500]\n",
    "init_options = [\"abi_psis\", \"abi\", \"stan-like\"]\n",
    "sort = False\n",
    "rng_key = jax.random.key(42)\n",
    "\n",
    "\n",
    "def get_save_path(observation_id, num_warmup, init_option, sort=False):\n",
    "    filename = f\"{test_dataset_name}_{observation_id}_num_warmup_{num_warmup}_init_option_{init_option}\"\n",
    "    if sort:\n",
    "        filename += \"_sorted\"\n",
    "    filename += \".pkl\"\n",
    "    result_save_path = (\n",
    "        paths[\"chees_hmc_result_dir\"]\n",
    "        / f\"warmup_tests_{mcmc_method}/{filename}\"\n",
    "    )\n",
    "    result_save_path.parent.mkdir(parents=True, exist_ok=True)\n",
    "    return result_save_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7875d070",
   "metadata": {},
   "outputs": [],
   "source": [
    "task_names = [\n",
    "    \"GEV\",\n",
    "    \"BernoulliGLM\",\n",
    "    \"psychometric_curve_overdispersion\",\n",
    "    \"CustomDDM(dt=0.0001)\",\n",
    "]\n",
    "\n",
    "n_rhats_tasks = {}\n",
    "for task_name in tqdm(task_names):\n",
    "    stuff = get_stuff(\n",
    "        task_name=task_name,\n",
    "        test_dataset_name=\"test_dataset_chunk_1\",\n",
    "        overwrite_stats=False,\n",
    "    )\n",
    "    task = stuff[\"task\"]\n",
    "    paths = stuff[\"paths\"]\n",
    "    test_dataset = stuff[\"test_dataset\"]\n",
    "    test_dataset_name = stuff[\"test_dataset_name\"]\n",
    "    config = stuff[\"config\"]\n",
    "\n",
    "    psis_stats = read_from_file(paths[\"psis_stats\"])\n",
    "    psis_failed_observation_ids = psis_stats[\"reject_inds\"]\n",
    "    test_observation_ids = psis_failed_observation_ids\n",
    "\n",
    "    valid_inds = []\n",
    "    for observation_id in test_observation_ids:\n",
    "        flag = True\n",
    "        for init_option in [\"abi_psis\", \"stan-like\", \"abi\"]:\n",
    "            for num_warmup in num_warmup_values:\n",
    "                file_path = get_save_path(\n",
    "                    observation_id, num_warmup, init_option, sort=sort\n",
    "                )\n",
    "                if not file_path.exists():\n",
    "                    # print(file_path)\n",
    "                    flag = False\n",
    "        if flag:\n",
    "            valid_inds.append(observation_id)\n",
    "    assert len(valid_inds) >= 20, f\"{len(valid_inds)}\"\n",
    "\n",
    "    from collections import OrderedDict\n",
    "\n",
    "    from sbi_mcmc.metrics import gskl, mtv, wasserstein_distance\n",
    "\n",
    "    n_rhats = {}\n",
    "    for init_option in [\"abi_psis\", \"stan-like\", \"abi\"]:\n",
    "        if init_option not in n_rhats:\n",
    "            n_rhats[init_option] = {}\n",
    "        for num_warmup in num_warmup_values:\n",
    "            max_nrhats = []\n",
    "            for observation_id in tqdm(valid_inds):\n",
    "                file_path = get_save_path(\n",
    "                    observation_id, num_warmup, init_option, sort=sort\n",
    "                )\n",
    "                result = read_from_file(file_path)\n",
    "                max_nrhats.append(np.mean(result[\"n_rhat\"].max() - 1))\n",
    "            n_rhats[init_option][num_warmup] = max_nrhats\n",
    "    n_rhats_tasks[task_name] = n_rhats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1602b404",
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks_display_names = {\n",
    "    \"GEV\": \"GEV\",\n",
    "    \"BernoulliGLM\": \"Bernoulli GLM\",\n",
    "    \"psychometric_curve_overdispersion\": \"Psychometric curve\",\n",
    "    \"CustomDDM(dt=0.0001)\": \"Decision model\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8013e80",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# Define custom figure parameters for NeurIPS\n",
    "fig_width = 5  # NeurIPS page width constraint\n",
    "fig_height = 3  # Slightly shorter to accommodate bottom legend\n",
    "\n",
    "# Keep your original color palette\n",
    "colors = {\n",
    "    \"abi_psis\": \"#DDAA33\",\n",
    "    \"stan-like\": \"#BB5566\",\n",
    "    \"abi\": \"#004488\",\n",
    "}\n",
    "labels = {\n",
    "    \"abi_psis\": \"Amortized + PSIS\",\n",
    "    \"stan-like\": \"Random Init.\",\n",
    "    \"abi\": \"Amortized\",\n",
    "}\n",
    "markers = {\n",
    "    \"abi_psis\": \"D\",\n",
    "    \"stan-like\": \"s\",\n",
    "    \"abi\": \"^\",\n",
    "}\n",
    "\n",
    "# Configure global plot settings for publication quality\n",
    "plt.rcParams.update(\n",
    "    {\n",
    "        \"font.family\": \"serif\",\n",
    "        \"font.size\": 9,\n",
    "        \"axes.labelsize\": 10,\n",
    "        \"axes.titlesize\": 10,\n",
    "        \"xtick.labelsize\": 8,\n",
    "        \"ytick.labelsize\": 8,\n",
    "        \"legend.fontsize\": 9,\n",
    "        \"axes.linewidth\": 0.8,  # Thinner spines\n",
    "        \"grid.linewidth\": 0.6,  # Thinner grid lines\n",
    "        \"lines.linewidth\": 1.0,  # Line thickness\n",
    "        \"lines.markersize\": 5,  # Default marker size\n",
    "        \"xtick.major.width\": 0.8,  # Tick width\n",
    "        \"ytick.major.width\": 0.8,\n",
    "        \"xtick.direction\": \"out\",  # Ticks facing outward\n",
    "        \"ytick.direction\": \"out\",\n",
    "    }\n",
    ")\n",
    "plt.rcParams[\"text.usetex\"] = True\n",
    "\n",
    "# Create figure with shared x axis\n",
    "fig, axs = plt.subplots(\n",
    "    2, 2, figsize=(fig_width, fig_height), sharex=True, sharey=False\n",
    ")\n",
    "\n",
    "axs = axs.flatten()\n",
    "\n",
    "interquartile = True\n",
    "if interquartile:\n",
    "    lower = 25\n",
    "    upper = 75\n",
    "else:\n",
    "    lower = 0\n",
    "    upper = 100\n",
    "\n",
    "# Loop through tasks to create individual subplots\n",
    "for task_idx, task_name in enumerate(task_names):\n",
    "    ax = axs[task_idx]\n",
    "\n",
    "    data = n_rhats_tasks[task_name]\n",
    "\n",
    "    # Loop through methods in the dictionary\n",
    "    for method_idx, method_key in enumerate(data):\n",
    "        x_values = sorted(data[method_key].keys())\n",
    "        x_offset = -5 + 5 * method_idx  # Offset for separation\n",
    "\n",
    "        rhat = np.array([data[method_key][x] for x in x_values])\n",
    "        x = np.array(x_values) + x_offset\n",
    "\n",
    "        # Calculate error bars\n",
    "        medians = np.median(rhat, axis=1)\n",
    "        lower_errs = medians - np.percentile(rhat, lower, axis=1)\n",
    "        upper_errs = np.percentile(rhat, upper, axis=1) - medians\n",
    "\n",
    "        ax.errorbar(\n",
    "            x,\n",
    "            medians,\n",
    "            yerr=[lower_errs, upper_errs],\n",
    "            fmt=markers[method_key],\n",
    "            lw=1.0,\n",
    "            color=colors[method_key],\n",
    "            capsize=2.5,\n",
    "            linestyle=\"-\",\n",
    "            label=labels[method_key],\n",
    "            markeredgecolor=\"black\",\n",
    "            markeredgewidth=0.5,\n",
    "            markersize=4.5,\n",
    "            zorder=3,  # Ensure data points are on top\n",
    "        )\n",
    "\n",
    "    # Configure subplot\n",
    "    ax.set_title(f\"{tasks_display_names.get(task_name)}\")\n",
    "    ax.set_yscale(\"log\")\n",
    "    ax.grid(True, alpha=0.25, linestyle=\"-\", which=\"major\", zorder=0)\n",
    "\n",
    "    # Add horizontal line at y=0.01 with better visibility\n",
    "    ax.axhline(\n",
    "        y=0.01, color=\"k\", linestyle=\"--\", alpha=0.7, linewidth=0.8, zorder=2\n",
    "    )\n",
    "\n",
    "    # Set specific y-ticks\n",
    "    ax.set_yticks([0.001, 0.01, 0.1])\n",
    "\n",
    "    # Remove top and right spines\n",
    "    ax.spines[\"right\"].set_visible(False)\n",
    "    ax.spines[\"top\"].set_visible(False)\n",
    "\n",
    "    # Only add x-label to bottom plots\n",
    "    if task_idx >= 2:\n",
    "        ax.set_xlabel(\"Warmup iterations\")\n",
    "        ax.set_xticks(x_values)\n",
    "\n",
    "    # Only add y-label to leftmost plots\n",
    "    if task_idx % 2 == 0:\n",
    "        if mcmc_method == \"NUTS\":\n",
    "            ylabel = r\"$\\widehat{R} - 1$\"\n",
    "        elif mcmc_method == \"ChEES-HMC\":\n",
    "            ylabel = r\"Nested $\\widehat{R} - 1$\"\n",
    "        ax.set_ylabel(ylabel)\n",
    "\n",
    "# Add a centered, publication-quality legend below the subplots\n",
    "handles, labels_list = axs[0].get_legend_handles_labels()\n",
    "if handles:\n",
    "    order = sorted(range(len(labels_list)), key=lambda x: labels_list[x])\n",
    "    legend = fig.legend(\n",
    "        [handles[idx] for idx in order],\n",
    "        [labels_list[idx] for idx in order],\n",
    "        loc=\"lower center\",  # Position below the subplots\n",
    "        bbox_to_anchor=(0.5, -0.05),  # Fine-tune vertical position\n",
    "        ncol=3,\n",
    "        # loc='center left',\n",
    "        # bbox_to_anchor=(1, 0.5),  # Positioned inside the fixed figure\n",
    "        # ncol=1,\n",
    "        frameon=False,\n",
    "        handlelength=1.2,\n",
    "        handletextpad=0.5,\n",
    "        borderaxespad=0.2,\n",
    "    )\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.subplots_adjust(bottom=0.17, wspace=0.25, hspace=0.35)\n",
    "fig_path = f\"figures/warmup_comparision_{mcmc_method}\"\n",
    "plt.savefig(f\"{fig_path}.pdf\", dpi=300, bbox_inches=\"tight\")\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d57a031a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sbi_mcmc_bf_forge",
   "language": "python",
   "name": "sbi_mcmc_bf_forge"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
