{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13a616c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54f50031",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "os.environ[\"KERAS_BACKEND\"] = \"jax\"\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\"\n",
    "os.environ[\"JAX_PLATFORMS\"] = \"cpu\"\n",
    "import keras\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from exp_utils import *\n",
    "from omegaconf import OmegaConf\n",
    "from sbi_mcmc.tasks import *\n",
    "from sbi_mcmc.tasks.tasks_utils import get_task_logp_func\n",
    "from sbi_mcmc.utils.experiment_utils import *\n",
    "from sbi_mcmc.utils.psis_utils import _sir, sampling_importance_resampling\n",
    "from sbi_mcmc.utils.utils import *\n",
    "from tqdm.autonotebook import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d8526cc",
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "task_names = [\n",
    "    \"GEV\",\n",
    "    \"BernoulliGLM\",\n",
    "    \"psychometric_curve_overdispersion\",\n",
    "    \"CustomDDM(dt=0.0001)\",\n",
    "]\n",
    "max_num_runs = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a573010",
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics_results_tasks = {k: {} for k in task_names}\n",
    "metrics_results_tasks_ids = {k: {} for k in task_names}\n",
    "for task_name in task_names:\n",
    "    print(\"=\" * 10)\n",
    "    print(task_name)\n",
    "    stuff = get_stuff(\n",
    "        task_name=task_name,\n",
    "        test_dataset_name=\"test_dataset_chunk_1\",\n",
    "        job=None,\n",
    "        overwrite_stats=False,\n",
    "    )\n",
    "    task = stuff[\"task\"]\n",
    "    paths = stuff[\"paths\"]\n",
    "    test_dataset_name = stuff[\"test_dataset_name\"]\n",
    "    config = stuff[\"config\"]\n",
    "\n",
    "    stats = {}\n",
    "    for job in [\"ood\", \"psis\", \"abi\", \"chees_hmc\"]:\n",
    "        stats_logger = PickleStatLogger(\n",
    "            paths[f\"{job}_stats\"], overwrite=False, verbose=True\n",
    "        )\n",
    "        stats[job] = stats_logger.data\n",
    "\n",
    "    stats[\"training\"] = PickleStatLogger(\n",
    "        paths[\"training_result_dir\"] / \"training_record.pkl\", overwrite=False\n",
    "    ).data\n",
    "\n",
    "    abi_accept_inds = set(\n",
    "        stats[\"ood\"][f\"Mahalanobis_{test_dataset_name}\"][\"ood_accept_inds\"]\n",
    "    )\n",
    "    abi_reject_inds = set(\n",
    "        stats[\"ood\"][f\"Mahalanobis_{test_dataset_name}\"][\"ood_failed_inds\"]\n",
    "    )\n",
    "    psis_accept_inds = set(stats[\"psis\"][\"accept_inds\"])\n",
    "    psis_reject_inds = set(stats[\"psis\"][\"reject_inds\"])\n",
    "    assert abi_reject_inds == psis_accept_inds | psis_reject_inds\n",
    "    assert abi_reject_inds.issuperset(psis_accept_inds)\n",
    "\n",
    "    chees_hmc_reject_inds = set(stats[\"chees_hmc\"][\"chees_hmc_reject_inds\"])\n",
    "    chees_hmc_accept_inds = set(stats[\"chees_hmc\"][\"chees_hmc_accept_inds\"])\n",
    "    assert psis_reject_inds == chees_hmc_reject_inds | chees_hmc_accept_inds\n",
    "    assert psis_reject_inds.issuperset(chees_hmc_reject_inds)\n",
    "\n",
    "    metrics_logger = PickleStatLogger(\n",
    "        paths[\"metrics_stats\"], overwrite=False, verbose=True\n",
    "    )\n",
    "\n",
    "    inds_dict = {\n",
    "        \"ABI(accepted)\": abi_accept_inds,\n",
    "        \"ABI(rejected)\": abi_reject_inds,\n",
    "        \"PSIS\": psis_accept_inds,\n",
    "        \"ChEES-HMC\": chees_hmc_accept_inds,\n",
    "    }\n",
    "    for id_type, inds in inds_dict.items():\n",
    "        inds_dict[id_type] = sorted(inds)\n",
    "\n",
    "    metric_name = \"W1\"\n",
    "    metric_names = [\"W1\", \"mmtv_FFTKDE\", \"GsKL\"]\n",
    "    for metric_name in tqdm(metric_names):\n",
    "        metric_values_dict = {k: [] for k in inds_dict.keys()}\n",
    "        corresponding_ids = {k: [] for k in inds_dict.keys()}\n",
    "        for id_type, inds in inds_dict.items():\n",
    "            for observation_id in inds:\n",
    "                record_key = f\"{id_type}-{observation_id}\"\n",
    "                m_value = metrics_logger.data[metric_name].get(record_key)\n",
    "                if m_value is not None:\n",
    "                    metric_values_dict[id_type].append(m_value)\n",
    "                    corresponding_ids[id_type].append(observation_id)\n",
    "\n",
    "        metrics_results_tasks[task_name][metric_name] = metric_values_dict\n",
    "        metrics_results_tasks_ids[task_name][metric_name] = corresponding_ids\n",
    "\n",
    "    for k, v in metric_values_dict.items():\n",
    "        if len(v) > 100:\n",
    "            v = v[:100]\n",
    "        print(k, len(v))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "77b8ae6f",
   "metadata": {},
   "source": [
    "Plot for paper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df7135ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.ticker as mticker\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "from exp_utils import set_default_plot_settings\n",
    "from matplotlib.ticker import MaxNLocator\n",
    "\n",
    "mpl.rcParams.update(mpl.rcParamsDefault)\n",
    "set_default_plot_settings()\n",
    "# Setup: publication-quality plot style with optimized settings for smaller figure\n",
    "sns.set(style=\"whitegrid\", context=\"paper\", font_scale=0.9)\n",
    "plt.rcParams[\"text.usetex\"] = True\n",
    "plt.rcParams[\"text.latex.preamble\"] = r\"\"\"\n",
    "\\usepackage{pifont}\n",
    "\\newcommand{\\cmark}{\\ding{51}}\n",
    "\\newcommand{\\xmark}{\\ding{55}}\n",
    "\"\"\"\n",
    "plt.rcParams.update(\n",
    "    {\n",
    "        \"font.family\": \"serif\",\n",
    "        \"figure.dpi\": 300,\n",
    "        \"axes.labelsize\": 9,\n",
    "        \"axes.titlesize\": 8.5,\n",
    "        \"xtick.labelsize\": 7,\n",
    "        \"ytick.labelsize\": 8,\n",
    "        \"lines.linewidth\": 0.8,\n",
    "    }\n",
    ")\n",
    "\n",
    "# Define parameters\n",
    "metrics = [\"W1\", \"mmtv_FFTKDE\"]\n",
    "metrics_display_names = {\"W1\": \"W1\", \"mmtv_FFTKDE\": \"MMTV\"}\n",
    "tasks_display_names = {\n",
    "    \"GEV\": \"GEV\",\n",
    "    \"BernoulliGLM\": \"Bernoulli GLM\",\n",
    "    \"psychometric_curve_overdispersion\": \"Psychometric curve\",\n",
    "    \"CustomDDM(dt=0.0001)\": \"Decision model\",\n",
    "}\n",
    "\n",
    "\n",
    "# Replace ABI labels with symbols\n",
    "def format_abi_labels(labels):\n",
    "    formatted_labels = []\n",
    "    for label in labels:\n",
    "        if \"ABI(accepted)\" in label:\n",
    "            formatted_labels.append(\n",
    "                label.replace(\"ABI(accepted)\", r\"ABI(\\cmark)\")\n",
    "            )\n",
    "        elif \"ABI(rejected)\" in label:\n",
    "            formatted_labels.append(\n",
    "                label.replace(\"ABI(rejected)\", r\" ABI(\\xmark)\")\n",
    "            )\n",
    "        elif \"HMC\" in label:\n",
    "            formatted_labels.append(label.replace(\"ChEES-HMC\", \"C-HMC\"))\n",
    "        else:\n",
    "            formatted_labels.append(label)\n",
    "    return formatted_labels\n",
    "\n",
    "\n",
    "# Create figure with minimal margins\n",
    "fig, axes = plt.subplots(\n",
    "    len(metrics),\n",
    "    len(task_names),\n",
    "    figsize=(5.5, 2.3),\n",
    "    sharey=\"row\",\n",
    ")\n",
    "\n",
    "# Minimize spacing between subplots\n",
    "plt.subplots_adjust(wspace=0.05, hspace=0.3)\n",
    "\n",
    "# Custom color palette - using more distinctive colors for better visibility\n",
    "colors = sns.color_palette(\"Set2\", 8)\n",
    "\n",
    "# Plot data\n",
    "for row, metric in enumerate(metrics):\n",
    "    for col, task_name in enumerate(task_names):\n",
    "        ax = axes[row, col]\n",
    "        data_dict = metrics_results_tasks[task_name][metric]\n",
    "\n",
    "        # Create boxplots with consistent colors\n",
    "        bplot = ax.boxplot(\n",
    "            data_dict.values(),\n",
    "            patch_artist=True,\n",
    "            widths=0.55,\n",
    "            showfliers=True,\n",
    "            medianprops={\"color\": \"black\", \"linewidth\": 1.0},\n",
    "            flierprops={\n",
    "                \"marker\": \".\",\n",
    "                \"markersize\": 2,\n",
    "                \"alpha\": 0.7,\n",
    "            },  # Smaller, transparent outliers\n",
    "        )\n",
    "\n",
    "        # Color boxes consistently\n",
    "        for patch, color in zip(bplot[\"boxes\"], colors, strict=False):\n",
    "            patch.set_facecolor(color)\n",
    "            patch.set_edgecolor(\"black\")\n",
    "            patch.set_linewidth(0.5)\n",
    "            patch.set_alpha(0.85)  # Slight transparency for better appearance\n",
    "\n",
    "        # Set titles only for top row\n",
    "        if row == 0:\n",
    "            ax.set_title(tasks_display_names.get(task_name, task_name))\n",
    "\n",
    "        # Set y-labels only for leftmost column\n",
    "        if col == 0:\n",
    "            ax.set_ylabel(metrics_display_names.get(metric, metric))\n",
    "            ax.yaxis.set_major_locator(MaxNLocator(nbins=4, prune=None))\n",
    "\n",
    "        # Set x-tick labels only for bottom row with checkmark/x symbols\n",
    "        if row == len(metrics) - 1:\n",
    "            # Replace ABI accepted/rejected with symbols\n",
    "            formatted_labels = format_abi_labels(list(data_dict.keys()))\n",
    "            ax.set_xticklabels(\n",
    "                formatted_labels, rotation=0, ha=\"center\", fontsize=5\n",
    "            )\n",
    "        else:\n",
    "            ax.set_xticklabels([])\n",
    "\n",
    "        # Tighter margins\n",
    "        ax.margins(0.03)\n",
    "\n",
    "        # Apply grid\n",
    "        ax.grid(axis=\"y\", linestyle=\"--\", alpha=0.6, linewidth=0.5)\n",
    "\n",
    "        # Remove most padding around the plot\n",
    "        # ax.tick_params(pad=1)\n",
    "\n",
    "        # Remove right and top spines\n",
    "        ax.spines[\"right\"].set_visible(False)\n",
    "        ax.spines[\"top\"].set_visible(False)\n",
    "\n",
    "        # Format y-axis to avoid scientific notation and use fewer decimals\n",
    "        if metric == \"W1\":\n",
    "            ax.yaxis.set_major_formatter(mticker.FormatStrFormatter(\"%.1f\"))\n",
    "        else:\n",
    "            ax.yaxis.set_major_formatter(mticker.FormatStrFormatter(\"%.1f\"))\n",
    "\n",
    "        if \"mmtv\" in metric:\n",
    "            ax.axhline(\n",
    "                y=0.2, color=\"k\", linestyle=\"--\", alpha=0.7, linewidth=0.8\n",
    "            )\n",
    "            ax.set_yticks([0.0, 0.2, 0.5, 0.8])\n",
    "\n",
    "# Use tight layout to maximize use of available space\n",
    "fig.tight_layout(pad=0.2, h_pad=0.5, w_pad=0.1)\n",
    "plt.savefig(\"figures/metrics_boxplots.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffd31544",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sbi_mcmc_bf_forge",
   "language": "python",
   "name": "sbi_mcmc_bf_forge"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
