{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "np.set_printoptions(suppress=True)\n",
    "\n",
    "# Will also use just-in-time (JIT) compilation via numba for speeding up simulations\n",
    "from numba import njit\n",
    "\n",
    "\n",
    "import bayesflow as bf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Introduction\n",
    "\n",
    "Modeling cognitive processes using both behavioral and neural data requires integrative neurocognitive models that jointly describe decision-making processes and neural dynamics. Traditional approaches, such as sequential sampling models (SSMs) like drift-diffusion models (DDMs), have successfully linked behavioral responses with underlying cognitive parameters. However, extending these models to simultaneously account for neural signals, such as EEG-derived event-related potentials (ERPs), presents significant challenges [1].\n",
    "\n",
    "A key limitation of standard Bayesian inference methods, such as Markov chain Monte Carlo (MCMC), is their computational inefficiency when applied to single-trial data, where each trial introduces new latent variables and dependencies between neural and behavioral measures. This complexity makes traditional likelihood-based inference impractical.\n",
    "\n",
    "Amortized Bayesian inference (ABI), as implemented in `BayesFlow` provides an efficient alternative by leveraging deep learning to approximate posterior distributions across many possible model configurations. By training on simulated data, our networks quickly learn to map observed EEG and behavioral measures directly to posterior distributions of cognitive parameters, avoiding the need for case-by-case inference. This approach not only accelerates parameter estimation but also enables robust model validation and generalization across datasets.\n",
    "\n",
    "[1] Ghaderi-Kangavari, A., Rad, J. A., & Nunez, M. D. (2023). A general integrative neurocognitive modeling framework to jointly describe EEG and decision-making on single trials. *Computational Brain & Behavior, 6(3)*, 317-376."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Joint Neurocognitive Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For this demonstration, we will focus on model $\\mathcal{M}_{1c}$ from Ghaderi-Kangavari et al. (2023), which defines the following generative process for behavavioral data $(x_n, c_n)$ and neural (EEG) data $(z_n, \\tau_n)$:\n",
    "$$\n",
    "\\begin{align*}\n",
    "    (x_n, c_n) &\\sim \\text{DDM}(\\alpha, \\tau_n, \\delta, \\beta)\\\\\n",
    "    z_n &\\sim \\mathcal{TN}(\\tau_n - \\tau_m, \\sigma^2, a, b)\\\\\n",
    "    \\tau_n &\\sim \\mathcal{U}(f(\\tau_e,  s_{\\tau}), f(\\tau_e,  -s_{\\tau}))\n",
    "\\end{align*}\n",
    "$$\n",
    "\n",
    "The model has seven parameters in total, $\\theta = (\\alpha, \\delta, \\beta, \\tau_e, \\tau_m, s_{\\tau}, \\sigma)$, indexing different neurcognitive aspects of the decision process. See Ghaderi-Kangavari et al. (2023) for more details."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Defining the Prior"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def draw_prior():\n",
    "    \"\"\"Draws samples from the prior distributions of m1c, as described in Ghaderi-Kangavari et al. (2023)\"\"\"\n",
    "\n",
    "    prior_draws = np.random.uniform(\n",
    "        low=(-3.0, 0.5, 0.1, 0.05, 0.06, 0.0, 0.0),\n",
    "        high=(3.0, 2.0, 0.9, 0.4, 0.6, 0.1, 0.1)\n",
    "    )\n",
    "\n",
    "    prior_draws = {\n",
    "        \"drift\": prior_draws[0],\n",
    "        \"boundary\": prior_draws[1],\n",
    "        \"beta\": prior_draws[2],\n",
    "        \"mu_tau_e\": prior_draws[3],\n",
    "        \"tau_m\": prior_draws[4],\n",
    "        \"sigma\": prior_draws[5],\n",
    "        \"varsigma\": prior_draws[6]\n",
    "    }\n",
    "\n",
    "    return prior_draws"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Definining the Simulator\n",
    "\n",
    "We use `numba` s just-in-time (JIT) compiler to speed up the for loops of the simulators."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "@njit\n",
    "def diffusion_trial(drift, boundary, beta, mu_tau_e, tau_m, sigma, varsigma, s=1.0, dt=5e-3):\n",
    "    \"\"\"Simulates a trial from the joint diffusion model m1b.\"\"\"\n",
    "\n",
    "    c = np.sqrt(dt) * s\n",
    "    n_steps = 0.0\n",
    "    evidence = boundary * beta\n",
    "\n",
    "    while evidence > 0 and evidence < boundary:\n",
    "        evidence += drift * dt + c * np.random.normal()\n",
    "        n_steps += 1.0\n",
    "\n",
    "    z = 0\n",
    "    while True:\n",
    "        # visual encoding\n",
    "        tau_encoding = mu_tau_e + np.random.uniform(-0.5 * np.sqrt(12) * varsigma, 0.5 * np.sqrt(12) * varsigma)\n",
    "        z = np.random.normal(tau_encoding, sigma)\n",
    "        if z > 0 and z < 0.5:\n",
    "            break\n",
    "\n",
    "    rt = n_steps * dt + tau_encoding + tau_m\n",
    "\n",
    "    if evidence >= boundary:\n",
    "        return (rt, 1.0, z)\n",
    "    return (rt, 0.0, z)\n",
    "\n",
    "\n",
    "def simulate_trials(drift, boundary, beta, mu_tau_e, tau_m, sigma, varsigma, num_trials=120):\n",
    "    \"\"\"Simulates a diffusion process for trials.\"\"\"\n",
    "\n",
    "    data = np.empty((num_trials, 3))\n",
    "    for i in range(num_trials):\n",
    "        data[i] = diffusion_trial(drift, boundary, beta, mu_tau_e, tau_m, sigma, varsigma)\n",
    "    return dict(data=data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Connecting Prior and Simulator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For now, we will use a fixed number of trials. In practice, different data sets (e.g., participants) will have different numbers of valid trials. We will demonstrate the application to a variable number of trials further down the line.\n",
    "\n",
    "**Note**: You can also completely bypass the BayesFlow simulation utilities. As long as your simulator returns dictionaries and you define a suitable adapter (see section **Custom Adapter**), you should be good to go."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "simulator = bf.make_simulator([draw_prior, simulate_trials])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output drift has a shape of (2, 1)\n",
      "Output boundary has a shape of (2, 1)\n",
      "Output beta has a shape of (2, 1)\n",
      "Output mu_tau_e has a shape of (2, 1)\n",
      "Output tau_m has a shape of (2, 1)\n",
      "Output sigma has a shape of (2, 1)\n",
      "Output varsigma has a shape of (2, 1)\n",
      "Output data has a shape of (2, 120, 3)\n"
     ]
    }
   ],
   "source": [
    "# Always a good idea to test the fidelity of the outputs\n",
    "test_sims = simulator.sample(2)\n",
    "for k, v in test_sims.items():\n",
    "    print(f\"Output {k} has a shape of {v.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Offline Dataset\n",
    "To demonstrate the utility of ABI, we will use a rather small set of offline simulations. Simulations are usually fast for cognitive modeling, so we recommend using online training in your applications."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "offline_sims = simulator.sample(6000)\n",
    "\n",
    "validation_sims = simulator.sample(200)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Workflow Components"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training hyperparameters\n",
    "batch_size = 32\n",
    "epochs = 100\n",
    "\n",
    "# The set transformer will compress data from N trials of shape (N, 3) into a vector of shape (16,)\n",
    "summary_network = bf.networks.SetTransformer(summary_dim=16)\n",
    "\n",
    "# We use the good old coupling flow. You can try some of the latest generative architectures as well (e.g., FlowMatching)\n",
    "inference_network = bf.networks.CouplingFlow()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The quickest way to create a workflow is demonstrated below. For more complex pipelines, you should create a data adapter explicitly, as demonstrated below.\n",
    "\n",
    "In practice, you should also save the networks using the `checkpoint_filepath` and `checkpoint_name` keyword arguments. The approximator can later ba loaded as\n",
    "- If you want to use the workflow functionality: `workflow.approximator = keras.saving.load(checkpoint_filepath/checkpoint_name.keras)`\n",
    "- If you just want the standalone pre-trained approximator: `approximator = keras.saving.load(checkpoint_filepath/checkpoint_name.keras)`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "workflow = bf.BasicWorkflow(\n",
    "    simulator=simulator, \n",
    "    inference_network=inference_network, \n",
    "    summary_network=summary_network,\n",
    "    inference_variables=[\"drift\", \"boundary\", \"beta\", \"mu_tau_e\", \"tau_m\", \"sigma\", \"varsigma\"],\n",
    "    inference_conditions=None,\n",
    "    summary_variables=[\"data\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can inspect the default adapter like so:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Adapter([0: ConvertDType -> 1: Concatenate(['drift', 'boundary', 'beta', 'mu_tau_e', 'tau_m', 'sigma', 'varsigma'] -> 'inference_variables') -> 2: Concatenate(['data'] -> 'summary_variables') -> 3: Standardize(include=['inference_variables'])])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "workflow.adapter"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Training then proceeds as follows. Without a GPU, training should take around 10 minutes; using a GPU, you can get down to less than a minute."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "history = workflow.fit_offline(offline_sims, epochs=epochs, batch_size=batch_size, validation_data=validation_sims)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## In Silico Validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "f = bf.diagnostics.plots.loss(history)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visual Diagnostics\n",
    "You can easily plot all checks of computational faithfulness pertaining to a principled Bayesian workflow [2]:\n",
    "\n",
    "[2] Schad, D. J., Betancourt, M., & Vasishth, S. (2021). Toward a principled Bayesian workflow in cognitive science. *Psychological methods, 26(1)*, 103."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = workflow.plot_default_diagnostics(\n",
    "    test_data=validation_sims, \n",
    "    num_samples=500,\n",
    "    loss_kwargs={\"figsize\": (15, 3), \"label_fontsize\": 12},\n",
    "    recovery_kwargs={\"figsize\": (15, 6), \"label_fontsize\": 12},\n",
    "    calibration_ecdf_kwargs={\"figsize\": (15, 6), \"legend_fontsize\": 8, \"difference\": True, \"label_fontsize\": 12},\n",
    "    z_score_contraction_kwargs={\"figsize\": (15, 6), \"label_fontsize\": 12}  \n",
    ")\n",
    "fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>drift</th>\n",
       "      <th>boundary</th>\n",
       "      <th>beta</th>\n",
       "      <th>mu_tau_e</th>\n",
       "      <th>tau_m</th>\n",
       "      <th>sigma</th>\n",
       "      <th>varsigma</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>NRMSE</th>\n",
       "      <td>0.071242</td>\n",
       "      <td>0.103692</td>\n",
       "      <td>0.094913</td>\n",
       "      <td>0.048024</td>\n",
       "      <td>0.060045</td>\n",
       "      <td>0.321079</td>\n",
       "      <td>0.178052</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Posterior Contraction</th>\n",
       "      <td>0.978946</td>\n",
       "      <td>0.967126</td>\n",
       "      <td>0.969568</td>\n",
       "      <td>0.991798</td>\n",
       "      <td>0.986764</td>\n",
       "      <td>0.376899</td>\n",
       "      <td>0.861550</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Calibration Error</th>\n",
       "      <td>0.035526</td>\n",
       "      <td>0.080395</td>\n",
       "      <td>0.068158</td>\n",
       "      <td>0.045789</td>\n",
       "      <td>0.081447</td>\n",
       "      <td>0.033421</td>\n",
       "      <td>0.039737</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                          drift  boundary      beta  mu_tau_e     tau_m  \\\n",
       "NRMSE                  0.071242  0.103692  0.094913  0.048024  0.060045   \n",
       "Posterior Contraction  0.978946  0.967126  0.969568  0.991798  0.986764   \n",
       "Calibration Error      0.035526  0.080395  0.068158  0.045789  0.081447   \n",
       "\n",
       "                          sigma  varsigma  \n",
       "NRMSE                  0.321079  0.178052  \n",
       "Posterior Contraction  0.376899  0.861550  \n",
       "Calibration Error      0.033421  0.039737  "
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "table = workflow.compute_default_diagnostics(validation_sims, num_samples=500)\n",
    "table"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Variable Number of Trials"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "def random_num_trials():\n",
    "    return dict(num_trials=np.random.randint(low=10, high=200+1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "var_trials_simulator = bf.make_simulator([draw_prior, simulate_trials], meta_fn=random_num_trials)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of trials: 169\n",
      "Output drift has a shape of (2, 1)\n",
      "Output boundary has a shape of (2, 1)\n",
      "Output beta has a shape of (2, 1)\n",
      "Output mu_tau_e has a shape of (2, 1)\n",
      "Output tau_m has a shape of (2, 1)\n",
      "Output sigma has a shape of (2, 1)\n",
      "Output varsigma has a shape of (2, 1)\n",
      "Output data has a shape of (2, 169, 3)\n"
     ]
    }
   ],
   "source": [
    "# Always a good idea to test the fidelity of the outputs\n",
    "test_sims = var_trials_simulator.sample(2)\n",
    "for k, v in test_sims.items():\n",
    "    print(f\"Output {k} has a shape of {v.shape}\" if type(v) is not int else f\"Number of trials: {v}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Custom Adapter\n",
    "\n",
    "The custom adapter now separates all simulator outputs into three categories:\n",
    "1. `inference_variables` - the targets of inference;\n",
    "2. `summary_variables` - the conditioning variables that will be first summarized by the summary network;\n",
    "3. `inference_conditions` - the conditioning variables that will bypass the summary network and directly inform the inference network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [],
   "source": [
    "adapter = (\n",
    "    bf.Adapter()\n",
    "    .sqrt(\"num_trials\")\n",
    "    .broadcast(\"num_trials\", to=\"data\", exclude=(1, 2), squeeze=-1)\n",
    "    .concatenate([\"drift\", \"boundary\", \"beta\", \"mu_tau_e\", \"tau_m\", \"sigma\", \"varsigma\"], into=\"inference_variables\")\n",
    "    .rename(\"data\", \"summary_variables\")\n",
    "    .rename(\"num_trials\", \"inference_conditions\")\n",
    "    .standardize(include=\"inference_variables\")\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Use a larger network for generalizing across different numbers of observations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training hyperparameters\n",
    "batch_size = 32\n",
    "epochs = 200\n",
    "\n",
    "# The set transformer will compress data from N trials of shape (N, 3) into a vector of shape (16,)\n",
    "summary_network = bf.networks.SetTransformer(summary_dim=16, embed_dims=(128, 128), num_seeds=3)\n",
    "\n",
    "# We use the good old coupling flow. You can try some of the latest generative architectures as well (e.g., FlowMatching)\n",
    "inference_network = bf.networks.CouplingFlow()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {},
   "outputs": [],
   "source": [
    "var_trials_workflow = bf.BasicWorkflow(\n",
    "    simulator=var_trials_simulator,\n",
    "    inference_network=inference_network, \n",
    "    summary_network=summary_network,\n",
    "    initial_learning_rate=1e-4,\n",
    "    adapter=adapter,\n",
    "    inference_variables=[\"drift\", \"boundary\", \"beta\", \"mu_tau_e\", \"tau_m\", \"sigma\", \"varsigma\"],\n",
    "    inference_conditions=None,\n",
    "    summary_variables=[\"data\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will now demonstrate online training, since offline training with variable number of trials can be tricky (it would require simulating with maximum number of trials and then taking subsets)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "history2 = var_trials_workflow.fit_online(epochs=epochs, num_batches_per_epoch=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [],
   "source": [
    "validation_sims = validation_sims.copy()\n",
    "validation_sims[\"num_trials\"] = 120"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = var_trials_workflow.plot_default_diagnostics(\n",
    "    test_data=validation_sims, \n",
    "    num_samples=500,\n",
    "    loss_kwargs={\"figsize\": (15, 3), \"label_fontsize\": 12},\n",
    "    recovery_kwargs={\"figsize\": (15, 6), \"label_fontsize\": 12},\n",
    "    calibration_ecdf_kwargs={\"figsize\": (15, 6), \"legend_fontsize\": 8, \"difference\": True, \"label_fontsize\": 12},\n",
    "    z_score_contraction_kwargs={\"figsize\": (15, 6), \"label_fontsize\": 12}  \n",
    ")\n",
    "fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "bayeskeras",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
