{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.append('../src')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import diffrax\n",
    "import distrax\n",
    "import equinox as eqx\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import optax\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "np.random.seed(0)\n",
    "key = jax.random.PRNGKey(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data Generation\n",
    "\n",
    "We generate data from a family of donuts, each with a randomly sampled mean, width, and radius."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_data(num_dists, num_samples, mean=None, radius=None, width=None):\n",
    "    if mean is None:\n",
    "        mean = np.random.normal(0.0, 1.0, size=(num_dists, 1, 2))\n",
    "    if radius is None:\n",
    "        radius = np.random.uniform(0.5, 2.0, size=(num_dists, 1, 1))\n",
    "    if width is None:\n",
    "        width = np.random.uniform(0.01, 0.75, size=(num_dists, 1, 1))\n",
    "\n",
    "    a = np.random.normal(0.0, 1.0, size=(num_dists, num_samples, 2))\n",
    "    b = np.random.uniform(-0.5, 0.5, size=(num_dists, num_samples, 1))\n",
    "\n",
    "    return mean + radius * (a / np.linalg.norm(a, keepdims=True, axis=-1)) * (1.0 + width * b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure()\n",
    "\n",
    "# Parameters used for the figures in the paper\n",
    "mean = np.array([[[-0.26, -0.70]], [[0.38, 1.20]], [[1.18, -1.88]]])\n",
    "radius = np.array([[[1.88]], [[1.72]], [[1.12]]])\n",
    "width = np.array([[[0.31]], [[0.26]], [[0.31]]])\n",
    "\n",
    "data = generate_data(3, 768, mean, radius, width)\n",
    "for i in range(3):\n",
    "    plt.scatter(data[i, :, 0], data[i, :, 1], s=20.0, marker='.')\n",
    "\n",
    "plt.xlabel(r'$x_1$')\n",
    "plt.ylabel(r'$x_2$')\n",
    "plt.grid()\n",
    "\n",
    "# Save these limits so that we can use them for Figure 2\n",
    "xlims = plt.gca().get_xlim()\n",
    "ylims = plt.gca().get_ylim()\n",
    "\n",
    "plt.savefig('donuts-example-fm.pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Flow Matching Model\n",
    "\n",
    "We define the Flow Matching model. This first one uses basic independent coupling from Tong et al. 2023 for defining the vector fields that represent the conditional probability paths. The second one is a more refined implementation that uses optimal transport from Lipman et al. 2023. See Table 1 in Tong et al. for a summary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ConditionalFlowMatcher(eqx.Module):\n",
    "    sigma: float = 0.0\n",
    "\n",
    "    def compute_mu_t(self, x0, x1, t):\n",
    "        # Compute the mean of the Gaussian probability path\n",
    "        return t * x1 + (1 - t) * x0\n",
    "\n",
    "    def compute_sigma_t(self, t):\n",
    "        # Compute the standard deviation of the Gaussian probability path\n",
    "        return self.sigma\n",
    "    \n",
    "    def sample_xt(self, x0, x1, t, epsilon):\n",
    "        # Draw a sample from the Gaussian probability path defined by `mu_t` and `sigma_t`\n",
    "        mu_t = self.compute_mu_t(x0, x1, t)\n",
    "        sigma_t = self.compute_sigma_t(t)\n",
    "        return mu_t + sigma_t * epsilon\n",
    "    \n",
    "    def compute_conditional_flow(self, x0, x1, t, xt):\n",
    "        # Compute the true vector field u_t(x_t | x0, x1)\n",
    "        return x1 - x0\n",
    "    \n",
    "    def sample_location_and_conditional_flow(self, x0, x1, key, t=None, return_noise=False):\n",
    "        # Compute a sample x_t and the conditional vector field u_t needed for training\n",
    "        if t is None:\n",
    "            key, time_key = jax.random.split(key, 2)\n",
    "            t = jax.random.uniform(time_key)\n",
    "\n",
    "        key, noise_key = jax.random.split(key, 2)\n",
    "        noise = jax.random.uniform(noise_key)\n",
    "\n",
    "        xt = self.sample_xt(x0, x1, t, noise)\n",
    "        ut = self.compute_conditional_flow(x0, x1, t, xt)\n",
    "\n",
    "        return (t, xt, ut, noise) if return_noise else (t, xt, ut)\n",
    "\n",
    "\n",
    "class TargetConditionalFlowMatcher(ConditionalFlowMatcher):\n",
    "    def compute_mu_t(self, x0, x1, t):\n",
    "        return t * x1\n",
    "\n",
    "    def compute_sigma_t(self, t):\n",
    "        return 1 - (1 - self.sigma) * t\n",
    "\n",
    "    def sample_xt(self, x0, x1, t, epsilon):\n",
    "        return self.compute_mu_t(x0, x1, t) + x0 * self.compute_sigma_t(t) \n",
    "\n",
    "    def compute_conditional_flow(self, x0, x1, t, xt):\n",
    "        return x1 - (1 - self.sigma) * x0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Flow Matching & Embedding Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ContinuousFlowEmbedding(eqx.Module):\n",
    "    embedding_network: eqx.Module\n",
    "    flow_network: eqx.Module\n",
    "    sample_dimension: int = 2\n",
    "\n",
    "    @eqx.filter_jit\n",
    "    def embed(self, samples):\n",
    "        embeddings = jax.vmap(self.embedding_network)(samples)\n",
    "\n",
    "        return jnp.mean(embeddings, axis=0)\n",
    "\n",
    "    @eqx.filter_jit\n",
    "    def predict(self, xt, z, t):\n",
    "        return self.flow_network(\n",
    "            jnp.concatenate([xt, z, jnp.expand_dims(t, axis=0)], axis=-1)\n",
    "        )\n",
    "\n",
    "    @eqx.filter_jit\n",
    "    def sample(self, embedding, num_samples, num_steps, use_ode, key):\n",
    "        return (self._sample_ode if use_ode else self._sample_euler)(\n",
    "            embedding, num_samples, num_steps, key\n",
    "        )\n",
    "\n",
    "    @eqx.filter_jit\n",
    "    def _sample_euler(self, z, num_samples, num_steps, key):\n",
    "        xt = jax.random.normal(key, [num_samples, self.sample_dimension])\n",
    "        dt = 1 / num_steps\n",
    "\n",
    "        def _step(i, xt):\n",
    "            t = (i + 1) * dt\n",
    "            vt = jax.vmap(self.predict, in_axes=(0, None, None))(xt, z, t)\n",
    "            xt = xt + dt * vt\n",
    "\n",
    "            return xt\n",
    "\n",
    "        return jax.lax.fori_loop(0, num_steps, _step, xt)\n",
    "\n",
    "    @eqx.filter_jit\n",
    "    def _sample_ode(self, z, num_samples, num_steps, key):\n",
    "        del num_steps\n",
    "\n",
    "        x0 = jax.random.normal(key, [num_samples, self.sample_dimension])\n",
    "\n",
    "        @eqx.filter_jit\n",
    "        def diff_func(t, x, z):\n",
    "            return self.predict(x, z, t)\n",
    "\n",
    "        solver = diffrax.Dopri5()\n",
    "        solution = jax.vmap(diffrax.diffeqsolve, in_axes=(None,) * 5 + (0, None))(\n",
    "            # Vector field, solver, t0, t1, dt0, x0, args\n",
    "            diffrax.ODETerm(diff_func), solver, 0.0, 1.0, 1e-3, x0, z\n",
    "        )\n",
    "\n",
    "        # Return samples at t = 1.0\n",
    "        return solution.ys[:, 0]\n",
    "\n",
    "    @eqx.filter_jit\n",
    "    def log_prob(self, x1, z, num_steps=100, num_hutchinson_steps=10, *, key):\n",
    "        \"\"\"\n",
    "        See Equations (27-32) in Appendix C of Lipman et al. (2023) for details on why\n",
    "        integrating the divergence lets us compute the log prob of the sample from p1.\n",
    "        \"\"\"\n",
    "\n",
    "        prior_dist = distrax.Normal(0.0, 1.0)\n",
    "        dt = 1.0 / num_steps\n",
    "\n",
    "        def _step(i, args):\n",
    "            xt, log_prob, key = args\n",
    "            t = (num_steps - i - 1) * dt\n",
    "\n",
    "            vt = self.predict(xt, z, t)\n",
    "\n",
    "            key, div_key = jax.random.split(key, 2)\n",
    "            div_est = self._compute_divergence(\n",
    "                self.predict, xt, z, t, num_hutchinson_steps, key=div_key\n",
    "            )\n",
    "\n",
    "            xt = xt - vt * dt\n",
    "            log_prob = log_prob + div_est * dt\n",
    "\n",
    "            return xt, log_prob, key\n",
    "\n",
    "        x0, log_prob, _ = jax.lax.fori_loop(0, num_steps, _step, (x1, 0.0, key))\n",
    "\n",
    "        return jnp.sum(prior_dist.log_prob(x0)) - log_prob\n",
    "\n",
    "    @eqx.filter_jit\n",
    "    def _compute_divergence(self, func, x, z, t, num_steps=10, *, key):\n",
    "        \"\"\"\n",
    "        We need to approximate the divergence of the vector field v_t at x, which is\n",
    "        the trace of its Jacobian at x. We use the Hutchinson's trace estimator to\n",
    "        estimate this, which takes the product Jacobian vector product (jvp) with a\n",
    "        random vector.\n",
    "        \"\"\"\n",
    "\n",
    "        def _step(i, args):\n",
    "            div_est, key = args\n",
    "\n",
    "            key, noise_key = jax.random.split(key, 2)\n",
    "            noise = jax.random.normal(noise_key, x.shape)\n",
    "            _, jvp = jax.jvp(func, (x, z, t), (noise, jnp.zeros_like(z), 0.0))\n",
    "\n",
    "            return div_est + jnp.dot(noise, jvp), key\n",
    "\n",
    "        div_est, _ = jax.lax.fori_loop(0, num_steps, _step, (0.0, key))\n",
    "\n",
    "        return div_est / num_steps\n",
    "\n",
    "    @eqx.filter_jit\n",
    "    def compute_loss(self, samples, flow_matcher, key=None):\n",
    "        key, shuffle_key, noise_key = jax.random.split(key, 3)\n",
    "\n",
    "        indices = jax.random.permutation(shuffle_key, samples.shape[0])\n",
    "        samples = samples[indices]\n",
    "\n",
    "        train_samples, test_samples = jnp.split(samples, 2, axis=0)\n",
    "        embedding = self.embed(train_samples)\n",
    "\n",
    "        noise = jax.random.normal(noise_key, [test_samples.shape[0], 2])\n",
    "\n",
    "        x0, x1, z = noise, test_samples, embedding\n",
    "\n",
    "        t, xt, ut = jax.vmap(flow_matcher.sample_location_and_conditional_flow)(\n",
    "            x0, x1, jax.random.split(key, x0.shape[0])\n",
    "        )\n",
    "\n",
    "        vt = jax.vmap(self.predict, in_axes=(0, None, 0))(xt, z, t)\n",
    "\n",
    "        return jnp.mean((vt - ut) ** 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_dim = 2\n",
    "output_dim = 2\n",
    "embedding_dim = 8\n",
    "hidden_size = 64\n",
    "num_layers = 3\n",
    "\n",
    "key, embedding_key = jax.random.split(key, 2)\n",
    "embedding_network = eqx.nn.MLP(input_dim, embedding_dim, hidden_size, num_layers, key=embedding_key)\n",
    "\n",
    "hidden_size = 64\n",
    "num_layers = 4\n",
    "\n",
    "key, flow_key = jax.random.split(key)\n",
    "flow_network = eqx.nn.MLP(input_dim + embedding_dim + 1, output_dim, hidden_size, num_layers, key=flow_key)\n",
    "\n",
    "flow = ContinuousFlowEmbedding(embedding_network, flow_network)\n",
    "\n",
    "optimizer = optax.adam(1e-3)\n",
    "optimizer_state = optimizer.init(eqx.filter(flow, eqx.is_inexact_array))\n",
    "\n",
    "sigma = 1e-3\n",
    "# flow_matcher = ConditionalFlowMatcher(sigma)\n",
    "flow_matcher = TargetConditionalFlowMatcher(sigma)\n",
    "\n",
    "print(\n",
    "    'Model Parameters:', sum(\n",
    "        l.size for l in jax.tree_util.tree_leaves(eqx.filter(flow, eqx.is_inexact_array))\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@eqx.filter_jit\n",
    "def train(flow, flow_matcher, optimizer, optimizer_state, batch, key):\n",
    "    @eqx.filter_value_and_grad\n",
    "    def compute_loss(flow, flow_matcher, samples, key):\n",
    "        loss = jax.vmap(flow.compute_loss, in_axes=(0, None))(\n",
    "            samples, flow_matcher, key=jax.random.split(key, samples.shape[0])\n",
    "        )\n",
    "\n",
    "        return jnp.mean(loss)\n",
    "\n",
    "    key, loss_key = jax.random.split(key, 2)\n",
    "    loss, grads = compute_loss(flow, flow_matcher, batch, key=loss_key)\n",
    "\n",
    "    updates, optimizer_state = optimizer.update(grads, optimizer_state)\n",
    "    flow = eqx.apply_updates(flow, updates)\n",
    "\n",
    "    return flow, optimizer_state, loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "losses = []\n",
    "num_dists = 32\n",
    "num_samples = 128\n",
    "\n",
    "progress_bar = tqdm(range(30000))\n",
    "for i in progress_bar:\n",
    "    key, train_key = jax.random.split(key, 2)\n",
    "\n",
    "    data = generate_data(num_dists, num_samples)\n",
    "    flow, optimizer_state, loss = train(\n",
    "        flow, flow_matcher, optimizer, optimizer_state, data, train_key\n",
    "    )\n",
    "\n",
    "    losses.append(loss)\n",
    "\n",
    "    if i % 100 == 0:\n",
    "        progress_bar.set_postfix({'loss': np.mean(losses)})\n",
    "\n",
    "# Plot a smoothed sequence of loss values\n",
    "plt.plot(np.convolve(losses, np.ones(16), 'valid') / 16)\n",
    "plt.grid()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Qualitative Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_gradient_cmap(color):\n",
    "    if isinstance(color, str):\n",
    "        color = mpl.colors.to_rgb(color)\n",
    "\n",
    "    cmap_name = f\"gradient_from_{color}\"\n",
    "    colors = [(1.0, 1.0, 1.0, 0.0), (*color, 1.0)]\n",
    "\n",
    "    cmap = mpl.colors.LinearSegmentedColormap.from_list(cmap_name, colors)\n",
    "\n",
    "    return cmap\n",
    "\n",
    "mcmc_samples = generate_data(3, 256, mean, radius, width)\n",
    "\n",
    "grid_size = 512\n",
    "x = np.linspace(*xlims, grid_size)\n",
    "y = np.linspace(*ylims, grid_size)\n",
    "x, y = np.meshgrid(x, y)\n",
    "grid = np.stack([x.ravel(), y.ravel()], axis=-1)\n",
    "\n",
    "color_cycler = plt.rcParams['axes.prop_cycle'].by_key()['color']\n",
    "for samples, color in zip(mcmc_samples, color_cycler):\n",
    "    embedding = flow.embed(samples)\n",
    "\n",
    "    key, *prob_keys = jax.random.split(key, 1 + grid.shape[0])\n",
    "    log_probs = jax.vmap(flow.log_prob, in_axes=(0, None, None, None))(\n",
    "        grid, embedding, 100, 10, key=jnp.array(prob_keys)\n",
    "    )\n",
    "\n",
    "    density = np.exp(log_probs)\n",
    "    density = np.where(density >= 0.02, density, 0)\n",
    "    density = np.reshape(density, [grid_size, grid_size])\n",
    "\n",
    "    cmap = create_gradient_cmap(color)\n",
    "\n",
    "    plt.contourf(x, y, density, levels=10, cmap=cmap, norm='logit')\n",
    "    plt.scatter(samples[:, 0], samples[:, 1], s=20.0, color=color, marker='.')\n",
    "\n",
    "plt.xlabel(r'$x_1$')\n",
    "plt.ylabel(r'$x_2$')\n",
    "plt.grid()\n",
    "\n",
    "plt.savefig('donuts-densities-fm.pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### View of the vector fields over time\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = np.linspace(-4, 4, 20)\n",
    "y = np.linspace(-4, 4, 20)\n",
    "x, y = np.meshgrid(x, y)\n",
    "data = np.stack([x.ravel(), y.ravel()], axis=-1)\n",
    "\n",
    "num_plots = 5\n",
    "ts = np.linspace(0, 1, 85)\n",
    "\n",
    "interval = len(ts) // (num_plots - 1)\n",
    "xt = np.random.normal((0, 0), 1.5, size=(10, 2))\n",
    "\n",
    "# Generate a couple of conditioning samples from a specific distribution and embed them\n",
    "cond_samples = generate_data(1, 64, mean=jnp.array([0.0, 0.0]), radius=2.0, width=0.1)[0]\n",
    "z = flow.embed(cond_samples)\n",
    "\n",
    "fig, axes = plt.subplots(1, num_plots, figsize=(20, 4))\n",
    "for i, t in enumerate(ts):\n",
    "    # Predict direction for every grid point\n",
    "    vt = jax.vmap(flow.predict, in_axes=(0, None, None))(data, z, t)\n",
    "    u = vt[:, 0].reshape(x.shape)\n",
    "    v = vt[:, 1].reshape(y.shape)\n",
    "\n",
    "    # Move our data points according to the predicted direction\n",
    "    xt = xt + 1.0 / len(ts) * jax.vmap(flow.predict, in_axes=(0, None, None))(xt, z, t)\n",
    "\n",
    "    if (i % interval) == 0:\n",
    "        axis_index = i // interval\n",
    "        axes[axis_index].set_title(f'Time = {t:0.2f}', y=-0.18)\n",
    "        axes[axis_index].quiver(x, y, u, v, color='blue')\n",
    "        axes[axis_index].plot(xt[:, 0], xt[:, 1], 'r.')\n",
    "        axes[axis_index].grid()\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Quantitative Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_dists = 64\n",
    "num_embedding_samples = 64\n",
    "num_generated_samples = 256\n",
    "\n",
    "\n",
    "def rmse(x, y, axis=-1):\n",
    "    return np.mean(np.sqrt(np.sum((x - y) ** 2, axis=axis)))\n",
    "\n",
    "\n",
    "def evaluate(flow, num_dists, num_embedding_samples, num_generated_samples, use_ode, key):\n",
    "    data = generate_data(num_dists, num_embedding_samples + num_generated_samples)\n",
    "\n",
    "    key, *sample_keys = jax.random.split(key, 1 + num_dists)\n",
    "    embeddings = jax.vmap(flow.embed)(data[:, :num_embedding_samples])\n",
    "    samples = jax.vmap(flow.sample, in_axes=(0, None, None, None, 0))(\n",
    "        embeddings, num_generated_samples, 500, use_ode, jnp.array(sample_keys)\n",
    "    )\n",
    "\n",
    "    data_mean = np.mean(data, keepdims=True, axis=1)\n",
    "    samples_mean = np.mean(samples, keepdims=True, axis=1)\n",
    "\n",
    "    mean_error = rmse(data_mean, samples_mean)\n",
    "    print(f'Mean Error: {mean_error:.5f}')\n",
    "\n",
    "    data_radius = np.linalg.norm(data - data_mean, axis=2)\n",
    "    samples_radius = np.linalg.norm(samples - samples_mean, axis=2)\n",
    "\n",
    "    mean_data_radius = np.mean(data_radius, keepdims=True, axis=1)\n",
    "    mean_samples_radius = np.mean(samples_radius, keepdims=True, axis=1)\n",
    "\n",
    "    radius_error = rmse(mean_data_radius, mean_samples_radius)\n",
    "    print(f'Radius Error: {radius_error:.5f}')\n",
    "\n",
    "    data_radius_diff = data_radius - mean_data_radius\n",
    "    samples_radius_diff = samples_radius - mean_samples_radius\n",
    "\n",
    "    std_data_radius_diff = np.std(data_radius_diff, keepdims=True, axis=1)\n",
    "    std_samples_radius_diff = np.std(samples_radius_diff, keepdims=True, axis=1)\n",
    "\n",
    "    width_error = rmse(std_data_radius_diff, std_samples_radius_diff)\n",
    "    print(f'Width Error: {width_error:.5f}')\n",
    "\n",
    "    width_error = rmse(\n",
    "        std_data_radius_diff / mean_data_radius,\n",
    "        std_samples_radius_diff / mean_samples_radius\n",
    "    )\n",
    "    print(f'Width / Radius Error: {width_error:.5f}')\n",
    "\n",
    "\n",
    "key, eval_key = jax.random.split(key, 2)\n",
    "evaluate(flow, num_dists, num_embedding_samples, num_generated_samples, use_ode=False, key=eval_key)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Density Estimation\n",
    "\n",
    "Below we approximate the divergence and use Euler's method to approximate the ODE needed to do density estimation for a point from the target distribution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cond_samples = generate_data(1, 128, mean=[0.0, 0.0], radius=2.0, width=0.1)[0]\n",
    "embedding = flow.embed(cond_samples)\n",
    "\n",
    "key, *prob_keys = jax.random.split(key, 3)\n",
    "\n",
    "in_sample_prob = flow.log_prob(\n",
    "    generate_data(1, 1, mean=[0.0, 0.0], radius=2.0, width=0.1)[0][0], embedding, key=prob_keys[0]\n",
    ")\n",
    "out_sample_prob = flow.log_prob(\n",
    "    np.random.normal([4.0, 4.0], 1.5, size=(1, 2))[0], embedding, key=prob_keys[1]\n",
    ")\n",
    "\n",
    "print(f'In-donut sample probability: {jnp.exp(in_sample_prob)}')\n",
    "print(f'Out-of-donut sample probability: {jnp.exp(out_sample_prob)}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
