{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abdb5474",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fa85038",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
    "os.environ[\"KERAS_BACKEND\"] = \"jax\"\n",
    "import keras\n",
    "import numpy as np\n",
    "from omegaconf import OmegaConf\n",
    "from sbi_mcmc.tasks import *\n",
    "from sbi_mcmc.tasks.tasks_utils import get_task_logp_func\n",
    "from sbi_mcmc.utils.bf_utils import bf_log_prob_posterior\n",
    "from sbi_mcmc.utils.experiment_utils import *\n",
    "from sbi_mcmc.utils.psis_utils import _sir, sampling_importance_resampling\n",
    "from sbi_mcmc.utils.tf_chees_hmc_utils import run_chees_hmc\n",
    "from sbi_mcmc.utils.utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7875d070",
   "metadata": {},
   "outputs": [],
   "source": [
    "stuff = get_stuff(\n",
    "    job=\"chees_hmc\",\n",
    ")\n",
    "task = stuff[\"task\"]\n",
    "paths = stuff[\"paths\"]\n",
    "test_dataset = stuff[\"test_dataset\"]\n",
    "test_dataset_name = stuff[\"test_dataset_name\"]\n",
    "stats_logger = stuff[\"stats_logger\"]\n",
    "config = stuff[\"config\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20ca5fd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "target_num_draws = config.target_num_draws\n",
    "K = config.chees_hmc.num_superchains\n",
    "M = config.chees_hmc.num_subchains_per_superchain\n",
    "init_step_size = config.chees_hmc.init_step_size\n",
    "num_warmup = config.chees_hmc.num_warmup\n",
    "num_chains = K * M\n",
    "num_sampling = int(np.ceil(target_num_draws / num_chains))\n",
    "print(f\"Number of sampling: {num_sampling}\")\n",
    "D = task.D\n",
    "print(f\"Dimension of the task: {D}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2946fc0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "psis_stats = read_from_file(paths[\"psis_stats\"])\n",
    "psis_failed_observation_ids = psis_stats[\"reject_inds\"]\n",
    "print(len(psis_failed_observation_ids))\n",
    "chees_abi_psis_failed = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b0eacfc",
   "metadata": {},
   "outputs": [],
   "source": [
    "dynamic_logp = config.get(\"dynamic_logp\", False)\n",
    "if dynamic_logp:\n",
    "    lp_fn_dynamic = get_task_logp_func(\n",
    "        task,\n",
    "        static=False,\n",
    "        pymc_model=task.setup_pymc_model(),\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ce6e5b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "for observation_id in tqdm(psis_failed_observation_ids):\n",
    "    try:\n",
    "        with stats_logger.timer(observation_id):\n",
    "            tic = time.time()\n",
    "            buffer = 5 * K\n",
    "            result_record = {\"time\": {}}\n",
    "            if task.name == \"BernoulliGLM\":\n",
    "                observation = test_dataset[\"observables_raw\"][observation_id]\n",
    "            else:\n",
    "                observation = test_dataset[\"observables\"][observation_id]\n",
    "            if not dynamic_logp:\n",
    "                lp_fn = get_task_logp_func(task, observation=observation)\n",
    "            else:\n",
    "                observation_data = task.observation_to_pymc_data(observation)\n",
    "                lp_fn = lambda x, obs=observation_data: lp_fn_dynamic(x, obs)\n",
    "            psis_results = read_from_file(paths[\"psis_result\"](observation_id))\n",
    "            abi_samples = psis_results[\"abi_samples\"]\n",
    "            pareto_log_weights = psis_results[\"pareto_log_weights\"]\n",
    "            abi_psis_resamples_unique = _sir(\n",
    "                abi_samples,\n",
    "                log_weights=pareto_log_weights,\n",
    "                num_samples=buffer,\n",
    "                with_replacement=False,\n",
    "            )\n",
    "            initial_positions = abi_psis_resamples_unique[:buffer]\n",
    "\n",
    "            result = run_chees_hmc(\n",
    "                initial_positions,\n",
    "                K,\n",
    "                M,\n",
    "                lp_fn,\n",
    "                num_warmup,\n",
    "                num_sampling,\n",
    "                D,\n",
    "                init_step_size=init_step_size,\n",
    "            )\n",
    "        result_record.update(result)\n",
    "        result_record[\"init_option\"] = \"abi_psis\"\n",
    "        result_record[\"time\"][\"time_chees_hmc\"] = time.time() - tic\n",
    "        result_save_path = paths[\"chees_hmc_result\"](observation_id)\n",
    "        save_to_file(result_record, result_save_path)\n",
    "    except Exception as e:\n",
    "        # raise e\n",
    "        print(str(e))\n",
    "        chees_abi_psis_failed[observation_id] = str(e)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81a517a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "stats_logger.update(\"chees_abi_psis_failed\", chees_abi_psis_failed)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6671555",
   "metadata": {},
   "source": [
    "Check the nested $\\hat{R}$ values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f769aa57",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_rhats = []\n",
    "n_rhats_dict = {}\n",
    "for observation_id in psis_failed_observation_ids:\n",
    "    try:\n",
    "        result = read_from_file(paths[\"chees_hmc_result\"](observation_id))\n",
    "        n_rhat = result[\"n_rhat\"]\n",
    "    except FileNotFoundError:\n",
    "        print(\n",
    "            f\"ChEES-HMC failed to process id {observation_id} for test dataset `{test_dataset_name}`\"\n",
    "        )\n",
    "        n_rhat = np.full(task.D, np.Inf)\n",
    "    n_rhats.append(n_rhat)\n",
    "    n_rhats_dict[observation_id] = n_rhat\n",
    "n_rhats = np.stack(n_rhats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc1e82ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_rhat_threshold = 1.01\n",
    "chees_hmc_reject_inds = []\n",
    "chees_hmc_accept_inds = []\n",
    "for i, value in n_rhats_dict.items():\n",
    "    if value.max() >= n_rhat_threshold:\n",
    "        chees_hmc_reject_inds.append(i)\n",
    "    else:\n",
    "        chees_hmc_accept_inds.append(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d28fa87",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"{len(chees_hmc_reject_inds)}/{len(psis_failed_observation_ids)}\")\n",
    "print(\n",
    "    f\"{len(psis_failed_observation_ids) - len(chees_hmc_reject_inds)}/{len(psis_failed_observation_ids)}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "285a2118",
   "metadata": {},
   "outputs": [],
   "source": [
    "stats_logger.update(\"chees_hmc_reject_inds\", chees_hmc_reject_inds)\n",
    "stats_logger.update(\"chees_hmc_accept_inds\", chees_hmc_accept_inds)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "abw_review",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
