{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13a616c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54f50031",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
    "# os.environ[\"JAX_PLATFORMS\"] = \"cpu\"\n",
    "os.environ[\"KERAS_BACKEND\"] = \"jax\"\n",
    "import keras\n",
    "import numpy as np\n",
    "from omegaconf import OmegaConf\n",
    "from sbi_mcmc.tasks import *\n",
    "from sbi_mcmc.tasks.tasks_utils import get_task_logp_func\n",
    "from sbi_mcmc.utils.experiment_utils import *\n",
    "from sbi_mcmc.utils.psis_utils import _sir, sampling_importance_resampling\n",
    "from sbi_mcmc.utils.utils import *\n",
    "from tqdm.autonotebook import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e80281cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "stuff = get_stuff(\n",
    "    job=\"psis\",\n",
    ")\n",
    "task = stuff[\"task\"]\n",
    "paths = stuff[\"paths\"]\n",
    "test_dataset = stuff[\"test_dataset\"]\n",
    "test_dataset_name = stuff[\"test_dataset_name\"]\n",
    "stats_logger = stuff[\"stats_logger\"]\n",
    "config = stuff[\"config\"]\n",
    "step_1_failed_inds = set()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "632c6094",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_results = read_from_file(paths[\"abi_result\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abcb873e",
   "metadata": {},
   "outputs": [],
   "source": [
    "ood_stats = read_from_file(paths[\"ood_stats\"])\n",
    "ood_failed_inds = sorted(\n",
    "    ood_stats[f\"Mahalanobis_{test_dataset_name}\"][\"ood_failed_inds\"]\n",
    ")\n",
    "step_1_failed_inds |= set(ood_failed_inds)\n",
    "print(len(step_1_failed_inds))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "229b3d2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "dynamic_logp = config.get(\"dynamic_logp\", False)\n",
    "if dynamic_logp:\n",
    "    print(\"Dynamic logp\")\n",
    "    lp_fn_dynamic = get_task_logp_func(\n",
    "        task,\n",
    "        static=False,\n",
    "        pymc_model=task.setup_pymc_model(),\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46f80055",
   "metadata": {},
   "outputs": [],
   "source": [
    "check_strict = False  # If False, replace NaN log densities with -inf\n",
    "\n",
    "step_1_failed_inds = sorted(step_1_failed_inds)\n",
    "for observation_id in tqdm(step_1_failed_inds):\n",
    "    tic = time.time()\n",
    "    with stats_logger.timer(observation_id):\n",
    "        if task.name == \"BernoulliGLM\":\n",
    "            observation = test_dataset[\"observables_raw\"][observation_id]\n",
    "        else:\n",
    "            observation = test_dataset[\"observables\"][observation_id]\n",
    "\n",
    "        if not dynamic_logp:\n",
    "            lp_fn = get_task_logp_func(task, observation=observation)\n",
    "        result_record = {\"time\": {}}\n",
    "        abi_samples = batch_results[\"abi_samples_batch\"][observation_id]\n",
    "        log_pdfs_abi = batch_results[\"log_pdfs_abi_batch\"][observation_id]\n",
    "        assert abi_samples.shape[0] >= config[\"target_num_draws\"]\n",
    "        assert log_pdfs_abi.shape[0] >= config[\"target_num_draws\"]\n",
    "        assert not np.isnan(abi_samples).any(), \"NaN in abi_samples. Exiting.\"\n",
    "\n",
    "        with stats_logger.timer(f\"{observation_id}_logp-task\"):\n",
    "            if dynamic_logp:\n",
    "                log_pdfs_task = lp_fn_dynamic(\n",
    "                    abi_samples, task.observation_to_pymc_data(observation)\n",
    "                )\n",
    "            else:\n",
    "                log_pdfs_task = lp_fn(abi_samples)\n",
    "        if np.isnan(log_pdfs_task).any():\n",
    "            if check_strict:\n",
    "                raise ValueError(\"NaN in log_pdfs_task. Exiting.\")\n",
    "            else:\n",
    "                print(\n",
    "                    f\"observation_id: {observation_id}. {sum(np.isnan(log_pdfs_task))} NaNs in log_pdfs_task. Replacing with -inf.\"\n",
    "                )\n",
    "                log_pdfs_task = np.nan_to_num(log_pdfs_task, nan=-np.inf)\n",
    "        assert not np.isinf(log_pdfs_task).all(), (\n",
    "            \"log_pdfs_task are all inf. Exiting.\"\n",
    "        )\n",
    "        assert log_pdfs_abi.ndim == log_pdfs_task.ndim == 1\n",
    "        with stats_logger.timer(f\"{observation_id}_sir\"):\n",
    "            abi_psis_resamples, k_stat, pareto_log_weights = (\n",
    "                sampling_importance_resampling(\n",
    "                    log_pdfs_task,\n",
    "                    log_pdfs_abi,\n",
    "                    abi_samples,\n",
    "                    return_weights=True,\n",
    "                    num_samples=config[\"target_num_draws\"],\n",
    "                )\n",
    "            )\n",
    "        result_record[\"log_pdfs_task\"] = log_pdfs_task\n",
    "        result_record[\"log_pdfs_abi\"] = log_pdfs_abi\n",
    "        result_record[\"pareto_log_weights\"] = pareto_log_weights\n",
    "        result_record[\"abi_samples\"] = abi_samples\n",
    "        result_record[\"abi_psis_resamples\"] = abi_psis_resamples\n",
    "        result_record[\"pareto_k\"] = k_stat\n",
    "        result_record[\"time\"][\"time_psis(exclude_abi_logp)\"] = (\n",
    "            time.time() - tic\n",
    "        )\n",
    "\n",
    "    stats_logger.update(\"pareto_k\", {observation_id: k_stat})\n",
    "    result_save_path = paths[\"psis_result\"](observation_id)\n",
    "    save_to_file(result_record, result_save_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbc24bca",
   "metadata": {},
   "outputs": [],
   "source": [
    "psis_counts = {\"0.5<=k<0.7\": [0, []], \"k<0.5\": [0, []], \"k>0.7\": [0, []]}\n",
    "for k, v in stats_logger.data[\"pareto_k\"].items():\n",
    "    if v >= 0.7:\n",
    "        name = \"k>0.7\"\n",
    "    elif v < 0.7 and v >= 0.5:\n",
    "        name = \"0.5<=k<0.7\"\n",
    "    elif v < 0.5:\n",
    "        name = \"k<0.5\"\n",
    "    else:\n",
    "        raise ValueError\n",
    "    psis_counts[name][0] += 1\n",
    "    psis_counts[name][1].append(k)\n",
    "psis_failed_observation_ids = psis_counts[\"k>0.7\"][1]\n",
    "psis_accept_inds = psis_counts[\"k<0.5\"][1] + psis_counts[\"0.5<=k<0.7\"][1]\n",
    "print(len(psis_failed_observation_ids))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cd79e99",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Visualize the results\n",
    "# from sbi_mcmc.utils.plot_utils import corner_plot\n",
    "\n",
    "# # observation_id = sorted(step_1_accept_inds)[2]\n",
    "# # observation_id = sorted(step_1_failed_inds)[1]\n",
    "# # observation_id = ood_failed_inds[1]\n",
    "# observation_id = psis_failed_observation_ids[1]\n",
    "# print(f\"observation_id: {observation_id}\")\n",
    "# psis_results = read_from_file(paths[\"psis_result\"](observation_id))\n",
    "# print(psis_results[\"pareto_k\"])\n",
    "# abi_samples = psis_results[\"abi_samples\"]\n",
    "# abi_psis_resamples = psis_results[\"abi_psis_resamples\"]\n",
    "# transform = None\n",
    "# transform = task.transform_to_constrained_space\n",
    "# corner_plot(\n",
    "#     abi_samples,\n",
    "#     abi_psis_resamples,\n",
    "#     labels=[\"ABI\", \"ABI(PSIS)\"],\n",
    "#     transform=transform,\n",
    "#     var_names=task.var_info.var_names_flatten,\n",
    "# );"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4cddccb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "stats_logger.update(\n",
    "    None,\n",
    "    {\n",
    "        \"psis_counts\": psis_counts,\n",
    "        \"reject_inds\": psis_failed_observation_ids,\n",
    "        \"accept_inds\": psis_accept_inds,\n",
    "    },\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa43c4b8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "abw_review",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
