{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abdb5474",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "id": "2fa85038",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "os.environ[\"KERAS_BACKEND\"] = \"jax\"\n",
    "import keras\n",
    "import numpy as np\n",
    "from omegaconf import OmegaConf\n",
    "from sbi_mcmc.tasks import *\n",
    "from sbi_mcmc.tasks.tasks_utils import get_task_logp_func\n",
    "from sbi_mcmc.utils.bf_utils import bf_log_prob_posterior\n",
    "from sbi_mcmc.utils.experiment_utils import *\n",
    "from sbi_mcmc.utils.psis_utils import _sir, sampling_importance_resampling\n",
    "from sbi_mcmc.utils.tf_chees_hmc_utils import run_chees_hmc\n",
    "from sbi_mcmc.utils.utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94ff1014",
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "mcmc_method = \"ChEES-HMC\"\n",
    "# mcmc_method = \"NUTS\"\n",
    "task_name = \"psychometric_curve_overdispersion\"\n",
    "task_name = \"CustomDDM(dt=0.0001)\"\n",
    "task_name = \"GEV\"\n",
    "# task_name = \"BernoulliGLM\"\n",
    "processed = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7875d070",
   "metadata": {},
   "outputs": [],
   "source": [
    "stuff = get_stuff(\n",
    "    task_name=task_name,\n",
    "    test_dataset_name=\"test_dataset_chunk_1\",\n",
    "    overwrite_stats=False,\n",
    ")\n",
    "task = stuff[\"task\"]\n",
    "paths = stuff[\"paths\"]\n",
    "test_dataset = stuff[\"test_dataset\"]\n",
    "test_dataset_name = stuff[\"test_dataset_name\"]\n",
    "config = stuff[\"config\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20ca5fd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "target_num_draws = config.target_num_draws\n",
    "init_step_size = config.chees_hmc.init_step_size\n",
    "print(f\"MCMC method: {mcmc_method}, Task: {task_name}\")\n",
    "if mcmc_method == \"ChEES-HMC\":\n",
    "    K = config.chees_hmc.num_superchains\n",
    "    M = config.chees_hmc.num_subchains_per_superchain\n",
    "    num_chains = K * M\n",
    "elif mcmc_method == \"NUTS\":\n",
    "    K = 4\n",
    "    num_chains = K\n",
    "else:\n",
    "    raise NotImplementedError\n",
    "num_sampling = int(np.ceil(target_num_draws / num_chains))\n",
    "print(f\"Number of sampling: {num_sampling}\")\n",
    "D = task.D\n",
    "print(f\"Dimension of the task: {D}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2946fc0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "psis_stats = read_from_file(paths[\"psis_stats\"])\n",
    "psis_failed_observation_ids = psis_stats[\"reject_inds\"]\n",
    "print(len(psis_failed_observation_ids))\n",
    "test_observation_ids = psis_failed_observation_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b0eacfc",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not processed:\n",
    "    dynamic_logp = config.get(\"dynamic_logp\", False)\n",
    "    if dynamic_logp:\n",
    "        lp_fn_dynamic = get_task_logp_func(\n",
    "            task,\n",
    "            static=False,\n",
    "            pymc_model=task.setup_pymc_model(),\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd8ccd8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import arviz as az\n",
    "import jax\n",
    "from sbi_mcmc.tasks.tasks import ndarray_values_as_dict\n",
    "\n",
    "num_runs = 20\n",
    "num_warmup_values = [10, 50, 100, 200, 300, 500]\n",
    "init_options = [\"abi_psis\", \"abi\", \"stan-like\"]\n",
    "sort = False\n",
    "rng_key = jax.random.key(42)\n",
    "\n",
    "\n",
    "def get_save_path(observation_id, num_warmup, init_option, sort=False):\n",
    "    filename = f\"{test_dataset_name}_{observation_id}_num_warmup_{num_warmup}_init_option_{init_option}\"\n",
    "    if sort:\n",
    "        filename += \"_sorted\"\n",
    "    filename += \".pkl\"\n",
    "    result_save_path = (\n",
    "        paths[\"chees_hmc_result_dir\"]\n",
    "        / f\"warmup_tests_{mcmc_method}/{filename}\"\n",
    "    )\n",
    "    result_save_path.parent.mkdir(parents=True, exist_ok=True)\n",
    "    return result_save_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2c6b380",
   "metadata": {},
   "outputs": [],
   "source": [
    "check_init_positions_only = (\n",
    "    True  # Debug check: init positions versus reference posterior samples\n",
    ")\n",
    "\n",
    "init_positions_record = {}\n",
    "if not processed:\n",
    "    proceesed_inds = []\n",
    "    exceptions = []\n",
    "    buffer = 5 * K\n",
    "    for observation_id in test_observation_ids:\n",
    "        psis_results = read_from_file(paths[\"psis_result\"](observation_id))\n",
    "        pareto_log_weights = psis_results[\"pareto_log_weights\"]\n",
    "        abi_samples = psis_results[\"abi_samples\"]\n",
    "        if task.name == \"BernoulliGLM\":\n",
    "            observation = test_dataset[\"observables_raw\"][observation_id]\n",
    "        else:\n",
    "            observation = test_dataset[\"observables\"][observation_id]\n",
    "        if not dynamic_logp:\n",
    "            lp_fn = get_task_logp_func(task, observation=observation)\n",
    "        else:\n",
    "            observation_data = task.observation_to_pymc_data(observation)\n",
    "            lp_fn = lambda x, obs=observation_data: lp_fn_dynamic(x, obs)\n",
    "        try:\n",
    "            abi_psis_resamples_unique = _sir(\n",
    "                abi_samples,\n",
    "                log_weights=pareto_log_weights,\n",
    "                num_samples=buffer,\n",
    "                with_replacement=False,\n",
    "            )\n",
    "        except Exception as e:\n",
    "            exceptions.append(\n",
    "                f\"id: {observation_id}: can't use abi_psis\" + str(e)\n",
    "            )\n",
    "            continue\n",
    "\n",
    "        for num_warmup in num_warmup_values:\n",
    "            for init_option in init_options:\n",
    "                print(observation_id, num_warmup, init_option)\n",
    "                if init_option == \"abi\":\n",
    "                    initial_positions = abi_samples[:buffer]\n",
    "                elif init_option == \"abi_psis\":\n",
    "                    initial_positions = abi_psis_resamples_unique[:buffer]\n",
    "                elif init_option == \"stan-like\":\n",
    "                    rng_key, init_key = jax.random.split(rng_key)\n",
    "                    initial_positions = jax.random.uniform(\n",
    "                        init_key, (buffer, D), minval=-2, maxval=2\n",
    "                    )\n",
    "\n",
    "                    if \"DDM\" in task.name:\n",
    "                        print(\"Make initial values for ndt meaningful\")\n",
    "                        initial_positions_constrained = (\n",
    "                            task.transform_to_constrained_space(\n",
    "                                initial_positions\n",
    "                            )\n",
    "                        )\n",
    "                        initial_positions_constrained[:, -2:] = (\n",
    "                            np.ones_like(\n",
    "                                (initial_positions_constrained.shape[0], 2)\n",
    "                            )\n",
    "                            * task.min_rts\n",
    "                            / 2\n",
    "                        )\n",
    "                        initial_positions = (\n",
    "                            task.transform_to_unconstrained_space(\n",
    "                                initial_positions_constrained\n",
    "                            )\n",
    "                        )\n",
    "                else:\n",
    "                    raise ValueError(f\"Unknown init_option: {init_option}\")\n",
    "\n",
    "                if observation_id not in init_positions_record.keys():\n",
    "                    init_positions_record[observation_id] = {}\n",
    "                init_positions_record[observation_id][init_option] = (\n",
    "                    initial_positions\n",
    "                )\n",
    "                if check_init_positions_only:\n",
    "                    continue\n",
    "                try:\n",
    "                    if mcmc_method == \"ChEES-HMC\":\n",
    "                        result = run_chees_hmc(\n",
    "                            initial_positions,\n",
    "                            K,\n",
    "                            M,\n",
    "                            lp_fn,\n",
    "                            num_warmup,\n",
    "                            num_sampling,\n",
    "                            D,\n",
    "                            init_step_size=init_step_size,\n",
    "                            sort=sort,\n",
    "                        )\n",
    "                    elif mcmc_method == \"NUTS\":\n",
    "                        from sbi_mcmc.utils.tf_chees_hmc_utils import (\n",
    "                            filter_invalid_init_positions,\n",
    "                        )\n",
    "\n",
    "                        initial_positions, _ = filter_invalid_init_positions(\n",
    "                            initial_positions, lp_fn, sort=sort\n",
    "                        )\n",
    "                        if len(initial_positions) < num_chains:\n",
    "                            raise ValueError(\n",
    "                                \"Not enough valid and unique initial positions\"\n",
    "                            )\n",
    "                        initial_positions = initial_positions[:num_chains]\n",
    "\n",
    "                        initial_positions = (\n",
    "                            task.transform_to_constrained_space(\n",
    "                                initial_positions\n",
    "                            )\n",
    "                        )\n",
    "                        initvals = ndarray_values_as_dict(\n",
    "                            initial_positions, task.var_dims\n",
    "                        )\n",
    "                        initvals = [\n",
    "                            ndarray_values_as_dict(\n",
    "                                initial_positions[i : i + 1], task.var_dims\n",
    "                            )\n",
    "                            for i in range(len(initial_positions))\n",
    "                        ]\n",
    "                        sampler_kwargs = {\n",
    "                            \"nuts_sampler\": \"numpyro\",\n",
    "                            \"nuts_sampler_kwargs\": {\"jitter\": False},\n",
    "                            \"target_accept\": 0.99,\n",
    "                        }\n",
    "                        pymc_model = task.setup_pymc_model(\n",
    "                            observation=observation\n",
    "                        )\n",
    "                        idata_post = pm.sample(\n",
    "                            tune=num_warmup,\n",
    "                            draws=num_sampling,\n",
    "                            chains=num_chains,\n",
    "                            model=pymc_model,\n",
    "                            initvals=initvals,\n",
    "                            progressbar=False,\n",
    "                            **sampler_kwargs,\n",
    "                        )\n",
    "                        rhat_az = az.rhat(\n",
    "                            idata_post, var_names=list(task.var_names)\n",
    "                        )\n",
    "                        _, rhat = az.sel_utils.xarray_to_ndarray(\n",
    "                            rhat_az, var_names=list(task.var_names)\n",
    "                        )\n",
    "                        result = {\"n_rhat\": rhat.squeeze()}\n",
    "                    else:\n",
    "                        continue\n",
    "                except ValueError as e:\n",
    "                    exceptions.append(\n",
    "                        f\"id: {observation_id}_{num_warmup}_{init_option}: \"\n",
    "                        + str(e)\n",
    "                    )\n",
    "                    continue\n",
    "                result_save_path = get_save_path(\n",
    "                    observation_id, num_warmup, init_option, sort=sort\n",
    "                )\n",
    "                # print(result[\"n_rhat\"])\n",
    "                save_to_file(result, result_save_path)\n",
    "        proceesed_inds.append(observation_id)\n",
    "        print(\"total processed: \", len(proceesed_inds))\n",
    "        if len(proceesed_inds) >= num_runs:\n",
    "            break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4aaf9f1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "valid_inds = []\n",
    "for observation_id in test_observation_ids:\n",
    "    flag = True\n",
    "    for init_option in [\"abi_psis\", \"stan-like\", \"abi\"]:\n",
    "        for num_warmup in num_warmup_values:\n",
    "            file_path = get_save_path(\n",
    "                observation_id, num_warmup, init_option, sort=sort\n",
    "            )\n",
    "            if not file_path.exists():\n",
    "                # print(file_path)\n",
    "                flag = False\n",
    "    if flag:\n",
    "        valid_inds.append(observation_id)\n",
    "print(len(valid_inds))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ea3ab5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import OrderedDict\n",
    "\n",
    "from sbi_mcmc.metrics import gskl, mtv, wasserstein_distance\n",
    "\n",
    "n_rhats = {}\n",
    "\n",
    "for init_option in [\"abi_psis\", \"stan-like\", \"abi\"]:\n",
    "    if init_option not in n_rhats:\n",
    "        n_rhats[init_option] = {}\n",
    "    for num_warmup in num_warmup_values:\n",
    "        max_nrhats = []\n",
    "        for observation_id in valid_inds:\n",
    "            file_path = get_save_path(\n",
    "                observation_id, num_warmup, init_option, sort=sort\n",
    "            )\n",
    "            result = read_from_file(file_path)\n",
    "            max_nrhats.append(np.mean(result[\"n_rhat\"].max() - 1))\n",
    "\n",
    "        n_rhats[init_option][num_warmup] = max_nrhats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "293aee4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "data = n_rhats\n",
    "fig = plt.figure(figsize=(5, 3))\n",
    "\n",
    "# serif font\n",
    "plt.rcParams[\"font.family\"] = \"serif\"\n",
    "\n",
    "# math serif\n",
    "plt.rcParams[\"mathtext.fontset\"] = \"dejavuserif\"\n",
    "\n",
    "# larger font\n",
    "plt.rcParams.update({\"font.size\": 12})\n",
    "# colors = plt.rcParams[\"axes.prop_cycle\"].by_key()[\"color\"]\n",
    "colors = {\n",
    "    \"abi_psis\": \"#DDAA33\",\n",
    "    \"stan-like\": \"#BB5566\",\n",
    "    \"abi\": \"#004488\",\n",
    "}\n",
    "labels = {\n",
    "    \"abi_psis\": \"Amortized + PSIS\",\n",
    "    \"stan-like\": \"Random Initialization\",\n",
    "    \"abi\": \"Amortized\",\n",
    "}\n",
    "\n",
    "markers = {\n",
    "    \"abi_psis\": \"D\",\n",
    "    \"stan-like\": \"s\",\n",
    "    \"abi\": \"^\",\n",
    "}\n",
    "interquartile = True\n",
    "# interquartile = False\n",
    "if interquartile:\n",
    "    title = \"median, interquartile\"\n",
    "    lower = 25\n",
    "    upper = 75\n",
    "else:\n",
    "    title = \"median, [Min,Max]\"\n",
    "    lower = 0\n",
    "    upper = 100\n",
    "# Loop through the outer keys in the dictionary\n",
    "for idx, key in enumerate(data):\n",
    "    x_values = sorted(data[key].keys())\n",
    "    x_offset = -4 + 4 * idx\n",
    "\n",
    "    rhat = np.array([data[key][x] for x in x_values])\n",
    "    x = np.array(x_values) + x_offset\n",
    "\n",
    "    plt.errorbar(\n",
    "        x,\n",
    "        np.median(rhat, axis=1),\n",
    "        yerr=[\n",
    "            np.percentile(rhat, lower, axis=1),\n",
    "            np.percentile(rhat, upper, axis=1),\n",
    "        ],\n",
    "        fmt=markers[key],\n",
    "        lw=1.5,\n",
    "        color=colors[key],\n",
    "        capsize=4,\n",
    "        linestyle=\"-\",\n",
    "        label=labels[key],\n",
    "        # marker edges black\n",
    "        markeredgecolor=\"black\",\n",
    "    )\n",
    "\n",
    "\n",
    "plt.xlabel(\"Number of warmup iterations\", fontsize=\"x-large\")\n",
    "plt.xticks(x_values)\n",
    "if mcmc_method == \"NUTS\":\n",
    "    ylabel = r\"$\\widehat{R} - 1$\"\n",
    "elif mcmc_method == \"ChEES-HMC\":\n",
    "    ylabel = r\"Nested $\\widehat{R} - 1$\"\n",
    "else:\n",
    "    raise ValueError\n",
    "plt.ylabel(ylabel, fontsize=\"x-large\")\n",
    "plt.yscale(\"log\")\n",
    "plt.grid(True, alpha=0.5)\n",
    "\n",
    "\n",
    "# get handles and labels\n",
    "handles, labels = plt.gca().get_legend_handles_labels()\n",
    "handles = [h[0] for h in handles]\n",
    "\n",
    "# specify order of items in legend\n",
    "order = sorted(range(len(labels)), key=lambda x: labels[x])\n",
    "plt.legend(\n",
    "    [handles[idx] for idx in order],\n",
    "    [labels[idx] for idx in order],\n",
    "    handlelength=1,\n",
    "    borderpad=0.2,\n",
    "    fontsize=9,\n",
    "    handletextpad=0.3,\n",
    ")\n",
    "\n",
    "plt.gca().spines[\"right\"].set_visible(False)\n",
    "plt.gca().spines[\"top\"].set_visible(False)\n",
    "plt.title(title)\n",
    "fig_path = (\n",
    "    paths[\"chees_hmc_result_dir\"]\n",
    "    / f\"figures/amortized_inits_mcmc_{mcmc_method}_{'_sorted' if sort else ''}.pdf\"\n",
    ")\n",
    "fig_path.parent.mkdir(parents=True, exist_ok=True)\n",
    "fig.savefig(fig_path, dpi=300, bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8013e80",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sbi_mcmc_bf_forge",
   "language": "python",
   "name": "sbi_mcmc_bf_forge"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
