{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "os.environ[\"JAX_ENABLE_X64\"] = \"False\"\n",
    "os.environ[\"KERAS_BACKEND\"] = \"jax\"\n",
    "\n",
    "sys.path.append(\"../BayesFlow/\")\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "import pytensor\n",
    "from jax import config\n",
    "\n",
    "pytensor.config.floatX = \"float32\"\n",
    "config.floatX = \"float32\"\n",
    "config.update(\"jax_enable_x64\", False)\n",
    "\n",
    "import warnings\n",
    "from pathlib import Path\n",
    "from pprint import pprint\n",
    "\n",
    "import bayesflow as bf\n",
    "import keras\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from sbi_mcmc.utils.experiment_utils import *\n",
    "from sbi_mcmc.utils.utils import *\n",
    "\n",
    "warnings.filterwarnings(\n",
    "    \"ignore\", message=\"The figure layout has changed to tight\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sbi_mcmc.tasks import (\n",
    "    BernoulliGLMTask,\n",
    "    CustomDDM,\n",
    "    GeneralizedExtremeValue,\n",
    "    PsychometricTask,\n",
    ")\n",
    "\n",
    "# task = GeneralizedExtremeValue()\n",
    "task = None  # If task is None, the task will be created in the get_stuff function according to the `config.yaml` file.\n",
    "stuff = get_stuff(task, job=\"training\")\n",
    "paths = stuff[\"paths\"]\n",
    "task = stuff[\"task\"]\n",
    "config = stuff[\"config\"]\n",
    "SMOKE_TEST = config.get(\"smoke_test\", False)\n",
    "print(task.var_names)\n",
    "\n",
    "result_logger = PickleStatLogger(\n",
    "    paths[\"training_result_dir\"] / \"training_record.pkl\", overwrite=False\n",
    ")\n",
    "print(\n",
    "    f\"save_dir: {paths['save_dir']},\\nresult_logger: {result_logger.filepath}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "task_config = get_task_configs(task)\n",
    "dataset_size_dict = task_config[\"dataset_size_dict\"]\n",
    "epochs = task_config[\"epochs\"]\n",
    "batch_size = task_config[\"batch_size\"]\n",
    "pprint(task_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Generate or read prior simulations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "\n",
    "logging.getLogger(\"sbi_mcmc.tasks.tasks\").setLevel(logging.CRITICAL)\n",
    "logging.getLogger(\"sbi_mcmc.tasks.ddm\").setLevel(logging.CRITICAL)\n",
    "\n",
    "for name, logger in logging.root.manager.loggerDict.items():\n",
    "    if name.startswith(\"sbi_mcmc.tasks.\"):\n",
    "        logger.setLevel(logging.CRITICAL)\n",
    "REGENERATE = [\"train\", \"val\", \"diagnostic\", \"lc2st_cal\"]\n",
    "# REGENERATE = []\n",
    "if SMOKE_TEST:\n",
    "    assert len(REGENERATE) == 0, \"REGENERATE should be empty for smoke test. \"\n",
    "prior_simulations_N = {}\n",
    "for dataset_name in [\"train\", \"val\", \"diagnostic\", \"lc2st_cal\"]:\n",
    "    filepath = paths[\"dataset_dir\"] / f\"{dataset_name}_dataset.pkl\"\n",
    "    num_simulations = dataset_size_dict[dataset_name]\n",
    "    if dataset_name in REGENERATE:\n",
    "        with result_logger.timer(f\"simulation_{dataset_name}\"):\n",
    "            simulations = task.sample(num_simulations)\n",
    "        save_to_file(simulations, filepath)\n",
    "    else:\n",
    "        print(f\"Loading {dataset_name} dataset from {filepath}\")\n",
    "        simulations = read_from_file(filepath)\n",
    "    print(f\"{filepath.name}: {filepath.stat().st_size / 1e6:.2f} MB\")\n",
    "    prior_simulations_N[dataset_name] = simulations\n",
    "\n",
    "if not SMOKE_TEST:\n",
    "    check_dataset_size_consistency(prior_simulations_N, dataset_size_dict)\n",
    "print(\"=== \\nTrain dataset:\")\n",
    "for k, v in prior_simulations_N[\"train\"].items():\n",
    "    print(k, v.shape)\n",
    "check_prior_simulations(prior_simulations_N, task)\n",
    "\n",
    "prior_simulations = prior_simulations_N"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bf_workflow_kwargs, bf_info = get_bf_configs(task, smoke_test=SMOKE_TEST)\n",
    "amortized_training_workflow = bf.BasicWorkflow(\n",
    "    **bf_workflow_kwargs,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # For resuming training from a checkpoints\n",
    "# amortized_training_workflow.approximator = keras.saving.load_model(\n",
    "#     bf_info[\"save_model_path\"]\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with result_logger.timer(\"training\"):\n",
    "    history = amortized_training_workflow.fit_offline(\n",
    "        prior_simulations[\"train\"],\n",
    "        epochs=epochs,\n",
    "        batch_size=batch_size,\n",
    "        validation_data=prior_simulations[\"val\"],\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = bf.diagnostics.plots.loss(history)\n",
    "fig.savefig(paths[\"figure_dir\"] / \"loss.png\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = bf_info[\"save_model_path\"]\n",
    "if SMOKE_TEST:\n",
    "    model_path = Path(\n",
    "        *[p for p in bf_info[\"save_model_path\"].parts if p != \"smoke_test\"]\n",
    "    )\n",
    "approximator = keras.saving.load_model(model_path)\n",
    "print(f\"Approximator loaded from {model_path}, \")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with result_logger.timer(\"diagnostic_post_draws\"):\n",
    "    post_draws = approximator.sample(\n",
    "        conditions=prior_simulations[\"diagnostic\"], num_samples=1000\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with result_logger.timer(\"diagnostic_transform_prior_sims_params\"):\n",
    "    prior_simulations[\"diagnostic\"][\"parameters_original\"] = (\n",
    "        task.transform_to_constrained_space(\n",
    "            prior_simulations[\"diagnostic\"][\"parameters\"]\n",
    "        )\n",
    "    )\n",
    "with result_logger.timer(\"diagnostic_transform_post_draws_params\"):\n",
    "    abi_samples = post_draws[\"parameters\"]\n",
    "    abi_samples_constrained = np.zeros_like(abi_samples)\n",
    "    for i in range(abi_samples.shape[0]):\n",
    "        abi_samples_constrained[i] = task.transform_to_constrained_space(\n",
    "            abi_samples[i]\n",
    "        )\n",
    "    post_draws[\"parameters_original\"] = abi_samples_constrained"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from bayesflow.diagnostics import plots as bf_plots\n",
    "\n",
    "display_names = {\n",
    "    \"CustomDDM(dt-0.0001)\": (\n",
    "        \"$v_1$\",\n",
    "        \"$v_2$\",\n",
    "        \"$a_1$\",\n",
    "        \"$a_2$\",\n",
    "        r\"$\\tau_c$\",\n",
    "        r\"$\\tau_n$\",\n",
    "    ),\n",
    "    \"BernoulliGLM\": [rf\"$\\theta_{i + 1}$\" for i in range(10)],\n",
    "    \"psychometric_curve_overdispersion\": [\n",
    "        r\"$\\tilde{m}$\",\n",
    "        \"$w$\",\n",
    "        r\"$\\gamma$\",\n",
    "        r\"$\\lambda$\",\n",
    "        r\"$\\eta$\",\n",
    "    ],\n",
    "    \"GEV\": [r\"$\\mu$\", r\"$\\sigma$\", r\"$\\xi$\"],\n",
    "}\n",
    "\n",
    "plot_fns = {\n",
    "    \"recovery\": bf_plots.recovery,\n",
    "    \"calibration_ecdf\": bf_plots.calibration_ecdf,\n",
    "    \"z_score_contraction\": bf_plots.z_score_contraction,\n",
    "    \"calibration_histogram\": bf.diagnostics.plots.calibration_histogram,\n",
    "}\n",
    "kwargs = {\"calibration_ecdf_kwargs\": {\"difference\": True}}\n",
    "\n",
    "for constrained_space in [True, False]:\n",
    "    param_key = \"parameters\"\n",
    "    if constrained_space:\n",
    "        param_key += \"_original\"\n",
    "        figure_dir = paths[\"figure_dir\"] / \"constrained_space\"\n",
    "    else:\n",
    "        figure_dir = paths[\"figure_dir\"] / \"unconstrained_space\"\n",
    "    figure_dir.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "    figures = {}\n",
    "    for k, plot_fn in plot_fns.items():\n",
    "        figures[k] = plot_fn(\n",
    "            estimates=post_draws[param_key],\n",
    "            targets=prior_simulations[\"diagnostic\"][param_key],\n",
    "            variable_names=display_names.get(\n",
    "                task.name, task.var_info.var_names_flatten\n",
    "            ),\n",
    "            **kwargs.get(f\"{k}_kwargs\", {}),\n",
    "        )\n",
    "        filepath = figure_dir / f\"{k}({bf_info['flow_type']}).pdf\"\n",
    "        figures[k].savefig(filepath)\n",
    "        print(f\"Saved {filepath}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "abw_review",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
