{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fe13184",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4dc6f067",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "os.environ[\"KERAS_BACKEND\"] = \"jax\"\n",
    "import keras\n",
    "import numpy as np\n",
    "from sbi_mcmc.tasks import *\n",
    "from sbi_mcmc.utils.bf_utils import compute_summary_statistics\n",
    "from sbi_mcmc.utils.experiment_utils import *\n",
    "from sbi_mcmc.utils.utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e044d92",
   "metadata": {},
   "outputs": [],
   "source": [
    "stuff = get_stuff(\n",
    "    job=\"ood\",\n",
    ")\n",
    "task = stuff[\"task\"]\n",
    "paths = stuff[\"paths\"]\n",
    "test_dataset = stuff[\"test_dataset\"]\n",
    "test_dataset_name = stuff[\"test_dataset_name\"]\n",
    "stats_logger = stuff[\"stats_logger\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e085b74",
   "metadata": {},
   "source": [
    "Out of distribution test for the test datasets, against the training dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "221b1fea",
   "metadata": {},
   "outputs": [],
   "source": [
    "approximator = keras.saving.load_model(paths[\"save_model_path\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81d5c3bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "s_train_file = paths[\"inference_diagnostic_dir\"] / f\"train_summary_outputs.pkl\"\n",
    "s_test_file = (\n",
    "    paths[\"inference_diagnostic_dir\"]\n",
    "    / f\"{test_dataset_name}_summary_outputs.pkl\"\n",
    ")\n",
    "\n",
    "train_dataset = read_from_file(paths[\"dataset_dir\"] / f\"train_dataset.pkl\")\n",
    "# use at most 10000 simulations for train summary outputs\n",
    "train_dataset[\"observables\"] = train_dataset[\"observables\"][:10000]\n",
    "with stats_logger.timer(\"compute_train_summary_outputs\"):\n",
    "    s_train = compute_summary_statistics(approximator, train_dataset)\n",
    "save_to_file(s_train, s_train_file)\n",
    "with stats_logger.timer(f\"compute_test_summary_outputs_{test_dataset_name}\"):\n",
    "    s_test = compute_summary_statistics(approximator, test_dataset)\n",
    "save_to_file(s_test, s_test_file)\n",
    "print(s_train.shape, s_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0398b876",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "from sklearn.covariance import EmpiricalCovariance\n",
    "\n",
    "summary_test = \"Mahalanobis\"\n",
    "with stats_logger.timer(f\"compute_{summary_test}_{test_dataset_name}\"):\n",
    "    if summary_test == \"Mahalanobis\":\n",
    "        cov = EmpiricalCovariance().fit(s_train)\n",
    "        mahalanobis_from_sims = cov.mahalanobis(s_train)\n",
    "        test_mahalanobis_from_sims = cov.mahalanobis(s_test)\n",
    "\n",
    "        train_t_statistics = mahalanobis_from_sims\n",
    "        test_t_statistics = test_mahalanobis_from_sims\n",
    "    else:\n",
    "        raise NotImplementedError"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87f5f008",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "with stats_logger.timer(f\"thresholding_{test_dataset_name}\"):\n",
    "    quantile = 0.95\n",
    "    cutoff = np.quantile(train_t_statistics, quantile)\n",
    "\n",
    "    print(f\"Cutoff: {cutoff}\")\n",
    "\n",
    "    # How much percent of the test samples are above the cutoff?\n",
    "    print(\n",
    "        f\"Test samples above cutoff: {np.mean(test_t_statistics > cutoff) * 100:.2f}%\"\n",
    "    )\n",
    "    ood_failed_inds = list(np.where(test_t_statistics > cutoff)[0])\n",
    "    ood_accept_inds = list(np.where(test_t_statistics <= cutoff)[0])\n",
    "    print(f\"{len(ood_failed_inds)}/{len(test_t_statistics)}\")\n",
    "\n",
    "stats_logger.update(\n",
    "    f\"{summary_test}_{test_dataset_name}\",\n",
    "    {\n",
    "        \"cutoff\": cutoff,\n",
    "        \"quantile\": quantile,\n",
    "        \"ood_failed_inds\": ood_failed_inds,\n",
    "        \"ood_accept_inds\": ood_accept_inds,\n",
    "        \"train_t_statistics\": train_t_statistics,\n",
    "        \"test_t_statistics\": test_t_statistics,\n",
    "    },\n",
    ")\n",
    "plt.hist(train_t_statistics, bins=50, alpha=0.5, label=\"Train\", density=True)\n",
    "plt.hist(test_t_statistics, bins=50, alpha=0.5, label=\"Test\", density=True)\n",
    "\n",
    "plt.axvline(\n",
    "    cutoff,\n",
    "    color=\"red\",\n",
    "    label=f\"{quantile * 100:.0f}% quantile of {summary_test} distance in train set ($H_0$)\",\n",
    ")\n",
    "\n",
    "plt.legend()\n",
    "plt.xlabel(\"Test statistic\")\n",
    "\n",
    "# Save the figure\n",
    "figure_path = (\n",
    "    paths[\"inference_diagnostic_dir\"]\n",
    "    / \"ood\"\n",
    "    / f\"{summary_test}_distribution_{test_dataset_name}.png\"\n",
    ")\n",
    "figure_path.parent.mkdir(parents=True, exist_ok=True)\n",
    "plt.savefig(figure_path, dpi=300)\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "abw_review",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
