{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "195563ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "913f870b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
    "\n",
    "os.environ[\"KERAS_BACKEND\"] = \"jax\"\n",
    "import keras\n",
    "import numpy as np\n",
    "from sbi_mcmc.tasks import *\n",
    "from sbi_mcmc.utils.bf_utils import bf_log_prob_posterior\n",
    "from sbi_mcmc.utils.experiment_utils import *\n",
    "from sbi_mcmc.utils.utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f34a5c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "stuff = get_stuff(\n",
    "    job=\"abi\",\n",
    ")\n",
    "task = stuff[\"task\"]\n",
    "paths = stuff[\"paths\"]\n",
    "test_dataset = stuff[\"test_dataset\"]\n",
    "test_dataset_name = stuff[\"test_dataset_name\"]\n",
    "stats_logger = stuff[\"stats_logger\"]\n",
    "config = stuff[\"config\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6dfbe24f",
   "metadata": {},
   "outputs": [],
   "source": [
    "approximator = keras.saving.load_model(paths[\"save_model_path\"])\n",
    "print(paths[\"save_model_path\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1ab2952",
   "metadata": {},
   "outputs": [],
   "source": [
    "with stats_logger.timer(\"sampling\"):\n",
    "    abi_samples = approximator.sample(\n",
    "        conditions=test_dataset, num_samples=config.target_num_draws\n",
    "    )[\"parameters\"]\n",
    "    assert not np.isnan(abi_samples).any(), \"NaN in abi_samples. Exiting.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7752e7a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "with stats_logger.timer(\"log_prob\"):\n",
    "    log_pdfs_abi = bf_log_prob_posterior(\n",
    "        approximator, abi_samples, test_dataset[\"observables\"]\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ca0a46a",
   "metadata": {},
   "outputs": [],
   "source": [
    "result_record = {\n",
    "    \"abi_samples_batch\": abi_samples,\n",
    "    \"log_pdfs_abi_batch\": log_pdfs_abi,\n",
    "}\n",
    "save_to_file(result_record, paths[\"abi_result\"])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "abw_review",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
