{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13a616c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54f50031",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "os.environ[\"KERAS_BACKEND\"] = \"jax\"\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\"\n",
    "os.environ[\"JAX_PLATFORMS\"] = \"cpu\"\n",
    "import keras\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from exp_utils import read_reference_posterior\n",
    "from omegaconf import OmegaConf\n",
    "from sbi_mcmc.tasks import *\n",
    "from sbi_mcmc.tasks.tasks_utils import get_task_logp_func\n",
    "from sbi_mcmc.utils.experiment_utils import *\n",
    "from sbi_mcmc.utils.psis_utils import _sir, sampling_importance_resampling\n",
    "from sbi_mcmc.utils.utils import *\n",
    "from tqdm.autonotebook import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d8526cc",
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "task_name = \"psychometric_curve_overdispersion\"\n",
    "# task_name = \"CustomDDM(dt=0.0001)\"\n",
    "# task_name = \"GEV\"\n",
    "# task_name = \"BernoulliGLM\"\n",
    "max_num_runs = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e80281cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "stuff = get_stuff(\n",
    "    task_name=task_name,\n",
    "    test_dataset_name=\"test_dataset_chunk_1\",\n",
    "    job=None,\n",
    "    overwrite_stats=False,\n",
    ")\n",
    "task = stuff[\"task\"]\n",
    "paths = stuff[\"paths\"]\n",
    "N_testdata = stuff[\"N_testdata\"]\n",
    "test_dataset = stuff[\"test_dataset\"]\n",
    "test_dataset_name = stuff[\"test_dataset_name\"]\n",
    "config = stuff[\"config\"]\n",
    "\n",
    "stats = {}\n",
    "for job in [\"ood\", \"psis\", \"abi\", \"chees_hmc\"]:\n",
    "    stats_logger = PickleStatLogger(\n",
    "        paths[f\"{job}_stats\"], overwrite=False, verbose=True\n",
    "    )\n",
    "    stats[job] = stats_logger.data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a573010",
   "metadata": {},
   "outputs": [],
   "source": [
    "abi_accept_inds = set(\n",
    "    stats[\"ood\"][f\"Mahalanobis_{test_dataset_name}\"][\"ood_accept_inds\"]\n",
    ")\n",
    "abi_reject_inds = set(\n",
    "    stats[\"ood\"][f\"Mahalanobis_{test_dataset_name}\"][\"ood_failed_inds\"]\n",
    ")\n",
    "psis_accept_inds = set(stats[\"psis\"][\"accept_inds\"])\n",
    "psis_reject_inds = set(stats[\"psis\"][\"reject_inds\"])\n",
    "assert abi_reject_inds == psis_accept_inds | psis_reject_inds\n",
    "assert abi_reject_inds.issuperset(psis_accept_inds)\n",
    "\n",
    "chees_hmc_reject_inds = set(stats[\"chees_hmc\"][\"chees_hmc_reject_inds\"])\n",
    "chees_hmc_accept_inds = set(stats[\"chees_hmc\"][\"chees_hmc_accept_inds\"])\n",
    "assert psis_reject_inds == chees_hmc_reject_inds | chees_hmc_accept_inds\n",
    "assert psis_reject_inds.issuperset(chees_hmc_reject_inds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4db77bb1",
   "metadata": {},
   "outputs": [],
   "source": [
    "pymc_source = f\"pymc_runs/{test_dataset_name}\"\n",
    "batch_results = read_from_file(paths[\"abi_result\"])\n",
    "\n",
    "metrics_logger = PickleStatLogger(\n",
    "    paths[\"metrics_stats\"], overwrite=False, verbose=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8dabe49c",
   "metadata": {},
   "outputs": [],
   "source": [
    "inds_dict = {\n",
    "    \"ABI(accepted)\": abi_accept_inds,\n",
    "    \"ABI(rejected)\": abi_reject_inds,\n",
    "    \"PSIS\": psis_accept_inds,\n",
    "    \"ChEES-HMC\": chees_hmc_accept_inds,\n",
    "}\n",
    "for id_type, inds in inds_dict.items():\n",
    "    inds_dict[id_type] = sorted(inds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5bbba625",
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_posterior_draws(id_type, observation_id):\n",
    "    # Retrieve posterior draws\n",
    "    if \"ABI\" in id_type:\n",
    "        posterior_draws = batch_results[\"abi_samples_batch\"][observation_id]\n",
    "\n",
    "    elif id_type == \"PSIS\":\n",
    "        psis_results = read_from_file(paths[\"psis_result\"](observation_id))\n",
    "        posterior_draws = psis_results[\"abi_psis_resamples\"]\n",
    "\n",
    "    elif id_type == \"ChEES-HMC\":\n",
    "        chees_hmc_results = read_from_file(\n",
    "            paths[\"chees_hmc_result\"](observation_id)\n",
    "        )\n",
    "        posterior_draws = chees_hmc_results[\"chees_draws_tfp\"][\n",
    "            : config.target_num_draws, 0, :\n",
    "        ]\n",
    "    else:\n",
    "        posterior_draws = None\n",
    "    return posterior_draws"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8531783",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import OrderedDict\n",
    "from functools import partial\n",
    "\n",
    "from sbi_mcmc.metrics import gskl, mtv, wasserstein_distance\n",
    "from sbi_mcmc.utils.plot_utils import corner_plot\n",
    "\n",
    "mtv_FFTKDE = partial(mtv, kde=\"FFTKDE\")\n",
    "metric_dict = OrderedDict(\n",
    "    {\"W1\": wasserstein_distance, \"mtv_FFTKDE\": mtv_FFTKDE, \"GsKL\": gskl}\n",
    ")\n",
    "\n",
    "computed_runs = {\n",
    "    \"ABI(accepted)\": max_num_runs,\n",
    "    \"ABI(rejected)\": max_num_runs,\n",
    "    \"PSIS\": max_num_runs,\n",
    "    \"ChEES-HMC\": max_num_runs,\n",
    "}\n",
    "# computed_runs = {}\n",
    "\n",
    "all_observation_ids = list(range(stuff[\"N_testdata\"]))\n",
    "\n",
    "pymc_missing_ids = []\n",
    "# for observation_id in tqdm(all_observation_ids):\n",
    "for id_type, inds in tqdm(inds_dict.items(), desc=\"Groups\"):\n",
    "    for observation_id in tqdm(inds, desc=f\"Obs ({id_type})\", leave=False):\n",
    "        if id_type not in computed_runs.keys():\n",
    "            computed_runs[id_type] = 0\n",
    "\n",
    "        if computed_runs[id_type] >= max_num_runs:\n",
    "            continue  # Skip unnecessary file reading\n",
    "\n",
    "        try:\n",
    "            samples_pymc_unconstrained = read_reference_posterior(\n",
    "                task, observation_id, pymc_source, raise_error=True\n",
    "            )\n",
    "        except ValueError as e:\n",
    "            if \"Rhat is too large\" in str(e):\n",
    "                print(f\"Skipping {observation_id} due to Rhat error.\")\n",
    "                continue\n",
    "            elif \"not found\" in str(e):\n",
    "                print(\n",
    "                    f\"Skipping {observation_id} due to PyMC result not found.\"\n",
    "                )\n",
    "                pymc_missing_ids.append(observation_id)\n",
    "                continue\n",
    "            else:\n",
    "                raise\n",
    "\n",
    "        posterior_draws = read_posterior_draws(id_type, observation_id)\n",
    "        if posterior_draws is None:\n",
    "            continue\n",
    "\n",
    "        # Compute metrics\n",
    "        for metric_name, metric_fn in metric_dict.items():\n",
    "            record_key = f\"{id_type}-{observation_id}\"\n",
    "            if (\n",
    "                metrics_logger.data.get(metric_name, {}).get(record_key)\n",
    "                is None\n",
    "            ):\n",
    "                metric_value = metric_fn(\n",
    "                    posterior_draws, samples_pymc_unconstrained\n",
    "                )\n",
    "                # print(metric_value)\n",
    "                metrics_logger.update(metric_name, {record_key: metric_value})\n",
    "                if metric_name == \"mtv\" or metric_name == \"mtv_FFTKDE\":\n",
    "                    metrics_logger.update(\n",
    "                        f\"m{metric_name}\", {record_key: np.mean(metric_value)}\n",
    "                    )\n",
    "            else:\n",
    "                pass\n",
    "\n",
    "        computed_runs[id_type] += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28edcf12",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "metric_names = [\"W1\", \"mmtv_FFTKDE\", \"GsKL\"]\n",
    "\n",
    "metrics_results = {}\n",
    "metrics_results_ids = {}\n",
    "for metric_name in tqdm(metric_names):\n",
    "    metric_values_dict = {k: [] for k in inds_dict.keys()}\n",
    "    corresponding_ids = {k: [] for k in inds_dict.keys()}\n",
    "    for id_type, inds in inds_dict.items():\n",
    "        print(id_type)\n",
    "        n_plots = 0\n",
    "        fig_dir = paths[\"metrics_result_dir\"] / id_type\n",
    "        fig_dir.mkdir(parents=True, exist_ok=True)\n",
    "        num_files = sum(1 for f in fig_dir.iterdir() if f.is_file())\n",
    "\n",
    "        for observation_id in inds:\n",
    "            record_key = f\"{id_type}-{observation_id}\"\n",
    "            m_value = metrics_logger.data[metric_name].get(record_key)\n",
    "            if m_value is not None:\n",
    "                metric_values_dict[id_type].append(m_value)\n",
    "                corresponding_ids[id_type].append(observation_id)\n",
    "\n",
    "                max_plots = 10\n",
    "                if n_plots <= max_plots and num_files <= max_plots:\n",
    "                    samples_pymc_unconstrained = read_reference_posterior(\n",
    "                        task, observation_id, pymc_source, raise_error=False\n",
    "                    )\n",
    "                    posterior_draws = read_posterior_draws(\n",
    "                        id_type, observation_id\n",
    "                    )\n",
    "                    fig = corner_plot(\n",
    "                        samples_pymc_unconstrained,\n",
    "                        posterior_draws,\n",
    "                        save_as=fig_dir / f\"{observation_id}.png\",\n",
    "                        dpi=100,\n",
    "                    )\n",
    "                    plt.close(fig)\n",
    "                    n_plots += 1\n",
    "\n",
    "    metrics_results[metric_name] = metric_values_dict\n",
    "    metrics_results_ids[metric_name] = corresponding_ids\n",
    "\n",
    "    for k, v in metric_values_dict.items():\n",
    "        print(k, len(v))\n",
    "\n",
    "    data_dict = metric_values_dict\n",
    "\n",
    "    fig = plt.figure(figsize=(5, 4))\n",
    "    plt.boxplot(\n",
    "        data_dict.values(),\n",
    "        labels=data_dict.keys(),\n",
    "        widths=0.6,\n",
    "        showfliers=True,\n",
    "    )\n",
    "    plt.ylabel(metric_name)\n",
    "    if metric_name in [\"GsKL\"]:\n",
    "        plt.yscale(\"log\")\n",
    "    fig_dir = paths[\"metrics_result_dir\"] / \"figures\"\n",
    "    fig_dir.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "    fig.savefig(fig_dir / f\"{metric_name}.png\", bbox_inches=\"tight\", dpi=100)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c87629ff",
   "metadata": {},
   "source": [
    "## Get the summary table for accepted datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad34b3d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "id_chunks_dict = {\n",
    "    \"GEV\": [1],\n",
    "    \"BernoulliGLM\": [1, 2],\n",
    "    \"psychometric_curve_overdispersion\": [1, 2],\n",
    "    \"CustomDDM(dt=0.0001)\": [1, 3, 4],\n",
    "}\n",
    "# We estimated the training time for V100 separately\n",
    "V100_training_times = {\n",
    "    \"GEV\": 141,\n",
    "    \"CustomDDM(dt=0.0001)\": 2200,\n",
    "    \"psychometric_curve_overdispersion\": 267,\n",
    "    \"BernoulliGLM\": 38,\n",
    "}\n",
    "speed_ups = []\n",
    "for task_name, chunks in id_chunks_dict.items():\n",
    "    print(\"===\")\n",
    "    print(task_name)\n",
    "\n",
    "    # Initialize accumulators\n",
    "    total_step_times = [0.0, 0.0, 0.0]\n",
    "    total_num_accepted = [0, 0, 0]\n",
    "    total_N_testdata = 0\n",
    "    total_abi_reject = 0\n",
    "    total_psis_reject = 0\n",
    "\n",
    "    for id_chunk in chunks:\n",
    "        stuff = get_stuff(\n",
    "            task_name=task_name,\n",
    "            test_dataset_name=f\"test_dataset_chunk_{id_chunk}\",\n",
    "        )\n",
    "        task = stuff[\"task\"]\n",
    "        paths = stuff[\"paths\"]\n",
    "        N_testdata = stuff[\"N_testdata\"]\n",
    "        test_dataset_name = stuff[\"test_dataset_name\"]\n",
    "\n",
    "        total_N_testdata += N_testdata\n",
    "\n",
    "        stats = {}\n",
    "        for job in [\"ood\", \"psis\", \"abi\", \"chees_hmc\"]:\n",
    "            stats_logger = PickleStatLogger(\n",
    "                paths[f\"{job}_stats\"], overwrite=False, verbose=True\n",
    "            )\n",
    "            stats[job] = stats_logger.data\n",
    "\n",
    "        stats[\"training\"] = PickleStatLogger(\n",
    "            paths[\"training_result_dir\"] / \"training_record.pkl\",\n",
    "            overwrite=False,\n",
    "        ).data\n",
    "\n",
    "        abi_accept_inds = set(\n",
    "            stats[\"ood\"][f\"Mahalanobis_{test_dataset_name}\"][\"ood_accept_inds\"]\n",
    "        )\n",
    "        abi_reject_inds = set(\n",
    "            stats[\"ood\"][f\"Mahalanobis_{test_dataset_name}\"][\"ood_failed_inds\"]\n",
    "        )\n",
    "        psis_accept_inds = set(stats[\"psis\"][\"accept_inds\"])\n",
    "        psis_reject_inds = set(stats[\"psis\"][\"reject_inds\"])\n",
    "        chees_hmc_accept_inds = set(\n",
    "            stats[\"chees_hmc\"][\"chees_hmc_accept_inds\"]\n",
    "        )\n",
    "        chees_hmc_reject_inds = set(\n",
    "            stats[\"chees_hmc\"][\"chees_hmc_reject_inds\"]\n",
    "        )\n",
    "        assert abi_reject_inds == psis_accept_inds | psis_reject_inds\n",
    "        assert abi_reject_inds.issuperset(psis_accept_inds)\n",
    "        assert (\n",
    "            psis_reject_inds == chees_hmc_reject_inds | chees_hmc_accept_inds\n",
    "        )\n",
    "        assert psis_reject_inds.issuperset(chees_hmc_reject_inds)\n",
    "\n",
    "        # Step 0\n",
    "        step_0_time_total = 0\n",
    "        for k, t in stats[\"training\"][\"wall_time\"].items():\n",
    "            if k == \"training\":\n",
    "                print(\"Use training time for V100\")\n",
    "                t = V100_training_times[task_name]\n",
    "            if \"lc2st_cal\" not in k:\n",
    "                step_0_time_total += t\n",
    "\n",
    "        if task_name == \"CustomDDM(dt=0.0001)\":\n",
    "            assert \"simulation_train\" not in stats[\"training\"]\n",
    "            step_0_time_total += (\n",
    "                2208 / 8 * 10.0\n",
    "            )  # Simulation cost for training dataset\n",
    "\n",
    "        # Step 1\n",
    "        ood_time = sum(stats[\"ood\"][\"wall_time\"].values())\n",
    "        abi_time = sum(stats[\"abi\"][\"wall_time\"].values())\n",
    "        total_step_times[0] += ood_time + abi_time\n",
    "\n",
    "        # Step 2\n",
    "        for obs_id in abi_reject_inds:\n",
    "            total_step_times[1] += stats[\"psis\"][\"wall_time\"][obs_id]\n",
    "\n",
    "        # Step 3\n",
    "        for obs_id in psis_reject_inds:\n",
    "            if obs_id not in stats[\"chees_hmc\"][\"wall_time\"].keys():\n",
    "                # print(f\"{obs_id} not in ChEES-HMC, probably failed\")\n",
    "                continue\n",
    "            total_step_times[2] += stats[\"chees_hmc\"][\"wall_time\"][obs_id]\n",
    "\n",
    "        # Accumulate counts\n",
    "        total_num_accepted[0] += len(abi_accept_inds)\n",
    "        total_num_accepted[1] += len(psis_accept_inds)\n",
    "        total_num_accepted[2] += len(chees_hmc_accept_inds)\n",
    "        total_abi_reject += len(abi_reject_inds)\n",
    "        total_psis_reject += len(psis_reject_inds)\n",
    "\n",
    "    total_step_times[0] += step_0_time_total  # Plus the training phase time\n",
    "    total_time = sum(total_step_times)\n",
    "    times = total_step_times + [total_time]\n",
    "    num_accepted = total_num_accepted + [sum(total_num_accepted)]\n",
    "\n",
    "    def format_g(i):\n",
    "        if i < 1:\n",
    "            return f\"{i:.1g}\"\n",
    "        return f\"{i:.0f}\"\n",
    "\n",
    "    estimate_total_time_chees_hmc = (\n",
    "        times[2] / total_psis_reject\n",
    "    ) * num_accepted[3]\n",
    "    speed_ups.append(estimate_total_time_chees_hmc // total_time)\n",
    "\n",
    "    data = {\n",
    "        \"name\": [\n",
    "            \"& Step 1: Amortized inference\",\n",
    "            \"& Step 2: Amortized + PSIS\",\n",
    "            \"& Step 3: ChEES-HMC w/ inits\",\n",
    "            \"& Workflow total\",\n",
    "            \"& Direct ChEES-HMC\",\n",
    "        ],\n",
    "        \"Accepted\": [\n",
    "            f\"${num_accepted[0]}/{total_N_testdata}$\",\n",
    "            f\"${num_accepted[1]}/{total_abi_reject}$\",\n",
    "            f\"${num_accepted[2]}/{total_psis_reject}$\",\n",
    "            f\"${num_accepted[3]}/{total_N_testdata}$\",\n",
    "            r\"\\NA\",\n",
    "        ],\n",
    "        \"Time (minutes)\": [format_g(t / 60) for t in times]\n",
    "        + [format_g(estimate_total_time_chees_hmc / 60)],\n",
    "        \"TPA\": [\n",
    "            format_g(t / a) for t, a in zip(times, num_accepted, strict=False)\n",
    "        ]\n",
    "        + [r\"\\NA\"],\n",
    "    }\n",
    "\n",
    "    df = pd.DataFrame(data)\n",
    "    latex_table = df.to_latex(index=False)\n",
    "    print(latex_table)\n",
    "print(\"Speed up: \", speed_ups)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sbi_mcmc_bf_forge",
   "language": "python",
   "name": "sbi_mcmc_bf_forge"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
