{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/marvin/miniforge3/envs/cmpe/lib/python3.10/site-packages/bayesflow/trainers.py:27: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from tqdm.autonotebook import tqdm\n"
     ]
    }
   ],
   "source": [
    "import pickle\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from train import get_setup\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"../../\")\n",
    "from reference_posteriors.two_moons.two_moons_lueckmann_numpy import analytic_posterior_numpy \n",
    "from inverse_kinematics import InverseKinematicsModel\n",
    "\n",
    "\n",
    "from cmdstanpy import CmdStanModel\n",
    "import logging\n",
    "\n",
    "\n",
    "from reference_posteriors.gmm_bimodal import GMM, GMMSimulator\n",
    "\n",
    "\n",
    "logger = logging.getLogger(\"cmdstanpy\")\n",
    "logger.addHandler(logging.NullHandler())\n",
    "logger.propagate = False\n",
    "logger.setLevel(logging.CRITICAL)\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_posterior_samples = 1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = ['gmm', 'twomoons', 'invkinematics']\n",
    "task_names = {'gmm': 'GMM', 'twomoons': 'Two Moons', 'invkinematics': 'Kinematics'}\n",
    "simulation_budgets = [1024]\n",
    "estimators = ['ac', 'nsf', 'cmpe', 'fmpe']\n",
    "\n",
    "colors = ['#440154', '#3b528b', '#21908dff', '#5dc962ff', '#fde725ff', '#fde725ff', '#fde725ff']\n",
    "\n",
    "gmm_idx = 1\n",
    "invkinematics_idx = 0\n",
    "\n",
    "tf.random.set_seed(1234)\n",
    "\n",
    "gmm_theta = np.array([-1.6, -1.0])\n",
    "gmm_y = GMMSimulator(GMM)(tf.convert_to_tensor([gmm_theta]))[0].numpy().astype(np.float32)\n",
    "\n",
    "iter_warmup = 2000\n",
    "\n",
    "n_obs, data_dim = gmm_y.shape\n",
    "param_dim = gmm_theta.shape[0]\n",
    "\n",
    "iter_sampling = num_posterior_samples // 2\n",
    "\n",
    "gmm_reference_samples = np.zeros((num_posterior_samples, param_dim))\n",
    "\n",
    "stan_data = {\"n_obs\": n_obs, \"data_dim\": data_dim, \"x\": gmm_y}\n",
    "model = CmdStanModel(stan_file=\"../../reference_posteriors/gmm_bimodal/gmm.stan\")\n",
    "fit = model.sample(\n",
    "    data=stan_data,\n",
    "    iter_warmup=iter_warmup,\n",
    "    iter_sampling=iter_sampling,\n",
    "    chains=1,\n",
    "    inits = {\"theta\": gmm_theta.tolist()},\n",
    "    show_progress=False\n",
    ")\n",
    "posterior_samples_chain = fit.stan_variable(\"theta\")\n",
    "gmm_reference_samples = np.concatenate([posterior_samples_chain, -1.0 * posterior_samples_chain], axis=0)\n",
    "\n",
    "\n",
    "test_instances = {\n",
    "    'gmm': {'summary_conditions': gmm_y[np.newaxis, ...]},\n",
    "    'twomoons': {'direct_conditions': np.array([[0, 0]]).astype(np.float32)},\n",
    "    'invkinematics': {'direct_conditions': np.array([[0, 1.5]]).astype(np.float32)},\n",
    "}\n",
    "\n",
    "reference_posteriors = {\n",
    "    'gmm': gmm_reference_samples,\n",
    "    'twomoons': analytic_posterior_numpy(test_instances['twomoons']['direct_conditions'][0], num_posterior_samples, rng=np.random.default_rng(seed=1234)),\n",
    "    'invkinematics': test_instances['invkinematics']['direct_conditions'][0][::-1],\n",
    "}\n",
    "\n",
    "inverse_kinematics_abc = pickle.load(open('./data/invkinematics_showcase_abc.pkl', 'rb'))[:num_posterior_samples]\n",
    "\n",
    "plot_settings = {\n",
    "    'ac': {'name': 'ACF', 'color': colors[0]},\n",
    "    'nsf': {'name': 'NSF', 'color': colors[1]},\n",
    "    'fmpe10': {'name': 'FMPE 10#', 'color': colors[3]},\n",
    "    'fmpe30': {'name': 'FMPE 30#', 'color': colors[4]},\n",
    "    'fmpe': {'name': 'FMPE 1000#', 'color': colors[2]},\n",
    "    'cmpe10': {'name': 'CMPE 10#', 'color': colors[5]},\n",
    "    'cmpe30': {'name': 'CMPE 30#', 'color': colors[6]},\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "def sample_timed(trainer, num_runs=3, **kwargs):\n",
    "    t_min = np.inf\n",
    "\n",
    "    for _ in range(num_runs):\n",
    "        tic = time.time()\n",
    "        samples = trainer.amortizer.sample(**kwargs)\n",
    "        toc = time.time()\n",
    "        t = toc - tic\n",
    "        if t < t_min:\n",
    "            t_min = t\n",
    "            samples_t_min = samples\n",
    "\n",
    "    return samples_t_min, t_min"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/12 [00:00<?, ?it/s]INFO:root:Trainer initialization: No generative model provided. Only offline learning mode is available!\n",
      "INFO:root:Initialized empty loss history.\n",
      "INFO:root:Initialized networks from scratch.\n",
      "  8%|▊         | 1/12 [00:00<00:09,  1.18it/s]INFO:root:Trainer initialization: No generative model provided. Only offline learning mode is available!\n",
      "INFO:root:Initialized empty loss history.\n",
      "INFO:root:Initialized networks from scratch.\n",
      " 17%|█▋        | 2/12 [00:01<00:06,  1.44it/s]INFO:root:Trainer initialization: No generative model provided. Only offline learning mode is available!\n",
      "INFO:root:Initialized empty loss history.\n",
      "INFO:root:Initialized networks from scratch.\n",
      " 25%|██▌       | 3/12 [00:02<00:08,  1.01it/s]INFO:root:Trainer initialization: No generative model provided. Only offline learning mode is available!\n",
      "INFO:root:Initialized empty loss history.\n",
      "INFO:root:Initialized networks from scratch.\n",
      " 33%|███▎      | 4/12 [00:14<00:42,  5.36s/it]INFO:root:Performing 2 pilot runs with the two_moons model...\n",
      "INFO:root:Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 2)\n",
      "INFO:root:Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 2)\n",
      "INFO:root:No optional prior non-batchable context provided.\n",
      "INFO:root:No optional prior batchable context provided.\n",
      "INFO:root:No optional simulation non-batchable context provided.\n",
      "INFO:root:No optional simulation batchable context provided.\n",
      "INFO:root:Trainer initialization: No generative model provided. Only offline learning mode is available!\n",
      "INFO:root:Loaded loss history from ./checkpoints/twomoons_ac_1024_run0/history_1.pkl.\n",
      "INFO:root:Networks loaded from ./checkpoints/twomoons_ac_1024_run0/ckpt-1\n",
      " 42%|████▏     | 5/12 [00:15<00:25,  3.66s/it]INFO:root:Performing 2 pilot runs with the two_moons model...\n",
      "INFO:root:Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 2)\n",
      "INFO:root:Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 2)\n",
      "INFO:root:No optional prior non-batchable context provided.\n",
      "INFO:root:No optional prior batchable context provided.\n",
      "INFO:root:No optional simulation non-batchable context provided.\n",
      "INFO:root:No optional simulation batchable context provided.\n",
      "INFO:root:Trainer initialization: No generative model provided. Only offline learning mode is available!\n",
      "INFO:root:Loaded loss history from ./checkpoints/twomoons_nsf_1024_run0/history_1.pkl.\n",
      "INFO:root:Networks loaded from ./checkpoints/twomoons_nsf_1024_run0/ckpt-1\n",
      " 50%|█████     | 6/12 [00:16<00:15,  2.61s/it]INFO:root:Performing 2 pilot runs with the two_moons model...\n",
      "INFO:root:Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 2)\n",
      "INFO:root:Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 2)\n",
      "INFO:root:No optional prior non-batchable context provided.\n",
      "INFO:root:No optional prior batchable context provided.\n",
      "INFO:root:No optional simulation non-batchable context provided.\n",
      "INFO:root:No optional simulation batchable context provided.\n",
      "INFO:root:Trainer initialization: No generative model provided. Only offline learning mode is available!\n",
      "INFO:root:Loaded loss history from ./checkpoints/twomoons_cmpe_1024_run0/history_1.pkl.\n",
      "INFO:root:Networks loaded from ./checkpoints/twomoons_cmpe_1024_run0/ckpt-1\n",
      " 58%|█████▊    | 7/12 [00:17<00:10,  2.08s/it]INFO:root:Performing 2 pilot runs with the two_moons model...\n",
      "INFO:root:Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 2)\n",
      "INFO:root:Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 2)\n",
      "INFO:root:No optional prior non-batchable context provided.\n",
      "INFO:root:No optional prior batchable context provided.\n",
      "INFO:root:No optional simulation non-batchable context provided.\n",
      "INFO:root:No optional simulation batchable context provided.\n",
      "INFO:root:Trainer initialization: No generative model provided. Only offline learning mode is available!\n",
      "INFO:root:Loaded loss history from ./checkpoints/twomoons_fmpe_1024_run0/history_1.pkl.\n",
      "INFO:root:Networks loaded from ./checkpoints/twomoons_fmpe_1024_run0/ckpt-1\n",
      " 67%|██████▋   | 8/12 [00:30<00:22,  5.58s/it]INFO:root:Performing 2 pilot runs with the inverse_kinematics model...\n",
      "INFO:root:Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 4)\n",
      "INFO:root:Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 2)\n",
      "INFO:root:No optional prior non-batchable context provided.\n",
      "INFO:root:No optional prior batchable context provided.\n",
      "INFO:root:No optional simulation non-batchable context provided.\n",
      "INFO:root:No optional simulation batchable context provided.\n",
      "INFO:root:Trainer initialization: No generative model provided. Only offline learning mode is available!\n",
      "INFO:root:Loaded loss history from ./checkpoints/invkinematics_ac_1024_run0/history_1.pkl.\n",
      "INFO:root:Networks loaded from ./checkpoints/invkinematics_ac_1024_run0/ckpt-1\n",
      " 75%|███████▌  | 9/12 [00:30<00:12,  4.08s/it]INFO:root:Performing 2 pilot runs with the inverse_kinematics model...\n",
      "INFO:root:Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 4)\n",
      "INFO:root:Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 2)\n",
      "INFO:root:No optional prior non-batchable context provided.\n",
      "INFO:root:No optional prior batchable context provided.\n",
      "INFO:root:No optional simulation non-batchable context provided.\n",
      "INFO:root:No optional simulation batchable context provided.\n",
      "INFO:root:Trainer initialization: No generative model provided. Only offline learning mode is available!\n",
      "INFO:root:Loaded loss history from ./checkpoints/invkinematics_nsf_1024_run0/history_1.pkl.\n",
      "INFO:root:Networks loaded from ./checkpoints/invkinematics_nsf_1024_run0/ckpt-1\n",
      " 83%|████████▎ | 10/12 [00:31<00:06,  3.00s/it]INFO:root:Performing 2 pilot runs with the inverse_kinematics model...\n",
      "INFO:root:Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 4)\n",
      "INFO:root:Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 2)\n",
      "INFO:root:No optional prior non-batchable context provided.\n",
      "INFO:root:No optional prior batchable context provided.\n",
      "INFO:root:No optional simulation non-batchable context provided.\n",
      "INFO:root:No optional simulation batchable context provided.\n",
      "INFO:root:Trainer initialization: No generative model provided. Only offline learning mode is available!\n",
      "INFO:root:Loaded loss history from ./checkpoints/invkinematics_cmpe_1024_run0/history_1.pkl.\n",
      "INFO:root:Networks loaded from ./checkpoints/invkinematics_cmpe_1024_run0/ckpt-1\n",
      " 92%|█████████▏| 11/12 [00:32<00:02,  2.40s/it]INFO:root:Performing 2 pilot runs with the inverse_kinematics model...\n",
      "INFO:root:Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 4)\n",
      "INFO:root:Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 2)\n",
      "INFO:root:No optional prior non-batchable context provided.\n",
      "INFO:root:No optional prior batchable context provided.\n",
      "INFO:root:No optional simulation non-batchable context provided.\n",
      "INFO:root:No optional simulation batchable context provided.\n",
      "INFO:root:Trainer initialization: No generative model provided. Only offline learning mode is available!\n",
      "INFO:root:Loaded loss history from ./checkpoints/invkinematics_fmpe_1024_run0/history_1.pkl.\n",
      "INFO:root:Networks loaded from ./checkpoints/invkinematics_fmpe_1024_run0/ckpt-1\n"
     ]
    }
   ],
   "source": [
    "# evaluate the estimators on the test data\n",
    "total = len(tasks) * len(simulation_budgets) * len(estimators)\n",
    "run_idx = 0\n",
    "eval_dict = {task: {budget: {estimator: {} for estimator in estimators} for budget in simulation_budgets} for task in tasks}\n",
    "num_runs_timed = 3\n",
    "with tqdm(total=total) as pbar:\n",
    "    for task in tasks:\n",
    "        for budget in simulation_budgets:\n",
    "            for estimator in estimators:\n",
    "                train_data_full = pickle.load(open(f'./data/{task}_train_data.pkl', 'rb'))\n",
    "                train_data = {\n",
    "                    'sim_data': train_data_full.get('sim_data')[:budget],\n",
    "                    'prior_draws': train_data_full.get('prior_draws')[:budget],\n",
    "                }\n",
    "                sigma2 = tf.math.reduce_variance(tf.constant(train_data[\"prior_draws\"], dtype=tf.float32), axis=0, keepdims=True)\n",
    "                ckpt_path = f'./checkpoints/{task}_{estimator}_{budget}_run{run_idx}'\n",
    "                trainer, settings = get_setup(task, estimator, sigma2, budget, ckpt_path)\n",
    "                eval_data = test_instances[task]\n",
    "                if estimator == 'cmpe':\n",
    "                    eval_dict[task][budget]['cmpe10'] = {}\n",
    "                    eval_dict[task][budget]['cmpe30'] = {}\n",
    "                    eval_dict[task][budget]['cmpe10']['posterior_samples'] = sample_timed(trainer, num_runs=num_runs_timed, input_dict=eval_data, n_steps=10, n_samples=num_posterior_samples, to_numpy=False)\n",
    "                    eval_dict[task][budget]['cmpe30']['posterior_samples'] = sample_timed(trainer, num_runs=num_runs_timed, input_dict=eval_data, n_steps=30, n_samples=num_posterior_samples, to_numpy=False)\n",
    "                elif estimator == 'fmpe':\n",
    "                    eval_dict[task][budget]['fmpe10'] = {}\n",
    "                    eval_dict[task][budget]['fmpe30'] = {}\n",
    "                    eval_dict[task][budget]['fmpe']['posterior_samples'] = sample_timed(trainer, num_runs=num_runs_timed, input_dict=eval_data, n_samples=num_posterior_samples, to_numpy=False)\n",
    "                    eval_dict[task][budget]['fmpe10']['posterior_samples'] = sample_timed(trainer, num_runs=num_runs_timed, input_dict=eval_data, step_size=1.0/10.0, n_samples=num_posterior_samples, to_numpy=False)\n",
    "                    eval_dict[task][budget]['fmpe30']['posterior_samples'] = sample_timed(trainer, num_runs=num_runs_timed, input_dict=eval_data, step_size=1.0/30.0, n_samples=num_posterior_samples, to_numpy=False)\n",
    "                else:\n",
    "                    eval_dict[task][budget][estimator]['posterior_samples'] = sample_timed(trainer, num_runs=num_runs_timed, input_dict=eval_data, n_samples=num_posterior_samples, to_numpy=False)\n",
    "                pbar.update(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Example Grid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot(ax, reference, approximate, task, reference_color=(0.8, 0.4, 0.4), approximate_color=(1, 1, 1), **kwargs):\n",
    "    if task == 'twomoons' or task == 'gmm':\n",
    "        if reference is not None:\n",
    "            ax.scatter(reference[:, 0], reference[:, 1], color=reference_color, **kwargs)\n",
    "        if approximate is not None:\n",
    "            ax.scatter(approximate[:, 0], approximate[:, 1], color=approximate_color, **kwargs)\n",
    "    elif task == 'invkinematics':\n",
    "        if approximate is not None:\n",
    "            if reference_color == (0.8, 0.4, 0.4):\n",
    "                linecolors = [(1,1,1), (0.8, 0.8, 0.8), (0.7, 0.7, 0.7)]\n",
    "            else:\n",
    "                linecolors = [(0.8, 0.4, 0.4)] * 3\n",
    "            m = InverseKinematicsModel(linecolors=linecolors)\n",
    "            m.update_plot_ax(ax, approximate, reference)#, target_label=r'$\\theta^*$')\n",
    "    else:\n",
    "        raise ValueError(f'Unknown task {task}')\n",
    "\n",
    "    if task == 'twomoons':\n",
    "        ax.set_xlim([-0.5, 0.5])\n",
    "        ax.set_ylim([-0.5, 0.5])\n",
    "    elif task == 'gmm':\n",
    "        gmm_limit = 2.5\n",
    "        ax.set_xlim([-gmm_limit, gmm_limit])\n",
    "        ax.set_ylim([-gmm_limit, gmm_limit])\n",
    "\n",
    "    return ax\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nrows = len(tasks)\n",
    "ncols = len(plot_settings) + 1\n",
    "\n",
    "scatter_kws = {\n",
    "    \"alpha\": 0.20,\n",
    "    \"rasterized\": True,\n",
    "    \"s\": 0.7,\n",
    "    \"marker\": \"D\",\n",
    "}\n",
    "\n",
    "f, axes = plt.subplots(nrows, ncols, figsize=(ncols*2, nrows*2), subplot_kw=dict(box_aspect=1), squeeze=False)\n",
    "\n",
    "\n",
    "for i, task in enumerate(tasks):\n",
    "    axes[i, 0].set_ylabel(task_names[task], rotation=90, size='xx-large')\n",
    "\n",
    "    # Plot Reference\n",
    "    if task == 'invkinematics':\n",
    "        axes[i, 0] = plot(axes[i, 0], reference_posteriors[task], inverse_kinematics_abc, task, reference_color='custom', **scatter_kws)\n",
    "    else:\n",
    "        axes[i, 0] = plot(axes[i, 0], reference_posteriors[task], None, task, **scatter_kws)\n",
    "\n",
    "    # Plot Approximate\n",
    "    for j, estimator in enumerate(plot_settings.keys(), 1):\n",
    "        if i == 0:\n",
    "            axes[i, j].set_title(plot_settings[estimator]['name'], size='xx-large')\n",
    "        posterior_samples, sampling_time = eval_dict[task][budget][estimator]['posterior_samples']\n",
    "        sampling_time_1000 = sampling_time / posterior_samples.shape[0] * 1000\n",
    "        axes[i, j] = plot(axes[i, j], reference_posteriors[task], posterior_samples.numpy(), task, **scatter_kws)\n",
    "\n",
    "        axes[i, j].annotate(text=f'00000ms', \n",
    "                            xy=(0.95, 0.06), \n",
    "                            xycoords='axes fraction',\n",
    "                            color='white',\n",
    "                            horizontalalignment='right',\n",
    "                            bbox=dict(facecolor='white'),\n",
    "                            fontsize='x-large'\n",
    "                            )\n",
    "        \n",
    "        #axes[i, j].annotate(text=f'{sampling_time_1000:.2f} sec', \n",
    "        axes[i, j].annotate(text=f'{int(sampling_time_1000*1000)}ms', \n",
    "                            xy=(0.95, 0.06), \n",
    "                            xycoords='axes fraction',\n",
    "                            horizontalalignment='right',\n",
    "                            fontsize='x-large'\n",
    "                            )\n",
    "\n",
    "axes[0, 0].set_title(\"Reference\", size='xx-large')\n",
    "\n",
    "for ax in axes.flat:\n",
    "    ax.grid(False)\n",
    "    ax.set_facecolor((0 / 255, 32 / 255, 64 / 255, 1.0))\n",
    "    ax.get_xaxis().set_ticks([])\n",
    "    ax.get_yaxis().set_ticks([])\n",
    "    ax.spines[\"bottom\"].set_alpha(0.0)\n",
    "    ax.spines[\"top\"].set_alpha(0.0)\n",
    "    ax.spines[\"right\"].set_alpha(0.0)\n",
    "    ax.spines[\"left\"].set_alpha(0.0)\n",
    "    ax.set_aspect('equal')\n",
    "\n",
    "f.tight_layout()\n",
    "\n",
    "f.savefig('./figures/benchmark_showcase.pdf', dpi=300, bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "cmpe",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
